From 9271f12b6a18b97f3474224fb914eec7c1288870 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Tue, 27 Jan 2026 01:39:36 +0100 Subject: [PATCH 1/2] moved init stuff to DI --- backend/app/api/routes/sse.py | 29 +- backend/app/core/dishka_lifespan.py | 33 +- backend/app/core/lifecycle.py | 62 -- backend/app/core/providers.py | 452 ++++---- backend/app/db/repositories/__init__.py | 12 + .../execution_queue_repository.py | 234 +++++ .../execution_state_repository.py | 65 ++ .../db/repositories/pod_state_repository.py | 180 ++++ .../db/repositories/resource_repository.py | 300 ++++++ backend/app/dlq/manager.py | 370 +++---- backend/app/events/core/__init__.py | 4 - backend/app/events/core/consumer.py | 277 +---- backend/app/events/core/producer.py | 120 +-- backend/app/events/event_store_consumer.py | 190 ---- backend/app/services/coordinator/__init__.py | 6 - .../app/services/coordinator/coordinator.py | 554 ++++------ .../app/services/coordinator/queue_manager.py | 271 ----- .../services/coordinator/resource_manager.py | 324 ------ backend/app/services/event_bus.py | 379 ++----- .../app/services/idempotency/middleware.py | 68 +- backend/app/services/k8s_worker/worker.py | 496 +++------ backend/app/services/kafka_event_service.py | 60 +- backend/app/services/notification_service.py | 976 +++++++----------- backend/app/services/pod_monitor/monitor.py | 441 ++------ .../services/result_processor/processor.py | 173 +--- backend/app/services/saga/__init__.py | 3 +- .../app/services/saga/saga_orchestrator.py | 435 +++----- .../app/services/sse/kafka_redis_bridge.py | 207 ++-- backend/app/services/sse/sse_service.py | 28 +- .../app/services/sse/sse_shutdown_manager.py | 40 - backend/app/services/user_settings_service.py | 40 +- backend/di_lifecycle_refactor_plan.md | 64 ++ .../tests/e2e/core/test_dishka_lifespan.py | 10 +- backend/tests/e2e/dlq/test_dlq_manager.py | 43 +- .../e2e/events/test_consume_roundtrip.py | 28 +- .../e2e/events/test_consumer_lifecycle.py | 44 +- .../tests/e2e/events/test_event_dispatcher.py | 33 +- .../e2e/events/test_producer_roundtrip.py | 44 +- .../idempotency/test_consumer_idempotent.py | 36 +- .../result_processor/test_result_processor.py | 134 --- .../coordinator/test_execution_coordinator.py | 148 +-- .../e2e/services/events/test_event_bus.py | 31 +- .../sse/test_partitioned_event_router.py | 81 -- .../tests/e2e/test_k8s_worker_create_pod.py | 41 +- .../coordinator/test_queue_manager.py | 41 - .../coordinator/test_resource_manager.py | 61 -- .../unit/services/pod_monitor/test_monitor.py | 784 +++----------- .../result_processor/test_processor.py | 30 +- .../saga/test_saga_orchestrator_unit.py | 37 +- .../services/sse/test_kafka_redis_bridge.py | 49 +- .../services/sse/test_shutdown_manager.py | 41 +- .../unit/services/sse/test_sse_service.py | 14 +- .../services/sse/test_sse_shutdown_manager.py | 32 +- backend/workers/dlq_processor.py | 4 +- backend/workers/run_coordinator.py | 44 +- backend/workers/run_event_replay.py | 46 +- backend/workers/run_k8s_worker.py | 43 +- backend/workers/run_pod_monitor.py | 45 +- backend/workers/run_result_processor.py | 86 +- backend/workers/run_saga_orchestrator.py | 37 +- 60 files changed, 3027 insertions(+), 5933 deletions(-) delete mode 100644 backend/app/core/lifecycle.py create mode 100644 backend/app/db/repositories/execution_queue_repository.py create mode 100644 backend/app/db/repositories/execution_state_repository.py create mode 100644 backend/app/db/repositories/pod_state_repository.py create mode 100644 backend/app/db/repositories/resource_repository.py delete mode 100644 backend/app/events/event_store_consumer.py delete mode 100644 backend/app/services/coordinator/queue_manager.py delete mode 100644 backend/app/services/coordinator/resource_manager.py create mode 100644 backend/di_lifecycle_refactor_plan.md delete mode 100644 backend/tests/e2e/result_processor/test_result_processor.py delete mode 100644 backend/tests/e2e/services/sse/test_partitioned_event_router.py delete mode 100644 backend/tests/unit/services/coordinator/test_queue_manager.py delete mode 100644 backend/tests/unit/services/coordinator/test_resource_manager.py diff --git a/backend/app/api/routes/sse.py b/backend/app/api/routes/sse.py index 6b1b406f..ae8d1367 100644 --- a/backend/app/api/routes/sse.py +++ b/backend/app/api/routes/sse.py @@ -3,13 +3,7 @@ from fastapi import APIRouter, Request from sse_starlette.sse import EventSourceResponse -from app.domain.sse import SSEHealthDomain -from app.schemas_pydantic.sse import ( - ShutdownStatusResponse, - SSEExecutionEventData, - SSEHealthResponse, - SSENotificationEventData, -) +from app.schemas_pydantic.sse import SSEExecutionEventData, SSENotificationEventData from app.services.auth_service import AuthService from app.services.sse.sse_service import SSEService @@ -38,24 +32,3 @@ async def execution_events( return EventSourceResponse( sse_service.create_execution_stream(execution_id=execution_id, user_id=current_user.user_id) ) - - -@router.get("/health", response_model=SSEHealthResponse) -async def sse_health( - request: Request, - sse_service: FromDishka[SSEService], - auth_service: FromDishka[AuthService], -) -> SSEHealthResponse: - """Get SSE service health status.""" - _ = await auth_service.get_current_user(request) - domain: SSEHealthDomain = await sse_service.get_health_status() - return SSEHealthResponse( - status=domain.status, - kafka_enabled=domain.kafka_enabled, - active_connections=domain.active_connections, - active_executions=domain.active_executions, - active_consumers=domain.active_consumers, - max_connections_per_user=domain.max_connections_per_user, - shutdown=ShutdownStatusResponse(**vars(domain.shutdown)), - timestamp=domain.timestamp, - ) diff --git a/backend/app/core/dishka_lifespan.py b/backend/app/core/dishka_lifespan.py index 857222be..f0177a94 100644 --- a/backend/app/core/dishka_lifespan.py +++ b/backend/app/core/dishka_lifespan.py @@ -2,7 +2,7 @@ import asyncio import logging -from contextlib import AsyncExitStack, asynccontextmanager +from contextlib import asynccontextmanager from typing import AsyncGenerator import redis.asyncio as redis @@ -15,7 +15,6 @@ from app.core.startup import initialize_rate_limits from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS -from app.events.event_store_consumer import EventStoreConsumer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.services.notification_service import NotificationService from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge @@ -76,26 +75,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: extra={"testing": settings.TESTING, "enable_tracing": settings.ENABLE_TRACING}, ) - # Phase 1: Resolve all DI dependencies in parallel - ( - schema_registry, - database, - redis_client, - rate_limit_metrics, - sse_bridge, - event_store_consumer, - notification_service, - ) = await asyncio.gather( + # Resolve DI dependencies in parallel (fail fast on config issues) + schema_registry, database, redis_client, rate_limit_metrics, _, _ = await asyncio.gather( container.get(SchemaRegistryManager), container.get(Database), container.get(redis.Redis), container.get(RateLimitMetrics), container.get(SSEKafkaRedisBridge), - container.get(EventStoreConsumer), container.get(NotificationService), ) - # Phase 2: Initialize infrastructure in parallel (independent subsystems) + # Initialize infrastructure in parallel await asyncio.gather( initialize_event_schemas(schema_registry), init_beanie(database=database, document_models=ALL_DOCUMENTS), @@ -103,16 +93,5 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: ) logger.info("Infrastructure initialized (schemas, beanie, rate limits)") - # Phase 3: Start Kafka consumers in parallel (providers already started them via async with, - # but __aenter__ is idempotent so this is safe and explicit) - async with AsyncExitStack() as stack: - stack.push_async_callback(sse_bridge.aclose) - stack.push_async_callback(event_store_consumer.aclose) - stack.push_async_callback(notification_service.aclose) - await asyncio.gather( - sse_bridge.__aenter__(), - event_store_consumer.__aenter__(), - notification_service.__aenter__(), - ) - logger.info("SSE bridge, EventStoreConsumer, and NotificationService started") - yield + yield + # Container close handles all cleanup automatically diff --git a/backend/app/core/lifecycle.py b/backend/app/core/lifecycle.py deleted file mode 100644 index 2e0d8f85..00000000 --- a/backend/app/core/lifecycle.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from types import TracebackType -from typing import Self - - -class LifecycleEnabled: - """Base class for services with async lifecycle management. - - Usage: - async with MyService() as service: - # service is running - # service is stopped - - Subclasses override _on_start() and _on_stop() for their logic. - Base class handles idempotency and context manager protocol. - - For internal component cleanup, use aclose() which follows Python's - standard async cleanup pattern (like aiofiles, aiohttp). - """ - - def __init__(self) -> None: - self._lifecycle_started: bool = False - - async def _on_start(self) -> None: - """Override with startup logic. Called once on enter.""" - pass - - async def _on_stop(self) -> None: - """Override with cleanup logic. Called once on exit.""" - pass - - async def aclose(self) -> None: - """Close the service. For internal component cleanup. - - Mirrors Python's standard aclose() pattern (like aiofiles, aiohttp). - Idempotent - safe to call multiple times. - """ - if not self._lifecycle_started: - return - self._lifecycle_started = False - await self._on_stop() - - @property - def is_running(self) -> bool: - """Check if service is currently running.""" - return self._lifecycle_started - - async def __aenter__(self) -> Self: - if self._lifecycle_started: - return self # Already started, idempotent - await self._on_start() - self._lifecycle_started = True - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc: BaseException | None, - tb: TracebackType | None, - ) -> None: - await self.aclose() diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index 7dc457ec..d2465123 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -1,9 +1,8 @@ -from __future__ import annotations - import logging from typing import AsyncIterator import redis.asyncio as redis +from aiokafka import AIOKafkaProducer from dishka import Provider, Scope, from_context, provide from pymongo.asynchronous.mongo_client import AsyncMongoClient @@ -39,20 +38,22 @@ from app.db.repositories.admin.admin_settings_repository import AdminSettingsRepository from app.db.repositories.admin.admin_user_repository import AdminUserRepository from app.db.repositories.dlq_repository import DLQRepository +from app.db.repositories.execution_queue_repository import ExecutionQueueRepository +from app.db.repositories.execution_state_repository import ExecutionStateRepository +from app.db.repositories.pod_state_repository import PodStateRepository from app.db.repositories.replay_repository import ReplayRepository from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository +from app.db.repositories.resource_repository import ResourceRepository from app.db.repositories.user_settings_repository import UserSettingsRepository -from app.dlq.manager import DLQManager, create_dlq_manager +from app.dlq.manager import DLQManager from app.domain.saga.models import SagaConfig -from app.events.core import UnifiedProducer +from app.events.core import ProducerMetrics, UnifiedProducer from app.events.event_store import EventStore, create_event_store -from app.events.event_store_consumer import EventStoreConsumer, create_event_store_consumer from app.events.schema.schema_registry import SchemaRegistryManager -from app.infrastructure.kafka.topics import get_all_topics from app.services.admin import AdminEventsService, AdminSettingsService, AdminUserService from app.services.auth_service import AuthService from app.services.coordinator.coordinator import ExecutionCoordinator -from app.services.event_bus import EventBusManager +from app.services.event_bus import EventBus from app.services.event_replay.replay_service import EventReplayService from app.services.event_service import EventService from app.services.execution_service import ExecutionService @@ -70,13 +71,13 @@ from app.services.rate_limit_service import RateLimitService from app.services.replay_service import ReplayService from app.services.result_processor.resource_cleaner import ResourceCleaner -from app.services.saga import SagaOrchestrator, create_saga_orchestrator +from app.services.saga import SagaOrchestrator from app.services.saga.saga_service import SagaService from app.services.saved_script_service import SavedScriptService -from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge, create_sse_kafka_redis_bridge +from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge from app.services.sse.redis_bus import SSERedisBus from app.services.sse.sse_service import SSEService -from app.services.sse.sse_shutdown_manager import SSEShutdownManager, create_sse_shutdown_manager +from app.services.sse.sse_shutdown_manager import SSEShutdownManager from app.services.user_settings_service import UserSettingsService from app.settings import Settings @@ -113,12 +114,12 @@ async def get_redis_client(self, settings: Settings, logger: logging.Logger) -> socket_timeout=5, ) # Test connection - await client.ping() # type: ignore[misc] # redis-py returns Awaitable[bool] | bool + await client.ping() # type: ignore[misc] # redis-py dual sync/async return type logger.info(f"Redis connected: {settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}") try: yield client finally: - await client.close() + await client.aclose() @provide def get_rate_limit_service( @@ -127,6 +128,48 @@ def get_rate_limit_service( return RateLimitService(redis_client, settings, rate_limit_metrics) +class RedisRepositoryProvider(Provider): + """Provides Redis-backed state repositories for stateless services.""" + + scope = Scope.APP + + @provide + def get_execution_state_repository( + self, redis_client: redis.Redis, logger: logging.Logger + ) -> ExecutionStateRepository: + return ExecutionStateRepository(redis_client, logger) + + @provide + def get_execution_queue_repository( + self, redis_client: redis.Redis, logger: logging.Logger, settings: Settings + ) -> ExecutionQueueRepository: + return ExecutionQueueRepository( + redis_client, + logger, + max_queue_size=10000, + max_executions_per_user=100, + ) + + @provide + async def get_resource_repository( + self, redis_client: redis.Redis, logger: logging.Logger, settings: Settings + ) -> ResourceRepository: + repo = ResourceRepository( + redis_client, + logger, + total_cpu_cores=32.0, + total_memory_mb=65536, + ) + await repo.initialize() + return repo + + @provide + def get_pod_state_repository( + self, redis_client: redis.Redis, logger: logging.Logger + ) -> PodStateRepository: + return PodStateRepository(redis_client, logger) + + class DatabaseProvider(Provider): scope = Scope.APP @@ -155,27 +198,69 @@ def get_tracer_manager(self, settings: Settings) -> TracerManager: return TracerManager(tracer_name=settings.TRACING_SERVICE_NAME) -class MessagingProvider(Provider): +class KafkaProvider(Provider): + """Provides Kafka producer - low-level AIOKafkaProducer for DI.""" + scope = Scope.APP @provide - async def get_kafka_producer( - self, settings: Settings, schema_registry: SchemaRegistryManager, logger: logging.Logger, - event_metrics: EventMetrics - ) -> AsyncIterator[UnifiedProducer]: - async with UnifiedProducer(schema_registry, logger, settings, event_metrics) as producer: + async def get_aiokafka_producer( + self, settings: Settings, logger: logging.Logger + ) -> AsyncIterator[AIOKafkaProducer]: + producer = AIOKafkaProducer( + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + acks="all", + enable_idempotence=True, + max_request_size=10 * 1024 * 1024, # 10MB + ) + await producer.start() + logger.info(f"Kafka producer started: {settings.KAFKA_BOOTSTRAP_SERVERS}") + try: yield producer + finally: + await producer.stop() + + +class MessagingProvider(Provider): + scope = Scope.APP @provide - async def get_dlq_manager( + def get_producer_metrics(self) -> ProducerMetrics: + return ProducerMetrics() + + @provide + def get_unified_producer( + self, + aiokafka_producer: AIOKafkaProducer, + schema_registry: SchemaRegistryManager, + logger: logging.Logger, + event_metrics: EventMetrics, + producer_metrics: ProducerMetrics, + ) -> UnifiedProducer: + return UnifiedProducer( + producer=aiokafka_producer, + schema_registry_manager=schema_registry, + logger=logger, + event_metrics=event_metrics, + producer_metrics=producer_metrics, + ) + + @provide + def get_dlq_manager( self, settings: Settings, + aiokafka_producer: AIOKafkaProducer, schema_registry: SchemaRegistryManager, logger: logging.Logger, dlq_metrics: DLQMetrics, - ) -> AsyncIterator[DLQManager]: - async with create_dlq_manager(settings, schema_registry, logger, dlq_metrics) as manager: - yield manager + ) -> DLQManager: + return DLQManager( + settings=settings, + producer=aiokafka_producer, + schema_registry=schema_registry, + logger=logger, + dlq_metrics=dlq_metrics, + ) @provide def get_idempotency_repository(self, redis_client: redis.Redis) -> RedisIdempotencyRepository: @@ -203,7 +288,7 @@ def get_schema_registry(self, settings: Settings, logger: logging.Logger) -> Sch return SchemaRegistryManager(settings, logger) @provide - async def get_event_store( + def get_event_store( self, schema_registry: SchemaRegistryManager, logger: logging.Logger, event_metrics: EventMetrics ) -> EventStore: return create_event_store( @@ -211,36 +296,19 @@ async def get_event_store( ) @provide - async def get_event_store_consumer( + def get_event_bus( self, - event_store: EventStore, - schema_registry: SchemaRegistryManager, + aiokafka_producer: AIOKafkaProducer, settings: Settings, - kafka_producer: UnifiedProducer, logger: logging.Logger, - event_metrics: EventMetrics, - ) -> AsyncIterator[EventStoreConsumer]: - topics = get_all_topics() - async with create_event_store_consumer( - event_store=event_store, - topics=list(topics), - schema_registry_manager=schema_registry, - settings=settings, - producer=kafka_producer, - logger=logger, - event_metrics=event_metrics, - ) as consumer: - yield consumer - - @provide - async def get_event_bus_manager( - self, settings: Settings, logger: logging.Logger, connection_metrics: ConnectionMetrics - ) -> AsyncIterator[EventBusManager]: - manager = EventBusManager(settings, logger, connection_metrics) - try: - yield manager - finally: - await manager.close() + connection_metrics: ConnectionMetrics, + ) -> EventBus: + return EventBus( + producer=aiokafka_producer, + settings=settings, + logger=logger, + connection_metrics=connection_metrics, + ) class KubernetesProvider(Provider): @@ -385,35 +453,25 @@ class SSEProvider(Provider): scope = Scope.APP @provide - async def get_sse_redis_bus( - self, redis_client: redis.Redis, logger: logging.Logger - ) -> AsyncIterator[SSERedisBus]: - bus = SSERedisBus(redis_client, logger) - yield bus + def get_sse_redis_bus(self, redis_client: redis.Redis, logger: logging.Logger) -> SSERedisBus: + return SSERedisBus(redis_client, logger) @provide - async def get_sse_kafka_redis_bridge( + def get_sse_kafka_redis_bridge( self, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_metrics: EventMetrics, sse_redis_bus: SSERedisBus, logger: logging.Logger, - ) -> AsyncIterator[SSEKafkaRedisBridge]: - async with create_sse_kafka_redis_bridge( - schema_registry=schema_registry, - settings=settings, - event_metrics=event_metrics, - sse_bus=sse_redis_bus, - logger=logger, - ) as bridge: - yield bridge + ) -> SSEKafkaRedisBridge: + return SSEKafkaRedisBridge( + sse_bus=sse_redis_bus, + logger=logger, + ) @provide(scope=Scope.REQUEST) def get_sse_shutdown_manager( self, logger: logging.Logger, connection_metrics: ConnectionMetrics ) -> SSEShutdownManager: - return create_sse_shutdown_manager(logger=logger, connection_metrics=connection_metrics) + return SSEShutdownManager(logger=logger, connection_metrics=connection_metrics) @provide(scope=Scope.REQUEST) def get_sse_service( @@ -426,7 +484,6 @@ def get_sse_service( logger: logging.Logger, connection_metrics: ConnectionMetrics, ) -> SSEService: - shutdown_manager.set_router(router) return SSEService( repository=sse_repository, router=router, @@ -483,12 +540,11 @@ async def get_user_settings_service( self, repository: UserSettingsRepository, kafka_event_service: KafkaEventService, - event_bus_manager: EventBusManager, - settings: Settings, + event_bus: EventBus, logger: logging.Logger, ) -> UserSettingsService: - service = UserSettingsService(repository, kafka_event_service, settings, logger) - await service.initialize(event_bus_manager) + service = UserSettingsService(repository, kafka_event_service, event_bus, logger) + await service.setup_event_subscription() return service @@ -513,31 +569,23 @@ def get_admin_settings_service( return AdminSettingsService(admin_settings_repository, logger) @provide - async def get_notification_service( + def get_notification_service( self, notification_repository: NotificationRepository, - kafka_event_service: KafkaEventService, - event_bus_manager: EventBusManager, - schema_registry: SchemaRegistryManager, + event_bus: EventBus, sse_redis_bus: SSERedisBus, settings: Settings, logger: logging.Logger, notification_metrics: NotificationMetrics, - event_metrics: EventMetrics, - ) -> AsyncIterator[NotificationService]: - service = NotificationService( + ) -> NotificationService: + return NotificationService( notification_repository=notification_repository, - event_service=kafka_event_service, - event_bus_manager=event_bus_manager, - schema_registry_manager=schema_registry, + event_bus=event_bus, sse_bus=sse_redis_bus, settings=settings, logger=logger, notification_metrics=notification_metrics, - event_metrics=event_metrics, ) - async with service: - yield service @provide def get_grafana_alert_processor( @@ -566,68 +614,120 @@ def _create_default_saga_config() -> SagaConfig: ) -# Standalone factory functions for lifecycle-managed services (eliminates duplication) -async def _provide_saga_orchestrator( - saga_repository: SagaRepository, - kafka_producer: UnifiedProducer, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, - resource_allocation_repository: ResourceAllocationRepository, - logger: logging.Logger, - event_metrics: EventMetrics, -) -> AsyncIterator[SagaOrchestrator]: - """Shared factory for SagaOrchestrator with lifecycle management.""" - async with create_saga_orchestrator( - saga_repository=saga_repository, +class CoordinatorProvider(Provider): + scope = Scope.APP + + @provide + def get_execution_coordinator( + self, + kafka_producer: UnifiedProducer, + execution_repository: ExecutionRepository, + state_repo: ExecutionStateRepository, + queue_repo: ExecutionQueueRepository, + resource_repo: ResourceRepository, + logger: logging.Logger, + coordinator_metrics: CoordinatorMetrics, + event_metrics: EventMetrics, + ) -> ExecutionCoordinator: + return ExecutionCoordinator( producer=kafka_producer, - schema_registry_manager=schema_registry, - settings=settings, - event_store=event_store, - idempotency_manager=idempotency_manager, - resource_allocation_repository=resource_allocation_repository, - config=_create_default_saga_config(), + execution_repository=execution_repository, + state_repo=state_repo, + queue_repo=queue_repo, + resource_repo=resource_repo, logger=logger, + coordinator_metrics=coordinator_metrics, event_metrics=event_metrics, - ) as orchestrator: - yield orchestrator - - -async def _provide_execution_coordinator( - kafka_producer: UnifiedProducer, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - execution_repository: ExecutionRepository, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - coordinator_metrics: CoordinatorMetrics, - event_metrics: EventMetrics, -) -> AsyncIterator[ExecutionCoordinator]: - """Shared factory for ExecutionCoordinator with lifecycle management.""" - async with ExecutionCoordinator( + ) + + +class K8sWorkerProvider(Provider): + scope = Scope.APP + + @provide + def get_kubernetes_worker( + self, + kafka_producer: UnifiedProducer, + pod_state_repo: PodStateRepository, + k8s_clients: K8sClients, + logger: logging.Logger, + kubernetes_metrics: KubernetesMetrics, + execution_metrics: ExecutionMetrics, + event_metrics: EventMetrics, + ) -> KubernetesWorker: + config = K8sWorkerConfig() + return KubernetesWorker( + config=config, producer=kafka_producer, - schema_registry_manager=schema_registry, - settings=settings, - event_store=event_store, - execution_repository=execution_repository, - idempotency_manager=idempotency_manager, + pod_state_repo=pod_state_repo, + v1_client=k8s_clients.v1, + networking_v1_client=k8s_clients.networking_v1, + apps_v1_client=k8s_clients.apps_v1, + logger=logger, + kubernetes_metrics=kubernetes_metrics, + execution_metrics=execution_metrics, + event_metrics=event_metrics, + ) + + +class PodMonitorProvider(Provider): + scope = Scope.APP + + @provide + def get_event_mapper( + self, + logger: logging.Logger, + k8s_clients: K8sClients, + ) -> PodEventMapper: + return PodEventMapper(logger=logger, k8s_api=k8s_clients.v1) + + @provide + def get_pod_monitor( + self, + kafka_producer: UnifiedProducer, + pod_state_repo: PodStateRepository, + k8s_clients: K8sClients, + logger: logging.Logger, + event_mapper: PodEventMapper, + kubernetes_metrics: KubernetesMetrics, + ) -> PodMonitor: + config = PodMonitorConfig() + return PodMonitor( + config=config, + producer=kafka_producer, + pod_state_repo=pod_state_repo, + v1_client=k8s_clients.v1, + event_mapper=event_mapper, + logger=logger, + kubernetes_metrics=kubernetes_metrics, + ) + + +class SagaOrchestratorProvider(Provider): + scope = Scope.APP + + @provide + def get_saga_orchestrator( + self, + saga_repository: SagaRepository, + kafka_producer: UnifiedProducer, + resource_allocation_repository: ResourceAllocationRepository, + logger: logging.Logger, + event_metrics: EventMetrics, + ) -> SagaOrchestrator: + return SagaOrchestrator( + config=_create_default_saga_config(), + saga_repository=saga_repository, + producer=kafka_producer, + resource_allocation_repository=resource_allocation_repository, logger=logger, - coordinator_metrics=coordinator_metrics, event_metrics=event_metrics, - ) as coordinator: - yield coordinator + ) class BusinessServicesProvider(Provider): scope = Scope.REQUEST - def __init__(self) -> None: - super().__init__() - # Register shared factory functions on instance (avoids warning about missing self) - self.provide(_provide_execution_coordinator) - @provide def get_saga_service( self, @@ -697,82 +797,6 @@ def get_admin_user_service( ) -class CoordinatorProvider(Provider): - scope = Scope.APP - - def __init__(self) -> None: - super().__init__() - self.provide(_provide_execution_coordinator) - - -class K8sWorkerProvider(Provider): - scope = Scope.APP - - @provide - async def get_kubernetes_worker( - self, - kafka_producer: UnifiedProducer, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - event_metrics: EventMetrics, - ) -> AsyncIterator[KubernetesWorker]: - config = K8sWorkerConfig() - async with KubernetesWorker( - config=config, - producer=kafka_producer, - schema_registry_manager=schema_registry, - settings=settings, - event_store=event_store, - idempotency_manager=idempotency_manager, - logger=logger, - event_metrics=event_metrics, - ) as worker: - yield worker - - -class PodMonitorProvider(Provider): - scope = Scope.APP - - @provide - def get_event_mapper( - self, - logger: logging.Logger, - k8s_clients: K8sClients, - ) -> PodEventMapper: - return PodEventMapper(logger=logger, k8s_api=k8s_clients.v1) - - @provide - async def get_pod_monitor( - self, - kafka_event_service: KafkaEventService, - k8s_clients: K8sClients, - logger: logging.Logger, - event_mapper: PodEventMapper, - kubernetes_metrics: KubernetesMetrics, - ) -> AsyncIterator[PodMonitor]: - config = PodMonitorConfig() - async with PodMonitor( - config=config, - kafka_event_service=kafka_event_service, - logger=logger, - k8s_clients=k8s_clients, - event_mapper=event_mapper, - kubernetes_metrics=kubernetes_metrics, - ) as monitor: - yield monitor - - -class SagaOrchestratorProvider(Provider): - scope = Scope.APP - - def __init__(self) -> None: - super().__init__() - self.provide(_provide_saga_orchestrator) - - class EventReplayProvider(Provider): scope = Scope.APP diff --git a/backend/app/db/repositories/__init__.py b/backend/app/db/repositories/__init__.py index 1e985797..c5e0199c 100644 --- a/backend/app/db/repositories/__init__.py +++ b/backend/app/db/repositories/__init__.py @@ -1,9 +1,13 @@ from app.db.repositories.admin.admin_settings_repository import AdminSettingsRepository from app.db.repositories.admin.admin_user_repository import AdminUserRepository from app.db.repositories.event_repository import EventRepository +from app.db.repositories.execution_queue_repository import ExecutionQueueRepository, QueuePriority, QueueStats from app.db.repositories.execution_repository import ExecutionRepository +from app.db.repositories.execution_state_repository import ExecutionStateRepository from app.db.repositories.notification_repository import NotificationRepository +from app.db.repositories.pod_state_repository import PodStateRepository from app.db.repositories.replay_repository import ReplayRepository +from app.db.repositories.resource_repository import ResourceAllocation, ResourceRepository, ResourceStats from app.db.repositories.saga_repository import SagaRepository from app.db.repositories.saved_script_repository import SavedScriptRepository from app.db.repositories.sse_repository import SSERepository @@ -15,8 +19,16 @@ "AdminUserRepository", "EventRepository", "ExecutionRepository", + "ExecutionQueueRepository", + "ExecutionStateRepository", "NotificationRepository", + "PodStateRepository", + "QueuePriority", + "QueueStats", "ReplayRepository", + "ResourceAllocation", + "ResourceRepository", + "ResourceStats", "SagaRepository", "SavedScriptRepository", "SSERepository", diff --git a/backend/app/db/repositories/execution_queue_repository.py b/backend/app/db/repositories/execution_queue_repository.py new file mode 100644 index 00000000..d24af7bf --- /dev/null +++ b/backend/app/db/repositories/execution_queue_repository.py @@ -0,0 +1,234 @@ +"""Redis-backed execution queue repository. + +Replaces in-memory priority queue (QueueManager) with Redis sorted sets +for stateless, horizontally-scalable services. +""" + +from __future__ import annotations + +import json +import logging +import time +from dataclasses import dataclass +from enum import IntEnum + +import redis.asyncio as redis + + +class QueuePriority(IntEnum): + """Execution queue priorities. Lower value = higher priority.""" + + CRITICAL = 0 + HIGH = 1 + NORMAL = 5 + LOW = 8 + BACKGROUND = 10 + + +@dataclass +class QueueStats: + """Queue statistics.""" + + total_size: int + priority_distribution: dict[str, int] + max_queue_size: int + utilization_percent: float + + +class ExecutionQueueRepository: + """Redis-backed priority queue for executions. + + Uses Redis sorted sets for O(log N) priority queue operations. + Stores event data in hash maps for retrieval. + """ + + QUEUE_KEY = "exec:queue" + DATA_KEY_PREFIX = "exec:queue:data" + USER_COUNT_KEY = "exec:queue:user_count" + + def __init__( + self, + redis_client: redis.Redis, + logger: logging.Logger, + max_queue_size: int = 10000, + max_executions_per_user: int = 100, + stale_timeout_seconds: int = 3600, + ) -> None: + self._redis = redis_client + self._logger = logger + self.max_queue_size = max_queue_size + self.max_executions_per_user = max_executions_per_user + self.stale_timeout_seconds = stale_timeout_seconds + + async def enqueue( + self, + execution_id: str, + event_data: dict[str, object], + priority: QueuePriority, + user_id: str, + ) -> tuple[bool, int | None, str | None]: + """Add execution to queue. Returns (success, position, error).""" + # Check queue size + queue_size = await self._redis.zcard(self.QUEUE_KEY) + if queue_size >= self.max_queue_size: + return False, None, "Queue is full" + + # Check user limit + user_count = await self._redis.hincrby(self.USER_COUNT_KEY, user_id, 0) # type: ignore[misc] + if user_count >= self.max_executions_per_user: + return False, None, f"User execution limit exceeded ({self.max_executions_per_user})" + + # Score: priority * 1e12 + timestamp (lower = higher priority, earlier = higher priority) + timestamp = time.time() + score = priority.value * 1e12 + timestamp + + # Use pipeline for atomicity + pipe = self._redis.pipeline() + + # Add to sorted set + pipe.zadd(self.QUEUE_KEY, {execution_id: score}) + + # Store event data + data_key = f"{self.DATA_KEY_PREFIX}:{execution_id}" + event_data["_enqueue_timestamp"] = timestamp + event_data["_priority"] = priority.value + event_data["_user_id"] = user_id + pipe.hset(data_key, mapping={k: json.dumps(v) if not isinstance(v, str) else v for k, v in event_data.items()}) + pipe.expire(data_key, self.stale_timeout_seconds + 60) + + # Increment user count + pipe.hincrby(self.USER_COUNT_KEY, user_id, 1) + + await pipe.execute() + + # Get position + position = await self._redis.zrank(self.QUEUE_KEY, execution_id) + + self._logger.info( + f"Enqueued execution {execution_id}. Priority: {priority.name}, " + f"Position: {position}, Queue size: {queue_size + 1}" + ) + + return True, position, None + + async def dequeue(self) -> tuple[str, dict[str, object | float | str]] | None: + """Remove and return highest priority execution. Returns (execution_id, event_data) or None.""" + while True: + # Pop the lowest score (highest priority) + result = await self._redis.zpopmin(self.QUEUE_KEY, count=1) + if not result: + return None + + execution_id = result[0][0] + if isinstance(execution_id, bytes): + execution_id = execution_id.decode() + + # Get event data + data_key = f"{self.DATA_KEY_PREFIX}:{execution_id}" + raw_data = await self._redis.hgetall(data_key) # type: ignore[misc] + + if not raw_data: + # Data expired or missing, skip this entry + self._logger.warning(f"Queue entry {execution_id} has no data, skipping") + continue + + # Parse data + event_data: dict[str, object | float | str] = {} + for k, v in raw_data.items(): + key = k.decode() if isinstance(k, bytes) else k + val = v.decode() if isinstance(v, bytes) else v + try: + event_data[key] = json.loads(val) + except (json.JSONDecodeError, TypeError): + event_data[key] = val + + # Check if stale + enqueue_time_val = event_data.pop("_enqueue_timestamp", 0) + enqueue_time = float(enqueue_time_val) if isinstance(enqueue_time_val, (int, float, str)) else 0.0 + event_data.pop("_priority", None) + user_id_val = event_data.pop("_user_id", "anonymous") + user_id = str(user_id_val) + + age = time.time() - enqueue_time + if age > self.stale_timeout_seconds: + # Stale, clean up and continue + await self._redis.delete(data_key) + await self._redis.hincrby(self.USER_COUNT_KEY, user_id, -1) # type: ignore[misc] + self._logger.info(f"Skipped stale execution {execution_id} (age: {age:.2f}s)") + continue + + # Clean up + await self._redis.delete(data_key) + await self._redis.hincrby(self.USER_COUNT_KEY, user_id, -1) # type: ignore[misc] + + self._logger.info(f"Dequeued execution {execution_id}. Wait time: {age:.2f}s") + return execution_id, event_data + + async def remove(self, execution_id: str) -> bool: + """Remove specific execution from queue. Returns True if removed.""" + # Get user_id before removing + data_key = f"{self.DATA_KEY_PREFIX}:{execution_id}" + raw_data = await self._redis.hgetall(data_key) # type: ignore[misc] + + removed = await self._redis.zrem(self.QUEUE_KEY, execution_id) + if removed: + # Decrement user count + if raw_data: + user_id_raw = raw_data.get(b"_user_id") or raw_data.get("_user_id") + if user_id_raw: + user_id = user_id_raw.decode() if isinstance(user_id_raw, bytes) else user_id_raw + try: + user_id = json.loads(user_id) + except (json.JSONDecodeError, TypeError): + pass + await self._redis.hincrby(self.USER_COUNT_KEY, str(user_id), -1) # type: ignore[misc] + + await self._redis.delete(data_key) + self._logger.info(f"Removed execution {execution_id} from queue") + return True + return False + + async def get_position(self, execution_id: str) -> int | None: + """Get queue position of execution (0-indexed).""" + result = await self._redis.zrank(self.QUEUE_KEY, execution_id) + return int(result) if result is not None else None + + async def get_stats(self) -> QueueStats: + """Get queue statistics.""" + total_size = await self._redis.zcard(self.QUEUE_KEY) + + # Count by priority (sample first 1000) + priority_counts: dict[str, int] = {} + entries = await self._redis.zrange(self.QUEUE_KEY, 0, 999, withscores=True) + for _, score in entries: + priority_value = int(score // 1e12) + try: + priority_name = QueuePriority(priority_value).name + except ValueError: + priority_name = "UNKNOWN" + priority_counts[priority_name] = priority_counts.get(priority_name, 0) + 1 + + return QueueStats( + total_size=total_size, + priority_distribution=priority_counts, + max_queue_size=self.max_queue_size, + utilization_percent=(total_size / self.max_queue_size) * 100 if self.max_queue_size > 0 else 0, + ) + + async def cleanup_stale(self) -> int: + """Remove stale entries. Returns count removed. Call periodically.""" + removed = 0 + threshold_score = QueuePriority.BACKGROUND.value * 1e12 + (time.time() - self.stale_timeout_seconds) + + # Get entries older than threshold + stale_entries = await self._redis.zrangebyscore(self.QUEUE_KEY, "-inf", threshold_score, start=0, num=100) + + for entry in stale_entries: + execution_id = entry.decode() if isinstance(entry, bytes) else entry + if await self.remove(execution_id): + removed += 1 + + if removed: + self._logger.info(f"Cleaned {removed} stale executions from queue") + + return removed diff --git a/backend/app/db/repositories/execution_state_repository.py b/backend/app/db/repositories/execution_state_repository.py new file mode 100644 index 00000000..e343ff02 --- /dev/null +++ b/backend/app/db/repositories/execution_state_repository.py @@ -0,0 +1,65 @@ +"""Redis-backed execution state tracking repository. + +Replaces in-memory state tracking (_active_executions sets) with Redis +for stateless, horizontally-scalable services. +""" + +from __future__ import annotations + +import logging + +import redis.asyncio as redis + + +class ExecutionStateRepository: + """Redis-backed execution state tracking. + + Provides atomic claim/release operations for executions, + replacing in-memory sets like `_active_executions`. + """ + + KEY_PREFIX = "exec:active" + + def __init__(self, redis_client: redis.Redis, logger: logging.Logger) -> None: + self._redis = redis_client + self._logger = logger + + async def try_claim(self, execution_id: str, ttl_seconds: int = 3600) -> bool: + """Atomically claim an execution. Returns True if claimed, False if already claimed. + + Uses Redis SETNX for atomic check-and-set. + TTL ensures cleanup if service crashes without releasing. + """ + key = f"{self.KEY_PREFIX}:{execution_id}" + result = await self._redis.set(key, "1", nx=True, ex=ttl_seconds) + if result: + self._logger.debug(f"Claimed execution {execution_id}") + return result is not None + + async def is_active(self, execution_id: str) -> bool: + """Check if an execution is currently active/claimed.""" + key = f"{self.KEY_PREFIX}:{execution_id}" + result = await self._redis.exists(key) + return bool(result) + + async def remove(self, execution_id: str) -> bool: + """Release/remove an execution claim. Returns True if was claimed.""" + key = f"{self.KEY_PREFIX}:{execution_id}" + deleted = await self._redis.delete(key) + if deleted: + self._logger.debug(f"Released execution {execution_id}") + return bool(deleted > 0) + + async def get_active_count(self) -> int: + """Get count of active executions. For metrics only.""" + pattern = f"{self.KEY_PREFIX}:*" + count = 0 + async for _ in self._redis.scan_iter(match=pattern, count=100): + count += 1 + return count + + async def extend_ttl(self, execution_id: str, ttl_seconds: int = 3600) -> bool: + """Extend the TTL of an active execution. Returns True if extended.""" + key = f"{self.KEY_PREFIX}:{execution_id}" + result = await self._redis.expire(key, ttl_seconds) + return bool(result) diff --git a/backend/app/db/repositories/pod_state_repository.py b/backend/app/db/repositories/pod_state_repository.py new file mode 100644 index 00000000..0e652720 --- /dev/null +++ b/backend/app/db/repositories/pod_state_repository.py @@ -0,0 +1,180 @@ +"""Redis-backed pod state tracking repository. + +Replaces in-memory pod state tracking (_tracked_pods, _active_creations) +for stateless, horizontally-scalable services. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from datetime import datetime, timezone + +import redis.asyncio as redis + + +@dataclass +class PodState: + """State of a tracked pod.""" + + pod_name: str + execution_id: str + status: str + created_at: datetime + updated_at: datetime + metadata: dict[str, object] | None = None + + +class PodStateRepository: + """Redis-backed pod state tracking. + + Provides atomic operations for pod creation tracking, + replacing in-memory sets like `_active_creations` and `_tracked_pods`. + """ + + CREATION_KEY_PREFIX = "pod:creating" + TRACKED_KEY_PREFIX = "pod:tracked" + RESOURCE_VERSION_KEY = "pod:resource_version" + + def __init__(self, redis_client: redis.Redis, logger: logging.Logger) -> None: + self._redis = redis_client + self._logger = logger + + # --- Active Creations (for KubernetesWorker) --- + + async def try_claim_creation(self, execution_id: str, ttl_seconds: int = 300) -> bool: + """Atomically claim a pod creation slot. Returns True if claimed.""" + key = f"{self.CREATION_KEY_PREFIX}:{execution_id}" + result = await self._redis.set(key, "1", nx=True, ex=ttl_seconds) + if result: + self._logger.debug(f"Claimed pod creation for {execution_id}") + return result is not None + + async def release_creation(self, execution_id: str) -> bool: + """Release a pod creation claim.""" + key = f"{self.CREATION_KEY_PREFIX}:{execution_id}" + deleted = await self._redis.delete(key) + if deleted: + self._logger.debug(f"Released pod creation for {execution_id}") + return bool(deleted) + + async def get_active_creations_count(self) -> int: + """Get count of active pod creations.""" + count = 0 + async for _ in self._redis.scan_iter(match=f"{self.CREATION_KEY_PREFIX}:*", count=100): + count += 1 + return count + + async def is_creation_active(self, execution_id: str) -> bool: + """Check if a pod creation is active.""" + key = f"{self.CREATION_KEY_PREFIX}:{execution_id}" + result = await self._redis.exists(key) + return bool(result) + + # --- Tracked Pods (for PodMonitor) --- + + async def track_pod( + self, + pod_name: str, + execution_id: str, + status: str, + metadata: dict[str, object] | None = None, + ttl_seconds: int = 7200, + ) -> None: + """Track a pod's state.""" + key = f"{self.TRACKED_KEY_PREFIX}:{pod_name}" + now = datetime.now(timezone.utc).isoformat() + + data = { + "pod_name": pod_name, + "execution_id": execution_id, + "status": status, + "created_at": now, + "updated_at": now, + "metadata": json.dumps(metadata) if metadata else "{}", + } + + await self._redis.hset(key, mapping=data) # type: ignore[misc] + await self._redis.expire(key, ttl_seconds) + self._logger.debug(f"Tracking pod {pod_name} for execution {execution_id}") + + async def update_pod_status(self, pod_name: str, status: str) -> bool: + """Update a tracked pod's status. Returns True if updated.""" + key = f"{self.TRACKED_KEY_PREFIX}:{pod_name}" + exists = await self._redis.exists(key) + if not exists: + return False + + now = datetime.now(timezone.utc).isoformat() + await self._redis.hset(key, mapping={"status": status, "updated_at": now}) # type: ignore[misc] + return True + + async def untrack_pod(self, pod_name: str) -> bool: + """Remove a pod from tracking. Returns True if removed.""" + key = f"{self.TRACKED_KEY_PREFIX}:{pod_name}" + deleted = await self._redis.delete(key) + if deleted: + self._logger.debug(f"Untracked pod {pod_name}") + return bool(deleted) + + async def get_pod_state(self, pod_name: str) -> PodState | None: + """Get state of a tracked pod.""" + key = f"{self.TRACKED_KEY_PREFIX}:{pod_name}" + data: dict[bytes | str, bytes | str] = await self._redis.hgetall(key) # type: ignore[misc] + if not data: + return None + + def get_str(k: str) -> str: + val = data.get(k.encode(), data.get(k, "")) + return val.decode() if isinstance(val, bytes) else str(val) + + metadata_str = get_str("metadata") + try: + metadata = json.loads(metadata_str) if metadata_str else None + except json.JSONDecodeError: + metadata = None + + return PodState( + pod_name=get_str("pod_name"), + execution_id=get_str("execution_id"), + status=get_str("status"), + created_at=datetime.fromisoformat(get_str("created_at")), + updated_at=datetime.fromisoformat(get_str("updated_at")), + metadata=metadata, + ) + + async def is_pod_tracked(self, pod_name: str) -> bool: + """Check if a pod is being tracked.""" + key = f"{self.TRACKED_KEY_PREFIX}:{pod_name}" + result = await self._redis.exists(key) + return bool(result) + + async def get_tracked_pods_count(self) -> int: + """Get count of tracked pods.""" + count = 0 + async for _ in self._redis.scan_iter(match=f"{self.TRACKED_KEY_PREFIX}:*", count=100): + count += 1 + return count + + async def get_tracked_pod_names(self) -> set[str]: + """Get set of all tracked pod names.""" + names: set[str] = set() + prefix_len = len(self.TRACKED_KEY_PREFIX) + 1 + async for key in self._redis.scan_iter(match=f"{self.TRACKED_KEY_PREFIX}:*", count=100): + key_str = key.decode() if isinstance(key, bytes) else key + names.add(key_str[prefix_len:]) + return names + + # --- Resource Version (for PodMonitor watch) --- + + async def get_resource_version(self) -> str | None: + """Get the last known resource version for watch resumption.""" + result = await self._redis.get(self.RESOURCE_VERSION_KEY) + if result: + return result.decode() if isinstance(result, bytes) else result + return None + + async def set_resource_version(self, version: str) -> None: + """Store the resource version for watch resumption.""" + await self._redis.set(self.RESOURCE_VERSION_KEY, version) diff --git a/backend/app/db/repositories/resource_repository.py b/backend/app/db/repositories/resource_repository.py new file mode 100644 index 00000000..1f6b54b0 --- /dev/null +++ b/backend/app/db/repositories/resource_repository.py @@ -0,0 +1,300 @@ +"""Redis-backed resource allocation repository. + +Replaces in-memory resource tracking (ResourceManager) with Redis +for stateless, horizontally-scalable services. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +import redis.asyncio as redis + + +@dataclass +class ResourceAllocation: + """Resource allocation for an execution.""" + + execution_id: str + cpu_cores: float + memory_mb: int + gpu_count: int = 0 + + @property + def cpu_millicores(self) -> int: + """Get CPU in millicores for Kubernetes.""" + return int(self.cpu_cores * 1000) + + @property + def memory_bytes(self) -> int: + """Get memory in bytes.""" + return self.memory_mb * 1024 * 1024 + + +@dataclass +class ResourceStats: + """Resource statistics.""" + + total_cpu: float + total_memory_mb: int + total_gpu: int + available_cpu: float + available_memory_mb: int + available_gpu: int + allocation_count: int + + +class ResourceRepository: + """Redis-backed resource allocation tracking. + + Uses Redis for atomic resource allocation with Lua scripts. + Replaces in-memory ResourceManager._allocations dict. + """ + + POOL_KEY = "resource:pool" + ALLOC_KEY_PREFIX = "resource:alloc" + + # Default allocations by language + DEFAULT_ALLOCATIONS = { + "python": (0.5, 512), + "javascript": (0.5, 512), + "go": (0.25, 256), + "rust": (0.5, 512), + "java": (1.0, 1024), + "cpp": (0.5, 512), + "r": (1.0, 2048), + } + + def __init__( + self, + redis_client: redis.Redis, + logger: logging.Logger, + total_cpu_cores: float = 32.0, + total_memory_mb: int = 65536, + total_gpu_count: int = 0, + overcommit_factor: float = 1.2, + max_cpu_per_execution: float = 4.0, + max_memory_per_execution_mb: int = 8192, + min_reserve_cpu: float = 2.0, + min_reserve_memory_mb: int = 4096, + ) -> None: + self._redis = redis_client + self._logger = logger + + # Apply overcommit + self._total_cpu = total_cpu_cores * overcommit_factor + self._total_memory = int(total_memory_mb * overcommit_factor) + self._total_gpu = total_gpu_count + + self._max_cpu_per_exec = max_cpu_per_execution + self._max_memory_per_exec = max_memory_per_execution_mb + + # Adjust reserves for small pools (max 10% of total) + self._min_reserve_cpu = min(min_reserve_cpu, 0.1 * self._total_cpu) + self._min_reserve_memory = min(min_reserve_memory_mb, int(0.1 * self._total_memory)) + + async def initialize(self) -> None: + """Initialize the resource pool if not exists.""" + exists = await self._redis.exists(self.POOL_KEY) + if not exists: + await self._redis.hset( # type: ignore[misc] + self.POOL_KEY, + mapping={ + "total_cpu": str(self._total_cpu), + "total_memory": str(self._total_memory), + "total_gpu": str(self._total_gpu), + "available_cpu": str(self._total_cpu), + "available_memory": str(self._total_memory), + "available_gpu": str(self._total_gpu), + }, + ) + self._logger.info( + f"Initialized resource pool: {self._total_cpu} CPU, " + f"{self._total_memory}MB RAM, {self._total_gpu} GPU" + ) + + async def allocate( + self, + execution_id: str, + language: str, + requested_cpu: float | None = None, + requested_memory_mb: int | None = None, + requested_gpu: int = 0, + ) -> ResourceAllocation | None: + """Allocate resources for execution. Returns allocation or None if insufficient.""" + # Check if already allocated + alloc_key = f"{self.ALLOC_KEY_PREFIX}:{execution_id}" + existing = await self._redis.hgetall(alloc_key) # type: ignore[misc] + if existing: + self._logger.warning(f"Execution {execution_id} already has allocation") + return ResourceAllocation( + execution_id=execution_id, + cpu_cores=float(existing.get(b"cpu", existing.get("cpu", 0))), + memory_mb=int(existing.get(b"memory", existing.get("memory", 0))), + gpu_count=int(existing.get(b"gpu", existing.get("gpu", 0))), + ) + + # Determine requested resources + if requested_cpu is None or requested_memory_mb is None: + default_cpu, default_memory = self.DEFAULT_ALLOCATIONS.get(language, (0.5, 512)) + requested_cpu = requested_cpu or default_cpu + requested_memory_mb = requested_memory_mb or default_memory + + # Apply limits + requested_cpu = min(requested_cpu, self._max_cpu_per_exec) + requested_memory_mb = min(requested_memory_mb, self._max_memory_per_exec) + + # Atomic allocation using Lua script + lua_script = """ + local pool_key = KEYS[1] + local alloc_key = KEYS[2] + local req_cpu = tonumber(ARGV[1]) + local req_memory = tonumber(ARGV[2]) + local req_gpu = tonumber(ARGV[3]) + local min_cpu = tonumber(ARGV[4]) + local min_memory = tonumber(ARGV[5]) + + local avail_cpu = tonumber(redis.call('HGET', pool_key, 'available_cpu') or '0') + local avail_memory = tonumber(redis.call('HGET', pool_key, 'available_memory') or '0') + local avail_gpu = tonumber(redis.call('HGET', pool_key, 'available_gpu') or '0') + + local cpu_after = avail_cpu - req_cpu + local memory_after = avail_memory - req_memory + local gpu_after = avail_gpu - req_gpu + + if cpu_after < min_cpu or memory_after < min_memory or gpu_after < 0 then + return 0 + end + + redis.call('HSET', pool_key, 'available_cpu', tostring(cpu_after)) + redis.call('HSET', pool_key, 'available_memory', tostring(memory_after)) + redis.call('HSET', pool_key, 'available_gpu', tostring(gpu_after)) + + redis.call('HSET', alloc_key, 'cpu', tostring(req_cpu), 'memory', tostring(req_memory), + 'gpu', tostring(req_gpu)) + redis.call('EXPIRE', alloc_key, 7200) + + return 1 + """ + + result = await self._redis.eval( # type: ignore[misc] + lua_script, + 2, + self.POOL_KEY, + alloc_key, + str(requested_cpu), + str(requested_memory_mb), + str(requested_gpu), + str(self._min_reserve_cpu), + str(self._min_reserve_memory), + ) + + if not result: + pool = await self._redis.hgetall(self.POOL_KEY) # type: ignore[misc] + avail_cpu = float(pool.get(b"available_cpu", pool.get("available_cpu", 0))) + avail_memory = int(float(pool.get(b"available_memory", pool.get("available_memory", 0)))) + self._logger.warning( + f"Insufficient resources for {execution_id}. " + f"Requested: {requested_cpu} CPU, {requested_memory_mb}MB. " + f"Available: {avail_cpu} CPU, {avail_memory}MB" + ) + return None + + self._logger.info( + f"Allocated resources for {execution_id}: " + f"{requested_cpu} CPU, {requested_memory_mb}MB RAM, {requested_gpu} GPU" + ) + + return ResourceAllocation( + execution_id=execution_id, + cpu_cores=requested_cpu, + memory_mb=requested_memory_mb, + gpu_count=requested_gpu, + ) + + async def release(self, execution_id: str) -> bool: + """Release resource allocation. Returns True if released.""" + alloc_key = f"{self.ALLOC_KEY_PREFIX}:{execution_id}" + + # Get current allocation + alloc = await self._redis.hgetall(alloc_key) # type: ignore[misc] + if not alloc: + self._logger.warning(f"No allocation found for {execution_id}") + return False + + cpu = float(alloc.get(b"cpu", alloc.get("cpu", 0))) + memory = int(float(alloc.get(b"memory", alloc.get("memory", 0)))) + gpu = int(alloc.get(b"gpu", alloc.get("gpu", 0))) + + # Release atomically + pipe = self._redis.pipeline() + pipe.hincrbyfloat(self.POOL_KEY, "available_cpu", cpu) + pipe.hincrbyfloat(self.POOL_KEY, "available_memory", memory) + pipe.hincrby(self.POOL_KEY, "available_gpu", gpu) + pipe.delete(alloc_key) + await pipe.execute() + + self._logger.info(f"Released resources for {execution_id}: {cpu} CPU, {memory}MB RAM, {gpu} GPU") + return True + + async def get_allocation(self, execution_id: str) -> ResourceAllocation | None: + """Get current allocation for execution.""" + alloc_key = f"{self.ALLOC_KEY_PREFIX}:{execution_id}" + alloc = await self._redis.hgetall(alloc_key) # type: ignore[misc] + if not alloc: + return None + + return ResourceAllocation( + execution_id=execution_id, + cpu_cores=float(alloc.get(b"cpu", alloc.get("cpu", 0))), + memory_mb=int(float(alloc.get(b"memory", alloc.get("memory", 0)))), + gpu_count=int(alloc.get(b"gpu", alloc.get("gpu", 0))), + ) + + async def get_stats(self) -> ResourceStats: + """Get resource statistics.""" + pool = await self._redis.hgetall(self.POOL_KEY) # type: ignore[misc] + + # Decode bytes if needed + def get_val(key: str, default: str = "0") -> str: + return str(pool.get(key.encode(), pool.get(key, default))) + + total_cpu = float(get_val("total_cpu")) + total_memory = int(float(get_val("total_memory"))) + total_gpu = int(get_val("total_gpu")) + available_cpu = float(get_val("available_cpu")) + available_memory = int(float(get_val("available_memory"))) + available_gpu = int(get_val("available_gpu")) + + # Count allocations + count = 0 + async for _ in self._redis.scan_iter(match=f"{self.ALLOC_KEY_PREFIX}:*", count=100): + count += 1 + + return ResourceStats( + total_cpu=total_cpu, + total_memory_mb=total_memory, + total_gpu=total_gpu, + available_cpu=available_cpu, + available_memory_mb=available_memory, + available_gpu=available_gpu, + allocation_count=count, + ) + + async def can_allocate(self, cpu_cores: float, memory_mb: int, gpu_count: int = 0) -> bool: + """Check if resources can be allocated.""" + pool = await self._redis.hgetall(self.POOL_KEY) # type: ignore[misc] + + def get_val(key: str) -> float: + return float(pool.get(key.encode(), pool.get(key, 0))) + + available_cpu = get_val("available_cpu") + available_memory = get_val("available_memory") + available_gpu = get_val("available_gpu") + + return ( + (available_cpu - cpu_cores) >= self._min_reserve_cpu + and (available_memory - memory_mb) >= self._min_reserve_memory + and (available_gpu - gpu_count) >= 0 + ) diff --git a/backend/app/dlq/manager.py b/backend/app/dlq/manager.py index 1e20dc23..c1f5472b 100644 --- a/backend/app/dlq/manager.py +++ b/backend/app/dlq/manager.py @@ -1,13 +1,19 @@ +"""DLQ Manager - stateless event handler. + +Manages Dead Letter Queue messages. Receives events, +processes them, and handles retries. No lifecycle management. +""" + +from __future__ import annotations + import asyncio import json import logging from datetime import datetime, timezone -from typing import Any, Callable -from aiokafka import AIOKafkaConsumer, AIOKafkaProducer +from aiokafka import AIOKafkaProducer from opentelemetry.trace import SpanKind -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import DLQMetrics from app.core.tracing import EventAttributes from app.core.tracing.utils import extract_trace_context, get_tracer, inject_trace_context @@ -21,7 +27,7 @@ RetryPolicy, RetryStrategy, ) -from app.domain.enums.kafka import GroupId, KafkaTopic +from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import ( DLQMessageDiscardedEvent, DLQMessageReceivedEvent, @@ -32,149 +38,118 @@ from app.settings import Settings -class DLQManager(LifecycleEnabled): +class DLQManager: + """Stateless DLQ manager - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + Worker entrypoint handles the consume loop. + """ + def __init__( self, settings: Settings, - consumer: AIOKafkaConsumer, producer: AIOKafkaProducer, schema_registry: SchemaRegistryManager, logger: logging.Logger, dlq_metrics: DLQMetrics, dlq_topic: KafkaTopic = KafkaTopic.DEAD_LETTER_QUEUE, retry_topic_suffix: str = "-retry", - default_retry_policy: RetryPolicy | None = None, - ): - super().__init__() - self.settings = settings - self.metrics = dlq_metrics - self.schema_registry = schema_registry - self.logger = logger - self.dlq_topic = dlq_topic - self.retry_topic_suffix = retry_topic_suffix - self.default_retry_policy = default_retry_policy or RetryPolicy( + ) -> None: + self._settings = settings + self._producer = producer + self._schema_registry = schema_registry + self._logger = logger + self._metrics = dlq_metrics + self._dlq_topic = dlq_topic + self._retry_topic_suffix = retry_topic_suffix + self._default_retry_policy = RetryPolicy( topic="default", strategy=RetryStrategy.EXPONENTIAL_BACKOFF ) - self.consumer: AIOKafkaConsumer = consumer - self.producer: AIOKafkaProducer = producer + self._retry_policies: dict[str, RetryPolicy] = {} + self._filters: list[object] = [] + self._dlq_events_topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.DLQ_EVENTS}" + self._event_metadata = EventMetadata(service_name="dlq-manager", service_version="1.0.0") - self._process_task: asyncio.Task[None] | None = None - self._monitor_task: asyncio.Task[None] | None = None + def set_retry_policy(self, topic: str, policy: RetryPolicy) -> None: + """Set retry policy for a specific topic.""" + self._retry_policies[topic] = policy - # Topic-specific retry policies - self._retry_policies: dict[str, RetryPolicy] = {} + def set_default_retry_policy(self, policy: RetryPolicy) -> None: + """Set the default retry policy.""" + self._default_retry_policy = policy - # Message filters - self._filters: list[Callable[[DLQMessage], bool]] = [] + def add_filter(self, filter_func: object) -> None: + """Add a message filter.""" + self._filters.append(filter_func) - self._dlq_events_topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.DLQ_EVENTS}" - self._event_metadata = EventMetadata(service_name="dlq-manager", service_version="1.0.0") + async def handle_dlq_message(self, raw_message: bytes, headers: dict[str, str]) -> None: + """Handle a DLQ message from Kafka. - def _kafka_msg_to_message(self, msg: Any) -> DLQMessage: - """Parse Kafka ConsumerRecord into DLQMessage.""" - data = json.loads(msg.value) - headers = {k: v.decode() for k, v in (msg.headers or [])} - return DLQMessage(**data, dlq_offset=msg.offset, dlq_partition=msg.partition, headers=headers) - - async def _on_start(self) -> None: - """Start DLQ manager.""" - # Start producer and consumer in parallel for faster startup - await asyncio.gather(self.producer.start(), self.consumer.start()) - - # Start processing tasks - self._process_task = asyncio.create_task(self._process_messages()) - self._monitor_task = asyncio.create_task(self._monitor_dlq()) - - self.logger.info("DLQ Manager started") - - async def _on_stop(self) -> None: - """Stop DLQ manager.""" - # Cancel tasks - for task in [self._process_task, self._monitor_task]: - if task: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - # Stop Kafka clients - await self.consumer.stop() - await self.producer.stop() - - self.logger.info("DLQ Manager stopped") - - async def _process_messages(self) -> None: - """Process DLQ messages using async iteration.""" - async for msg in self.consumer: - try: - start = asyncio.get_running_loop().time() - dlq_msg = self._kafka_msg_to_message(msg) - - # Record metrics - self.metrics.record_dlq_message_received(dlq_msg.original_topic, dlq_msg.event.event_type) - self.metrics.record_dlq_message_age((datetime.now(timezone.utc) - dlq_msg.failed_at).total_seconds()) - - # Process with tracing - ctx = extract_trace_context(dlq_msg.headers) - with get_tracer().start_as_current_span( - name="dlq.consume", - context=ctx, - kind=SpanKind.CONSUMER, - attributes={ - EventAttributes.KAFKA_TOPIC: self.dlq_topic, - EventAttributes.EVENT_TYPE: dlq_msg.event.event_type, - EventAttributes.EVENT_ID: dlq_msg.event.event_id, - }, - ): - await self._process_dlq_message(dlq_msg) - - # Commit and record duration - await self.consumer.commit() - self.metrics.record_dlq_processing_duration(asyncio.get_running_loop().time() - start, "process") + Called by worker entrypoint for each message from consume loop. + """ + start = asyncio.get_running_loop().time() - except Exception as e: - self.logger.error(f"Error processing DLQ message: {e}") + try: + data = json.loads(raw_message) + dlq_msg = DLQMessage(**data, headers=headers) + + self._metrics.record_dlq_message_received(dlq_msg.original_topic, dlq_msg.event.event_type) + self._metrics.record_dlq_message_age( + (datetime.now(timezone.utc) - dlq_msg.failed_at).total_seconds() + ) + + ctx = extract_trace_context(dlq_msg.headers) + with get_tracer().start_as_current_span( + name="dlq.consume", + context=ctx, + kind=SpanKind.CONSUMER, + attributes={ + EventAttributes.KAFKA_TOPIC: str(self._dlq_topic), + EventAttributes.EVENT_TYPE: dlq_msg.event.event_type, + EventAttributes.EVENT_ID: dlq_msg.event.event_id, + }, + ): + await self._process_dlq_message(dlq_msg) + + self._metrics.record_dlq_processing_duration( + asyncio.get_running_loop().time() - start, "process" + ) + + except Exception as e: + self._logger.error(f"Error processing DLQ message: {e}") async def _process_dlq_message(self, message: DLQMessage) -> None: - # Apply filters + """Process a DLQ message.""" for filter_func in self._filters: - if not filter_func(message): - self.logger.info("Message filtered out", extra={"event_id": message.event.event_id}) + if not filter_func(message): # type: ignore[operator] + self._logger.info("Message filtered out", extra={"event_id": message.event.event_id}) return - # Store in MongoDB via Beanie await self._store_message(message) - # Get retry policy for topic - retry_policy = self._retry_policies.get(message.original_topic, self.default_retry_policy) + retry_policy = self._retry_policies.get(message.original_topic, self._default_retry_policy) - # Check if should retry if not retry_policy.should_retry(message): await self._discard_message(message, "max_retries_exceeded") return - # Calculate next retry time next_retry = retry_policy.get_next_retry_time(message) - # Update message status await self._update_message_status( message.event.event_id, DLQMessageUpdate(status=DLQMessageStatus.SCHEDULED, next_retry_at=next_retry), ) - # If immediate retry, process now if retry_policy.strategy == RetryStrategy.IMMEDIATE: await self._retry_message(message) async def _store_message(self, message: DLQMessage) -> None: - # Ensure message has proper status and timestamps + """Store DLQ message in MongoDB.""" message.status = DLQMessageStatus.PENDING message.last_updated = datetime.now(timezone.utc) doc = DLQMessageDocument(**message.model_dump()) - # Upsert using Beanie existing = await DLQMessageDocument.find_one({"event.event_id": message.event.event_id}) if existing: doc.id = existing.id @@ -183,11 +158,12 @@ async def _store_message(self, message: DLQMessage) -> None: await self._emit_message_received_event(message) async def _update_message_status(self, event_id: str, update: DLQMessageUpdate) -> None: + """Update DLQ message status.""" doc = await DLQMessageDocument.find_one({"event.event_id": event_id}) if not doc: return - update_dict: dict[str, Any] = {"status": update.status, "last_updated": datetime.now(timezone.utc)} + update_dict: dict[str, object] = {"status": update.status, "last_updated": datetime.now(timezone.utc)} if update.next_retry_at is not None: update_dict["next_retry_at"] = update.next_retry_at if update.retried_at is not None: @@ -204,8 +180,8 @@ async def _update_message_status(self, event_id: str, update: DLQMessageUpdate) await doc.set(update_dict) async def _retry_message(self, message: DLQMessage) -> None: - # Send to retry topic first (for monitoring) - retry_topic = f"{message.original_topic}{self.retry_topic_suffix}" + """Retry a DLQ message.""" + retry_topic = f"{message.original_topic}{self._retry_topic_suffix}" hdrs: dict[str, str] = { "dlq_retry_count": str(message.retry_count + 1), @@ -215,31 +191,26 @@ async def _retry_message(self, message: DLQMessage) -> None: hdrs = inject_trace_context(hdrs) kafka_headers: list[tuple[str, bytes]] = [(k, v.encode()) for k, v in hdrs.items()] - # Get the original event event = message.event - # Send to retry topic - await self.producer.send_and_wait( + await self._producer.send_and_wait( topic=retry_topic, value=json.dumps(event.model_dump(mode="json")).encode(), key=message.event.event_id.encode(), headers=kafka_headers, ) - # Send to original topic - await self.producer.send_and_wait( + await self._producer.send_and_wait( topic=message.original_topic, value=json.dumps(event.model_dump(mode="json")).encode(), key=message.event.event_id.encode(), headers=kafka_headers, ) - # Update metrics - self.metrics.record_dlq_message_retried(message.original_topic, message.event.event_type, "success") + self._metrics.record_dlq_message_retried(message.original_topic, message.event.event_type, "success") new_retry_count = message.retry_count + 1 - # Update status await self._update_message_status( message.event.event_id, DLQMessageUpdate( @@ -249,16 +220,14 @@ async def _retry_message(self, message: DLQMessage) -> None: ), ) - # Emit DLQ message retried event await self._emit_message_retried_event(message, retry_topic, new_retry_count) - self.logger.info("Successfully retried message", extra={"event_id": message.event.event_id}) + self._logger.info("Successfully retried message", extra={"event_id": message.event.event_id}) async def _discard_message(self, message: DLQMessage, reason: str) -> None: - # Update metrics - self.metrics.record_dlq_message_discarded(message.original_topic, message.event.event_type, reason) + """Discard a DLQ message.""" + self._metrics.record_dlq_message_discarded(message.original_topic, message.event.event_type, reason) - # Update status await self._update_message_status( message.event.event_id, DLQMessageUpdate( @@ -270,57 +239,49 @@ async def _discard_message(self, message: DLQMessage, reason: str) -> None: await self._emit_message_discarded_event(message, reason) - self.logger.warning("Discarded message", extra={"event_id": message.event.event_id, "reason": reason}) + self._logger.warning("Discarded message", extra={"event_id": message.event.event_id, "reason": reason}) - async def _monitor_dlq(self) -> None: - while self.is_running: - try: - # Find messages ready for retry using Beanie - now = datetime.now(timezone.utc) - - docs = ( - await DLQMessageDocument.find( - { - "status": DLQMessageStatus.SCHEDULED, - "next_retry_at": {"$lte": now}, - } - ) - .limit(100) - .to_list() - ) - - for doc in docs: - message = DLQMessage.model_validate(doc, from_attributes=True) - await self._retry_message(message) - - # Update queue size metrics - await self._update_queue_metrics() - - # Sleep before next check - await asyncio.sleep(10) + async def check_scheduled_retries(self, batch_size: int = 100) -> int: + """Check for scheduled messages ready for retry. - except Exception as e: - self.logger.error(f"Error in DLQ monitor: {e}") - await asyncio.sleep(60) + Should be called periodically from worker entrypoint. + Returns number of messages retried. + """ + now = datetime.now(timezone.utc) + + docs = ( + await DLQMessageDocument.find( + { + "status": DLQMessageStatus.SCHEDULED, + "next_retry_at": {"$lte": now}, + } + ) + .limit(batch_size) + .to_list() + ) + + count = 0 + for doc in docs: + message = DLQMessage.model_validate(doc, from_attributes=True) + await self._retry_message(message) + count += 1 + + await self._update_queue_metrics() + + return count async def _update_queue_metrics(self) -> None: - # Get counts by topic using Beanie aggregation - pipeline: list[dict[str, Any]] = [ + """Update queue size metrics.""" + pipeline: list[dict[str, object]] = [ {"$match": {"status": {"$in": [DLQMessageStatus.PENDING, DLQMessageStatus.SCHEDULED]}}}, {"$group": {"_id": "$original_topic", "count": {"$sum": 1}}}, ] async for result in DLQMessageDocument.aggregate(pipeline): - self.metrics.update_dlq_queue_size(result["_id"], result["count"]) - - def set_retry_policy(self, topic: str, policy: RetryPolicy) -> None: - self._retry_policies[topic] = policy - - def add_filter(self, filter_func: Callable[[DLQMessage], bool]) -> None: - self._filters.append(filter_func) + self._metrics.update_dlq_queue_size(result["_id"], result["count"]) async def _emit_message_received_event(self, message: DLQMessage) -> None: - """Emit a DLQMessageReceivedEvent to the DLQ events topic.""" + """Emit a DLQMessageReceivedEvent.""" event = DLQMessageReceivedEvent( dlq_event_id=message.event.event_id, original_topic=message.original_topic, @@ -333,8 +294,10 @@ async def _emit_message_received_event(self, message: DLQMessage) -> None: ) await self._produce_dlq_event(event) - async def _emit_message_retried_event(self, message: DLQMessage, retry_topic: str, new_retry_count: int) -> None: - """Emit a DLQMessageRetriedEvent to the DLQ events topic.""" + async def _emit_message_retried_event( + self, message: DLQMessage, retry_topic: str, new_retry_count: int + ) -> None: + """Emit a DLQMessageRetriedEvent.""" event = DLQMessageRetriedEvent( dlq_event_id=message.event.event_id, original_topic=message.original_topic, @@ -346,7 +309,7 @@ async def _emit_message_retried_event(self, message: DLQMessage, retry_topic: st await self._produce_dlq_event(event) async def _emit_message_discarded_event(self, message: DLQMessage, reason: str) -> None: - """Emit a DLQMessageDiscardedEvent to the DLQ events topic.""" + """Emit a DLQMessageDiscardedEvent.""" event = DLQMessageDiscardedEvent( dlq_event_id=message.event.event_id, original_topic=message.original_topic, @@ -360,26 +323,26 @@ async def _emit_message_discarded_event(self, message: DLQMessage, reason: str) async def _produce_dlq_event( self, event: DLQMessageReceivedEvent | DLQMessageRetriedEvent | DLQMessageDiscardedEvent ) -> None: - """Produce a DLQ lifecycle event to the DLQ events topic.""" + """Produce a DLQ lifecycle event.""" try: - serialized = await self.schema_registry.serialize_event(event) - await self.producer.send_and_wait( + serialized = await self._schema_registry.serialize_event(event) + await self._producer.send_and_wait( topic=self._dlq_events_topic, value=serialized, key=event.event_id.encode(), ) except Exception as e: - self.logger.error(f"Failed to emit DLQ event {event.event_type}: {e}") + self._logger.error(f"Failed to emit DLQ event {event.event_type}: {e}") async def retry_message_manually(self, event_id: str) -> bool: + """Manually retry a DLQ message.""" doc = await DLQMessageDocument.find_one({"event.event_id": event_id}) if not doc: - self.logger.error("Message not found in DLQ", extra={"event_id": event_id}) + self._logger.error("Message not found in DLQ", extra={"event_id": event_id}) return False - # Guard against invalid states if doc.status in {DLQMessageStatus.DISCARDED, DLQMessageStatus.RETRIED}: - self.logger.info("Skipping manual retry", extra={"event_id": event_id, "status": doc.status}) + self._logger.info("Skipping manual retry", extra={"event_id": event_id, "status": doc.status}) return False message = DLQMessage.model_validate(doc, from_attributes=True) @@ -387,14 +350,7 @@ async def retry_message_manually(self, event_id: str) -> bool: return True async def retry_messages_batch(self, event_ids: list[str]) -> DLQBatchRetryResult: - """Retry multiple DLQ messages in batch. - - Args: - event_ids: List of event IDs to retry - - Returns: - Batch result with success/failure counts and details - """ + """Retry multiple DLQ messages in batch.""" details: list[DLQRetryResult] = [] successful = 0 failed = 0 @@ -409,78 +365,24 @@ async def retry_messages_batch(self, event_ids: list[str]) -> DLQBatchRetryResul failed += 1 details.append(DLQRetryResult(event_id=event_id, status="failed", error="Retry failed")) except Exception as e: - self.logger.error(f"Error retrying message {event_id}: {e}") + self._logger.error(f"Error retrying message {event_id}: {e}") failed += 1 details.append(DLQRetryResult(event_id=event_id, status="failed", error=str(e))) return DLQBatchRetryResult(total=len(event_ids), successful=successful, failed=failed, details=details) async def discard_message_manually(self, event_id: str, reason: str) -> bool: - """Manually discard a DLQ message with state validation. - - Args: - event_id: The event ID to discard - reason: Reason for discarding - - Returns: - True if discarded, False if not found or in terminal state - """ + """Manually discard a DLQ message.""" doc = await DLQMessageDocument.find_one({"event.event_id": event_id}) if not doc: - self.logger.error("Message not found in DLQ", extra={"event_id": event_id}) + self._logger.error("Message not found in DLQ", extra={"event_id": event_id}) return False - # Guard against invalid states (terminal states) if doc.status in {DLQMessageStatus.DISCARDED, DLQMessageStatus.RETRIED}: - self.logger.info("Skipping manual discard", extra={"event_id": event_id, "status": doc.status}) + self._logger.info("Skipping manual discard", extra={"event_id": event_id, "status": doc.status}) return False message = DLQMessage.model_validate(doc, from_attributes=True) await self._discard_message(message, reason) return True - -def create_dlq_manager( - settings: Settings, - schema_registry: SchemaRegistryManager, - logger: logging.Logger, - dlq_metrics: DLQMetrics, - dlq_topic: KafkaTopic = KafkaTopic.DEAD_LETTER_QUEUE, - retry_topic_suffix: str = "-retry", - default_retry_policy: RetryPolicy | None = None, -) -> DLQManager: - topic_name = f"{settings.KAFKA_TOPIC_PREFIX}{dlq_topic}" - consumer = AIOKafkaConsumer( - topic_name, - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=GroupId.DLQ_MANAGER, - enable_auto_commit=False, - auto_offset_reset="earliest", - client_id="dlq-manager-consumer", - session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - producer = AIOKafkaProducer( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - client_id="dlq-manager-producer", - acks="all", - compression_type="gzip", - max_batch_size=16384, - linger_ms=10, - enable_idempotence=True, - ) - if default_retry_policy is None: - default_retry_policy = RetryPolicy(topic="default", strategy=RetryStrategy.EXPONENTIAL_BACKOFF) - return DLQManager( - settings=settings, - consumer=consumer, - producer=producer, - schema_registry=schema_registry, - logger=logger, - dlq_metrics=dlq_metrics, - dlq_topic=dlq_topic, - retry_topic_suffix=retry_topic_suffix, - default_retry_policy=default_retry_policy, - ) diff --git a/backend/app/events/core/__init__.py b/backend/app/events/core/__init__.py index 3b12df76..1723502a 100644 --- a/backend/app/events/core/__init__.py +++ b/backend/app/events/core/__init__.py @@ -8,15 +8,11 @@ from .types import ( ConsumerConfig, ConsumerMetrics, - ConsumerState, ProducerMetrics, - ProducerState, ) __all__ = [ # Types - "ProducerState", - "ConsumerState", "ConsumerConfig", "ProducerMetrics", "ConsumerMetrics", diff --git a/backend/app/events/core/consumer.py b/backend/app/events/core/consumer.py index d0532f37..3365af98 100644 --- a/backend/app/events/core/consumer.py +++ b/backend/app/events/core/consumer.py @@ -1,258 +1,69 @@ -import asyncio +"""Unified Kafka consumer - pure message handler. + +Handles deserialization, dispatch, and metrics for Kafka messages. +No lifecycle, no properties, no state - just handle(). +Worker gets AIOKafkaConsumer directly from DI. +""" + +from __future__ import annotations + import logging -from collections.abc import Awaitable, Callable -from datetime import datetime, timezone -from typing import Any -from aiokafka import AIOKafkaConsumer, TopicPartition -from aiokafka.errors import KafkaError +from aiokafka import ConsumerRecord from opentelemetry.trace import SpanKind from app.core.metrics import EventMetrics from app.core.tracing import EventAttributes from app.core.tracing.utils import extract_trace_context, get_tracer -from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent from app.events.schema.schema_registry import SchemaRegistryManager -from app.settings import Settings from .dispatcher import EventDispatcher -from .types import ConsumerConfig, ConsumerMetrics, ConsumerMetricsSnapshot, ConsumerState, ConsumerStatus class UnifiedConsumer: + """Pure message handler - deserialize, dispatch, record metrics.""" + def __init__( self, - config: ConsumerConfig, event_dispatcher: EventDispatcher, schema_registry: SchemaRegistryManager, - settings: Settings, logger: logging.Logger, event_metrics: EventMetrics, - ): - self._config = config - self.logger = logger - self._schema_registry = schema_registry + group_id: str, + ) -> None: self._dispatcher = event_dispatcher - self._consumer: AIOKafkaConsumer | None = None - self._state = ConsumerState.STOPPED - self._running = False - self._metrics = ConsumerMetrics() + self._schema_registry = schema_registry + self._logger = logger self._event_metrics = event_metrics - self._error_callback: "Callable[[Exception, DomainEvent], Awaitable[None]] | None" = None - self._consume_task: asyncio.Task[None] | None = None - self._topic_prefix = settings.KAFKA_TOPIC_PREFIX - - async def start(self, topics: list[KafkaTopic]) -> None: - self._state = self._state if self._state != ConsumerState.STOPPED else ConsumerState.STARTING - - topic_strings = [f"{self._topic_prefix}{str(topic)}" for topic in topics] - - self._consumer = AIOKafkaConsumer( - *topic_strings, - bootstrap_servers=self._config.bootstrap_servers, - group_id=self._config.group_id, - client_id=self._config.client_id, - auto_offset_reset=self._config.auto_offset_reset, - enable_auto_commit=self._config.enable_auto_commit, - session_timeout_ms=self._config.session_timeout_ms, - heartbeat_interval_ms=self._config.heartbeat_interval_ms, - max_poll_interval_ms=self._config.max_poll_interval_ms, - request_timeout_ms=self._config.request_timeout_ms, - fetch_min_bytes=self._config.fetch_min_bytes, - fetch_max_wait_ms=self._config.fetch_max_wait_ms, - ) - - await self._consumer.start() - self._running = True - self._consume_task = asyncio.create_task(self._consume_loop()) - - self._state = ConsumerState.RUNNING - - self.logger.info(f"Consumer started for topics: {topic_strings}") - - async def stop(self) -> None: - self._state = ( - ConsumerState.STOPPING - if self._state not in (ConsumerState.STOPPED, ConsumerState.STOPPING) - else self._state - ) - - self._running = False - - if self._consume_task: - self._consume_task.cancel() - await asyncio.gather(self._consume_task, return_exceptions=True) - self._consume_task = None - - await self._cleanup() - self._state = ConsumerState.STOPPED - - async def _cleanup(self) -> None: - if self._consumer: - await self._consumer.stop() - self._consumer = None - - async def _consume_loop(self) -> None: - self.logger.info(f"Consumer loop started for group {self._config.group_id}") - poll_count = 0 - message_count = 0 - - while self._running and self._consumer: - poll_count += 1 - if poll_count % 100 == 0: # Log every 100 polls - self.logger.debug(f"Consumer loop active: polls={poll_count}, messages={message_count}") - + self._group_id = group_id + + async def handle(self, msg: ConsumerRecord) -> DomainEvent | None: + """Handle a Kafka message - deserialize, dispatch, record metrics.""" + if msg.value is None: + return None + + event = await self._schema_registry.deserialize_event(msg.value, msg.topic) + headers = {k: v.decode() for k, v in msg.headers} + + with get_tracer().start_as_current_span( + name="kafka.consume", + context=extract_trace_context(headers), + kind=SpanKind.CONSUMER, + attributes={ + EventAttributes.KAFKA_TOPIC: msg.topic, + EventAttributes.KAFKA_PARTITION: msg.partition, + EventAttributes.KAFKA_OFFSET: msg.offset, + EventAttributes.EVENT_TYPE: event.event_type, + EventAttributes.EVENT_ID: event.event_id, + }, + ): try: - # Use getone() with timeout for single message consumption - msg = await asyncio.wait_for( - self._consumer.getone(), - timeout=0.1 - ) - - message_count += 1 - self.logger.debug( - f"Message received from topic {msg.topic}, partition {msg.partition}, offset {msg.offset}" - ) - await self._process_message(msg) - if not self._config.enable_auto_commit: - await self._consumer.commit() - - except asyncio.TimeoutError: - # No message available within timeout, continue polling - await asyncio.sleep(0.01) - except KafkaError as e: - self.logger.error(f"Consumer error: {e}") - self._metrics.processing_errors += 1 - - self.logger.warning( - f"Consumer loop ended for group {self._config.group_id}: " - f"running={self._running}, consumer={self._consumer is not None}" - ) - - async def _process_message(self, message: Any) -> None: - """Process a ConsumerRecord from aiokafka.""" - topic = message.topic - if not topic: - self.logger.warning("Message with no topic received") - return - - raw_value = message.value - if not raw_value: - self.logger.warning(f"Empty message from topic {topic}") - return - - self.logger.debug(f"Deserializing message from topic {topic}, size={len(raw_value)} bytes") - event = await self._schema_registry.deserialize_event(raw_value, topic) - self.logger.info(f"Deserialized event: type={event.event_type}, id={event.event_id}") - - # Extract trace context from Kafka headers and start a consumer span - # aiokafka headers are list of tuples: [(key, value), ...] - header_list = message.headers or [] - headers: dict[str, str] = {} - for k, v in header_list: - headers[str(k)] = v.decode("utf-8") if isinstance(v, (bytes, bytearray)) else (v or "") - ctx = extract_trace_context(headers) - tracer = get_tracer() - - # Dispatch event through EventDispatcher - try: - self.logger.debug(f"Dispatching {event.event_type} to handlers") - partition_val = message.partition - offset_val = message.offset - part_attr = partition_val if partition_val is not None else -1 - off_attr = offset_val if offset_val is not None else -1 - with tracer.start_as_current_span( - name="kafka.consume", - context=ctx, - kind=SpanKind.CONSUMER, - attributes={ - EventAttributes.KAFKA_TOPIC: topic, - EventAttributes.KAFKA_PARTITION: part_attr, - EventAttributes.KAFKA_OFFSET: off_attr, - EventAttributes.EVENT_TYPE: event.event_type, - EventAttributes.EVENT_ID: event.event_id, - }, - ): await self._dispatcher.dispatch(event) - self.logger.debug(f"Successfully dispatched {event.event_type}") - # Update metrics on successful dispatch - self._metrics.messages_consumed += 1 - self._metrics.bytes_consumed += len(raw_value) - self._metrics.last_message_time = datetime.now(timezone.utc) - # Record Kafka consumption metrics - self._event_metrics.record_kafka_message_consumed(topic=topic, consumer_group=self._config.group_id) - except Exception as e: - self.logger.error(f"Dispatcher error for event {event.event_type}: {e}") - self._metrics.processing_errors += 1 - # Record Kafka consumption error - self._event_metrics.record_kafka_consumption_error( - topic=topic, consumer_group=self._config.group_id, error_type=type(e).__name__ - ) - if self._error_callback: - await self._error_callback(e, event) - - def register_error_callback(self, callback: Callable[[Exception, DomainEvent], Awaitable[None]]) -> None: - self._error_callback = callback - - @property - def state(self) -> ConsumerState: - return self._state - - @property - def metrics(self) -> ConsumerMetrics: - return self._metrics - - @property - def is_running(self) -> bool: - return self._state == ConsumerState.RUNNING - - @property - def consumer(self) -> AIOKafkaConsumer | None: - return self._consumer - - def get_status(self) -> ConsumerStatus: - return ConsumerStatus( - state=self._state, - is_running=self.is_running, - group_id=self._config.group_id, - client_id=self._config.client_id, - metrics=ConsumerMetricsSnapshot( - messages_consumed=self._metrics.messages_consumed, - bytes_consumed=self._metrics.bytes_consumed, - consumer_lag=self._metrics.consumer_lag, - commit_failures=self._metrics.commit_failures, - processing_errors=self._metrics.processing_errors, - last_message_time=self._metrics.last_message_time, - last_updated=self._metrics.last_updated, - ), - ) - - async def seek_to_beginning(self) -> None: - """Seek all assigned partitions to the beginning.""" - if not self._consumer: - self.logger.warning("Cannot seek: consumer not initialized") - return - - assignment = self._consumer.assignment() - if assignment: - await self._consumer.seek_to_beginning(*assignment) - - async def seek_to_end(self) -> None: - """Seek all assigned partitions to the end.""" - if not self._consumer: - self.logger.warning("Cannot seek: consumer not initialized") - return - - assignment = self._consumer.assignment() - if assignment: - await self._consumer.seek_to_end(*assignment) - - async def seek_to_offset(self, topic: str, partition: int, offset: int) -> None: - """Seek a specific partition to a specific offset.""" - if not self._consumer: - self.logger.warning("Cannot seek to offset: consumer not initialized") - return + self._event_metrics.record_kafka_message_consumed(msg.topic, self._group_id) + except Exception as e: + self._logger.error(f"Dispatch error: {event.event_type}: {e}") + self._event_metrics.record_kafka_consumption_error(msg.topic, self._group_id, type(e).__name__) + raise - tp = TopicPartition(topic, partition) - self._consumer.seek(tp, offset) + return event diff --git a/backend/app/events/core/producer.py b/backend/app/events/core/producer.py index a41188c7..98ab0b74 100644 --- a/backend/app/events/core/producer.py +++ b/backend/app/events/core/producer.py @@ -1,122 +1,60 @@ +"""Unified Kafka producer - thin wrapper over AIOKafkaProducer. + +The producer receives a ready-to-use AIOKafkaProducer from DI. +No lifecycle management - DI provider handles start/stop. +""" + +from __future__ import annotations + import asyncio import json import logging import socket from datetime import datetime, timezone -from typing import Any from aiokafka import AIOKafkaProducer from aiokafka.errors import KafkaError -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics from app.dlq.models import DLQMessage, DLQMessageStatus from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent from app.events.schema.schema_registry import SchemaRegistryManager from app.infrastructure.kafka.mappings import EVENT_TYPE_TO_TOPIC -from app.settings import Settings -from .types import ProducerMetrics, ProducerState +from .types import ProducerMetrics -class UnifiedProducer(LifecycleEnabled): - """Fully async Kafka producer using aiokafka.""" +class UnifiedProducer: + """Kafka producer wrapper - receives ready-to-use producer from DI. + + No lifecycle methods (start/stop) - DI provider manages AIOKafkaProducer lifecycle. + """ def __init__( self, + producer: AIOKafkaProducer, schema_registry_manager: SchemaRegistryManager, logger: logging.Logger, - settings: Settings, event_metrics: EventMetrics, - ): - super().__init__() - self._settings = settings + producer_metrics: ProducerMetrics, + topic_prefix: str = "", + ) -> None: + self._producer = producer self._schema_registry = schema_registry_manager - self.logger = logger - self._producer: AIOKafkaProducer | None = None - self._state = ProducerState.STOPPED - self._metrics = ProducerMetrics() + self._logger = logger self._event_metrics = event_metrics - self._topic_prefix = settings.KAFKA_TOPIC_PREFIX - - @property - def is_running(self) -> bool: - return self._state == ProducerState.RUNNING - - @property - def state(self) -> ProducerState: - return self._state - - @property - def metrics(self) -> ProducerMetrics: - return self._metrics - - @property - def producer(self) -> AIOKafkaProducer | None: - return self._producer - - async def _on_start(self) -> None: - """Start the Kafka producer.""" - self._state = ProducerState.STARTING - self.logger.info("Starting producer...") - - self._producer = AIOKafkaProducer( - bootstrap_servers=self._settings.KAFKA_BOOTSTRAP_SERVERS, - client_id=f"{self._settings.SERVICE_NAME}-producer", - acks="all", - compression_type="gzip", - max_batch_size=16384, - linger_ms=10, - enable_idempotence=True, - ) - - await self._producer.start() - self._state = ProducerState.RUNNING - self.logger.info(f"Producer started: {self._settings.KAFKA_BOOTSTRAP_SERVERS}") - - def get_status(self) -> dict[str, Any]: - return { - "state": self._state, - "running": self.is_running, - "config": { - "bootstrap_servers": self._settings.KAFKA_BOOTSTRAP_SERVERS, - "client_id": f"{self._settings.SERVICE_NAME}-producer", - }, - "metrics": { - "messages_sent": self._metrics.messages_sent, - "messages_failed": self._metrics.messages_failed, - "bytes_sent": self._metrics.bytes_sent, - "queue_size": self._metrics.queue_size, - "avg_latency_ms": self._metrics.avg_latency_ms, - "last_error": self._metrics.last_error, - "last_error_time": self._metrics.last_error_time.isoformat() if self._metrics.last_error_time else None, - }, - } - - async def _on_stop(self) -> None: - """Stop the Kafka producer.""" - self._state = ProducerState.STOPPING - self.logger.info("Stopping producer...") - - if self._producer: - await self._producer.stop() - self._producer = None - - self._state = ProducerState.STOPPED - self.logger.info("Producer stopped") + self._topic_prefix = topic_prefix + self._metrics = producer_metrics async def produce( self, event_to_produce: DomainEvent, key: str | None = None, headers: dict[str, str] | None = None ) -> None: """Produce a message to Kafka.""" - if not self._producer: - self.logger.error("Producer not running") - return + topic = f"{self._topic_prefix}{EVENT_TYPE_TO_TOPIC[event_to_produce.event_type]}" try: serialized_value = await self._schema_registry.serialize_event(event_to_produce) - topic = f"{self._topic_prefix}{EVENT_TYPE_TO_TOPIC[event_to_produce.event_type]}" # Convert headers to list of tuples format header_list = [(k, v.encode()) for k, v in headers.items()] if headers else None @@ -135,24 +73,20 @@ async def produce( # Record Kafka metrics self._event_metrics.record_kafka_message_produced(topic) - self.logger.debug(f"Message [{event_to_produce}] sent to topic: {topic}") + self._logger.debug(f"Message [{event_to_produce}] sent to topic: {topic}") except KafkaError as e: self._metrics.messages_failed += 1 self._metrics.last_error = str(e) self._metrics.last_error_time = datetime.now(timezone.utc) self._event_metrics.record_kafka_production_error(topic=topic, error_type=type(e).__name__) - self.logger.error(f"Failed to produce message: {e}") + self._logger.error(f"Failed to produce message: {e}") raise async def send_to_dlq( self, original_event: DomainEvent, original_topic: str, error: Exception, retry_count: int = 0 ) -> None: """Send a failed event to the Dead Letter Queue.""" - if not self._producer: - self.logger.error("Producer not running, cannot send to DLQ") - return - try: # Get producer ID (hostname + task name) current_task = asyncio.current_task() @@ -202,7 +136,7 @@ async def send_to_dlq( self._event_metrics.record_kafka_message_produced(dlq_topic) self._metrics.messages_sent += 1 - self.logger.warning( + self._logger.warning( f"Event {original_event.event_id} sent to DLQ. " f"Original topic: {original_topic}, Error: {error}, " f"Retry count: {retry_count}" @@ -210,7 +144,7 @@ async def send_to_dlq( except Exception as e: # If we can't send to DLQ, log critically but don't crash - self.logger.critical( + self._logger.critical( f"Failed to send event {original_event.event_id} to DLQ: {e}. Original error: {error}", exc_info=True ) self._metrics.messages_failed += 1 diff --git a/backend/app/events/event_store_consumer.py b/backend/app/events/event_store_consumer.py deleted file mode 100644 index 1dbdb83c..00000000 --- a/backend/app/events/event_store_consumer.py +++ /dev/null @@ -1,190 +0,0 @@ -import asyncio -import logging - -from opentelemetry.trace import SpanKind - -from app.core.lifecycle import LifecycleEnabled -from app.core.metrics import EventMetrics -from app.core.tracing.utils import trace_span -from app.domain.enums.events import EventType -from app.domain.enums.kafka import GroupId, KafkaTopic -from app.domain.events.typed import DomainEvent -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer, create_dlq_error_handler -from app.events.event_store import EventStore -from app.events.schema.schema_registry import SchemaRegistryManager -from app.settings import Settings - - -class EventStoreConsumer(LifecycleEnabled): - """Consumes events from Kafka and stores them in MongoDB.""" - - def __init__( - self, - event_store: EventStore, - topics: list[KafkaTopic], - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - logger: logging.Logger, - event_metrics: EventMetrics, - producer: UnifiedProducer | None = None, - group_id: GroupId = GroupId.EVENT_STORE_CONSUMER, - batch_size: int = 100, - batch_timeout_seconds: float = 5.0, - ): - super().__init__() - self.event_store = event_store - self.topics = topics - self.settings = settings - self.group_id = group_id - self.batch_size = batch_size - self.batch_timeout = batch_timeout_seconds - self.logger = logger - self.event_metrics = event_metrics - self.consumer: UnifiedConsumer | None = None - self.schema_registry_manager = schema_registry_manager - self.dispatcher = EventDispatcher(logger) - self.producer = producer # For DLQ handling - self._batch_buffer: list[DomainEvent] = [] - self._batch_lock = asyncio.Lock() - self._last_batch_time: float = 0.0 - self._batch_task: asyncio.Task[None] | None = None - - async def _on_start(self) -> None: - """Start consuming and storing events.""" - self._last_batch_time = asyncio.get_running_loop().time() - config = ConsumerConfig( - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=self.group_id, - enable_auto_commit=False, - max_poll_records=self.batch_size, - session_timeout_ms=self.settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self.settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self.settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self.settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - self.consumer = UnifiedConsumer( - config, - event_dispatcher=self.dispatcher, - schema_registry=self.schema_registry_manager, - settings=self.settings, - logger=self.logger, - event_metrics=self.event_metrics, - ) - - # Register handler for all event types - store everything - for event_type in EventType: - self.dispatcher.register(event_type)(self._handle_event) - - # Register error callback - use DLQ if producer is available - if self.producer: - # Use DLQ handler with retry logic - dlq_handler = create_dlq_error_handler( - producer=self.producer, - original_topic="event-store", # Generic topic name for event store - logger=self.logger, - max_retries=3, - ) - self.consumer.register_error_callback(dlq_handler) - else: - # Fallback to simple logging - self.consumer.register_error_callback(self._handle_error_with_event) - - await self.consumer.start(self.topics) - - self._batch_task = asyncio.create_task(self._batch_processor()) - - self.logger.info(f"Event store consumer started for topics: {self.topics}") - - async def _on_stop(self) -> None: - """Stop consumer.""" - await self._flush_batch() - - if self._batch_task: - self._batch_task.cancel() - try: - await self._batch_task - except asyncio.CancelledError: - pass - - if self.consumer: - await self.consumer.stop() - - self.logger.info("Event store consumer stopped") - - async def _handle_event(self, event: DomainEvent) -> None: - """Handle incoming event from dispatcher.""" - self.logger.info(f"Event store received event: {event.event_type} - {event.event_id}") - - async with self._batch_lock: - self._batch_buffer.append(event) - - if len(self._batch_buffer) >= self.batch_size: - await self._flush_batch() - - async def _handle_error_with_event(self, error: Exception, event: DomainEvent) -> None: - """Handle processing errors with event context.""" - self.logger.error(f"Error processing event {event.event_id} ({event.event_type}): {error}", exc_info=True) - - async def _batch_processor(self) -> None: - """Periodically flush batches based on timeout.""" - while self.is_running: - try: - await asyncio.sleep(1) - - async with self._batch_lock: - time_since_last_batch = asyncio.get_running_loop().time() - self._last_batch_time - - if self._batch_buffer and time_since_last_batch >= self.batch_timeout: - await self._flush_batch() - - except Exception as e: - self.logger.error(f"Error in batch processor: {e}") - - async def _flush_batch(self) -> None: - if not self._batch_buffer: - return - - batch = self._batch_buffer.copy() - self._batch_buffer.clear() - self._last_batch_time = asyncio.get_running_loop().time() - - self.logger.info(f"Event store flushing batch of {len(batch)} events") - with trace_span( - name="event_store.flush_batch", - kind=SpanKind.CONSUMER, - attributes={"events.batch.count": len(batch)}, - ): - results = await self.event_store.store_batch(batch) - - self.logger.info( - f"Stored event batch: total={results['total']}, " - f"stored={results['stored']}, duplicates={results['duplicates']}, " - f"failed={results['failed']}" - ) - - -def create_event_store_consumer( - event_store: EventStore, - topics: list[KafkaTopic], - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - logger: logging.Logger, - event_metrics: EventMetrics, - producer: UnifiedProducer | None = None, - group_id: GroupId = GroupId.EVENT_STORE_CONSUMER, - batch_size: int = 100, - batch_timeout_seconds: float = 5.0, -) -> EventStoreConsumer: - return EventStoreConsumer( - event_store=event_store, - topics=topics, - group_id=group_id, - batch_size=batch_size, - batch_timeout_seconds=batch_timeout_seconds, - schema_registry_manager=schema_registry_manager, - settings=settings, - logger=logger, - event_metrics=event_metrics, - producer=producer, - ) diff --git a/backend/app/services/coordinator/__init__.py b/backend/app/services/coordinator/__init__.py index b3890c9d..2c79d4a3 100644 --- a/backend/app/services/coordinator/__init__.py +++ b/backend/app/services/coordinator/__init__.py @@ -1,11 +1,5 @@ from app.services.coordinator.coordinator import ExecutionCoordinator -from app.services.coordinator.queue_manager import QueueManager, QueuePriority -from app.services.coordinator.resource_manager import ResourceAllocation, ResourceManager __all__ = [ "ExecutionCoordinator", - "QueueManager", - "QueuePriority", - "ResourceManager", - "ResourceAllocation", ] diff --git a/backend/app/services/coordinator/coordinator.py b/backend/app/services/coordinator/coordinator.py index e9e0591b..b8c9e6e5 100644 --- a/backend/app/services/coordinator/coordinator.py +++ b/backend/app/services/coordinator/coordinator.py @@ -1,15 +1,24 @@ -import asyncio +"""Execution Coordinator - stateless event handler. + +Coordinates execution scheduling across the system. Receives events, +processes them, and publishes results. No lifecycle management. +All state is stored in Redis repositories. +""" + +from __future__ import annotations + import logging import time -from collections.abc import Coroutine -from typing import Any, TypeAlias from uuid import uuid4 -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import CoordinatorMetrics, EventMetrics +from app.db.repositories import ( + ExecutionQueueRepository, + ExecutionStateRepository, + QueuePriority, + ResourceRepository, +) from app.db.repositories.execution_repository import ExecutionRepository -from app.domain.enums.events import EventType -from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId from app.domain.enums.storage import ExecutionErrorType from app.domain.events.typed import ( CreatePodCommandEvent, @@ -20,392 +29,219 @@ ExecutionFailedEvent, ExecutionRequestedEvent, ) -from app.domain.idempotency import KeyStrategy -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import ( - SchemaRegistryManager, -) -from app.services.coordinator.queue_manager import QueueManager, QueuePriority -from app.services.coordinator.resource_manager import ResourceAllocation, ResourceManager -from app.services.idempotency import IdempotencyManager -from app.services.idempotency.middleware import IdempotentConsumerWrapper -from app.settings import Settings +from app.events.core import UnifiedProducer -EventHandler: TypeAlias = Coroutine[Any, Any, None] -ExecutionMap: TypeAlias = dict[str, ResourceAllocation] +class ExecutionCoordinator: + """Stateless execution coordinator - pure event handler. -class ExecutionCoordinator(LifecycleEnabled): - """ - Coordinates execution scheduling across the system. - - This service: - 1. Consumes ExecutionRequested events - 2. Manages execution queue with priority - 3. Enforces rate limits - 4. Allocates resources - 5. Publishes ExecutionStarted events for workers + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + All state (active executions, queue, resources) stored in Redis. """ def __init__( self, producer: UnifiedProducer, - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, execution_repository: ExecutionRepository, - idempotency_manager: IdempotencyManager, + state_repo: ExecutionStateRepository, + queue_repo: ExecutionQueueRepository, + resource_repo: ResourceRepository, logger: logging.Logger, coordinator_metrics: CoordinatorMetrics, event_metrics: EventMetrics, - consumer_group: str = GroupId.EXECUTION_COORDINATOR, - max_concurrent_scheduling: int = 10, - scheduling_interval_seconds: float = 0.5, - ): - super().__init__() - self.logger = logger - self.metrics = coordinator_metrics + ) -> None: + self._producer = producer + self._execution_repository = execution_repository + self._state_repo = state_repo + self._queue_repo = queue_repo + self._resource_repo = resource_repo + self._logger = logger + self._metrics = coordinator_metrics self._event_metrics = event_metrics - self._settings = settings - - # Kafka configuration - self.kafka_servers = self._settings.KAFKA_BOOTSTRAP_SERVERS - self.consumer_group = consumer_group - - # Components - self.queue_manager = QueueManager( - logger=self.logger, - coordinator_metrics=coordinator_metrics, - max_queue_size=10000, - max_executions_per_user=100, - stale_timeout_seconds=3600, - ) - - self.resource_manager = ResourceManager( - logger=self.logger, - coordinator_metrics=coordinator_metrics, - total_cpu_cores=32.0, - total_memory_mb=65536, - total_gpu_count=0, - ) - - # Kafka components - self.consumer: UnifiedConsumer | None = None - self.idempotent_consumer: IdempotentConsumerWrapper | None = None - self.producer: UnifiedProducer = producer - - # Persistence via repositories - self.execution_repository = execution_repository - self.idempotency_manager = idempotency_manager - self._event_store = event_store - - # Scheduling - self.max_concurrent_scheduling = max_concurrent_scheduling - self.scheduling_interval = scheduling_interval_seconds - self._scheduling_semaphore = asyncio.Semaphore(max_concurrent_scheduling) - - # State tracking - self._scheduling_task: asyncio.Task[None] | None = None - self._active_executions: set[str] = set() - self._execution_resources: ExecutionMap = {} - self._schema_registry_manager = schema_registry_manager - self.dispatcher = EventDispatcher(logger=self.logger) - - async def _on_start(self) -> None: - """Start the coordinator service.""" - self.logger.info("Starting ExecutionCoordinator service...") - - await self.queue_manager.start() - - await self.idempotency_manager.initialize() - - consumer_config = ConsumerConfig( - bootstrap_servers=self.kafka_servers, - group_id=self.consumer_group, - enable_auto_commit=False, - session_timeout_ms=self._settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self._settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self._settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self._settings.KAFKA_REQUEST_TIMEOUT_MS, - max_poll_records=100, # Process max 100 messages at a time for flow control - fetch_max_wait_ms=500, # Wait max 500ms for data (reduces latency) - fetch_min_bytes=1, # Return immediately if any data available - ) - - self.consumer = UnifiedConsumer( - consumer_config, - event_dispatcher=self.dispatcher, - schema_registry=self._schema_registry_manager, - settings=self._settings, - logger=self.logger, - event_metrics=self._event_metrics, - ) - - # Register handlers with EventDispatcher BEFORE wrapping with idempotency - @self.dispatcher.register(EventType.EXECUTION_REQUESTED) - async def handle_requested(event: ExecutionRequestedEvent) -> None: - await self._route_execution_event(event) - - @self.dispatcher.register(EventType.EXECUTION_COMPLETED) - async def handle_completed(event: ExecutionCompletedEvent) -> None: - await self._route_execution_result(event) - - @self.dispatcher.register(EventType.EXECUTION_FAILED) - async def handle_failed(event: ExecutionFailedEvent) -> None: - await self._route_execution_result(event) - - @self.dispatcher.register(EventType.EXECUTION_CANCELLED) - async def handle_cancelled(event: ExecutionCancelledEvent) -> None: - await self._route_execution_event(event) - - self.idempotent_consumer = IdempotentConsumerWrapper( - consumer=self.consumer, - idempotency_manager=self.idempotency_manager, - dispatcher=self.dispatcher, - logger=self.logger, - default_key_strategy=KeyStrategy.EVENT_BASED, # Use event ID for deduplication - default_ttl_seconds=7200, # 2 hours TTL for coordinator events - enable_for_all_handlers=True, # Enable idempotency for ALL handlers - ) - self.logger.info("COORDINATOR: Event handlers registered with idempotency protection") - - await self.idempotent_consumer.start(list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.EXECUTION_COORDINATOR])) - - # Start scheduling task - self._scheduling_task = asyncio.create_task(self._scheduling_loop()) - - self.logger.info("ExecutionCoordinator service started successfully") - - async def _on_stop(self) -> None: - """Stop the coordinator service.""" - self.logger.info("Stopping ExecutionCoordinator service...") - - # Stop scheduling task - if self._scheduling_task: - self._scheduling_task.cancel() - try: - await self._scheduling_task - except asyncio.CancelledError: - pass - - # Stop consumer (idempotent wrapper only) - if self.idempotent_consumer: - await self.idempotent_consumer.stop() - - await self.queue_manager.stop() - - # Close idempotency manager - if hasattr(self, "idempotency_manager") and self.idempotency_manager: - await self.idempotency_manager.close() - - self.logger.info(f"ExecutionCoordinator service stopped. Active executions: {len(self._active_executions)}") - - async def _route_execution_event(self, event: ExecutionRequestedEvent | ExecutionCancelledEvent) -> None: - """Route execution events to appropriate handlers based on event type""" - self.logger.info( - f"COORDINATOR: Routing execution event - type: {event.event_type}, " - f"id: {event.event_id}, " - f"actual class: {type(event).__name__}" - ) - - if event.event_type == EventType.EXECUTION_REQUESTED: - await self._handle_execution_requested(event) - elif event.event_type == EventType.EXECUTION_CANCELLED: - await self._handle_execution_cancelled(event) - else: - self.logger.debug(f"Ignoring execution event type: {event.event_type}") - - async def _route_execution_result(self, event: ExecutionCompletedEvent | ExecutionFailedEvent) -> None: - """Route execution result events to appropriate handlers based on event type""" - if event.event_type == EventType.EXECUTION_COMPLETED: - await self._handle_execution_completed(event) - elif event.event_type == EventType.EXECUTION_FAILED: - await self._handle_execution_failed(event) - else: - self.logger.debug(f"Ignoring execution result event type: {event.event_type}") - - async def _handle_execution_requested(self, event: ExecutionRequestedEvent) -> None: - """Handle execution requested event - add to queue for processing""" - self.logger.info(f"HANDLER CALLED: _handle_execution_requested for event {event.event_id}") + async def handle_execution_requested(self, event: ExecutionRequestedEvent) -> None: + """Handle execution requested event - add to queue and try to schedule.""" + self._logger.info(f"Handling ExecutionRequestedEvent: {event.execution_id}") start_time = time.time() try: - # Add to queue with priority - success, position, error = await self.queue_manager.add_execution( - event, - priority=QueuePriority(event.priority), + priority = QueuePriority(event.priority) + user_id = event.metadata.user_id or "anonymous" + + # Add to Redis queue + success, position, error = await self._queue_repo.enqueue( + execution_id=event.execution_id, + event_data=event.model_dump(mode="json"), + priority=priority, + user_id=user_id, ) if not success: - # Publish queue full event await self._publish_queue_full(event, error or "Queue is full") - self.metrics.record_coordinator_execution_scheduled("queue_full") + self._metrics.record_coordinator_execution_scheduled("queue_full") return # Publish ExecutionAcceptedEvent - if position is None: - position = 0 - await self._publish_execution_accepted(event, position, event.priority) + await self._publish_execution_accepted(event, position or 0, event.priority) # Track metrics duration = time.time() - start_time - self.metrics.record_coordinator_scheduling_duration(duration) - self.metrics.record_coordinator_execution_scheduled("queued") + self._metrics.record_coordinator_scheduling_duration(duration) + self._metrics.record_coordinator_execution_scheduled("queued") - self.logger.info(f"Execution {event.execution_id} added to queue at position {position}") + self._logger.info(f"Execution {event.execution_id} added to queue at position {position}") - # Schedule immediately if at front of queue (position 0) + # If at front of queue (position 0), try to schedule immediately if position == 0: - await self._schedule_execution(event) + await self._try_schedule_next() except Exception as e: - self.logger.error(f"Failed to handle execution request {event.execution_id}: {e}", exc_info=True) - self.metrics.record_coordinator_execution_scheduled("error") + self._logger.error(f"Failed to handle execution request {event.execution_id}: {e}", exc_info=True) + self._metrics.record_coordinator_execution_scheduled("error") - async def _handle_execution_cancelled(self, event: ExecutionCancelledEvent) -> None: - """Handle execution cancelled event""" + async def handle_execution_completed(self, event: ExecutionCompletedEvent) -> None: + """Handle execution completed - release resources and try to schedule next.""" execution_id = event.execution_id + self._logger.info(f"Handling ExecutionCompletedEvent: {execution_id}") + + # Release resources + await self._resource_repo.release(execution_id) - removed = await self.queue_manager.remove_execution(execution_id) + # Remove from active state + await self._state_repo.remove(execution_id) - if execution_id in self._execution_resources: - await self.resource_manager.release_allocation(execution_id) - del self._execution_resources[execution_id] + # Update metrics + count = await self._state_repo.get_active_count() + self._metrics.update_coordinator_active_executions(count) - self._active_executions.discard(execution_id) - self.metrics.update_coordinator_active_executions(len(self._active_executions)) + self._logger.info(f"Execution {execution_id} completed, resources released") - if removed: - self.logger.info(f"Execution {execution_id} cancelled and removed from queue") + # Try to schedule next execution from queue + await self._try_schedule_next() - async def _handle_execution_completed(self, event: ExecutionCompletedEvent) -> None: - """Handle execution completed event""" + async def handle_execution_failed(self, event: ExecutionFailedEvent) -> None: + """Handle execution failed - release resources and try to schedule next.""" execution_id = event.execution_id + self._logger.info(f"Handling ExecutionFailedEvent: {execution_id}") - if execution_id in self._execution_resources: - await self.resource_manager.release_allocation(execution_id) - del self._execution_resources[execution_id] + # Release resources + await self._resource_repo.release(execution_id) + + # Remove from active state + await self._state_repo.remove(execution_id) - # Remove from active set - self._active_executions.discard(execution_id) - self.metrics.update_coordinator_active_executions(len(self._active_executions)) + # Update metrics + count = await self._state_repo.get_active_count() + self._metrics.update_coordinator_active_executions(count) - self.logger.info(f"Execution {execution_id} completed, resources released") + # Try to schedule next execution from queue + await self._try_schedule_next() - async def _handle_execution_failed(self, event: ExecutionFailedEvent) -> None: - """Handle execution failed event""" + async def handle_execution_cancelled(self, event: ExecutionCancelledEvent) -> None: + """Handle execution cancelled - remove from queue and release resources.""" execution_id = event.execution_id + self._logger.info(f"Handling ExecutionCancelledEvent: {execution_id}") - # Release resources - if execution_id in self._execution_resources: - await self.resource_manager.release_allocation(execution_id) - del self._execution_resources[execution_id] - - # Remove from active set - self._active_executions.discard(execution_id) - self.metrics.update_coordinator_active_executions(len(self._active_executions)) - - async def _scheduling_loop(self) -> None: - """Main scheduling loop""" - while self.is_running: - try: - # Get next execution from queue - execution = await self.queue_manager.get_next_execution() - - if execution: - # Schedule execution - asyncio.create_task(self._schedule_execution(execution)) - else: - # No executions in queue, wait - await asyncio.sleep(self.scheduling_interval) - - except Exception as e: - self.logger.error(f"Error in scheduling loop: {e}", exc_info=True) - await asyncio.sleep(5) # Wait before retrying + # Remove from queue if present + await self._queue_repo.remove(execution_id) + + # Release resources if allocated + await self._resource_repo.release(execution_id) + + # Remove from active state + await self._state_repo.remove(execution_id) + + # Update metrics + count = await self._state_repo.get_active_count() + self._metrics.update_coordinator_active_executions(count) + + async def _try_schedule_next(self) -> None: + """Try to schedule the next execution from the queue.""" + result = await self._queue_repo.dequeue() + if not result: + return + + execution_id, event_data = result + + # Reconstruct event from stored data + try: + event = ExecutionRequestedEvent.model_validate(event_data) + await self._schedule_execution(event) + except Exception as e: + self._logger.error(f"Failed to schedule execution {execution_id}: {e}", exc_info=True) async def _schedule_execution(self, event: ExecutionRequestedEvent) -> None: - """Schedule a single execution""" - async with self._scheduling_semaphore: - start_time = time.time() - execution_id = event.execution_id - - # Atomic check-and-claim: no await between check and add prevents TOCTOU race - # when both eager scheduling (position=0) and _scheduling_loop try to schedule - if execution_id in self._active_executions: - self.logger.debug(f"Execution {execution_id} already claimed, skipping") - return - self._active_executions.add(execution_id) - - try: - # Request resource allocation - allocation = await self.resource_manager.request_allocation( - execution_id, - event.language, - requested_cpu=None, # Use defaults for now - requested_memory_mb=None, - requested_gpu=0, - ) + """Schedule a single execution - allocate resources and publish command.""" + start_time = time.time() + execution_id = event.execution_id + + # Try to claim this execution atomically + claimed = await self._state_repo.try_claim(execution_id) + if not claimed: + self._logger.debug(f"Execution {execution_id} already claimed, skipping") + return - if not allocation: - # No resources available, release claim and requeue - self._active_executions.discard(execution_id) - await self.queue_manager.requeue_execution(event, increment_retry=False) - self.logger.info(f"No resources available for {execution_id}, requeued") - return - - # Track allocation (already in _active_executions from claim above) - self._execution_resources[execution_id] = allocation - self.metrics.update_coordinator_active_executions(len(self._active_executions)) - - # Publish execution started event for workers - self.logger.info(f"About to publish ExecutionStartedEvent for {event.execution_id}") - try: - await self._publish_execution_started(event) - self.logger.info(f"Successfully published ExecutionStartedEvent for {event.execution_id}") - except Exception as publish_error: - self.logger.error( - f"Failed to publish ExecutionStartedEvent for {event.execution_id}: {publish_error}", - exc_info=True, - ) - raise - - # Track metrics - queue_time = start_time - event.timestamp.timestamp() - priority = getattr(event, "priority", QueuePriority.NORMAL.value) - self.metrics.record_coordinator_queue_time(queue_time, QueuePriority(priority).name) - - scheduling_duration = time.time() - start_time - self.metrics.record_coordinator_scheduling_duration(scheduling_duration) - self.metrics.record_coordinator_execution_scheduled("scheduled") - - self.logger.info( - f"Scheduled execution {event.execution_id}. " - f"Queue time: {queue_time:.2f}s, " - f"Resources: {allocation.cpu_cores} CPU, " - f"{allocation.memory_mb}MB RAM" + try: + # Allocate resources + allocation = await self._resource_repo.allocate( + execution_id=execution_id, + language=event.language, + requested_cpu=None, + requested_memory_mb=None, + requested_gpu=0, + ) + + if not allocation: + # No resources available, release claim and requeue + await self._state_repo.remove(execution_id) + await self._queue_repo.enqueue( + execution_id=event.execution_id, + event_data=event.model_dump(mode="json"), + priority=QueuePriority(event.priority), + user_id=event.metadata.user_id or "anonymous", ) + self._logger.info(f"No resources available for {execution_id}, requeued") + return + + # Update metrics + count = await self._state_repo.get_active_count() + self._metrics.update_coordinator_active_executions(count) - except Exception as e: - self.logger.error(f"Failed to schedule execution {event.execution_id}: {e}", exc_info=True) + # Publish CreatePodCommand + await self._publish_execution_started(event) - # Release any allocated resources - if event.execution_id in self._execution_resources: - await self.resource_manager.release_allocation(event.execution_id) - del self._execution_resources[event.execution_id] + # Track metrics + queue_time = start_time - event.timestamp.timestamp() + priority = QueuePriority(event.priority) + self._metrics.record_coordinator_queue_time(queue_time, priority.name) + + scheduling_duration = time.time() - start_time + self._metrics.record_coordinator_scheduling_duration(scheduling_duration) + self._metrics.record_coordinator_execution_scheduled("scheduled") + + self._logger.info( + f"Scheduled execution {event.execution_id}. " + f"Queue time: {queue_time:.2f}s, " + f"Resources: {allocation.cpu_cores} CPU, {allocation.memory_mb}MB RAM" + ) + + except Exception as e: + self._logger.error(f"Failed to schedule execution {event.execution_id}: {e}", exc_info=True) - self._active_executions.discard(event.execution_id) - self.metrics.update_coordinator_active_executions(len(self._active_executions)) - self.metrics.record_coordinator_execution_scheduled("error") + # Release resources and claim + await self._resource_repo.release(execution_id) + await self._state_repo.remove(execution_id) - # Publish failure event - await self._publish_scheduling_failed(event, str(e)) + count = await self._state_repo.get_active_count() + self._metrics.update_coordinator_active_executions(count) + self._metrics.record_coordinator_execution_scheduled("error") + + # Publish failure event + await self._publish_scheduling_failed(event, str(e)) async def _build_command_metadata(self, request: ExecutionRequestedEvent) -> EventMetadata: """Build metadata for CreatePodCommandEvent with guaranteed user_id.""" - # Prefer execution record user_id to avoid missing attribution - exec_rec = await self.execution_repository.get_execution(request.execution_id) + exec_rec = await self._execution_repository.get_execution(request.execution_id) user_id: str = exec_rec.user_id if exec_rec and exec_rec.user_id else "system" return EventMetadata( @@ -416,7 +252,7 @@ async def _build_command_metadata(self, request: ExecutionRequestedEvent) -> Eve ) async def _publish_execution_started(self, request: ExecutionRequestedEvent) -> None: - """Send CreatePodCommandEvent to k8s-worker via SAGA_COMMANDS topic""" + """Send CreatePodCommandEvent to k8s-worker via SAGA_COMMANDS topic.""" metadata = await self._build_command_metadata(request) create_pod_cmd = CreatePodCommandEvent( @@ -437,64 +273,54 @@ async def _publish_execution_started(self, request: ExecutionRequestedEvent) -> metadata=metadata, ) - await self.producer.produce(event_to_produce=create_pod_cmd, key=request.execution_id) - - async def _publish_execution_accepted(self, request: ExecutionRequestedEvent, position: int, priority: int) -> None: - """Publish execution accepted event to notify that request was valid and queued""" - self.logger.info(f"Publishing ExecutionAcceptedEvent for execution {request.execution_id}") + await self._producer.produce(event_to_produce=create_pod_cmd, key=request.execution_id) + self._logger.info(f"Published CreatePodCommandEvent for {request.execution_id}") + async def _publish_execution_accepted( + self, request: ExecutionRequestedEvent, position: int, priority: int + ) -> None: + """Publish execution accepted event.""" event = ExecutionAcceptedEvent( execution_id=request.execution_id, queue_position=position, - estimated_wait_seconds=None, # Could calculate based on queue analysis + estimated_wait_seconds=None, priority=priority, metadata=request.metadata, ) - await self.producer.produce(event_to_produce=event) - self.logger.info(f"ExecutionAcceptedEvent published for {request.execution_id}") + await self._producer.produce(event_to_produce=event) + self._logger.info(f"ExecutionAcceptedEvent published for {request.execution_id}") async def _publish_queue_full(self, request: ExecutionRequestedEvent, error: str) -> None: - """Publish queue full event""" - # Get queue stats for context - queue_stats = await self.queue_manager.get_queue_stats() + """Publish queue full event.""" + queue_stats = await self._queue_repo.get_stats() event = ExecutionFailedEvent( execution_id=request.execution_id, error_type=ExecutionErrorType.RESOURCE_LIMIT, exit_code=-1, - stderr=f"Queue full: {error}. Queue size: {queue_stats.get('total_size', 'unknown')}", + stderr=f"Queue full: {error}. Queue size: {queue_stats.total_size}", resource_usage=None, metadata=request.metadata, error_message=error, ) - await self.producer.produce(event_to_produce=event, key=request.execution_id) + await self._producer.produce(event_to_produce=event, key=request.execution_id) async def _publish_scheduling_failed(self, request: ExecutionRequestedEvent, error: str) -> None: - """Publish scheduling failed event""" - # Get resource stats for context - resource_stats = await self.resource_manager.get_resource_stats() + """Publish scheduling failed event.""" + resource_stats = await self._resource_repo.get_stats() event = ExecutionFailedEvent( execution_id=request.execution_id, error_type=ExecutionErrorType.SYSTEM_ERROR, exit_code=-1, stderr=f"Failed to schedule execution: {error}. " - f"Available resources: CPU={resource_stats.available.cpu_cores}, " - f"Memory={resource_stats.available.memory_mb}MB", + f"Available resources: CPU={resource_stats.available_cpu}, " + f"Memory={resource_stats.available_memory_mb}MB", resource_usage=None, metadata=request.metadata, error_message=error, ) - await self.producer.produce(event_to_produce=event, key=request.execution_id) - - async def get_status(self) -> dict[str, Any]: - """Get coordinator status""" - return { - "running": self.is_running, - "active_executions": len(self._active_executions), - "queue_stats": await self.queue_manager.get_queue_stats(), - "resource_stats": await self.resource_manager.get_resource_stats(), - } + await self._producer.produce(event_to_produce=event, key=request.execution_id) diff --git a/backend/app/services/coordinator/queue_manager.py b/backend/app/services/coordinator/queue_manager.py deleted file mode 100644 index 8dab2643..00000000 --- a/backend/app/services/coordinator/queue_manager.py +++ /dev/null @@ -1,271 +0,0 @@ -import asyncio -import heapq -import logging -import time -from collections import defaultdict -from dataclasses import dataclass, field -from enum import IntEnum -from typing import Any - -from app.core.metrics import CoordinatorMetrics -from app.domain.events.typed import ExecutionRequestedEvent - - -class QueuePriority(IntEnum): - CRITICAL = 0 - HIGH = 1 - NORMAL = 5 - LOW = 8 - BACKGROUND = 10 - - -@dataclass(order=True) -class QueuedExecution: - priority: int - timestamp: float = field(compare=False) - event: ExecutionRequestedEvent = field(compare=False) - retry_count: int = field(default=0, compare=False) - - @property - def execution_id(self) -> str: - return self.event.execution_id - - @property - def user_id(self) -> str: - return self.event.metadata.user_id or "anonymous" - - @property - def age_seconds(self) -> float: - return time.time() - self.timestamp - - -class QueueManager: - def __init__( - self, - logger: logging.Logger, - coordinator_metrics: CoordinatorMetrics, - max_queue_size: int = 10000, - max_executions_per_user: int = 100, - stale_timeout_seconds: int = 3600, - ) -> None: - self.logger = logger - self.metrics = coordinator_metrics - self.max_queue_size = max_queue_size - self.max_executions_per_user = max_executions_per_user - self.stale_timeout_seconds = stale_timeout_seconds - - self._queue: list[QueuedExecution] = [] - self._queue_lock = asyncio.Lock() - self._user_execution_count: dict[str, int] = defaultdict(int) - self._execution_users: dict[str, str] = {} - self._cleanup_task: asyncio.Task[None] | None = None - self._running = False - - async def start(self) -> None: - if self._running: - return - - self._running = True - self._cleanup_task = asyncio.create_task(self._cleanup_stale_executions()) - self.logger.info("Queue manager started") - - async def stop(self) -> None: - if not self._running: - return - - self._running = False - - if self._cleanup_task: - self._cleanup_task.cancel() - try: - await self._cleanup_task - except asyncio.CancelledError: - pass - - self.logger.info(f"Queue manager stopped. Final queue size: {len(self._queue)}") - - async def add_execution( - self, event: ExecutionRequestedEvent, priority: QueuePriority | None = None - ) -> tuple[bool, int | None, str | None]: - async with self._queue_lock: - if len(self._queue) >= self.max_queue_size: - return False, None, "Queue is full" - - user_id = event.metadata.user_id or "anonymous" - - if self._user_execution_count[user_id] >= self.max_executions_per_user: - return False, None, f"User execution limit exceeded ({self.max_executions_per_user})" - - if priority is None: - priority = QueuePriority(event.priority) - - queued = QueuedExecution(priority=priority.value, timestamp=time.time(), event=event) - - heapq.heappush(self._queue, queued) - self._track_execution(event.execution_id, user_id) - position = self._get_queue_position(event.execution_id) - - # Update single authoritative metric for execution request queue depth - self.metrics.update_execution_request_queue_size(len(self._queue)) - - self.logger.info( - f"Added execution {event.execution_id} to queue. " - f"Priority: {priority.name}, Position: {position}, " - f"Queue size: {len(self._queue)}" - ) - - return True, position, None - - async def get_next_execution(self) -> ExecutionRequestedEvent | None: - async with self._queue_lock: - while self._queue: - queued = heapq.heappop(self._queue) - - if self._is_stale(queued): - self._untrack_execution(queued.execution_id) - self._record_removal("stale") - continue - - self._untrack_execution(queued.execution_id) - self._record_wait_time(queued) - # Update metric after removal from the queue - self.metrics.update_execution_request_queue_size(len(self._queue)) - - self.logger.info( - f"Retrieved execution {queued.execution_id} from queue. " - f"Wait time: {queued.age_seconds:.2f}s, Queue size: {len(self._queue)}" - ) - - return queued.event - - return None - - async def remove_execution(self, execution_id: str) -> bool: - async with self._queue_lock: - initial_size = len(self._queue) - self._queue = [q for q in self._queue if q.execution_id != execution_id] - - if len(self._queue) < initial_size: - heapq.heapify(self._queue) - self._untrack_execution(execution_id) - # Update metric after explicit removal - self.metrics.update_execution_request_queue_size(len(self._queue)) - self.logger.info(f"Removed execution {execution_id} from queue") - return True - - return False - - async def get_queue_position(self, execution_id: str) -> int | None: - async with self._queue_lock: - return self._get_queue_position(execution_id) - - async def get_queue_stats(self) -> dict[str, Any]: - async with self._queue_lock: - priority_counts: dict[str, int] = defaultdict(int) - user_counts: dict[str, int] = defaultdict(int) - - for queued in self._queue: - priority_name = QueuePriority(queued.priority).name - priority_counts[priority_name] += 1 - user_counts[queued.user_id] += 1 - - top_users = dict(sorted(user_counts.items(), key=lambda x: x[1], reverse=True)[:10]) - - return { - "total_size": len(self._queue), - "priority_distribution": dict(priority_counts), - "top_users": top_users, - "max_queue_size": self.max_queue_size, - "utilization_percent": (len(self._queue) / self.max_queue_size) * 100, - } - - async def requeue_execution( - self, event: ExecutionRequestedEvent, increment_retry: bool = True - ) -> tuple[bool, int | None, str | None]: - def _next_lower(p: QueuePriority) -> QueuePriority: - order = [ - QueuePriority.CRITICAL, - QueuePriority.HIGH, - QueuePriority.NORMAL, - QueuePriority.LOW, - QueuePriority.BACKGROUND, - ] - try: - idx = order.index(p) - except ValueError: - # Fallback: treat unknown numeric as NORMAL - idx = order.index(QueuePriority.NORMAL) - return order[min(idx + 1, len(order) - 1)] - - if increment_retry: - original_priority = QueuePriority(event.priority) - new_priority = _next_lower(original_priority) - else: - new_priority = QueuePriority(event.priority) - - return await self.add_execution(event, priority=new_priority) - - def _get_queue_position(self, execution_id: str) -> int | None: - for position, queued in enumerate(self._queue): - if queued.execution_id == execution_id: - return position - return None - - def _is_stale(self, queued: QueuedExecution) -> bool: - return queued.age_seconds > self.stale_timeout_seconds - - def _track_execution(self, execution_id: str, user_id: str) -> None: - self._user_execution_count[user_id] += 1 - self._execution_users[execution_id] = user_id - - def _untrack_execution(self, execution_id: str) -> None: - if execution_id in self._execution_users: - user_id = self._execution_users.pop(execution_id) - self._user_execution_count[user_id] -= 1 - if self._user_execution_count[user_id] <= 0: - del self._user_execution_count[user_id] - - def _record_removal(self, reason: str) -> None: - # No-op: we keep a single queue depth metric and avoid operation counters - return - - def _record_wait_time(self, queued: QueuedExecution) -> None: - self.metrics.record_queue_wait_time_by_priority( - queued.age_seconds, QueuePriority(queued.priority).name, "default" - ) - - def _update_add_metrics(self, priority: QueuePriority) -> None: - # Deprecated in favor of single execution queue depth metric - self.metrics.update_execution_request_queue_size(len(self._queue)) - - def _update_queue_size(self) -> None: - self.metrics.update_execution_request_queue_size(len(self._queue)) - - async def _cleanup_stale_executions(self) -> None: - while self._running: - try: - await asyncio.sleep(300) - - async with self._queue_lock: - stale_executions = [] - active_executions = [] - - for queued in self._queue: - if self._is_stale(queued): - stale_executions.append(queued) - else: - active_executions.append(queued) - - if stale_executions: - self._queue = active_executions - heapq.heapify(self._queue) - - for queued in stale_executions: - self._untrack_execution(queued.execution_id) - - # Update metric after stale cleanup - self.metrics.update_execution_request_queue_size(len(self._queue)) - self.logger.info(f"Cleaned {len(stale_executions)} stale executions from queue") - - except Exception as e: - self.logger.error(f"Error in queue cleanup: {e}") diff --git a/backend/app/services/coordinator/resource_manager.py b/backend/app/services/coordinator/resource_manager.py deleted file mode 100644 index bd0c2fbf..00000000 --- a/backend/app/services/coordinator/resource_manager.py +++ /dev/null @@ -1,324 +0,0 @@ -import asyncio -import logging -from dataclasses import dataclass - -from app.core.metrics import CoordinatorMetrics - - -@dataclass -class ResourceAllocation: - """Resource allocation for an execution""" - - cpu_cores: float - memory_mb: int - gpu_count: int = 0 - - @property - def cpu_millicores(self) -> int: - """Get CPU in millicores for Kubernetes""" - return int(self.cpu_cores * 1000) - - @property - def memory_bytes(self) -> int: - """Get memory in bytes""" - return self.memory_mb * 1024 * 1024 - - -@dataclass -class ResourcePool: - """Available resource pool""" - - total_cpu_cores: float - total_memory_mb: int - total_gpu_count: int - - available_cpu_cores: float - available_memory_mb: int - available_gpu_count: int - - # Resource limits per execution - max_cpu_per_execution: float = 4.0 - max_memory_per_execution_mb: int = 8192 - max_gpu_per_execution: int = 1 - - # Minimum resources to keep available - min_available_cpu_cores: float = 2.0 - min_available_memory_mb: int = 4096 - - -@dataclass -class ResourceGroup: - """Resource group with usage information""" - - cpu_cores: float - memory_mb: int - gpu_count: int - - -@dataclass -class ResourceStats: - """Resource statistics""" - - total: ResourceGroup - available: ResourceGroup - allocated: ResourceGroup - utilization: dict[str, float] - allocation_count: int - limits: dict[str, int | float] - - -@dataclass -class ResourceAllocationInfo: - """Information about a resource allocation""" - - execution_id: str - cpu_cores: float - memory_mb: int - gpu_count: int - cpu_percentage: float - memory_percentage: float - - -class ResourceManager: - """Manages resource allocation for executions""" - - def __init__( - self, - logger: logging.Logger, - coordinator_metrics: CoordinatorMetrics, - total_cpu_cores: float = 32.0, - total_memory_mb: int = 65536, # 64GB - total_gpu_count: int = 0, - overcommit_factor: float = 1.2, # Allow 20% overcommit - ): - self.logger = logger - self.metrics = coordinator_metrics - self.pool = ResourcePool( - total_cpu_cores=total_cpu_cores * overcommit_factor, - total_memory_mb=int(total_memory_mb * overcommit_factor), - total_gpu_count=total_gpu_count, - available_cpu_cores=total_cpu_cores * overcommit_factor, - available_memory_mb=int(total_memory_mb * overcommit_factor), - available_gpu_count=total_gpu_count, - ) - - # Adjust minimum reserve thresholds proportionally for small pools. - # Keep at most 10% of total as reserve (but not higher than defaults). - # This avoids refusing small, reasonable allocations on modest clusters. - self.pool.min_available_cpu_cores = min( - self.pool.min_available_cpu_cores, - max(0.1 * self.pool.total_cpu_cores, 0.0), - ) - self.pool.min_available_memory_mb = min( - self.pool.min_available_memory_mb, - max(int(0.1 * self.pool.total_memory_mb), 0), - ) - - # Track allocations - self._allocations: dict[str, ResourceAllocation] = {} - self._allocation_lock = asyncio.Lock() - - # Default allocations by language - self.default_allocations = { - "python": ResourceAllocation(cpu_cores=0.5, memory_mb=512), - "javascript": ResourceAllocation(cpu_cores=0.5, memory_mb=512), - "go": ResourceAllocation(cpu_cores=0.25, memory_mb=256), - "rust": ResourceAllocation(cpu_cores=0.5, memory_mb=512), - "java": ResourceAllocation(cpu_cores=1.0, memory_mb=1024), - "cpp": ResourceAllocation(cpu_cores=0.5, memory_mb=512), - "r": ResourceAllocation(cpu_cores=1.0, memory_mb=2048), - } - - # Update initial metrics - self._update_metrics() - - async def request_allocation( - self, - execution_id: str, - language: str, - requested_cpu: float | None = None, - requested_memory_mb: int | None = None, - requested_gpu: int = 0, - ) -> ResourceAllocation | None: - """ - Request resource allocation for execution - - Returns: - ResourceAllocation if successful, None if resources unavailable - """ - async with self._allocation_lock: - # Check if already allocated - if execution_id in self._allocations: - self.logger.warning(f"Execution {execution_id} already has allocation") - return self._allocations[execution_id] - - # Determine requested resources - if requested_cpu is None or requested_memory_mb is None: - # Use defaults based on language - default = self.default_allocations.get(language, ResourceAllocation(cpu_cores=0.5, memory_mb=512)) - requested_cpu = requested_cpu or default.cpu_cores - requested_memory_mb = requested_memory_mb or default.memory_mb - - # Apply limits - requested_cpu = min(requested_cpu, self.pool.max_cpu_per_execution) - requested_memory_mb = min(requested_memory_mb, self.pool.max_memory_per_execution_mb) - requested_gpu = min(requested_gpu, self.pool.max_gpu_per_execution) - - # Check availability (considering minimum reserves) - cpu_after = self.pool.available_cpu_cores - requested_cpu - memory_after = self.pool.available_memory_mb - requested_memory_mb - gpu_after = self.pool.available_gpu_count - requested_gpu - - if ( - cpu_after < self.pool.min_available_cpu_cores - or memory_after < self.pool.min_available_memory_mb - or gpu_after < 0 - ): - self.logger.warning( - f"Insufficient resources for execution {execution_id}. " - f"Requested: {requested_cpu} CPU, {requested_memory_mb}MB RAM, " - f"{requested_gpu} GPU. Available: {self.pool.available_cpu_cores} CPU, " - f"{self.pool.available_memory_mb}MB RAM, {self.pool.available_gpu_count} GPU" - ) - return None - - # Create allocation - allocation = ResourceAllocation( - cpu_cores=requested_cpu, memory_mb=requested_memory_mb, gpu_count=requested_gpu - ) - - # Update pool - self.pool.available_cpu_cores = cpu_after - self.pool.available_memory_mb = memory_after - self.pool.available_gpu_count = gpu_after - - # Track allocation - self._allocations[execution_id] = allocation - - # Update metrics - self._update_metrics() - - self.logger.info( - f"Allocated resources for execution {execution_id}: " - f"{allocation.cpu_cores} CPU, {allocation.memory_mb}MB RAM, " - f"{allocation.gpu_count} GPU" - ) - - return allocation - - async def release_allocation(self, execution_id: str) -> bool: - """Release resource allocation""" - async with self._allocation_lock: - if execution_id not in self._allocations: - self.logger.warning(f"No allocation found for execution {execution_id}") - return False - - allocation = self._allocations[execution_id] - - # Return resources to pool - self.pool.available_cpu_cores += allocation.cpu_cores - self.pool.available_memory_mb += allocation.memory_mb - self.pool.available_gpu_count += allocation.gpu_count - - # Remove allocation - del self._allocations[execution_id] - - # Update metrics - self._update_metrics() - - self.logger.info( - f"Released resources for execution {execution_id}: " - f"{allocation.cpu_cores} CPU, {allocation.memory_mb}MB RAM, " - f"{allocation.gpu_count} GPU" - ) - - return True - - async def get_allocation(self, execution_id: str) -> ResourceAllocation | None: - """Get current allocation for execution""" - async with self._allocation_lock: - return self._allocations.get(execution_id) - - async def can_allocate(self, cpu_cores: float, memory_mb: int, gpu_count: int = 0) -> bool: - """Check if resources can be allocated""" - async with self._allocation_lock: - cpu_after = self.pool.available_cpu_cores - cpu_cores - memory_after = self.pool.available_memory_mb - memory_mb - gpu_after = self.pool.available_gpu_count - gpu_count - - return ( - cpu_after >= self.pool.min_available_cpu_cores - and memory_after >= self.pool.min_available_memory_mb - and gpu_after >= 0 - ) - - async def get_resource_stats(self) -> ResourceStats: - """Get resource statistics""" - async with self._allocation_lock: - allocated_cpu = self.pool.total_cpu_cores - self.pool.available_cpu_cores - allocated_memory = self.pool.total_memory_mb - self.pool.available_memory_mb - allocated_gpu = self.pool.total_gpu_count - self.pool.available_gpu_count - - gpu_percent = (allocated_gpu / self.pool.total_gpu_count * 100) if self.pool.total_gpu_count > 0 else 0 - - return ResourceStats( - total=ResourceGroup( - cpu_cores=self.pool.total_cpu_cores, - memory_mb=self.pool.total_memory_mb, - gpu_count=self.pool.total_gpu_count, - ), - available=ResourceGroup( - cpu_cores=self.pool.available_cpu_cores, - memory_mb=self.pool.available_memory_mb, - gpu_count=self.pool.available_gpu_count, - ), - allocated=ResourceGroup(cpu_cores=allocated_cpu, memory_mb=allocated_memory, gpu_count=allocated_gpu), - utilization={ - "cpu_percent": (allocated_cpu / self.pool.total_cpu_cores * 100), - "memory_percent": (allocated_memory / self.pool.total_memory_mb * 100), - "gpu_percent": gpu_percent, - }, - allocation_count=len(self._allocations), - limits={ - "max_cpu_per_execution": self.pool.max_cpu_per_execution, - "max_memory_per_execution_mb": self.pool.max_memory_per_execution_mb, - "max_gpu_per_execution": self.pool.max_gpu_per_execution, - }, - ) - - async def get_allocations_by_resource_usage(self) -> list[ResourceAllocationInfo]: - """Get allocations sorted by resource usage""" - async with self._allocation_lock: - allocations = [] - for exec_id, allocation in self._allocations.items(): - allocations.append( - ResourceAllocationInfo( - execution_id=str(exec_id), - cpu_cores=allocation.cpu_cores, - memory_mb=allocation.memory_mb, - gpu_count=allocation.gpu_count, - cpu_percentage=(allocation.cpu_cores / self.pool.total_cpu_cores * 100), - memory_percentage=(allocation.memory_mb / self.pool.total_memory_mb * 100), - ) - ) - - # Sort by total resource usage - allocations.sort(key=lambda x: x.cpu_percentage + x.memory_percentage, reverse=True) - - return allocations - - def _update_metrics(self) -> None: - """Update metrics""" - cpu_usage = self.pool.total_cpu_cores - self.pool.available_cpu_cores - cpu_percent = cpu_usage / self.pool.total_cpu_cores * 100 - self.metrics.update_resource_usage("cpu", cpu_percent) - - memory_usage = self.pool.total_memory_mb - self.pool.available_memory_mb - memory_percent = memory_usage / self.pool.total_memory_mb * 100 - self.metrics.update_resource_usage("memory", memory_percent) - - gpu_usage = self.pool.total_gpu_count - self.pool.available_gpu_count - gpu_percent = gpu_usage / max(1, self.pool.total_gpu_count) * 100 - self.metrics.update_resource_usage("gpu", gpu_percent) - - self.metrics.update_coordinator_active_executions(len(self._allocations)) diff --git a/backend/app/services/event_bus.py b/backend/app/services/event_bus.py index 613b2ef6..6f9d62a0 100644 --- a/backend/app/services/event_bus.py +++ b/backend/app/services/event_bus.py @@ -1,323 +1,140 @@ +"""Event Bus - stateless pub/sub service. + +Distributed event bus for cross-instance communication via Kafka. +No lifecycle management - receives ready-to-use producer from DI. +""" + +from __future__ import annotations + import asyncio import fnmatch import json import logging from dataclasses import dataclass, field -from typing import Any, Callable +from datetime import datetime, timezone from uuid import uuid4 -from aiokafka import AIOKafkaConsumer, AIOKafkaProducer -from aiokafka.errors import KafkaError -from fastapi import Request +from aiokafka import AIOKafkaProducer +from pydantic import BaseModel, ConfigDict -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import ConnectionMetrics from app.domain.enums.kafka import KafkaTopic -from app.domain.events.typed import BaseEvent, DomainEvent, domain_event_adapter from app.settings import Settings +class EventBusEvent(BaseModel): + """Represents an event on the event bus.""" + + model_config = ConfigDict(from_attributes=True) + + id: str + event_type: str + timestamp: datetime + payload: dict[str, object] + + @dataclass class Subscription: """Represents a single event subscription.""" id: str = field(default_factory=lambda: str(uuid4())) pattern: str = "" - handler: Callable[[DomainEvent], Any] = field(default=lambda _: None) - - -class EventBus(LifecycleEnabled): - """ - Distributed event bus for cross-instance communication via Kafka. - - Publishers send events to Kafka. Subscribers receive events from OTHER instances - only - self-published messages are filtered out. This design means: - - Publishers should update their own state directly before calling publish() - - Handlers only run for events from other instances (cache invalidation, etc.) - - Supports pattern-based subscriptions using wildcards: - - execution.* - matches all execution events - - execution.123.* - matches all events for execution 123 - - *.completed - matches all completed events - """ - - def __init__(self, settings: Settings, logger: logging.Logger, connection_metrics: ConnectionMetrics) -> None: - super().__init__() - self.logger = logger - self.settings = settings - self.metrics = connection_metrics - self.producer: AIOKafkaProducer | None = None - self.consumer: AIOKafkaConsumer | None = None - self._subscriptions: dict[str, Subscription] = {} # id -> Subscription - self._pattern_index: dict[str, set[str]] = {} # pattern -> set of subscription ids - self._consumer_task: asyncio.Task[None] | None = None + handler: object = field(default=None) + + +class EventBus: + """Stateless event bus - pure pub/sub service.""" + + def __init__( + self, + producer: AIOKafkaProducer, + settings: Settings, + logger: logging.Logger, + connection_metrics: ConnectionMetrics, + ) -> None: + self._producer = producer + self._settings = settings + self._logger = logger + self._metrics = connection_metrics + self._subscriptions: dict[str, Subscription] = {} + self._pattern_index: dict[str, set[str]] = {} self._lock = asyncio.Lock() - self._topic = f"{self.settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EVENT_BUS_STREAM}" - self._instance_id = str(uuid4()) # Unique ID for filtering self-published messages - - async def _on_start(self) -> None: - """Start the event bus with Kafka backing.""" - await self._initialize_kafka() - self._consumer_task = asyncio.create_task(self._kafka_listener()) - self.logger.info("Event bus started with Kafka backing") - - async def _initialize_kafka(self) -> None: - """Initialize Kafka producer and consumer.""" - # Producer setup - self.producer = AIOKafkaProducer( - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - client_id=f"event-bus-producer-{uuid4()}", - linger_ms=10, - max_batch_size=16384, - enable_idempotence=True, + self._topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EVENT_BUS_STREAM}" + self._instance_id = str(uuid4()) + + async def publish(self, event_type: str, data: dict[str, object]) -> None: + """Publish an event to Kafka for cross-instance distribution.""" + event = EventBusEvent( + id=str(uuid4()), + event_type=event_type, + timestamp=datetime.now(timezone.utc), + payload=data, ) - # Consumer setup - self.consumer = AIOKafkaConsumer( - self._topic, - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"event-bus-{uuid4()}", - auto_offset_reset="latest", - enable_auto_commit=True, - client_id=f"event-bus-consumer-{uuid4()}", - session_timeout_ms=self.settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self.settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self.settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self.settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - # Start both in parallel for faster startup - await asyncio.gather(self.producer.start(), self.consumer.start()) - - async def _on_stop(self) -> None: - """Stop the event bus and clean up resources.""" - # Cancel consumer task - if self._consumer_task and not self._consumer_task.done(): - self._consumer_task.cancel() - try: - await self._consumer_task - except asyncio.CancelledError: - pass - - # Stop Kafka components - if self.consumer: - await self.consumer.stop() - self.consumer = None - - if self.producer: - await self.producer.stop() - self.producer = None - - # Clear subscriptions - async with self._lock: - self._subscriptions.clear() - self._pattern_index.clear() - - self.logger.info("Event bus stopped") - - async def publish(self, event: BaseEvent) -> None: - """ - Publish a typed event to Kafka for cross-instance distribution. - - Local handlers receive events only from OTHER instances via the Kafka listener. - Publishers should update their own state directly before calling publish(). - - Args: - event: Typed domain event to publish - """ - if self.producer: - try: - value = event.model_dump_json().encode("utf-8") - key = event.event_type.encode("utf-8") - headers = [("source_instance", self._instance_id.encode("utf-8"))] - - await self.producer.send_and_wait( - topic=self._topic, - value=value, - key=key, - headers=headers, - ) - except Exception as e: - self.logger.error(f"Failed to publish to Kafka: {e}") - - async def subscribe(self, pattern: str, handler: Callable[[DomainEvent], Any]) -> str: - """ - Subscribe to events matching a pattern. - - Args: - pattern: Event pattern with wildcards (e.g., "execution.*") - handler: Async function to handle matching events + try: + await self._producer.send_and_wait( + topic=self._topic, + value=event.model_dump_json().encode(), + key=event_type.encode(), + headers=[("source_instance", self._instance_id.encode())], + ) + except Exception as e: + self._logger.error(f"Failed to publish to Kafka: {e}") - Returns: - Subscription ID for later unsubscribe - """ + async def subscribe(self, pattern: str, handler: object) -> str: + """Subscribe to events matching a pattern. Returns subscription ID.""" subscription = Subscription(pattern=pattern, handler=handler) async with self._lock: - # Store subscription self._subscriptions[subscription.id] = subscription - - # Update pattern index if pattern not in self._pattern_index: self._pattern_index[pattern] = set() self._pattern_index[pattern].add(subscription.id) + self._metrics.update_event_bus_subscribers(len(self._pattern_index[pattern]), pattern) - # Update metrics - self._update_metrics(pattern) - - self.logger.debug(f"Created subscription {subscription.id} for pattern: {pattern}") return subscription.id - async def unsubscribe(self, pattern: str, handler: Callable[[DomainEvent], Any]) -> None: - """Unsubscribe a specific handler from a pattern.""" + async def unsubscribe(self, pattern: str, handler: object) -> None: + """Unsubscribe a handler from a pattern.""" async with self._lock: - # Find subscription with matching pattern and handler - for sub_id, subscription in list(self._subscriptions.items()): - if subscription.pattern == pattern and subscription.handler == handler: - await self._remove_subscription(sub_id) + for sub_id, sub in list(self._subscriptions.items()): + if sub.pattern == pattern and sub.handler == handler: + del self._subscriptions[sub_id] + self._pattern_index[pattern].discard(sub_id) + if not self._pattern_index[pattern]: + del self._pattern_index[pattern] + self._metrics.update_event_bus_subscribers(0, pattern) + else: + self._metrics.update_event_bus_subscribers(len(self._pattern_index[pattern]), pattern) return - self.logger.warning(f"No subscription found for pattern {pattern} with given handler") - - async def _remove_subscription(self, subscription_id: str) -> None: - """Remove a subscription by ID (must be called within lock).""" - if subscription_id not in self._subscriptions: - self.logger.warning(f"Subscription {subscription_id} not found") - return - - subscription = self._subscriptions[subscription_id] - pattern = subscription.pattern - - # Remove from subscriptions - del self._subscriptions[subscription_id] - - # Update pattern index - if pattern in self._pattern_index: - self._pattern_index[pattern].discard(subscription_id) - if not self._pattern_index[pattern]: - del self._pattern_index[pattern] - - # Update metrics - self._update_metrics(pattern) - - self.logger.debug(f"Removed subscription {subscription_id} for pattern: {pattern}") - - async def _distribute_event(self, event: DomainEvent) -> None: - """Distribute event to all matching local subscribers.""" - # Find matching subscriptions - matching_handlers = await self._find_matching_handlers(event.event_type) - - if not matching_handlers: + async def handle_kafka_message(self, raw_message: bytes, headers: dict[str, str]) -> None: + """Handle a Kafka message. Skips messages from this instance.""" + if headers.get("source_instance") == self._instance_id: return - # Execute all handlers concurrently - results = await asyncio.gather( - *(self._invoke_handler(handler, event) for handler in matching_handlers), return_exceptions=True - ) - - # Log any errors - for _i, result in enumerate(results): - if isinstance(result, Exception): - self.logger.error(f"Handler failed for event {event.event_type}: {result}") - - async def _find_matching_handlers(self, event_type: str) -> list[Callable[[DomainEvent], Any]]: - """Find all handlers matching the event type.""" - async with self._lock: - handlers: list[Callable[[DomainEvent], Any]] = [] - for pattern, sub_ids in self._pattern_index.items(): - if fnmatch.fnmatch(event_type, pattern): - handlers.extend( - self._subscriptions[sub_id].handler for sub_id in sub_ids if sub_id in self._subscriptions - ) - return handlers - - async def _invoke_handler(self, handler: Callable[[DomainEvent], Any], event: DomainEvent) -> None: - """Invoke a single handler, handling both sync and async.""" - if asyncio.iscoroutinefunction(handler): - await handler(event) - else: - await asyncio.to_thread(handler, event) - - async def _kafka_listener(self) -> None: - """Listen for Kafka messages from OTHER instances and distribute to local subscribers.""" - if not self.consumer: - return - - self.logger.info("Kafka listener started") - try: - while self.is_running: - try: - msg = await asyncio.wait_for(self.consumer.getone(), timeout=0.1) - - # Skip messages from this instance - publisher handles its own state - headers = dict(msg.headers) if msg.headers else {} - source = headers.get("source_instance", b"").decode("utf-8") - if source == self._instance_id: - continue - - try: - event_dict = json.loads(msg.value.decode("utf-8")) - event = domain_event_adapter.validate_python(event_dict) - await self._distribute_event(event) - except Exception as e: - self.logger.error(f"Error processing Kafka message: {e}") - - except asyncio.TimeoutError: - continue - except KafkaError as e: - self.logger.error(f"Consumer error: {e}") - continue - - except asyncio.CancelledError: - self.logger.info("Kafka listener cancelled") + event = EventBusEvent.model_validate(json.loads(raw_message)) + await self._distribute_event(event) except Exception as e: - self.logger.error(f"Fatal error in Kafka listener: {e}") - - def _update_metrics(self, pattern: str) -> None: - """Update metrics for a pattern (must be called within lock).""" - if self.metrics: - count = len(self._pattern_index.get(pattern, set())) - self.metrics.update_event_bus_subscribers(count, pattern) + self._logger.error(f"Error processing Kafka message: {e}") - async def get_statistics(self) -> dict[str, Any]: - """Get event bus statistics.""" + async def _distribute_event(self, event: EventBusEvent) -> None: + """Distribute event to matching local subscribers.""" async with self._lock: - return { - "patterns": list(self._pattern_index.keys()), - "total_patterns": len(self._pattern_index), - "total_subscriptions": len(self._subscriptions), - "kafka_enabled": self.producer is not None, - "running": self.is_running, - } - - -class EventBusManager: - """Manages EventBus lifecycle as a singleton.""" - - def __init__(self, settings: Settings, logger: logging.Logger, connection_metrics: ConnectionMetrics) -> None: - self.settings = settings - self.logger = logger - self._connection_metrics = connection_metrics - self._event_bus: EventBus | None = None - self._lock = asyncio.Lock() - - async def get_event_bus(self) -> EventBus: - """Get or create the event bus instance.""" - async with self._lock: - if self._event_bus is None: - self._event_bus = EventBus(self.settings, self.logger, self._connection_metrics) - await self._event_bus.__aenter__() - return self._event_bus - - async def close(self) -> None: - """Stop and clean up the event bus.""" - async with self._lock: - if self._event_bus: - await self._event_bus.aclose() - self._event_bus = None - - -async def get_event_bus(request: Request) -> EventBus: - manager: EventBusManager = request.app.state.event_bus_manager - return await manager.get_event_bus() + handlers = [ + self._subscriptions[sub_id].handler + for pattern, sub_ids in self._pattern_index.items() + if fnmatch.fnmatch(event.event_type, pattern) + for sub_id in sub_ids + if sub_id in self._subscriptions + ] + + for handler in handlers: + try: + if asyncio.iscoroutinefunction(handler): + await handler(event) + else: + handler(event) # type: ignore[operator] + except Exception as e: + self._logger.error(f"Handler failed for {event.event_type}: {e}") diff --git a/backend/app/services/idempotency/middleware.py b/backend/app/services/idempotency/middleware.py index 04a4f931..4dac5287 100644 --- a/backend/app/services/idempotency/middleware.py +++ b/backend/app/services/idempotency/middleware.py @@ -1,12 +1,11 @@ +"""Idempotent event processing middleware""" + import asyncio import logging -from collections.abc import Awaitable -from typing import Any, Callable +from typing import Any, Awaitable, Callable, Dict, Set from app.domain.enums.events import EventType -from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent -from app.domain.idempotency import KeyStrategy from app.events.core import EventDispatcher, UnifiedConsumer from app.services.idempotency.idempotency_manager import IdempotencyManager @@ -19,9 +18,9 @@ def __init__( handler: Callable[[DomainEvent], Awaitable[None]], idempotency_manager: IdempotencyManager, logger: logging.Logger, - key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, + key_strategy: str = "event_based", custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: set[str] | None = None, + fields: Set[str] | None = None, ttl_seconds: int | None = None, cache_result: bool = True, on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, @@ -44,7 +43,7 @@ async def __call__(self, event: DomainEvent) -> None: ) # Generate custom key if function provided custom_key = None - if self.key_strategy == KeyStrategy.CUSTOM and self.custom_key_func: + if self.key_strategy == "custom" and self.custom_key_func: custom_key = self.custom_key_func(event) # Check idempotency @@ -93,9 +92,9 @@ async def __call__(self, event: DomainEvent) -> None: def idempotent_handler( idempotency_manager: IdempotencyManager, logger: logging.Logger, - key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, + key_strategy: str = "event_based", custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: set[str] | None = None, + fields: Set[str] | None = None, ttl_seconds: int | None = None, cache_result: bool = True, on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, @@ -128,7 +127,7 @@ def __init__( idempotency_manager: IdempotencyManager, dispatcher: EventDispatcher, logger: logging.Logger, - default_key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, + default_key_strategy: str = "event_based", default_ttl_seconds: int = 3600, enable_for_all_handlers: bool = True, ): @@ -138,22 +137,19 @@ def __init__( self.logger = logger self.default_key_strategy = default_key_strategy self.default_ttl_seconds = default_ttl_seconds - self.enable_for_all_handlers = enable_for_all_handlers - self._original_handlers: dict[EventType, list[Callable[[DomainEvent], Awaitable[None]]]] = {} + self._original_handlers: Dict[EventType, list[Callable[[DomainEvent], Awaitable[None]]]] = {} - def make_handlers_idempotent(self) -> None: - """Wrap all registered handlers with idempotency""" - self.logger.info( - f"make_handlers_idempotent called: enable_for_all={self.enable_for_all_handlers}, " - f"dispatcher={self.dispatcher is not None}" - ) - if not self.enable_for_all_handlers or not self.dispatcher: - self.logger.warning("Skipping handler wrapping - conditions not met") + if enable_for_all_handlers: + self._wrap_handlers() + + def _wrap_handlers(self) -> None: + """Wrap all registered handlers with idempotency.""" + if not self.dispatcher: + self.logger.warning("No dispatcher available for handler wrapping") return - # Store original handlers using public API self._original_handlers = self.dispatcher.get_all_handlers() - self.logger.info(f"Got {len(self._original_handlers)} event types with handlers to wrap") + self.logger.debug(f"Wrapping {len(self._original_handlers)} event types with idempotency") # Wrap each handler for event_type, handlers in self._original_handlers.items(): @@ -169,21 +165,15 @@ def make_handlers_idempotent(self) -> None: ) wrapped_handlers.append(wrapped) - # Replace handlers using public API - self.logger.info( - f"Replacing {len(handlers)} handlers for {event_type} with {len(wrapped_handlers)} wrapped handlers" - ) self.dispatcher.replace_handlers(event_type, wrapped_handlers) - self.logger.info("Handler wrapping complete") - def subscribe_idempotent_handler( self, event_type: str, handler: Callable[[DomainEvent], Awaitable[None]], - key_strategy: KeyStrategy | None = None, + key_strategy: str | None = None, custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: set[str] | None = None, + fields: Set[str] | None = None, ttl_seconds: int | None = None, cache_result: bool = True, on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, @@ -259,21 +249,3 @@ async def dispatch_handler(event: DomainEvent) -> None: else: # Fallback to direct consumer registration if no dispatcher self.logger.error(f"No EventDispatcher available for registering idempotent handler for {event_type}") - - async def start(self, topics: list[KafkaTopic]) -> None: - """Start the consumer with idempotency""" - self.logger.info(f"IdempotentConsumerWrapper.start called with topics: {topics}") - # Make handlers idempotent before starting - self.make_handlers_idempotent() - - # Start the consumer with required topics parameter - await self.consumer.start(topics) - self.logger.info("IdempotentConsumerWrapper started successfully") - - async def stop(self) -> None: - """Stop the consumer""" - await self.consumer.stop() - - # Delegate other methods to the wrapped consumer - def __getattr__(self, name: str) -> Any: - return getattr(self.consumer, name) diff --git a/backend/app/services/k8s_worker/worker.py b/backend/app/services/k8s_worker/worker.py index eafceeca..c49ec98f 100644 --- a/backend/app/services/k8s_worker/worker.py +++ b/backend/app/services/k8s_worker/worker.py @@ -1,336 +1,169 @@ +"""Kubernetes Worker - stateless event handler. + +Creates Kubernetes pods from execution events. Receives events, +processes them, and publishes results. No lifecycle management. +All state is stored in Redis repositories. +""" + +from __future__ import annotations + import asyncio import logging -import os import time from pathlib import Path -from typing import Any from kubernetes import client as k8s_client -from kubernetes import config as k8s_config from kubernetes.client.rest import ApiException -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics, ExecutionMetrics, KubernetesMetrics -from app.domain.enums.events import EventType -from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId +from app.db.repositories.pod_state_repository import PodStateRepository from app.domain.enums.storage import ExecutionErrorType from app.domain.events.typed import ( CreatePodCommandEvent, DeletePodCommandEvent, - DomainEvent, ExecutionFailedEvent, ExecutionStartedEvent, PodCreatedEvent, ) -from app.domain.idempotency import KeyStrategy -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import ( - SchemaRegistryManager, -) +from app.events.core import UnifiedProducer from app.runtime_registry import RUNTIME_REGISTRY -from app.services.idempotency import IdempotencyManager -from app.services.idempotency.middleware import IdempotentConsumerWrapper from app.services.k8s_worker.config import K8sWorkerConfig from app.services.k8s_worker.pod_builder import PodBuilder -from app.settings import Settings -class KubernetesWorker(LifecycleEnabled): - """ - Worker service that creates Kubernetes pods from execution events. - - This service: - 1. Consumes ExecutionStarted events from Kafka - 2. Creates ConfigMaps with script content - 3. Creates Pods to execute the scripts - 4. Creates NetworkPolicies for security - 5. Publishes PodCreated events +class KubernetesWorker: + """Stateless Kubernetes worker - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + All state (active creations) stored in Redis via PodStateRepository. """ def __init__( - self, - config: K8sWorkerConfig, - producer: UnifiedProducer, - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - event_metrics: EventMetrics, - ): - super().__init__() + self, + config: K8sWorkerConfig, + producer: UnifiedProducer, + pod_state_repo: PodStateRepository, + v1_client: k8s_client.CoreV1Api, + networking_v1_client: k8s_client.NetworkingV1Api, + apps_v1_client: k8s_client.AppsV1Api, + logger: logging.Logger, + kubernetes_metrics: KubernetesMetrics, + execution_metrics: ExecutionMetrics, + event_metrics: EventMetrics, + ) -> None: + self._config = config + self._producer = producer + self._pod_state_repo = pod_state_repo + self._v1 = v1_client + self._networking_v1 = networking_v1_client + self._apps_v1 = apps_v1_client + self._logger = logger + self._metrics = kubernetes_metrics + self._execution_metrics = execution_metrics self._event_metrics = event_metrics - self.logger = logger - self.metrics = KubernetesMetrics(settings) - self.execution_metrics = ExecutionMetrics(settings) - self.config = config or K8sWorkerConfig() - self._settings = settings - - self.kafka_servers = self._settings.KAFKA_BOOTSTRAP_SERVERS - self._event_store = event_store - - # Kubernetes clients - self.v1: k8s_client.CoreV1Api | None = None - self.networking_v1: k8s_client.NetworkingV1Api | None = None - self.apps_v1: k8s_client.AppsV1Api | None = None - - # Components - self.pod_builder = PodBuilder(namespace=self.config.namespace, config=self.config) - self.consumer: UnifiedConsumer | None = None - self.idempotent_consumer: IdempotentConsumerWrapper | None = None - self.idempotency_manager: IdempotencyManager = idempotency_manager - self.dispatcher: EventDispatcher | None = None - self.producer: UnifiedProducer = producer - - # State tracking - self._active_creations: set[str] = set() - self._creation_semaphore = asyncio.Semaphore(self.config.max_concurrent_pods) - self._schema_registry_manager = schema_registry_manager - - async def _on_start(self) -> None: - """Start the Kubernetes worker.""" - self.logger.info("Starting KubernetesWorker service...") - self.logger.info("DEBUG: About to initialize Kubernetes client") - - if self.config.namespace == "default": - raise RuntimeError( - "KubernetesWorker namespace 'default' is forbidden. Set K8S_NAMESPACE to a dedicated namespace." - ) - - # Initialize Kubernetes client - self._initialize_kubernetes_client() - self.logger.info("DEBUG: Kubernetes client initialized") - - self.logger.info("Using provided producer") - - self.logger.info("Idempotency manager provided") - - # Create consumer configuration - consumer_config = ConsumerConfig( - bootstrap_servers=self.kafka_servers, - group_id=self.config.consumer_group, - enable_auto_commit=False, - session_timeout_ms=self._settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self._settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self._settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self._settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - # Create dispatcher and register handlers for saga commands - self.dispatcher = EventDispatcher(logger=self.logger) - self.dispatcher.register_handler(EventType.CREATE_POD_COMMAND, self._handle_create_pod_command_wrapper) - self.dispatcher.register_handler(EventType.DELETE_POD_COMMAND, self._handle_delete_pod_command_wrapper) - - # Create consumer with dispatcher - self.consumer = UnifiedConsumer( - consumer_config, - event_dispatcher=self.dispatcher, - schema_registry=self._schema_registry_manager, - settings=self._settings, - logger=self.logger, - event_metrics=self._event_metrics, - ) - - # Wrap consumer with idempotency - use content hash for pod commands - self.idempotent_consumer = IdempotentConsumerWrapper( - consumer=self.consumer, - idempotency_manager=self.idempotency_manager, - dispatcher=self.dispatcher, - logger=self.logger, - default_key_strategy=KeyStrategy.CONTENT_HASH, # Hash execution_id + script for deduplication - default_ttl_seconds=3600, # 1 hour TTL for pod creation events - enable_for_all_handlers=True, # Enable idempotency for all handlers - ) - - # Start the consumer with idempotency - topics from centralized config - await self.idempotent_consumer.start(list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.K8S_WORKER])) - - # Create daemonset for image pre-pulling - asyncio.create_task(self.ensure_image_pre_puller_daemonset()) - self.logger.info("Image pre-puller daemonset task scheduled") - - self.logger.info("KubernetesWorker service started successfully") - - async def _on_stop(self) -> None: - """Stop the Kubernetes worker.""" - self.logger.info("Stopping KubernetesWorker service...") - - # Wait for active creations to complete - if self._active_creations: - self.logger.info(f"Waiting for {len(self._active_creations)} active pod creations to complete...") - timeout = 30 - start_time = time.time() - - while self._active_creations and (time.time() - start_time) < timeout: - await asyncio.sleep(1) - - if self._active_creations: - self.logger.warning(f"Timeout waiting for pod creations, {len(self._active_creations)} still active") - - # Stop the consumer (idempotent wrapper only) - if self.idempotent_consumer: - await self.idempotent_consumer.stop() - - # Close idempotency manager - await self.idempotency_manager.close() - - # Note: producer is managed by DI container, not stopped here - - self.logger.info("KubernetesWorker service stopped") - - def _initialize_kubernetes_client(self) -> None: - """Initialize Kubernetes API clients""" - try: - # Load config - if self.config.in_cluster: - self.logger.info("Using in-cluster Kubernetes configuration") - k8s_config.load_incluster_config() - elif self.config.kubeconfig_path and os.path.exists(self.config.kubeconfig_path): - self.logger.info(f"Using kubeconfig from {self.config.kubeconfig_path}") - k8s_config.load_kube_config(config_file=self.config.kubeconfig_path) - else: - # Try default locations - if os.path.exists("/var/run/secrets/kubernetes.io/serviceaccount"): - self.logger.info("Detected in-cluster environment") - k8s_config.load_incluster_config() - else: - self.logger.info("Using default kubeconfig") - k8s_config.load_kube_config() - - # Get the default configuration that was set by load_kube_config - configuration = k8s_client.Configuration.get_default_copy() - - # The certificate data should already be configured by load_kube_config - # Log the configuration for debugging - self.logger.info(f"Kubernetes API host: {configuration.host}") - self.logger.info(f"SSL CA cert configured: {configuration.ssl_ca_cert is not None}") + self._pod_builder = PodBuilder(namespace=config.namespace, config=config) - # Create API clients with the configuration - api_client = k8s_client.ApiClient(configuration) - self.v1 = k8s_client.CoreV1Api(api_client) - self.networking_v1 = k8s_client.NetworkingV1Api(api_client) - self.apps_v1 = k8s_client.AppsV1Api(api_client) - - # Test connection with namespace-scoped operation - _ = self.v1.list_namespaced_pod(namespace=self.config.namespace, limit=1) - self.logger.info(f"Successfully connected to Kubernetes API, namespace {self.config.namespace} accessible") - - except Exception as e: - self.logger.error(f"Failed to initialize Kubernetes client: {e}") - raise - - async def _handle_create_pod_command_wrapper(self, event: DomainEvent) -> None: - """Wrapper for handling CreatePodCommandEvent with type safety.""" - assert isinstance(event, CreatePodCommandEvent) - self.logger.info(f"Processing create_pod_command for execution {event.execution_id} from saga {event.saga_id}") - await self._handle_create_pod_command(event) - - async def _handle_delete_pod_command_wrapper(self, event: DomainEvent) -> None: - """Wrapper for handling DeletePodCommandEvent.""" - assert isinstance(event, DeletePodCommandEvent) - self.logger.info(f"Processing delete_pod_command for execution {event.execution_id} from saga {event.saga_id}") - await self._handle_delete_pod_command(event) - - async def _handle_create_pod_command(self, command: CreatePodCommandEvent) -> None: - """Handle create pod command from saga orchestrator""" + async def handle_create_pod_command(self, command: CreatePodCommandEvent) -> None: + """Handle create pod command from saga orchestrator.""" execution_id = command.execution_id + self._logger.info(f"Processing create_pod_command for execution {execution_id} from saga {command.saga_id}") - # Check if already processing - if execution_id in self._active_creations: - self.logger.warning(f"Already creating pod for execution {execution_id}") + # Try to claim this creation atomically in Redis + claimed = await self._pod_state_repo.try_claim_creation(execution_id, ttl_seconds=300) + if not claimed: + self._logger.warning(f"Already creating pod for execution {execution_id}, skipping") return - # Create pod asynchronously - asyncio.create_task(self._create_pod_for_execution(command)) + await self._create_pod_for_execution(command) - async def _handle_delete_pod_command(self, command: DeletePodCommandEvent) -> None: - """Handle delete pod command from saga orchestrator (compensation)""" + async def handle_delete_pod_command(self, command: DeletePodCommandEvent) -> None: + """Handle delete pod command from saga orchestrator (compensation).""" execution_id = command.execution_id - self.logger.info(f"Deleting pod for execution {execution_id} due to: {command.reason}") + self._logger.info(f"Deleting pod for execution {execution_id} due to: {command.reason}") try: # Delete the pod pod_name = f"executor-{execution_id}" - if self.v1: - await asyncio.to_thread( - self.v1.delete_namespaced_pod, - name=pod_name, - namespace=self.config.namespace, - grace_period_seconds=30, - ) - self.logger.info(f"Successfully deleted pod {pod_name}") + await asyncio.to_thread( + self._v1.delete_namespaced_pod, + name=pod_name, + namespace=self._config.namespace, + grace_period_seconds=30, + ) + self._logger.info(f"Successfully deleted pod {pod_name}") # Delete associated ConfigMap configmap_name = f"script-{execution_id}" - if self.v1: - await asyncio.to_thread( - self.v1.delete_namespaced_config_map, name=configmap_name, namespace=self.config.namespace - ) - self.logger.info(f"Successfully deleted ConfigMap {configmap_name}") - - # NetworkPolicy cleanup is managed via a static cluster policy; no per-execution NP deletion + await asyncio.to_thread( + self._v1.delete_namespaced_config_map, + name=configmap_name, + namespace=self._config.namespace, + ) + self._logger.info(f"Successfully deleted ConfigMap {configmap_name}") except ApiException as e: if e.status == 404: - self.logger.warning(f"Resources for execution {execution_id} not found (may have already been deleted)") + self._logger.warning( + f"Resources for execution {execution_id} not found (may have already been deleted)" + ) else: - self.logger.error(f"Failed to delete resources for execution {execution_id}: {e}") + self._logger.error(f"Failed to delete resources for execution {execution_id}: {e}") async def _create_pod_for_execution(self, command: CreatePodCommandEvent) -> None: - """Create pod for execution""" - async with self._creation_semaphore: - execution_id = command.execution_id - self._active_creations.add(execution_id) - self.metrics.update_k8s_active_creations(len(self._active_creations)) - - # Queue depth is owned by the coordinator; do not modify here - - start_time = time.time() - - try: - # We now have the CreatePodCommandEvent directly from saga - script_content = command.script - entrypoint_content = await self._get_entrypoint_script() - - # Create ConfigMap - config_map = self.pod_builder.build_config_map( - command=command, script_content=script_content, entrypoint_content=entrypoint_content - ) + """Create pod for execution.""" + execution_id = command.execution_id + start_time = time.time() - await self._create_config_map(config_map) + try: + # Update metrics for active creations + active_count = await self._pod_state_repo.get_active_creations_count() + self._metrics.update_k8s_active_creations(active_count) + + # Build and create ConfigMap + script_content = command.script + entrypoint_content = await self._get_entrypoint_script() + + config_map = self._pod_builder.build_config_map( + command=command, + script_content=script_content, + entrypoint_content=entrypoint_content, + ) + await self._create_config_map(config_map) - pod = self.pod_builder.build_pod_manifest(command=command) - await self._create_pod(pod) + # Build and create Pod + pod = self._pod_builder.build_pod_manifest(command=command) + await self._create_pod(pod) - # Publish PodCreated event - await self._publish_pod_created(command, pod) + # Publish PodCreated event + await self._publish_pod_created(command, pod) - # Update metrics - duration = time.time() - start_time - self.metrics.record_k8s_pod_creation_duration(duration, command.language) - self.metrics.record_k8s_pod_created("success", command.language) + # Update metrics + duration = time.time() - start_time + self._metrics.record_k8s_pod_creation_duration(duration, command.language) + self._metrics.record_k8s_pod_created("success", command.language) - self.logger.info( - f"Successfully created pod {pod.metadata.name} for execution {execution_id}. " - f"Duration: {duration:.2f}s" - ) + self._logger.info( + f"Successfully created pod {pod.metadata.name} for execution {execution_id}. " + f"Duration: {duration:.2f}s" + ) - except Exception as e: - self.logger.error(f"Failed to create pod for execution {execution_id}: {e}", exc_info=True) + except Exception as e: + self._logger.error(f"Failed to create pod for execution {execution_id}: {e}", exc_info=True) + self._metrics.record_k8s_pod_created("failed", "unknown") - # Update metrics - self.metrics.record_k8s_pod_created("failed", "unknown") + # Publish failure event + await self._publish_pod_creation_failed(command, str(e)) - # Publish failure event - await self._publish_pod_creation_failed(command, str(e)) + finally: + # Release the creation claim + await self._pod_state_repo.release_creation(execution_id) - finally: - self._active_creations.discard(execution_id) - self.metrics.update_k8s_active_creations(len(self._active_creations)) + # Update metrics + active_count = await self._pod_state_repo.get_active_creations_count() + self._metrics.update_k8s_active_creations(active_count) async def _get_entrypoint_script(self) -> str: - """Get entrypoint script content""" + """Get entrypoint script content.""" entrypoint_path = Path("app/scripts/entrypoint.sh") if entrypoint_path.exists(): return await asyncio.to_thread(entrypoint_path.read_text) @@ -353,67 +186,62 @@ async def _get_entrypoint_script(self) -> str: """ async def _create_config_map(self, config_map: k8s_client.V1ConfigMap) -> None: - """Create ConfigMap in Kubernetes""" - if not self.v1: - raise RuntimeError("Kubernetes client not initialized") + """Create ConfigMap in Kubernetes.""" try: await asyncio.to_thread( - self.v1.create_namespaced_config_map, namespace=self.config.namespace, body=config_map + self._v1.create_namespaced_config_map, + namespace=self._config.namespace, + body=config_map, ) - self.metrics.record_k8s_config_map_created("success") - self.logger.debug(f"Created ConfigMap {config_map.metadata.name}") + self._metrics.record_k8s_config_map_created("success") + self._logger.debug(f"Created ConfigMap {config_map.metadata.name}") except ApiException as e: if e.status == 409: # Already exists - self.logger.warning(f"ConfigMap {config_map.metadata.name} already exists") - self.metrics.record_k8s_config_map_created("already_exists") + self._logger.warning(f"ConfigMap {config_map.metadata.name} already exists") + self._metrics.record_k8s_config_map_created("already_exists") else: - self.metrics.record_k8s_config_map_created("failed") + self._metrics.record_k8s_config_map_created("failed") raise async def _create_pod(self, pod: k8s_client.V1Pod) -> None: - """Create Pod in Kubernetes""" - if not self.v1: - raise RuntimeError("Kubernetes client not initialized") + """Create Pod in Kubernetes.""" try: - await asyncio.to_thread(self.v1.create_namespaced_pod, namespace=self.config.namespace, body=pod) - self.logger.debug(f"Created Pod {pod.metadata.name}") + await asyncio.to_thread( + self._v1.create_namespaced_pod, + namespace=self._config.namespace, + body=pod, + ) + self._logger.debug(f"Created Pod {pod.metadata.name}") except ApiException as e: if e.status == 409: # Already exists - self.logger.warning(f"Pod {pod.metadata.name} already exists") + self._logger.warning(f"Pod {pod.metadata.name} already exists") else: raise async def _publish_execution_started(self, command: CreatePodCommandEvent, pod: k8s_client.V1Pod) -> None: - """Publish execution started event""" + """Publish execution started event.""" event = ExecutionStartedEvent( execution_id=command.execution_id, - aggregate_id=command.execution_id, # Set aggregate_id to execution_id + aggregate_id=command.execution_id, pod_name=pod.metadata.name, node_name=pod.spec.node_name, - container_id=None, # Will be set when container actually starts + container_id=None, metadata=command.metadata, ) - if not self.producer: - self.logger.error("Producer not initialized") - return - await self.producer.produce(event_to_produce=event) + await self._producer.produce(event_to_produce=event) async def _publish_pod_created(self, command: CreatePodCommandEvent, pod: k8s_client.V1Pod) -> None: - """Publish pod created event""" + """Publish pod created event.""" event = PodCreatedEvent( execution_id=command.execution_id, pod_name=pod.metadata.name, namespace=pod.metadata.namespace, metadata=command.metadata, ) - - if not self.producer: - self.logger.error("Producer not initialized") - return - await self.producer.produce(event_to_produce=event) + await self._producer.produce(event_to_produce=event) async def _publish_pod_creation_failed(self, command: CreatePodCommandEvent, error: str) -> None: - """Publish pod creation failed event""" + """Publish pod creation failed event.""" event = ExecutionFailedEvent( execution_id=command.execution_id, error_type=ExecutionErrorType.SYSTEM_ERROR, @@ -423,33 +251,28 @@ async def _publish_pod_creation_failed(self, command: CreatePodCommandEvent, err metadata=command.metadata, error_message=str(error), ) + await self._producer.produce(event_to_produce=event, key=command.execution_id) - if not self.producer: - self.logger.error("Producer not initialized") - return - await self.producer.produce(event_to_produce=event) - - async def get_status(self) -> dict[str, Any]: - """Get worker status""" + async def get_status(self) -> dict[str, object]: + """Get worker status.""" + active_count = await self._pod_state_repo.get_active_creations_count() return { - "running": self.is_running, - "active_creations": len(self._active_creations), + "active_creations": active_count, "config": { - "namespace": self.config.namespace, - "max_concurrent_pods": self.config.max_concurrent_pods, + "namespace": self._config.namespace, + "max_concurrent_pods": self._config.max_concurrent_pods, "enable_network_policies": True, }, } async def ensure_image_pre_puller_daemonset(self) -> None: - """Ensure the runtime image pre-puller DaemonSet exists""" - if not self.apps_v1: - self.logger.warning("Kubernetes AppsV1Api client not initialized. Skipping DaemonSet creation.") - return + """Ensure the runtime image pre-puller DaemonSet exists. + This should be called once at startup from the worker entrypoint, + not as a background task. + """ daemonset_name = "runtime-image-pre-puller" - namespace = self.config.namespace - await asyncio.sleep(5) + namespace = self._config.namespace try: init_containers = [] @@ -457,7 +280,7 @@ async def ensure_image_pre_puller_daemonset(self) -> None: for i, image_ref in enumerate(sorted(list(all_images))): sanitized_image_ref = image_ref.split("/")[-1].replace(":", "-").replace(".", "-").replace("_", "-") - self.logger.info(f"DAEMONSET: before: {image_ref} -> {sanitized_image_ref}") + self._logger.info(f"DAEMONSET: before: {image_ref} -> {sanitized_image_ref}") container_name = f"pull-{i}-{sanitized_image_ref}" init_containers.append( { @@ -468,7 +291,7 @@ async def ensure_image_pre_puller_daemonset(self) -> None: } ) - manifest: dict[str, Any] = { + manifest: dict[str, object] = { "apiVersion": "apps/v1", "kind": "DaemonSet", "metadata": {"name": daemonset_name, "namespace": namespace}, @@ -488,24 +311,31 @@ async def ensure_image_pre_puller_daemonset(self) -> None: try: await asyncio.to_thread( - self.apps_v1.read_namespaced_daemon_set, name=daemonset_name, namespace=namespace + self._apps_v1.read_namespaced_daemon_set, + name=daemonset_name, + namespace=namespace, ) - self.logger.info(f"DaemonSet '{daemonset_name}' exists. Replacing to ensure it is up-to-date.") + self._logger.info(f"DaemonSet '{daemonset_name}' exists. Replacing to ensure it is up-to-date.") await asyncio.to_thread( - self.apps_v1.replace_namespaced_daemon_set, name=daemonset_name, namespace=namespace, body=manifest + self._apps_v1.replace_namespaced_daemon_set, + name=daemonset_name, + namespace=namespace, + body=manifest, ) - self.logger.info(f"DaemonSet '{daemonset_name}' replaced successfully.") + self._logger.info(f"DaemonSet '{daemonset_name}' replaced successfully.") except ApiException as e: if e.status == 404: - self.logger.info(f"DaemonSet '{daemonset_name}' not found. Creating...") + self._logger.info(f"DaemonSet '{daemonset_name}' not found. Creating...") await asyncio.to_thread( - self.apps_v1.create_namespaced_daemon_set, namespace=namespace, body=manifest + self._apps_v1.create_namespaced_daemon_set, + namespace=namespace, + body=manifest, ) - self.logger.info(f"DaemonSet '{daemonset_name}' created successfully.") + self._logger.info(f"DaemonSet '{daemonset_name}' created successfully.") else: raise except ApiException as e: - self.logger.error(f"K8s API error applying DaemonSet '{daemonset_name}': {e.reason}", exc_info=True) + self._logger.error(f"K8s API error applying DaemonSet '{daemonset_name}': {e.reason}", exc_info=True) except Exception as e: - self.logger.error(f"Unexpected error applying image-puller DaemonSet: {e}", exc_info=True) + self._logger.error(f"Unexpected error applying image-puller DaemonSet: {e}", exc_info=True) diff --git a/backend/app/services/kafka_event_service.py b/backend/app/services/kafka_event_service.py index ac2207ca..9c152b97 100644 --- a/backend/app/services/kafka_event_service.py +++ b/backend/app/services/kafka_event_service.py @@ -1,7 +1,7 @@ import logging import time from datetime import datetime, timezone -from typing import Any +from typing import Any, Dict from uuid import uuid4 from opentelemetry import trace @@ -21,12 +21,12 @@ class KafkaEventService: def __init__( - self, - event_repository: EventRepository, - kafka_producer: UnifiedProducer, - settings: Settings, - logger: logging.Logger, - event_metrics: EventMetrics, + self, + event_repository: EventRepository, + kafka_producer: UnifiedProducer, + settings: Settings, + logger: logging.Logger, + event_metrics: EventMetrics, ): self.event_repository = event_repository self.kafka_producer = kafka_producer @@ -35,12 +35,12 @@ def __init__( self.settings = settings async def publish_event( - self, - event_type: EventType, - payload: dict[str, Any], - aggregate_id: str | None, - correlation_id: str | None = None, - metadata: EventMetadata | None = None, + self, + event_type: EventType, + payload: Dict[str, Any], + aggregate_id: str | None, + correlation_id: str | None = None, + metadata: EventMetadata | None = None, ) -> str: """ Publish an event to Kafka and store an audit copy via the repository @@ -90,7 +90,7 @@ async def publish_event( await self.event_repository.store_event(domain_event) # Prepare headers - headers: dict[str, str] = { + headers: Dict[str, str] = { "event_type": event_type, "correlation_id": event_metadata.correlation_id or "", "service": event_metadata.service_name, @@ -113,12 +113,12 @@ async def publish_event( return domain_event.event_id async def publish_execution_event( - self, - event_type: EventType, - execution_id: str, - status: str, - metadata: EventMetadata | None = None, - error_message: str | None = None, + self, + event_type: EventType, + execution_id: str, + status: str, + metadata: EventMetadata | None = None, + error_message: str | None = None, ) -> str: """Publish execution-related event using provided metadata (no framework coupling).""" self.logger.info( @@ -154,13 +154,13 @@ async def publish_execution_event( return event_id async def publish_pod_event( - self, - event_type: EventType, - pod_name: str, - execution_id: str, - namespace: str = "integr8scode", - status: str | None = None, - metadata: EventMetadata | None = None, + self, + event_type: EventType, + pod_name: str, + execution_id: str, + namespace: str = "integr8scode", + status: str | None = None, + metadata: EventMetadata | None = None, ) -> str: """Publish pod-related event""" payload = {"pod_name": pod_name, "execution_id": execution_id, "namespace": namespace} @@ -185,7 +185,7 @@ async def publish_domain_event(self, event: DomainEvent, key: str | None = None) start_time = time.time() await self.event_repository.store_event(event) - headers: dict[str, str] = { + headers: Dict[str, str] = { "event_type": event.event_type, "correlation_id": event.metadata.correlation_id or "", "service": event.metadata.service_name, @@ -201,7 +201,3 @@ async def publish_domain_event(self, event: DomainEvent, key: str | None = None) self.metrics.record_event_processing_duration(time.time() - start_time, event.event_type) self.logger.info("Domain event published", extra={"event_id": event.event_id}) return event.event_id - - async def close(self) -> None: - """Close event service resources""" - await self.kafka_producer.aclose() diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index 2d005fbc..1e37d987 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -1,17 +1,21 @@ +"""Notification Service - stateless event handler. + +Handles notification creation and delivery. Receives events, +processes them, and delivers notifications. No lifecycle management. +""" + +from __future__ import annotations + import asyncio import logging from dataclasses import dataclass, field from datetime import UTC, datetime, timedelta -from typing import Awaitable, Callable import httpx -from app.core.lifecycle import LifecycleEnabled -from app.core.metrics import EventMetrics, NotificationMetrics +from app.core.metrics import NotificationMetrics from app.core.tracing.utils import add_span_attributes from app.db.repositories.notification_repository import NotificationRepository -from app.domain.enums.events import EventType -from app.domain.enums.kafka import GroupId from app.domain.enums.notification import ( NotificationChannel, NotificationSeverity, @@ -20,13 +24,9 @@ from app.domain.enums.user import UserRole from app.domain.events.typed import ( DomainEvent, - EventMetadata, ExecutionCompletedEvent, ExecutionFailedEvent, ExecutionTimeoutEvent, - NotificationAllReadEvent, - NotificationCreatedEvent, - NotificationReadEvent, ) from app.domain.notification import ( DomainNotification, @@ -39,25 +39,13 @@ NotificationThrottledError, NotificationValidationError, ) -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer -from app.events.schema.schema_registry import SchemaRegistryManager -from app.infrastructure.kafka.mappings import get_topic_for_event from app.schemas_pydantic.sse import RedisNotificationMessage -from app.services.event_bus import EventBusManager -from app.services.kafka_event_service import KafkaEventService +from app.services.event_bus import EventBus from app.services.sse.redis_bus import SSERedisBus from app.settings import Settings -# Constants ENTITY_EXECUTION_TAG = "entity:execution" -# Type aliases -type EventPayload = dict[str, object] -type NotificationContext = dict[str, object] -type ChannelHandler = Callable[[DomainNotification, DomainNotificationSubscription], Awaitable[None]] -type SystemNotificationStats = dict[str, int] -type SlackMessage = dict[str, object] - @dataclass class ThrottleCache: @@ -67,11 +55,11 @@ class ThrottleCache: _lock: asyncio.Lock = field(default_factory=asyncio.Lock) async def check_throttle( - self, - user_id: str, - severity: NotificationSeverity, - window_hours: int, - max_per_hour: int, + self, + user_id: str, + severity: NotificationSeverity, + window_hours: int, + max_per_hour: int, ) -> bool: """Check if notification should be throttled.""" key = f"{user_id}:{severity}" @@ -82,14 +70,11 @@ async def check_throttle( if key not in self._entries: self._entries[key] = [] - # Clean old entries self._entries[key] = [ts for ts in self._entries[key] if ts > window_start] - # Check limit if len(self._entries[key]) >= max_per_hour: return True - # Add new entry self._entries[key].append(now) return False @@ -105,157 +90,137 @@ class SystemConfig: throttle_exempt: bool -class NotificationService(LifecycleEnabled): +class NotificationService: + """Stateless notification service - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + Worker entrypoint handles the consume loop. + """ + def __init__( - self, - notification_repository: NotificationRepository, - event_service: KafkaEventService, - event_bus_manager: EventBusManager, - schema_registry_manager: SchemaRegistryManager, - sse_bus: SSERedisBus, - settings: Settings, - logger: logging.Logger, - notification_metrics: NotificationMetrics, - event_metrics: EventMetrics, + self, + notification_repository: NotificationRepository, + event_bus: EventBus, + sse_bus: SSERedisBus, + settings: Settings, + logger: logging.Logger, + notification_metrics: NotificationMetrics, ) -> None: - super().__init__() - self.repository = notification_repository - self.event_service = event_service - self.event_bus_manager = event_bus_manager - self.metrics = notification_metrics - self._event_metrics = event_metrics - self.settings = settings - self.schema_registry_manager = schema_registry_manager - self.sse_bus = sse_bus - self.logger = logger - - # State + self._repository = notification_repository + self._event_bus = event_bus + self._sse_bus = sse_bus + self._settings = settings + self._logger = logger + self._metrics = notification_metrics self._throttle_cache = ThrottleCache() - # Tasks - self._tasks: set[asyncio.Task[None]] = set() - - self._consumer: UnifiedConsumer | None = None - self._dispatcher: EventDispatcher | None = None - self._consumer_task: asyncio.Task[None] | None = None - - self.logger.info( - "NotificationService initialized", - extra={ - "repository": type(notification_repository).__name__, - "event_service": type(event_service).__name__, - "schema_registry": type(schema_registry_manager).__name__, - }, - ) - - # Channel handlers mapping - self._channel_handlers: dict[NotificationChannel, ChannelHandler] = { + self._channel_handlers: dict[NotificationChannel, object] = { NotificationChannel.IN_APP: self._send_in_app, NotificationChannel.WEBHOOK: self._send_webhook, NotificationChannel.SLACK: self._send_slack, } - async def _on_start(self) -> None: - """Start the notification service with Kafka consumer.""" - self.logger.info("Starting notification service...") - self._start_background_tasks() - await self._subscribe_to_events() - self.logger.info("Notification service started with Kafka consumer") - - async def _on_stop(self) -> None: - """Stop the notification service.""" - self.logger.info("Stopping notification service...") - - # Cancel all tasks - for task in self._tasks: - task.cancel() - - # Wait for cancellation - if self._tasks: - await asyncio.gather(*self._tasks, return_exceptions=True) - - # Stop consumer - if self._consumer: - await self._consumer.stop() - - # Clear cache - await self._throttle_cache.clear() - - self.logger.info("Notification service stopped") - - def _start_background_tasks(self) -> None: - """Start background processing tasks.""" - tasks = [ - asyncio.create_task(self._process_pending_notifications()), - asyncio.create_task(self._cleanup_old_notifications()), - ] - - for task in tasks: - self._tasks.add(task) - task.add_done_callback(self._tasks.discard) - - async def _subscribe_to_events(self) -> None: - """Subscribe to relevant events for notifications.""" - # Configure consumer for notification-relevant events - consumer_config = ConsumerConfig( - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=GroupId.NOTIFICATION_SERVICE, - max_poll_records=10, - enable_auto_commit=True, - auto_offset_reset="latest", - session_timeout_ms=self.settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self.settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self.settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self.settings.KAFKA_REQUEST_TIMEOUT_MS, + self._logger.info("NotificationService initialized") + + async def handle_execution_event(self, event: DomainEvent) -> None: + """Handle execution result events. + + Called by worker entrypoint for each event. + """ + try: + if isinstance(event, ExecutionCompletedEvent): + await self._handle_execution_completed(event) + elif isinstance(event, ExecutionFailedEvent): + await self._handle_execution_failed(event) + elif isinstance(event, ExecutionTimeoutEvent): + await self._handle_execution_timeout(event) + else: + self._logger.warning(f"Unhandled execution event type: {event.event_type}") + except Exception as e: + self._logger.error(f"Error handling execution event: {e}", exc_info=True) + + async def _handle_execution_completed(self, event: ExecutionCompletedEvent) -> None: + """Handle execution completed event.""" + user_id = event.metadata.user_id + if not user_id: + self._logger.error("No user_id in event metadata") + return + + title = f"Execution Completed: {event.execution_id}" + duration = event.resource_usage.execution_time_wall_seconds if event.resource_usage else 0.0 + body = f"Your execution completed successfully. Duration: {duration:.2f}s." + await self.create_notification( + user_id=user_id, + subject=title, + body=body, + severity=NotificationSeverity.MEDIUM, + tags=["execution", "completed", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], + metadata=event.model_dump( + exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} + ), ) - execution_results_topic = get_topic_for_event(EventType.EXECUTION_COMPLETED) - - # Log topics for debugging - self.logger.info(f"Notification service will subscribe to topics: {execution_results_topic}") - - # Create dispatcher and register handlers for specific event types - self._dispatcher = EventDispatcher(logger=self.logger) - # Use a single handler for execution result events (simpler and less brittle) - self._dispatcher.register_handler(EventType.EXECUTION_COMPLETED, self._handle_execution_event) - self._dispatcher.register_handler(EventType.EXECUTION_FAILED, self._handle_execution_event) - self._dispatcher.register_handler(EventType.EXECUTION_TIMEOUT, self._handle_execution_event) - - # Create consumer with dispatcher - self._consumer = UnifiedConsumer( - consumer_config, - event_dispatcher=self._dispatcher, - schema_registry=self.schema_registry_manager, - settings=self.settings, - logger=self.logger, - event_metrics=self._event_metrics, + async def _handle_execution_failed(self, event: ExecutionFailedEvent) -> None: + """Handle execution failed event.""" + user_id = event.metadata.user_id + if not user_id: + self._logger.error("No user_id in event metadata") + return + + event_data = event.model_dump( + exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} ) + event_data["stdout"] = event_data["stdout"][:200] + event_data["stderr"] = event_data["stderr"][:200] - # Start consumer - await self._consumer.start([execution_results_topic]) + title = f"Execution Failed: {event.execution_id}" + body = f"Your execution failed: {event.error_message}" + await self.create_notification( + user_id=user_id, + subject=title, + body=body, + severity=NotificationSeverity.HIGH, + tags=["execution", "failed", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], + metadata=event_data, + ) - # Start consumer task - self._consumer_task = asyncio.create_task(self._run_consumer()) - self._tasks.add(self._consumer_task) - self._consumer_task.add_done_callback(self._tasks.discard) + async def _handle_execution_timeout(self, event: ExecutionTimeoutEvent) -> None: + """Handle execution timeout event.""" + user_id = event.metadata.user_id + if not user_id: + self._logger.error("No user_id in event metadata") + return - self.logger.info("Notification service subscribed to execution events") + title = f"Execution Timeout: {event.execution_id}" + body = f"Your execution timed out after {event.timeout_seconds}s." + await self.create_notification( + user_id=user_id, + subject=title, + body=body, + severity=NotificationSeverity.HIGH, + tags=["execution", "timeout", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], + metadata=event.model_dump( + exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} + ), + ) async def create_notification( - self, - user_id: str, - subject: str, - body: str, - tags: list[str], - severity: NotificationSeverity = NotificationSeverity.MEDIUM, - channel: NotificationChannel = NotificationChannel.IN_APP, - scheduled_for: datetime | None = None, - action_url: str | None = None, - metadata: NotificationContext | None = None, + self, + user_id: str, + subject: str, + body: str, + tags: list[str], + severity: NotificationSeverity = NotificationSeverity.MEDIUM, + channel: NotificationChannel = NotificationChannel.IN_APP, + scheduled_for: datetime | None = None, + action_url: str | None = None, + metadata: dict[str, object] | None = None, ) -> DomainNotification: + """Create a new notification.""" if not tags: raise NotificationValidationError("tags must be a non-empty list") - self.logger.info( + + self._logger.info( f"Creating notification for user {user_id}", extra={ "user_id": user_id, @@ -266,26 +231,24 @@ async def create_notification( }, ) - # Check throttling if await self._throttle_cache.check_throttle( - user_id, - severity, - window_hours=self.settings.NOTIF_THROTTLE_WINDOW_HOURS, - max_per_hour=self.settings.NOTIF_THROTTLE_MAX_PER_HOUR, + user_id, + severity, + window_hours=self._settings.NOTIF_THROTTLE_WINDOW_HOURS, + max_per_hour=self._settings.NOTIF_THROTTLE_MAX_PER_HOUR, ): error_msg = ( f"Notification rate limit exceeded for user {user_id}. " - f"Max {self.settings.NOTIF_THROTTLE_MAX_PER_HOUR} " - f"per {self.settings.NOTIF_THROTTLE_WINDOW_HOURS} hour(s)" + f"Max {self._settings.NOTIF_THROTTLE_MAX_PER_HOUR} " + f"per {self._settings.NOTIF_THROTTLE_WINDOW_HOURS} hour(s)" ) - self.logger.warning(error_msg) + self._logger.warning(error_msg) raise NotificationThrottledError( user_id, - self.settings.NOTIF_THROTTLE_MAX_PER_HOUR, - self.settings.NOTIF_THROTTLE_WINDOW_HOURS, + self._settings.NOTIF_THROTTLE_MAX_PER_HOUR, + self._settings.NOTIF_THROTTLE_WINDOW_HOURS, ) - # Create notification create_data = DomainNotificationCreate( user_id=user_id, channel=channel, @@ -298,26 +261,16 @@ async def create_notification( metadata=metadata or {}, ) - # Save to database - notification = await self.repository.create_notification(create_data) + notification = await self._repository.create_notification(create_data) - # Publish event - event_bus = await self.event_bus_manager.get_event_bus() - await event_bus.publish( - NotificationCreatedEvent( - notification_id=str(notification.notification_id), - user_id=user_id, - subject=subject, - body=body, - severity=severity, - tags=notification.tags, - channels=[channel], - metadata=EventMetadata( - service_name=self.settings.SERVICE_NAME, - service_version=self.settings.SERVICE_VERSION, - user_id=user_id, - ), - ) + await self._event_bus.publish( + "notifications.created", + { + "notification_id": str(notification.notification_id), + "user_id": user_id, + "severity": str(severity), + "tags": notification.tags, + }, ) await self._deliver_notification(notification) @@ -325,23 +278,21 @@ async def create_notification( return notification async def create_system_notification( - self, - title: str, - message: str, - severity: NotificationSeverity = NotificationSeverity.MEDIUM, - tags: list[str] | None = None, - metadata: dict[str, object] | None = None, - target_users: list[str] | None = None, - target_roles: list[UserRole] | None = None, - ) -> SystemNotificationStats: - """Create system notifications with streamlined control flow. - - Returns stats with totals and created/failed/throttled counts. - """ + self, + title: str, + message: str, + severity: NotificationSeverity = NotificationSeverity.MEDIUM, + tags: list[str] | None = None, + metadata: dict[str, object] | None = None, + target_users: list[str] | None = None, + target_roles: list[UserRole] | None = None, + ) -> dict[str, int]: + """Create system notifications with streamlined control flow.""" cfg = SystemConfig( - severity=severity, throttle_exempt=(severity in (NotificationSeverity.HIGH, NotificationSeverity.URGENT)) + severity=severity, + throttle_exempt=(severity in (NotificationSeverity.HIGH, NotificationSeverity.URGENT)), ) - base_context: NotificationContext = {"message": message, **(metadata or {})} + base_context: dict[str, object] = {"message": message, **(metadata or {})} users = await self._resolve_targets(target_users, target_roles) if not users: @@ -354,14 +305,16 @@ async def worker(uid: str) -> str: return await self._create_system_for_user(uid, cfg, title, base_context, tags or ["system"]) results = ( - [await worker(u) for u in users] if len(users) <= 20 else await asyncio.gather(*(worker(u) for u in users)) + [await worker(u) for u in users] + if len(users) <= 20 + else await asyncio.gather(*(worker(u) for u in users)) ) created = sum(1 for r in results if r == "created") throttled = sum(1 for r in results if r == "throttled") failed = sum(1 for r in results if r == "failed") - self.logger.info( + self._logger.info( "System notification completed", extra={ "severity": cfg.severity, @@ -376,31 +329,31 @@ async def worker(uid: str) -> str: return {"total_users": len(users), "created": created, "failed": failed, "throttled": throttled} async def _resolve_targets( - self, - target_users: list[str] | None, - target_roles: list[UserRole] | None, + self, + target_users: list[str] | None, + target_roles: list[UserRole] | None, ) -> list[str]: if target_users is not None: return target_users if target_roles: - return await self.repository.get_users_by_roles(target_roles) - return await self.repository.get_active_users(days=30) + return await self._repository.get_users_by_roles(target_roles) + return await self._repository.get_active_users(days=30) async def _create_system_for_user( - self, - user_id: str, - cfg: SystemConfig, - title: str, - base_context: NotificationContext, - tags: list[str], + self, + user_id: str, + cfg: SystemConfig, + title: str, + base_context: dict[str, object], + tags: list[str], ) -> str: try: if not cfg.throttle_exempt: throttled = await self._throttle_cache.check_throttle( user_id, cfg.severity, - window_hours=self.settings.NOTIF_THROTTLE_WINDOW_HOURS, - max_per_hour=self.settings.NOTIF_THROTTLE_MAX_PER_HOUR, + window_hours=self._settings.NOTIF_THROTTLE_WINDOW_HOURS, + max_per_hour=self._settings.NOTIF_THROTTLE_MAX_PER_HOUR, ) if throttled: return "throttled" @@ -416,27 +369,29 @@ async def _create_system_for_user( ) return "created" except Exception as e: - self.logger.error( - "Failed to create system notification for user", extra={"user_id": user_id, "error": str(e)} + self._logger.error( + "Failed to create system notification for user", + extra={"user_id": user_id, "error": str(e)}, ) return "failed" async def _send_in_app( - self, notification: DomainNotification, subscription: DomainNotificationSubscription + self, + notification: DomainNotification, + subscription: DomainNotificationSubscription, ) -> None: - """Send in-app notification via SSE bus (fan-out to connected clients).""" + """Send in-app notification via SSE bus.""" await self._publish_notification_sse(notification) async def _send_webhook( - self, notification: DomainNotification, subscription: DomainNotificationSubscription + self, + notification: DomainNotification, + subscription: DomainNotificationSubscription, ) -> None: """Send webhook notification.""" webhook_url = notification.webhook_url or subscription.webhook_url if not webhook_url: - raise ValueError( - f"No webhook URL configured for user {notification.user_id} on channel {notification.channel}. " - f"Configure in notification settings." - ) + raise ValueError(f"No webhook URL configured for user {notification.user_id}") payload = { "notification_id": str(notification.notification_id), @@ -453,15 +408,6 @@ async def _send_webhook( headers = notification.webhook_headers or {} headers["Content-Type"] = "application/json" - self.logger.debug( - f"Sending webhook notification to {webhook_url}", - extra={ - "notification_id": str(notification.notification_id), - "payload_size": len(str(payload)), - "webhook_url": webhook_url, - }, - ) - add_span_attributes( **{ "notification.id": str(notification.notification_id), @@ -472,25 +418,17 @@ async def _send_webhook( async with httpx.AsyncClient() as client: response = await client.post(webhook_url, json=payload, headers=headers, timeout=30.0) response.raise_for_status() - self.logger.debug( - "Webhook delivered successfully", - extra={ - "notification_id": str(notification.notification_id), - "status_code": response.status_code, - "response_time_ms": int(response.elapsed.total_seconds() * 1000), - }, - ) - async def _send_slack(self, notification: DomainNotification, subscription: DomainNotificationSubscription) -> None: + async def _send_slack( + self, + notification: DomainNotification, + subscription: DomainNotificationSubscription, + ) -> None: """Send Slack notification.""" if not subscription.slack_webhook: - raise ValueError( - f"No Slack webhook URL configured for user {notification.user_id}. " - f"Please configure Slack integration in notification settings." - ) + raise ValueError(f"No Slack webhook URL configured for user {notification.user_id}") - # Format message for Slack - slack_message: SlackMessage = { + slack_message: dict[str, object] = { "text": notification.subject, "attachments": [ { @@ -502,20 +440,12 @@ async def _send_slack(self, notification: DomainNotification, subscription: Doma ], } - # Add action button if URL provided if notification.action_url: attachments = slack_message.get("attachments", []) if attachments and isinstance(attachments, list): - attachments[0]["actions"] = [{"type": "button", "text": "View Details", "url": notification.action_url}] - - self.logger.debug( - "Sending Slack notification", - extra={ - "notification_id": str(notification.notification_id), - "has_action": notification.action_url is not None, - "priority_color": self._get_slack_color(notification.severity), - }, - ) + attachments[0]["actions"] = [ + {"type": "button", "text": "View Details", "url": notification.action_url} + ] add_span_attributes( **{ @@ -526,172 +456,170 @@ async def _send_slack(self, notification: DomainNotification, subscription: Doma async with httpx.AsyncClient() as client: response = await client.post(subscription.slack_webhook, json=slack_message, timeout=30.0) response.raise_for_status() - self.logger.debug( - "Slack notification delivered successfully", - extra={"notification_id": str(notification.notification_id), "status_code": response.status_code}, - ) def _get_slack_color(self, priority: NotificationSeverity) -> str: """Get Slack color based on severity.""" return { - NotificationSeverity.LOW: "#36a64f", # Green - NotificationSeverity.MEDIUM: "#ff9900", # Orange - NotificationSeverity.HIGH: "#ff0000", # Red - NotificationSeverity.URGENT: "#990000", # Dark Red - }.get(priority, "#808080") # Default gray - - async def _process_pending_notifications(self) -> None: - """Process pending notifications in background.""" - while self.is_running: - try: - # Find pending notifications - notifications = await self.repository.find_pending_notifications( - batch_size=self.settings.NOTIF_PENDING_BATCH_SIZE - ) + NotificationSeverity.LOW: "#36a64f", + NotificationSeverity.MEDIUM: "#ff9900", + NotificationSeverity.HIGH: "#ff0000", + NotificationSeverity.URGENT: "#990000", + }.get(priority, "#808080") - # Process each notification - for notification in notifications: - if not self.is_running: - break - await self._deliver_notification(notification) - - # Sleep between batches - await asyncio.sleep(5) - - except Exception as e: - self.logger.error(f"Error processing pending notifications: {e}") - await asyncio.sleep(10) - - async def _cleanup_old_notifications(self) -> None: - """Cleanup old notifications periodically.""" - while self.is_running: - try: - # Run cleanup once per day - await asyncio.sleep(86400) # 24 hours - - if not self.is_running: - break - - # Delete old notifications - deleted_count = await self.repository.cleanup_old_notifications(self.settings.NOTIF_OLD_DAYS) - - self.logger.info(f"Cleaned up {deleted_count} old notifications") - - except Exception as e: - self.logger.error(f"Error cleaning up old notifications: {e}") - - async def _run_consumer(self) -> None: - """Run the event consumer loop.""" - while self.is_running: - try: - # Consumer handles polling internally - await asyncio.sleep(1) - except asyncio.CancelledError: - self.logger.info("Notification consumer task cancelled") - break - except Exception as e: - self.logger.error(f"Error in notification consumer loop: {e}") - await asyncio.sleep(5) - - async def _handle_execution_timeout_typed(self, event: ExecutionTimeoutEvent) -> None: - """Handle typed execution timeout event.""" - user_id = event.metadata.user_id - if not user_id: - self.logger.error("No user_id in event metadata") - return + async def process_pending_notifications(self, batch_size: int = 10) -> int: + """Process pending notifications. - title = f"Execution Timeout: {event.execution_id}" - body = f"Your execution timed out after {event.timeout_seconds}s." - await self.create_notification( - user_id=user_id, - subject=title, - body=body, - severity=NotificationSeverity.HIGH, - tags=["execution", "timeout", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], - metadata=event.model_dump( - exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} - ), - ) + Should be called periodically from worker entrypoint. + Returns number of notifications processed. + """ + notifications = await self._repository.find_pending_notifications(batch_size=batch_size) + count = 0 - async def _handle_execution_completed_typed(self, event: ExecutionCompletedEvent) -> None: - """Handle typed execution completed event.""" - user_id = event.metadata.user_id - if not user_id: - self.logger.error("No user_id in event metadata") + for notification in notifications: + await self._deliver_notification(notification) + count += 1 + + return count + + async def cleanup_old_notifications(self, days: int = 30) -> int: + """Cleanup old notifications. + + Should be called periodically from worker entrypoint. + Returns number of notifications deleted. + """ + return await self._repository.cleanup_old_notifications(days) + + async def _should_skip_notification( + self, + notification: DomainNotification, + subscription: DomainNotificationSubscription, + ) -> str | None: + """Check if notification should be skipped based on subscription filters.""" + if not subscription.enabled: + return f"User {notification.user_id} has {notification.channel} disabled" + + if subscription.severities and notification.severity not in subscription.severities: + return f"Notification severity '{notification.severity}' filtered by user preferences" + + if subscription.include_tags and not any( + tag in subscription.include_tags for tag in (notification.tags or []) + ): + return f"Notification tags {notification.tags} not in include list" + + if subscription.exclude_tags and any( + tag in subscription.exclude_tags for tag in (notification.tags or []) + ): + return f"Notification tags {notification.tags} excluded by preferences" + + return None + + async def _deliver_notification(self, notification: DomainNotification) -> None: + """Deliver notification through configured channel.""" + claimed = await self._repository.try_claim_pending(notification.notification_id) + if not claimed: return - title = f"Execution Completed: {event.execution_id}" - duration = event.resource_usage.execution_time_wall_seconds if event.resource_usage else 0.0 - body = f"Your execution completed successfully. Duration: {duration:.2f}s." - await self.create_notification( - user_id=user_id, - subject=title, - body=body, - severity=NotificationSeverity.MEDIUM, - tags=["execution", "completed", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], - metadata=event.model_dump( - exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} - ), + self._logger.info( + f"Delivering notification {notification.notification_id}", + extra={ + "notification_id": str(notification.notification_id), + "user_id": notification.user_id, + "channel": notification.channel, + "severity": notification.severity, + "tags": list(notification.tags or []), + }, ) - async def _handle_execution_event(self, event: DomainEvent) -> None: - """Unified handler for execution result events.""" - try: - if isinstance(event, ExecutionCompletedEvent): - await self._handle_execution_completed_typed(event) - elif isinstance(event, ExecutionFailedEvent): - await self._handle_execution_failed_typed(event) - elif isinstance(event, ExecutionTimeoutEvent): - await self._handle_execution_timeout_typed(event) - else: - self.logger.warning(f"Unhandled execution event type: {event.event_type}") - except Exception as e: - self.logger.error(f"Error handling execution event: {e}", exc_info=True) + subscription = await self._repository.get_subscription(notification.user_id, notification.channel) - async def _handle_execution_failed_typed(self, event: ExecutionFailedEvent) -> None: - """Handle typed execution failed event.""" - user_id = event.metadata.user_id - if not user_id: - self.logger.error("No user_id in event metadata") + skip_reason = await self._should_skip_notification(notification, subscription) + if skip_reason: + self._logger.info(skip_reason) + await self._repository.update_notification( + notification.notification_id, + notification.user_id, + DomainNotificationUpdate(status=NotificationStatus.SKIPPED, error_message=skip_reason), + ) return - # Use model_dump to get all event data - event_data = event.model_dump( - exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} - ) + start_time = asyncio.get_running_loop().time() + try: + handler = self._channel_handlers.get(notification.channel) + if handler is None: + raise ValueError(f"No handler configured for channel: {notification.channel}") - # Truncate stdout/stderr for notification context - event_data["stdout"] = event_data["stdout"][:200] - event_data["stderr"] = event_data["stderr"][:200] + await handler(notification, subscription) # type: ignore + delivery_time = asyncio.get_running_loop().time() - start_time - title = f"Execution Failed: {event.execution_id}" - body = f"Your execution failed: {event.error_message}" - await self.create_notification( - user_id=user_id, - subject=title, - body=body, - severity=NotificationSeverity.HIGH, - tags=["execution", "failed", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], - metadata=event_data, + await self._repository.update_notification( + notification.notification_id, + notification.user_id, + DomainNotificationUpdate(status=NotificationStatus.DELIVERED, delivered_at=datetime.now(UTC)), + ) + + self._logger.info( + f"Successfully delivered notification {notification.notification_id}", + extra={ + "notification_id": str(notification.notification_id), + "channel": notification.channel, + "delivery_time_ms": int(delivery_time * 1000), + }, + ) + + self._metrics.record_notification_sent( + notification.severity, channel=notification.channel, severity=notification.severity + ) + self._metrics.record_notification_delivery_time(delivery_time, notification.severity) + + except Exception as e: + self._logger.error( + f"Failed to deliver notification {notification.notification_id}: {str(e)}", + exc_info=True, + ) + + new_retry_count = notification.retry_count + 1 + error_message = f"Delivery failed via {notification.channel}: {str(e)}" + failed_at = datetime.now(UTC) + + notif_status = NotificationStatus.PENDING \ + if new_retry_count < notification.max_retries else NotificationStatus.FAILED + await self._repository.update_notification( + notification.notification_id, + notification.user_id, + DomainNotificationUpdate( + status=notif_status, + failed_at=failed_at, + error_message=error_message, + retry_count=new_retry_count, + ), + ) + + async def _publish_notification_sse(self, notification: DomainNotification) -> None: + """Publish an in-app notification to the SSE bus.""" + message = RedisNotificationMessage( + notification_id=notification.notification_id, + severity=notification.severity, + status=notification.status, + tags=list(notification.tags or []), + subject=notification.subject, + body=notification.body, + action_url=notification.action_url or "", + created_at=notification.created_at, ) + await self._sse_bus.publish_notification(notification.user_id, message) async def mark_as_read(self, user_id: str, notification_id: str) -> bool: """Mark notification as read.""" - success = await self.repository.mark_as_read(notification_id, user_id) + success = await self._repository.mark_as_read(notification_id, user_id) - event_bus = await self.event_bus_manager.get_event_bus() if success: - await event_bus.publish( - NotificationReadEvent( - notification_id=str(notification_id), - user_id=user_id, - read_at=datetime.now(UTC), - metadata=EventMetadata( - service_name=self.settings.SERVICE_NAME, - service_version=self.settings.SERVICE_VERSION, - user_id=user_id, - ), - ) + await self._event_bus.publish( + "notifications.read", + { + "notification_id": str(notification_id), + "user_id": user_id, + "read_at": datetime.now(UTC).isoformat(), + }, ) else: raise NotificationNotFoundError(notification_id) @@ -700,21 +628,20 @@ async def mark_as_read(self, user_id: str, notification_id: str) -> bool: async def get_unread_count(self, user_id: str) -> int: """Get count of unread notifications.""" - return await self.repository.get_unread_count(user_id) + return await self._repository.get_unread_count(user_id) async def list_notifications( - self, - user_id: str, - status: NotificationStatus | None = None, - limit: int = 20, - offset: int = 0, - include_tags: list[str] | None = None, - exclude_tags: list[str] | None = None, - tag_prefix: str | None = None, + self, + user_id: str, + status: NotificationStatus | None = None, + limit: int = 20, + offset: int = 0, + include_tags: list[str] | None = None, + exclude_tags: list[str] | None = None, + tag_prefix: str | None = None, ) -> DomainNotificationListResult: """List notifications with pagination.""" - # Get notifications - notifications = await self.repository.list_notifications( + notifications = await self._repository.list_notifications( user_id=user_id, status=status, skip=offset, @@ -724,9 +651,8 @@ async def list_notifications( tag_prefix=tag_prefix, ) - # Get counts total, unread_count = await asyncio.gather( - self.repository.count_notifications( + self._repository.count_notifications( user_id=user_id, status=status, include_tags=include_tags, @@ -736,21 +662,24 @@ async def list_notifications( self.get_unread_count(user_id), ) - return DomainNotificationListResult(notifications=notifications, total=total, unread_count=unread_count) + return DomainNotificationListResult( + notifications=notifications, + total=total, + unread_count=unread_count, + ) async def update_subscription( - self, - user_id: str, - channel: NotificationChannel, - enabled: bool, - webhook_url: str | None = None, - slack_webhook: str | None = None, - severities: list[NotificationSeverity] | None = None, - include_tags: list[str] | None = None, - exclude_tags: list[str] | None = None, + self, + user_id: str, + channel: NotificationChannel, + enabled: bool, + webhook_url: str | None = None, + slack_webhook: str | None = None, + severities: list[NotificationSeverity] | None = None, + include_tags: list[str] | None = None, + exclude_tags: list[str] | None = None, ) -> DomainNotificationSubscription: """Update notification subscription preferences.""" - # Validate channel-specific requirements if channel == NotificationChannel.WEBHOOK and enabled: if not webhook_url: raise NotificationValidationError("webhook_url is required when enabling WEBHOOK") @@ -762,7 +691,6 @@ async def update_subscription( if not slack_webhook.startswith("https://hooks.slack.com/"): raise NotificationValidationError("slack_webhook must be a valid Slack webhook URL") - # Build update data update_data = DomainSubscriptionUpdate( enabled=enabled, webhook_url=webhook_url, @@ -772,193 +700,27 @@ async def update_subscription( exclude_tags=exclude_tags, ) - return await self.repository.upsert_subscription(user_id, channel, update_data) + return await self._repository.upsert_subscription(user_id, channel, update_data) async def mark_all_as_read(self, user_id: str) -> int: """Mark all notifications as read for a user.""" - count = await self.repository.mark_all_as_read(user_id) + count = await self._repository.mark_all_as_read(user_id) - event_bus = await self.event_bus_manager.get_event_bus() if count > 0: - await event_bus.publish( - NotificationAllReadEvent( - user_id=user_id, - count=count, - read_at=datetime.now(UTC), - metadata=EventMetadata( - service_name=self.settings.SERVICE_NAME, - service_version=self.settings.SERVICE_VERSION, - user_id=user_id, - ), - ) + await self._event_bus.publish( + "notifications.all_read", + {"user_id": user_id, "count": count, "read_at": datetime.now(UTC).isoformat()}, ) return count async def get_subscriptions(self, user_id: str) -> dict[NotificationChannel, DomainNotificationSubscription]: """Get all notification subscriptions for a user.""" - return await self.repository.get_all_subscriptions(user_id) + return await self._repository.get_all_subscriptions(user_id) async def delete_notification(self, user_id: str, notification_id: str) -> bool: """Delete a notification.""" - deleted = await self.repository.delete_notification(str(notification_id), user_id) + deleted = await self._repository.delete_notification(str(notification_id), user_id) if not deleted: raise NotificationNotFoundError(notification_id) return deleted - - async def _publish_notification_sse(self, notification: DomainNotification) -> None: - """Publish an in-app notification to the SSE bus for realtime delivery.""" - message = RedisNotificationMessage( - notification_id=notification.notification_id, - severity=notification.severity, - status=notification.status, - tags=list(notification.tags or []), - subject=notification.subject, - body=notification.body, - action_url=notification.action_url or "", - created_at=notification.created_at, - ) - await self.sse_bus.publish_notification(notification.user_id, message) - - async def _should_skip_notification( - self, notification: DomainNotification, subscription: DomainNotificationSubscription - ) -> str | None: - """Check if notification should be skipped based on subscription filters. - - Returns skip reason if should skip, None otherwise. - """ - if not subscription.enabled: - return f"User {notification.user_id} has {notification.channel} disabled; skipping delivery." - - if subscription.severities and notification.severity not in subscription.severities: - return ( - f"Notification severity '{notification.severity}' filtered by user preferences " - f"for {notification.channel}" - ) - - if subscription.include_tags and not any(tag in subscription.include_tags for tag in (notification.tags or [])): - return f"Notification tags {notification.tags} not in include list for {notification.channel}" - - if subscription.exclude_tags and any(tag in subscription.exclude_tags for tag in (notification.tags or [])): - return f"Notification tags {notification.tags} excluded by preferences for {notification.channel}" - - return None - - async def _deliver_notification(self, notification: DomainNotification) -> None: - """Deliver notification through configured channel using safe state transitions.""" - # Attempt to claim this notification for sending - claimed = await self.repository.try_claim_pending(notification.notification_id) - if not claimed: - return - - self.logger.info( - f"Delivering notification {notification.notification_id}", - extra={ - "notification_id": str(notification.notification_id), - "user_id": notification.user_id, - "channel": notification.channel, - "severity": notification.severity, - "tags": list(notification.tags or []), - }, - ) - - # Check user subscription for the channel - subscription = await self.repository.get_subscription(notification.user_id, notification.channel) - - # Check if notification should be skipped - skip_reason = await self._should_skip_notification(notification, subscription) - if skip_reason: - self.logger.info(skip_reason) - await self.repository.update_notification( - notification.notification_id, - notification.user_id, - DomainNotificationUpdate(status=NotificationStatus.SKIPPED, error_message=skip_reason), - ) - return - - # Send through channel - start_time = asyncio.get_running_loop().time() - try: - handler = self._channel_handlers.get(notification.channel) - if handler is None: - raise ValueError( - f"No handler configured for notification channel: {notification.channel}. " - f"Available channels: {list(self._channel_handlers.keys())}" - ) - - self.logger.debug(f"Using handler {handler.__name__} for channel {notification.channel}") - await handler(notification, subscription) - delivery_time = asyncio.get_running_loop().time() - start_time - - # Mark delivered - await self.repository.update_notification( - notification.notification_id, - notification.user_id, - DomainNotificationUpdate(status=NotificationStatus.DELIVERED, delivered_at=datetime.now(UTC)), - ) - - self.logger.info( - f"Successfully delivered notification {notification.notification_id}", - extra={ - "notification_id": str(notification.notification_id), - "channel": notification.channel, - "delivery_time_ms": int(delivery_time * 1000), - }, - ) - - # Metrics (use tag string or severity) - self.metrics.record_notification_sent( - notification.severity, channel=notification.channel, severity=notification.severity - ) - self.metrics.record_notification_delivery_time(delivery_time, notification.severity) - - except Exception as e: - error_details = { - "notification_id": str(notification.notification_id), - "channel": notification.channel, - "error_type": type(e).__name__, - "error_message": str(e), - "retry_count": notification.retry_count, - "max_retries": notification.max_retries, - } - - self.logger.error( - f"Failed to deliver notification {notification.notification_id}: {str(e)}", - extra=error_details, - exc_info=True, - ) - - new_retry_count = notification.retry_count + 1 - error_message = f"Delivery failed via {notification.channel}: {str(e)}" - failed_at = datetime.now(UTC) - - # Schedule retry if under limit - if new_retry_count < notification.max_retries: - retry_time = datetime.now(UTC) + timedelta(minutes=self.settings.NOTIF_RETRY_DELAY_MINUTES) - self.logger.info( - f"Scheduled retry {new_retry_count}/{notification.max_retries} for {notification.notification_id}", - extra={"retry_at": retry_time.isoformat()}, - ) - # Will be retried - keep as PENDING but with scheduled_for - # Note: scheduled_for not in DomainNotificationUpdate, so we update status fields only - await self.repository.update_notification( - notification.notification_id, - notification.user_id, - DomainNotificationUpdate( - status=NotificationStatus.PENDING, - failed_at=failed_at, - error_message=error_message, - retry_count=new_retry_count, - ), - ) - else: - await self.repository.update_notification( - notification.notification_id, - notification.user_id, - DomainNotificationUpdate( - status=NotificationStatus.FAILED, - failed_at=failed_at, - error_message=error_message, - retry_count=new_retry_count, - ), - ) diff --git a/backend/app/services/pod_monitor/monitor.py b/backend/app/services/pod_monitor/monitor.py index ecbb4556..046cbee9 100644 --- a/backend/app/services/pod_monitor/monitor.py +++ b/backend/app/services/pod_monitor/monitor.py @@ -1,35 +1,28 @@ +"""Pod Monitor - stateless event handler. + +Monitors Kubernetes pods and publishes lifecycle events. Receives events, +processes them, and publishes results. No lifecycle management. +All state is stored in Redis via PodStateRepository. +""" + +from __future__ import annotations + import asyncio import logging import time -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager from dataclasses import dataclass -from enum import auto from typing import Any from kubernetes import client as k8s_client -from kubernetes.client.rest import ApiException -from app.core.k8s_clients import K8sClients, close_k8s_clients, create_k8s_clients -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import KubernetesMetrics from app.core.utils import StringEnum +from app.db.repositories.pod_state_repository import PodStateRepository from app.domain.events.typed import DomainEvent -from app.services.kafka_event_service import KafkaEventService +from app.events.core import UnifiedProducer from app.services.pod_monitor.config import PodMonitorConfig from app.services.pod_monitor.event_mapper import PodEventMapper -# Type aliases -type PodName = str -type ResourceVersion = str -type EventType = str -type KubeEvent = dict[str, Any] -type StatusDict = dict[str, Any] - -# Constants -MAX_BACKOFF_SECONDS: int = 300 # 5 minutes -RECONCILIATION_LOG_INTERVAL: int = 60 # 1 minute - class WatchEventType(StringEnum): """Kubernetes watch event types.""" @@ -39,33 +32,13 @@ class WatchEventType(StringEnum): DELETED = "DELETED" -class MonitorState(StringEnum): - """Pod monitor states.""" - - IDLE = auto() - RUNNING = auto() - STOPPING = auto() - STOPPED = auto() - - class ErrorType(StringEnum): """Error types for metrics.""" - RESOURCE_VERSION_EXPIRED = auto() - API_ERROR = auto() - UNEXPECTED = auto() - PROCESSING_ERROR = auto() - - -@dataclass(frozen=True, slots=True) -class WatchContext: - """Immutable context for watch operations.""" - - namespace: str - label_selector: str - field_selector: str | None - timeout_seconds: int - resource_version: ResourceVersion | None + RESOURCE_VERSION_EXPIRED = "resource_version_expired" + API_ERROR = "api_error" + UNEXPECTED = "unexpected" + PROCESSING_ERROR = "processing_error" @dataclass(frozen=True, slots=True) @@ -74,206 +47,70 @@ class PodEvent: event_type: WatchEventType pod: k8s_client.V1Pod - resource_version: ResourceVersion | None + resource_version: str | None @dataclass(frozen=True, slots=True) class ReconciliationResult: """Result of state reconciliation.""" - missing_pods: set[PodName] - extra_pods: set[PodName] + missing_pods: set[str] + extra_pods: set[str] duration_seconds: float success: bool error: str | None = None -class PodMonitor(LifecycleEnabled): - """ - Monitors Kubernetes pods and publishes lifecycle events. +class PodMonitor: + """Stateless pod monitor - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + All state (tracked pods, resource version) stored in Redis via PodStateRepository. - This service watches pods with specific labels using the K8s watch API, - maps Kubernetes events to application events, and publishes them to Kafka. - Events are stored in the events collection AND published to Kafka via KafkaEventService. + Worker entrypoint handles the watch loop: + watch = Watch() + for event in watch.stream(...): + await monitor.handle_raw_event(event) """ def __init__( self, config: PodMonitorConfig, - kafka_event_service: KafkaEventService, - logger: logging.Logger, - k8s_clients: K8sClients, + producer: UnifiedProducer, + pod_state_repo: PodStateRepository, + v1_client: k8s_client.CoreV1Api, event_mapper: PodEventMapper, + logger: logging.Logger, kubernetes_metrics: KubernetesMetrics, ) -> None: - """Initialize the pod monitor with all required dependencies. - - All dependencies must be provided - use create_pod_monitor() factory - for automatic dependency creation in production. - """ - super().__init__() - self.logger = logger - self.config = config - - # Kubernetes clients (required, no nullability) - self._clients = k8s_clients - self._v1 = k8s_clients.v1 - self._watch = k8s_clients.watch - - # Components (required, no nullability) + self._config = config + self._producer = producer + self._pod_state_repo = pod_state_repo + self._v1 = v1_client self._event_mapper = event_mapper - self._kafka_event_service = kafka_event_service - - # State - self._state = MonitorState.IDLE - self._tracked_pods: set[PodName] = set() - self._reconnect_attempts: int = 0 - self._last_resource_version: ResourceVersion | None = None - - # Tasks - self._watch_task: asyncio.Task[None] | None = None - self._reconcile_task: asyncio.Task[None] | None = None - - # Metrics + self._logger = logger self._metrics = kubernetes_metrics - @property - def state(self) -> MonitorState: - """Get current monitor state.""" - return self._state - - async def _on_start(self) -> None: - """Start the pod monitor.""" - self.logger.info("Starting PodMonitor service...") - - # Verify K8s connectivity (all clients already injected via __init__) - await asyncio.to_thread(self._v1.get_api_resources) - self.logger.info("Successfully connected to Kubernetes API") - - # Start monitoring - self._state = MonitorState.RUNNING - self._watch_task = asyncio.create_task(self._watch_pods()) - - # Start reconciliation if enabled - if self.config.enable_state_reconciliation: - self._reconcile_task = asyncio.create_task(self._reconciliation_loop()) - - self.logger.info("PodMonitor service started successfully") - - async def _on_stop(self) -> None: - """Stop the pod monitor.""" - self.logger.info("Stopping PodMonitor service...") - self._state = MonitorState.STOPPING - - # Cancel tasks - tasks = [t for t in [self._watch_task, self._reconcile_task] if t] - for task in tasks: - task.cancel() - - # Wait for cancellation - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - # Close watch - if self._watch: - self._watch.stop() - - # Clear state - self._tracked_pods.clear() - self._event_mapper.clear_cache() - - self._state = MonitorState.STOPPED - self.logger.info("PodMonitor service stopped") - - async def _watch_pods(self) -> None: - """Main watch loop for pods.""" - while self._state == MonitorState.RUNNING: - try: - self._reconnect_attempts = 0 - await self._watch_pod_events() - - except ApiException as e: - match e.status: - case 410: # Gone - resource version too old - self.logger.warning("Resource version expired, resetting watch") - self._last_resource_version = None - self._metrics.record_pod_monitor_watch_error(ErrorType.RESOURCE_VERSION_EXPIRED) - case _: - self.logger.error(f"API error in watch: {e}") - self._metrics.record_pod_monitor_watch_error(ErrorType.API_ERROR) - - await self._handle_watch_error() - - except Exception as e: - self.logger.error(f"Unexpected error in watch: {e}", exc_info=True) - self._metrics.record_pod_monitor_watch_error(ErrorType.UNEXPECTED) - await self._handle_watch_error() - - async def _watch_pod_events(self) -> None: - """Watch for pod events.""" - # self._v1 and self._watch are guaranteed initialized by start() - - context = WatchContext( - namespace=self.config.namespace, - label_selector=self.config.label_selector, - field_selector=self.config.field_selector, - timeout_seconds=self.config.watch_timeout_seconds, - resource_version=self._last_resource_version, - ) - - self.logger.info(f"Starting pod watch with selector: {context.label_selector}, namespace: {context.namespace}") - - # Create watch stream - kwargs = { - "namespace": context.namespace, - "label_selector": context.label_selector, - "timeout_seconds": context.timeout_seconds, - } - - if context.field_selector: - kwargs["field_selector"] = context.field_selector - - if context.resource_version: - kwargs["resource_version"] = context.resource_version - - # Watch stream (clients guaranteed by __init__) - stream = self._watch.stream(self._v1.list_namespaced_pod, **kwargs) + async def handle_raw_event(self, raw_event: dict[str, Any]) -> None: + """Process a raw Kubernetes watch event. + Called by worker entrypoint for each event from watch stream. + """ try: - for event in stream: - if self._state != MonitorState.RUNNING: - break - - await self._process_raw_event(event) - - finally: - # Store resource version for next watch - self._update_resource_version(stream) - - def _update_resource_version(self, stream: Any) -> None: - """Update last resource version from stream.""" - try: - if stream._stop_event and stream._stop_event.resource_version: - self._last_resource_version = stream._stop_event.resource_version - except AttributeError: - pass - - async def _process_raw_event(self, raw_event: KubeEvent) -> None: - """Process a raw Kubernetes watch event.""" - try: - # Parse event event = PodEvent( event_type=WatchEventType(raw_event["type"].upper()), pod=raw_event["object"], resource_version=( - raw_event["object"].metadata.resource_version if raw_event["object"].metadata else None + raw_event["object"].metadata.resource_version + if raw_event["object"].metadata + else None ), ) await self._process_pod_event(event) except (KeyError, ValueError) as e: - self.logger.error(f"Invalid event format: {e}") + self._logger.error(f"Invalid event format: {e}") self._metrics.record_pod_monitor_watch_error(ErrorType.PROCESSING_ERROR) async def _process_pod_event(self, event: PodEvent) -> None: @@ -281,25 +118,38 @@ async def _process_pod_event(self, event: PodEvent) -> None: start_time = time.time() try: - # Update resource version + # Update resource version in Redis if event.resource_version: - self._last_resource_version = event.resource_version + await self._pod_state_repo.set_resource_version(event.resource_version) # Skip ignored phases pod_phase = event.pod.status.phase if event.pod.status else None - if pod_phase in self.config.ignored_pod_phases: + if pod_phase in self._config.ignored_pod_phases: return - # Update tracked pods + # Get pod info pod_name = event.pod.metadata.name + execution_id = ( + event.pod.metadata.labels.get("execution-id") + if event.pod.metadata and event.pod.metadata.labels + else None + ) + + # Update tracked pods in Redis match event.event_type: case WatchEventType.ADDED | WatchEventType.MODIFIED: - self._tracked_pods.add(pod_name) + if execution_id: + await self._pod_state_repo.track_pod( + pod_name=pod_name, + execution_id=execution_id, + status=pod_phase or "Unknown", + ) case WatchEventType.DELETED: - self._tracked_pods.discard(pod_name) + await self._pod_state_repo.untrack_pod(pod_name) # Update metrics - self._metrics.update_pod_monitor_pods_watched(len(self._tracked_pods)) + tracked_count = await self._pod_state_repo.get_tracked_pods_count() + self._metrics.update_pod_monitor_pods_watched(tracked_count) # Map to application events app_events = self._event_mapper.map_pod_event(event.pod, event.event_type) @@ -310,7 +160,7 @@ async def _process_pod_event(self, event: PodEvent) -> None: # Log event if app_events: - self.logger.info( + self._logger.info( f"Processed {event.event_type} event for pod {pod_name} " f"(phase: {pod_phase or 'Unknown'}), " f"published {len(app_events)} events" @@ -321,11 +171,11 @@ async def _process_pod_event(self, event: PodEvent) -> None: self._metrics.record_pod_monitor_event_processing_duration(duration, event.event_type) except Exception as e: - self.logger.error(f"Error processing pod event: {e}", exc_info=True) + self._logger.error(f"Error processing pod event: {e}", exc_info=True) self._metrics.record_pod_monitor_watch_error(ErrorType.PROCESSING_ERROR) async def _publish_event(self, event: DomainEvent, pod: k8s_client.V1Pod) -> None: - """Publish event to Kafka and store in events collection.""" + """Publish event to Kafka.""" try: # Add correlation ID from pod labels if pod.metadata and pod.metadata.labels: @@ -334,94 +184,74 @@ async def _publish_event(self, event: DomainEvent, pod: k8s_client.V1Pod) -> Non execution_id = getattr(event, "execution_id", None) or event.aggregate_id key = str(execution_id or (pod.metadata.name if pod.metadata else "unknown")) - await self._kafka_event_service.publish_domain_event(event=event, key=key) + await self._producer.produce(event_to_produce=event, key=key) phase = pod.status.phase if pod.status else "Unknown" self._metrics.record_pod_monitor_event_published(event.event_type, phase) except Exception as e: - self.logger.error(f"Error publishing event: {e}", exc_info=True) - - async def _handle_watch_error(self) -> None: - """Handle watch errors with exponential backoff.""" - self._reconnect_attempts += 1 - - if self._reconnect_attempts > self.config.max_reconnect_attempts: - self.logger.error( - f"Max reconnect attempts ({self.config.max_reconnect_attempts}) exceeded, stopping pod monitor" - ) - self._state = MonitorState.STOPPING - return - - # Calculate exponential backoff - backoff = min(self.config.watch_reconnect_delay * (2 ** (self._reconnect_attempts - 1)), MAX_BACKOFF_SECONDS) - - self.logger.info( - f"Reconnecting watch in {backoff}s " - f"(attempt {self._reconnect_attempts}/{self.config.max_reconnect_attempts})" - ) - - self._metrics.increment_pod_monitor_watch_reconnects() - await asyncio.sleep(backoff) - - async def _reconciliation_loop(self) -> None: - """Periodically reconcile state with Kubernetes.""" - while self._state == MonitorState.RUNNING: - try: - await asyncio.sleep(self.config.reconcile_interval_seconds) - - if self._state == MonitorState.RUNNING: - result = await self._reconcile_state() - self._log_reconciliation_result(result) + self._logger.error(f"Error publishing event: {e}", exc_info=True) - except Exception as e: - self.logger.error(f"Error in reconciliation loop: {e}", exc_info=True) + async def reconcile_state(self) -> ReconciliationResult: + """Reconcile tracked pods with actual Kubernetes state. - async def _reconcile_state(self) -> ReconciliationResult: - """Reconcile tracked pods with actual state.""" + Should be called periodically from worker entrypoint if reconciliation + is enabled in config. + """ start_time = time.time() try: - self.logger.info("Starting pod state reconciliation") + self._logger.info("Starting pod state reconciliation") - # List all pods matching selector (clients guaranteed by __init__) + # List all pods matching selector pods = await asyncio.to_thread( - self._v1.list_namespaced_pod, namespace=self.config.namespace, label_selector=self.config.label_selector + self._v1.list_namespaced_pod, + namespace=self._config.namespace, + label_selector=self._config.label_selector, ) - # Get current pod names + # Get current pod names from K8s current_pods = {pod.metadata.name for pod in pods.items} + # Get tracked pods from Redis + tracked_pods = await self._pod_state_repo.get_tracked_pod_names() + # Find differences - missing_pods = current_pods - self._tracked_pods - extra_pods = self._tracked_pods - current_pods + missing_pods = current_pods - tracked_pods + extra_pods = tracked_pods - current_pods - # Process missing pods + # Process missing pods (add them to tracking) for pod in pods.items: if pod.metadata.name in missing_pods: - self.logger.info(f"Reconciling missing pod: {pod.metadata.name}") + self._logger.info(f"Reconciling missing pod: {pod.metadata.name}") event = PodEvent( - event_type=WatchEventType.ADDED, pod=pod, resource_version=pod.metadata.resource_version + event_type=WatchEventType.ADDED, + pod=pod, + resource_version=pod.metadata.resource_version, ) await self._process_pod_event(event) - # Remove extra pods + # Remove stale pods from Redis for pod_name in extra_pods: - self.logger.info(f"Removing stale pod from tracking: {pod_name}") - self._tracked_pods.discard(pod_name) + self._logger.info(f"Removing stale pod from tracking: {pod_name}") + await self._pod_state_repo.untrack_pod(pod_name) # Update metrics - self._metrics.update_pod_monitor_pods_watched(len(self._tracked_pods)) + tracked_count = await self._pod_state_repo.get_tracked_pods_count() + self._metrics.update_pod_monitor_pods_watched(tracked_count) self._metrics.record_pod_monitor_reconciliation_run("success") duration = time.time() - start_time return ReconciliationResult( - missing_pods=missing_pods, extra_pods=extra_pods, duration_seconds=duration, success=True + missing_pods=missing_pods, + extra_pods=extra_pods, + duration_seconds=duration, + success=True, ) except Exception as e: - self.logger.error(f"Failed to reconcile state: {e}", exc_info=True) + self._logger.error(f"Failed to reconcile state: {e}", exc_info=True) self._metrics.record_pod_monitor_reconciliation_run("failed") return ReconciliationResult( @@ -431,74 +261,3 @@ async def _reconcile_state(self) -> ReconciliationResult: success=False, error=str(e), ) - - def _log_reconciliation_result(self, result: ReconciliationResult) -> None: - """Log reconciliation result.""" - if result.success: - self.logger.info( - f"Reconciliation completed in {result.duration_seconds:.2f}s. " - f"Found {len(result.missing_pods)} missing, " - f"{len(result.extra_pods)} extra pods" - ) - else: - self.logger.error(f"Reconciliation failed after {result.duration_seconds:.2f}s: {result.error}") - - async def get_status(self) -> StatusDict: - """Get monitor status.""" - return { - "state": self._state, - "tracked_pods": len(self._tracked_pods), - "reconnect_attempts": self._reconnect_attempts, - "last_resource_version": self._last_resource_version, - "config": { - "namespace": self.config.namespace, - "label_selector": self.config.label_selector, - "enable_reconciliation": self.config.enable_state_reconciliation, - }, - } - - -@asynccontextmanager -async def create_pod_monitor( - config: PodMonitorConfig, - kafka_event_service: KafkaEventService, - logger: logging.Logger, - kubernetes_metrics: KubernetesMetrics, - k8s_clients: K8sClients | None = None, - event_mapper: PodEventMapper | None = None, -) -> AsyncIterator[PodMonitor]: - """Create and manage a pod monitor instance. - - This factory handles production dependency creation: - - Creates K8sClients if not provided (using config settings) - - Creates PodEventMapper if not provided - - Cleans up created K8sClients on exit - """ - # Track whether we created clients (so we know to close them) - owns_clients = k8s_clients is None - - if k8s_clients is None: - k8s_clients = create_k8s_clients( - logger=logger, - kubeconfig_path=config.kubeconfig_path, - in_cluster=config.in_cluster, - ) - - if event_mapper is None: - event_mapper = PodEventMapper(logger=logger, k8s_api=k8s_clients.v1) - - monitor = PodMonitor( - config=config, - kafka_event_service=kafka_event_service, - logger=logger, - k8s_clients=k8s_clients, - event_mapper=event_mapper, - kubernetes_metrics=kubernetes_metrics, - ) - - try: - async with monitor: - yield monitor - finally: - if owns_clients: - close_k8s_clients(k8s_clients) diff --git a/backend/app/services/result_processor/processor.py b/backend/app/services/result_processor/processor.py index 464584c7..c7d5f1c7 100644 --- a/backend/app/services/result_processor/processor.py +++ b/backend/app/services/result_processor/processor.py @@ -1,19 +1,21 @@ +"""Result Processor - stateless event handler. + +Processes execution completion events and stores results. +Receives events, processes them, and publishes results. No lifecycle management. +""" + +from __future__ import annotations + import logging -from enum import auto -from typing import Any from pydantic import BaseModel, ConfigDict, Field -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics, ExecutionMetrics -from app.core.utils import StringEnum from app.db.repositories.execution_repository import ExecutionRepository -from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId, KafkaTopic from app.domain.enums.storage import ExecutionErrorType, StorageType from app.domain.events.typed import ( - DomainEvent, EventMetadata, ExecutionCompletedEvent, ExecutionFailedEvent, @@ -22,22 +24,10 @@ ResultStoredEvent, ) from app.domain.execution import ExecutionNotFoundError, ExecutionResultDomain -from app.domain.idempotency import KeyStrategy -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency import IdempotencyManager -from app.services.idempotency.middleware import IdempotentConsumerWrapper +from app.events.core import UnifiedProducer from app.settings import Settings -class ProcessingState(StringEnum): - """Processing state enumeration.""" - - IDLE = auto() - PROCESSING = auto() - STOPPED = auto() - - class ResultProcessorConfig(BaseModel): """Configuration for result processor.""" @@ -52,126 +42,33 @@ class ResultProcessorConfig(BaseModel): processing_timeout: int = Field(default=300) -class ResultProcessor(LifecycleEnabled): - """Service for processing execution completion events and storing results.""" +class ResultProcessor: + """Stateless result processor - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + Worker entrypoint handles the consume loop. + """ def __init__( self, execution_repo: ExecutionRepository, producer: UnifiedProducer, - schema_registry: SchemaRegistryManager, settings: Settings, - idempotency_manager: IdempotencyManager, logger: logging.Logger, execution_metrics: ExecutionMetrics, event_metrics: EventMetrics, + config: ResultProcessorConfig | None = None, ) -> None: - """Initialize the result processor.""" - super().__init__() - self.config = ResultProcessorConfig() self._execution_repo = execution_repo self._producer = producer - self._schema_registry = schema_registry self._settings = settings + self._logger = logger self._metrics = execution_metrics self._event_metrics = event_metrics - self._idempotency_manager: IdempotencyManager = idempotency_manager - self._state = ProcessingState.IDLE - self._consumer: IdempotentConsumerWrapper | None = None - self._dispatcher: EventDispatcher | None = None - self.logger = logger - - async def _on_start(self) -> None: - """Start the result processor.""" - self.logger.info("Starting ResultProcessor...") - - # Initialize idempotency manager (safe to call multiple times) - await self._idempotency_manager.initialize() - self.logger.info("Idempotency manager initialized for ResultProcessor") - - self._dispatcher = self._create_dispatcher() - self._consumer = await self._create_consumer() - self._state = ProcessingState.PROCESSING - self.logger.info("ResultProcessor started successfully with idempotency protection") - - async def _on_stop(self) -> None: - """Stop the result processor.""" - self.logger.info("Stopping ResultProcessor...") - self._state = ProcessingState.STOPPED - - if self._consumer: - await self._consumer.stop() - - await self._idempotency_manager.close() - # Note: producer is managed by DI container, not stopped here - self.logger.info("ResultProcessor stopped") - - def _create_dispatcher(self) -> EventDispatcher: - """Create and configure event dispatcher with handlers.""" - dispatcher = EventDispatcher(logger=self.logger) - - # Register handlers for specific event types - dispatcher.register_handler(EventType.EXECUTION_COMPLETED, self._handle_completed_wrapper) - dispatcher.register_handler(EventType.EXECUTION_FAILED, self._handle_failed_wrapper) - dispatcher.register_handler(EventType.EXECUTION_TIMEOUT, self._handle_timeout_wrapper) - - return dispatcher - - async def _create_consumer(self) -> IdempotentConsumerWrapper: - """Create and configure idempotent Kafka consumer.""" - consumer_config = ConsumerConfig( - bootstrap_servers=self._settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=self.config.consumer_group, - max_poll_records=1, - enable_auto_commit=True, - auto_offset_reset="earliest", - session_timeout_ms=self._settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self._settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self._settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self._settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - # Create consumer with schema registry and dispatcher - if not self._dispatcher: - raise RuntimeError("Event dispatcher not initialized") - - base_consumer = UnifiedConsumer( - consumer_config, - event_dispatcher=self._dispatcher, - schema_registry=self._schema_registry, - settings=self._settings, - logger=self.logger, - event_metrics=self._event_metrics, - ) - wrapper = IdempotentConsumerWrapper( - consumer=base_consumer, - idempotency_manager=self._idempotency_manager, - dispatcher=self._dispatcher, - logger=self.logger, - default_key_strategy=KeyStrategy.CONTENT_HASH, - default_ttl_seconds=7200, - enable_for_all_handlers=True, - ) - await wrapper.start(self.config.topics) - return wrapper - - # Wrappers accepting DomainEvent to satisfy dispatcher typing + self._config = config or ResultProcessorConfig() - async def _handle_completed_wrapper(self, event: DomainEvent) -> None: - assert isinstance(event, ExecutionCompletedEvent) - await self._handle_completed(event) - - async def _handle_failed_wrapper(self, event: DomainEvent) -> None: - assert isinstance(event, ExecutionFailedEvent) - await self._handle_failed(event) - - async def _handle_timeout_wrapper(self, event: DomainEvent) -> None: - assert isinstance(event, ExecutionTimeoutEvent) - await self._handle_timeout(event) - - async def _handle_completed(self, event: ExecutionCompletedEvent) -> None: + async def handle_execution_completed(self, event: ExecutionCompletedEvent) -> None: """Handle execution completed event.""" - exec_obj = await self._execution_repo.get_execution(event.execution_id) if exec_obj is None: raise ExecutionNotFoundError(event.execution_id) @@ -190,7 +87,7 @@ async def _handle_completed(self, event: ExecutionCompletedEvent) -> None: # Calculate and record memory utilization percentage settings_limit = self._settings.K8S_POD_MEMORY_LIMIT - memory_limit_mib = int(settings_limit.rstrip("Mi")) # TODO: Less brittle acquisition of limit + memory_limit_mib = int(settings_limit.rstrip("Mi")) memory_percent = (memory_mib / memory_limit_mib) * 100 self._metrics.memory_utilization_percent.record( memory_percent, attributes={"lang_and_version": lang_and_version} @@ -203,20 +100,18 @@ async def _handle_completed(self, event: ExecutionCompletedEvent) -> None: stdout=event.stdout, stderr=event.stderr, resource_usage=event.resource_usage, - metadata=event.metadata, + metadata=event.metadata.model_dump(), ) try: await self._execution_repo.write_terminal_result(result) await self._publish_result_stored(result) except Exception as e: - self.logger.error(f"Failed to handle ExecutionCompletedEvent: {e}", exc_info=True) + self._logger.error(f"Failed to handle ExecutionCompletedEvent: {e}", exc_info=True) await self._publish_result_failed(event.execution_id, str(e)) - async def _handle_failed(self, event: ExecutionFailedEvent) -> None: + async def handle_execution_failed(self, event: ExecutionFailedEvent) -> None: """Handle execution failed event.""" - - # Fetch execution to get language and version for metrics exec_obj = await self._execution_repo.get_execution(event.execution_id) if exec_obj is None: raise ExecutionNotFoundError(event.execution_id) @@ -232,19 +127,19 @@ async def _handle_failed(self, event: ExecutionFailedEvent) -> None: stdout=event.stdout, stderr=event.stderr, resource_usage=event.resource_usage, - metadata=event.metadata, + metadata=event.metadata.model_dump(), error_type=event.error_type, ) + try: await self._execution_repo.write_terminal_result(result) await self._publish_result_stored(result) except Exception as e: - self.logger.error(f"Failed to handle ExecutionFailedEvent: {e}", exc_info=True) + self._logger.error(f"Failed to handle ExecutionFailedEvent: {e}", exc_info=True) await self._publish_result_failed(event.execution_id, str(e)) - async def _handle_timeout(self, event: ExecutionTimeoutEvent) -> None: + async def handle_execution_timeout(self, event: ExecutionTimeoutEvent) -> None: """Handle execution timeout event.""" - exec_obj = await self._execution_repo.get_execution(event.execution_id) if exec_obj is None: raise ExecutionNotFoundError(event.execution_id) @@ -263,19 +158,19 @@ async def _handle_timeout(self, event: ExecutionTimeoutEvent) -> None: stdout=event.stdout, stderr=event.stderr, resource_usage=event.resource_usage, - metadata=event.metadata, + metadata=event.metadata.model_dump(), error_type=ExecutionErrorType.TIMEOUT, ) + try: await self._execution_repo.write_terminal_result(result) await self._publish_result_stored(result) except Exception as e: - self.logger.error(f"Failed to handle ExecutionTimeoutEvent: {e}", exc_info=True) + self._logger.error(f"Failed to handle ExecutionTimeoutEvent: {e}", exc_info=True) await self._publish_result_failed(event.execution_id, str(e)) async def _publish_result_stored(self, result: ExecutionResultDomain) -> None: """Publish result stored event.""" - size_bytes = len(result.stdout) + len(result.stderr) event = ResultStoredEvent( execution_id=result.execution_id, @@ -292,7 +187,6 @@ async def _publish_result_stored(self, result: ExecutionResultDomain) -> None: async def _publish_result_failed(self, execution_id: str, error_message: str) -> None: """Publish result processing failed event.""" - event = ResultFailedEvent( execution_id=execution_id, error=error_message, @@ -303,10 +197,3 @@ async def _publish_result_failed(self, execution_id: str, error_message: str) -> ) await self._producer.produce(event_to_produce=event, key=execution_id) - - async def get_status(self) -> dict[str, Any]: - """Get processor status.""" - return { - "state": self._state, - "consumer_active": self._consumer is not None, - } diff --git a/backend/app/services/saga/__init__.py b/backend/app/services/saga/__init__.py index e89535ae..ec47a201 100644 --- a/backend/app/services/saga/__init__.py +++ b/backend/app/services/saga/__init__.py @@ -12,7 +12,7 @@ RemoveFromQueueCompensation, ValidateExecutionStep, ) -from app.services.saga.saga_orchestrator import SagaOrchestrator, create_saga_orchestrator +from app.services.saga.saga_orchestrator import SagaOrchestrator from app.services.saga.saga_step import CompensationStep, SagaContext, SagaStep __all__ = [ @@ -34,5 +34,4 @@ "ReleaseResourcesCompensation", "RemoveFromQueueCompensation", "DeletePodCompensation", - "create_saga_orchestrator", ] diff --git a/backend/app/services/saga/saga_orchestrator.py b/backend/app/services/saga/saga_orchestrator.py index 7d607bb6..e032d6a7 100644 --- a/backend/app/services/saga/saga_orchestrator.py +++ b/backend/app/services/saga/saga_orchestrator.py @@ -1,11 +1,18 @@ -import asyncio +"""Saga Orchestrator - stateless event handler. + +Orchestrates saga execution and compensation. Receives events, +processes them, and publishes results. No lifecycle management. +All state is stored in SagaRepository (MongoDB). +""" + +from __future__ import annotations + import logging from datetime import UTC, datetime, timedelta from uuid import uuid4 from opentelemetry.trace import SpanKind -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics from app.core.tracing import EventAttributes from app.core.tracing.utils import get_tracer @@ -14,165 +21,79 @@ from app.domain.enums.events import EventType from app.domain.enums.saga import SagaState from app.domain.events.typed import DomainEvent, EventMetadata, SagaCancelledEvent -from app.domain.idempotency import KeyStrategy from app.domain.saga.models import Saga, SagaConfig -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import SchemaRegistryManager -from app.infrastructure.kafka.mappings import get_topic_for_event -from app.services.idempotency import IdempotentConsumerWrapper -from app.services.idempotency.idempotency_manager import IdempotencyManager -from app.settings import Settings +from app.events.core import UnifiedProducer from .base_saga import BaseSaga from .execution_saga import ExecutionSaga from .saga_step import SagaContext -class SagaOrchestrator(LifecycleEnabled): - """Orchestrates saga execution and compensation""" +class SagaOrchestrator: + """Stateless saga orchestrator - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + All state stored in SagaRepository. Worker entrypoint handles the consume loop. + """ def __init__( self, config: SagaConfig, saga_repository: SagaRepository, producer: UnifiedProducer, - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, resource_allocation_repository: ResourceAllocationRepository, logger: logging.Logger, event_metrics: EventMetrics, - ): - super().__init__() - self.config = config - self._sagas: dict[str, type[BaseSaga]] = {} - self._running_instances: dict[str, Saga] = {} - self._consumer: IdempotentConsumerWrapper | None = None - self._idempotency_manager: IdempotencyManager = idempotency_manager + ) -> None: + self._config = config + self._repo = saga_repository self._producer = producer - self._schema_registry_manager = schema_registry_manager - self._settings = settings - self._event_store = event_store - self._repo: SagaRepository = saga_repository - self._alloc_repo: ResourceAllocationRepository = resource_allocation_repository - self._tasks: list[asyncio.Task[None]] = [] - self.logger = logger + self._alloc_repo = resource_allocation_repository + self._logger = logger self._event_metrics = event_metrics + self._sagas: dict[str, type[BaseSaga]] = {} + + # Register default sagas + self._register_default_sagas() def register_saga(self, saga_class: type[BaseSaga]) -> None: + """Register a saga class.""" self._sagas[saga_class.get_name()] = saga_class - self.logger.info(f"Registered saga: {saga_class.get_name()}") + self._logger.info(f"Registered saga: {saga_class.get_name()}") def _register_default_sagas(self) -> None: + """Register default sagas.""" self.register_saga(ExecutionSaga) - self.logger.info("Registered default sagas") - - async def _on_start(self) -> None: - """Start the saga orchestrator.""" - self.logger.info(f"Starting saga orchestrator: {self.config.name}") - - self._register_default_sagas() - - await self._start_consumer() - - timeout_task = asyncio.create_task(self._check_timeouts()) - self._tasks.append(timeout_task) + self._logger.info("Registered default sagas") - self.logger.info("Saga orchestrator started") + def get_trigger_event_types(self) -> set[EventType]: + """Get all event types that trigger sagas. - async def _on_stop(self) -> None: - """Stop the saga orchestrator.""" - self.logger.info("Stopping saga orchestrator...") - - if self._consumer: - await self._consumer.stop() - - await self._idempotency_manager.close() - - for task in self._tasks: - if not task.done(): - task.cancel() - - if self._tasks: - await asyncio.gather(*self._tasks, return_exceptions=True) - - self.logger.info("Saga orchestrator stopped") - - async def _start_consumer(self) -> None: - self.logger.info(f"Registered sagas: {list(self._sagas.keys())}") - topics = set() - event_types_to_register = set() + Helper for worker entrypoint to know which topics to subscribe to. + """ + event_types: set[EventType] = set() for saga_class in self._sagas.values(): - trigger_event_types = saga_class.get_trigger_events() - self.logger.info(f"Saga {saga_class.get_name()} triggers on event types: {trigger_event_types}") - - # Convert event types to topics for subscription - for event_type in trigger_event_types: - topic = get_topic_for_event(event_type) - topics.add(topic) - event_types_to_register.add(event_type) - self.logger.debug(f"Event type {event_type} maps to topic {topic}") - - # Also register handlers for completion events so execution sagas can complete - completion_event_types = { + trigger_events = saga_class.get_trigger_events() + event_types.update(trigger_events) + + # Also include completion events + completion_events = { EventType.EXECUTION_COMPLETED, EventType.EXECUTION_FAILED, EventType.EXECUTION_TIMEOUT, } - for event_type in completion_event_types: - topic = get_topic_for_event(event_type) - topics.add(topic) - event_types_to_register.add(event_type) - self.logger.debug(f"Completion event type {event_type} maps to topic {topic}") - - if not topics: - self.logger.warning("No trigger events found in registered sagas") - return + event_types.update(completion_events) - consumer_config = ConsumerConfig( - bootstrap_servers=self._settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"saga-{self.config.name}", - enable_auto_commit=False, - session_timeout_ms=self._settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self._settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self._settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self._settings.KAFKA_REQUEST_TIMEOUT_MS, - ) + return event_types - dispatcher = EventDispatcher(logger=self.logger) - for event_type in event_types_to_register: - dispatcher.register_handler(event_type, self._handle_event) - self.logger.info(f"Registered handler for event type: {event_type}") - - base_consumer = UnifiedConsumer( - config=consumer_config, - event_dispatcher=dispatcher, - schema_registry=self._schema_registry_manager, - settings=self._settings, - logger=self.logger, - event_metrics=self._event_metrics, - ) - self._consumer = IdempotentConsumerWrapper( - consumer=base_consumer, - idempotency_manager=self._idempotency_manager, - dispatcher=dispatcher, - logger=self.logger, - default_key_strategy=KeyStrategy.EVENT_BASED, - default_ttl_seconds=7200, - enable_for_all_handlers=False, - ) + async def handle_event(self, event: DomainEvent) -> None: + """Handle incoming event. - assert self._consumer is not None - await self._consumer.start(list(topics)) - - self.logger.info(f"Saga consumer started for topics: {topics}") + Called by worker entrypoint for each event. + """ + self._logger.info(f"Saga orchestrator handling event: type={event.event_type}, id={event.event_id}") - async def _handle_event(self, event: DomainEvent) -> None: - """Handle incoming event""" - self.logger.info(f"Saga orchestrator handling event: type={event.event_type}, id={event.event_id}") try: # Check if this is a completion event that should update an existing saga completion_events = { @@ -187,88 +108,87 @@ async def _handle_event(self, event: DomainEvent) -> None: # Check if this event should trigger a new saga saga_triggered = False for saga_name, saga_class in self._sagas.items(): - self.logger.debug(f"Checking if {saga_name} should be triggered by {event.event_type}") + self._logger.debug(f"Checking if {saga_name} should be triggered by {event.event_type}") if self._should_trigger_saga(saga_class, event): - self.logger.info(f"Event {event.event_type} triggers saga {saga_name}") + self._logger.info(f"Event {event.event_type} triggers saga {saga_name}") saga_triggered = True saga_id = await self._start_saga(saga_name, event) if not saga_id: raise RuntimeError(f"Failed to create saga {saga_name} for event {event.event_id}") if not saga_triggered: - self.logger.debug(f"Event {event.event_type} did not trigger any saga") + self._logger.debug(f"Event {event.event_type} did not trigger any saga") except Exception as e: - self.logger.error(f"Error handling event {event.event_id}: {e}", exc_info=True) + self._logger.error(f"Error handling event {event.event_id}: {e}", exc_info=True) raise async def _handle_completion_event(self, event: DomainEvent) -> None: """Handle execution completion events to update saga state.""" execution_id = getattr(event, "execution_id", None) if not execution_id: - self.logger.warning(f"Completion event {event.event_type} has no execution_id") + self._logger.warning(f"Completion event {event.event_type} has no execution_id") return - # Find the execution saga specifically (not other saga types) + # Find the execution saga specifically saga = await self._repo.get_saga_by_execution_and_name(execution_id, ExecutionSaga.get_name()) if not saga: - self.logger.debug(f"No execution_saga found for execution {execution_id}") + self._logger.debug(f"No execution_saga found for execution {execution_id}") return # Only update if saga is still in a running state if saga.state not in (SagaState.RUNNING, SagaState.CREATED): - self.logger.debug(f"Saga {saga.saga_id} already in terminal state {saga.state}") + self._logger.debug(f"Saga {saga.saga_id} already in terminal state {saga.state}") return # Update saga state based on completion event type if event.event_type == EventType.EXECUTION_COMPLETED: - self.logger.info(f"Marking saga {saga.saga_id} as COMPLETED due to execution completion") + self._logger.info(f"Marking saga {saga.saga_id} as COMPLETED due to execution completion") saga.state = SagaState.COMPLETED saga.completed_at = datetime.now(UTC) elif event.event_type == EventType.EXECUTION_TIMEOUT: timeout_seconds = getattr(event, "timeout_seconds", None) - self.logger.info(f"Marking saga {saga.saga_id} as TIMEOUT after {timeout_seconds}s") + self._logger.info(f"Marking saga {saga.saga_id} as TIMEOUT after {timeout_seconds}s") saga.state = SagaState.TIMEOUT saga.error_message = f"Execution timed out after {timeout_seconds} seconds" saga.completed_at = datetime.now(UTC) else: # EXECUTION_FAILED error_msg = getattr(event, "error_message", None) or f"Execution {event.event_type}" - self.logger.info(f"Marking saga {saga.saga_id} as FAILED: {error_msg}") + self._logger.info(f"Marking saga {saga.saga_id} as FAILED: {error_msg}") saga.state = SagaState.FAILED saga.error_message = error_msg saga.completed_at = datetime.now(UTC) await self._save_saga(saga) - self._running_instances.pop(saga.saga_id, None) def _should_trigger_saga(self, saga_class: type[BaseSaga], event: DomainEvent) -> bool: + """Check if event should trigger a saga.""" trigger_event_types = saga_class.get_trigger_events() should_trigger = event.event_type in trigger_event_types - self.logger.debug( + self._logger.debug( f"Saga {saga_class.get_name()} triggers on {trigger_event_types}, " f"event is {event.event_type}, should trigger: {should_trigger}" ) return should_trigger async def _start_saga(self, saga_name: str, trigger_event: DomainEvent) -> str | None: - """Start a new saga instance""" - self.logger.info(f"Starting saga {saga_name} for event {trigger_event.event_type}") + """Start a new saga instance.""" + self._logger.info(f"Starting saga {saga_name} for event {trigger_event.event_type}") saga_class = self._sagas.get(saga_name) if not saga_class: raise ValueError(f"Unknown saga: {saga_name}") execution_id = getattr(trigger_event, "execution_id", None) - self.logger.debug(f"Extracted execution_id={execution_id} from event") + self._logger.debug(f"Extracted execution_id={execution_id} from event") if not execution_id: - self.logger.warning(f"Could not extract execution ID from event: {trigger_event}") + self._logger.warning(f"Could not extract execution ID from event: {trigger_event}") return None existing = await self._repo.get_saga_by_execution_and_name(execution_id, saga_name) if existing: - self.logger.info(f"Saga {saga_name} already exists for execution {execution_id}") - saga_id: str = existing.saga_id - return saga_id + self._logger.info(f"Saga {saga_name} already exists for execution {execution_id}") + return existing.saga_id instance = Saga( saga_id=str(uuid4()), @@ -278,25 +198,21 @@ async def _start_saga(self, saga_name: str, trigger_event: DomainEvent) -> str | ) await self._save_saga(instance) - self._running_instances[instance.saga_id] = instance - - self.logger.info(f"Started saga {saga_name} (ID: {instance.saga_id}) for execution {execution_id}") + self._logger.info(f"Started saga {saga_name} (ID: {instance.saga_id}) for execution {execution_id}") + # Execute saga steps synchronously saga = saga_class() - # Inject runtime dependencies explicitly (no DI via context) try: saga.bind_dependencies( producer=self._producer, alloc_repo=self._alloc_repo, - publish_commands=bool(getattr(self.config, "publish_commands", False)), + publish_commands=bool(getattr(self._config, "publish_commands", False)), ) except Exception: - # Back-compat: if saga doesn't support binding, it will fallback to context where needed pass context = SagaContext(instance.saga_id, execution_id) - - asyncio.create_task(self._execute_saga(saga, instance, context, trigger_event)) + await self._execute_saga(saga, instance, context, trigger_event) return instance.saga_id @@ -307,24 +223,17 @@ async def _execute_saga( context: SagaContext, trigger_event: DomainEvent, ) -> None: - """Execute saga steps""" + """Execute saga steps synchronously.""" tracer = get_tracer() try: - # Get saga steps steps = saga.get_steps() - # Execute each step for step in steps: - if not self.is_running: - break - - # Update current step instance.current_step = step.name await self._save_saga(instance) - self.logger.info(f"Executing saga step: {step.name} for saga {instance.saga_id}") + self._logger.info(f"Executing saga step: {step.name} for saga {instance.saga_id}") - # Execute step within a span with tracer.start_as_current_span( name="saga.step", kind=SpanKind.INTERNAL, @@ -339,8 +248,6 @@ async def _execute_saga( if success: instance.completed_steps.append(step.name) - - # Persist only safe, public context (no ephemeral objects) instance.context_data = context.to_public_dict() await self._save_saga(instance) @@ -348,162 +255,126 @@ async def _execute_saga( if compensation: context.add_compensation(compensation) else: - # Step failed, start compensation - self.logger.error(f"Saga step {step.name} failed for saga {instance.saga_id}") + self._logger.error(f"Saga step {step.name} failed for saga {instance.saga_id}") - if self.config.enable_compensation: + if self._config.enable_compensation: await self._compensate_saga(instance, context) else: await self._fail_saga(instance, "Step failed without compensation") - return - # All steps completed successfully - # Execution saga waits for external completion events (EXECUTION_COMPLETED/FAILED) + # All steps completed if instance.saga_name == ExecutionSaga.get_name(): - self.logger.info(f"Saga {instance.saga_id} steps done, waiting for execution completion event") + self._logger.info(f"Saga {instance.saga_id} steps done, waiting for execution completion event") else: await self._complete_saga(instance) except Exception as e: - self.logger.error(f"Error executing saga {instance.saga_id}: {e}", exc_info=True) + self._logger.error(f"Error executing saga {instance.saga_id}: {e}", exc_info=True) - if self.config.enable_compensation: + if self._config.enable_compensation: await self._compensate_saga(instance, context) else: await self._fail_saga(instance, str(e)) async def _compensate_saga(self, instance: Saga, context: SagaContext) -> None: - """Execute compensation steps""" - self.logger.info(f"Starting compensation for saga {instance.saga_id}") + """Execute compensation steps.""" + self._logger.info(f"Starting compensation for saga {instance.saga_id}") - # Only update state if not already cancelled if instance.state != SagaState.CANCELLED: instance.state = SagaState.COMPENSATING await self._save_saga(instance) - # Execute compensations in reverse order for compensation in reversed(context.compensations): try: - self.logger.info(f"Executing compensation: {compensation.name} for saga {instance.saga_id}") - + self._logger.info(f"Executing compensation: {compensation.name} for saga {instance.saga_id}") success = await compensation.compensate(context) if success: instance.compensated_steps.append(compensation.name) else: - self.logger.error(f"Compensation {compensation.name} failed for saga {instance.saga_id}") + self._logger.error(f"Compensation {compensation.name} failed for saga {instance.saga_id}") except Exception as e: - self.logger.error(f"Error in compensation {compensation.name}: {e}", exc_info=True) + self._logger.error(f"Error in compensation {compensation.name}: {e}", exc_info=True) - # Mark saga as failed or keep as cancelled if instance.state == SagaState.CANCELLED: - # Keep cancelled state but update compensated steps instance.updated_at = datetime.now(UTC) await self._save_saga(instance) - self.logger.info(f"Saga {instance.saga_id} compensation completed after cancellation") + self._logger.info(f"Saga {instance.saga_id} compensation completed after cancellation") else: - # Mark as failed for non-cancelled compensations await self._fail_saga(instance, "Saga compensated due to failure") async def _complete_saga(self, instance: Saga) -> None: - """Mark saga as completed""" + """Mark saga as completed.""" instance.state = SagaState.COMPLETED instance.completed_at = datetime.now(UTC) await self._save_saga(instance) - - # Remove from running instances - self._running_instances.pop(instance.saga_id, None) - - self.logger.info(f"Saga {instance.saga_id} completed successfully") + self._logger.info(f"Saga {instance.saga_id} completed successfully") async def _fail_saga(self, instance: Saga, error_message: str) -> None: - """Mark saga as failed""" + """Mark saga as failed.""" instance.state = SagaState.FAILED instance.error_message = error_message instance.completed_at = datetime.now(UTC) await self._save_saga(instance) + self._logger.error(f"Saga {instance.saga_id} failed: {error_message}") - # Remove from running instances - self._running_instances.pop(instance.saga_id, None) - - self.logger.error(f"Saga {instance.saga_id} failed: {error_message}") + async def check_timeouts(self) -> int: + """Check for saga timeouts. - async def _check_timeouts(self) -> None: - """Check for saga timeouts""" - while self.is_running: - try: - # Check every 30 seconds - await asyncio.sleep(30) - - cutoff_time = datetime.now(UTC) - timedelta(seconds=self.config.timeout_seconds) - - timed_out = await self._repo.find_timed_out_sagas(cutoff_time) - - for instance in timed_out: - self.logger.warning(f"Saga {instance.saga_id} timed out") - - instance.state = SagaState.TIMEOUT - instance.error_message = f"Saga timed out after {self.config.timeout_seconds} seconds" - instance.completed_at = datetime.now(UTC) - - await self._save_saga(instance) - self._running_instances.pop(instance.saga_id, None) + Should be called periodically from worker entrypoint. + Returns number of timed out sagas. + """ + cutoff_time = datetime.now(UTC) - timedelta(seconds=self._config.timeout_seconds) + timed_out = await self._repo.find_timed_out_sagas(cutoff_time) + count = 0 + + for instance in timed_out: + self._logger.warning(f"Saga {instance.saga_id} timed out") + instance.state = SagaState.TIMEOUT + instance.error_message = f"Saga timed out after {self._config.timeout_seconds} seconds" + instance.completed_at = datetime.now(UTC) + await self._save_saga(instance) + count += 1 - except Exception as e: - self.logger.error(f"Error checking timeouts: {e}") + return count async def _save_saga(self, instance: Saga) -> None: - """Persist saga through repository""" + """Persist saga through repository.""" instance.updated_at = datetime.now(UTC) await self._repo.upsert_saga(instance) async def get_saga_status(self, saga_id: str) -> Saga | None: - """Get saga instance status""" - # Check memory first - if saga_id in self._running_instances: - return self._running_instances[saga_id] - + """Get saga instance status.""" return await self._repo.get_saga(saga_id) async def get_execution_sagas(self, execution_id: str) -> list[Saga]: - """Get all sagas for an execution, sorted by created_at descending (newest first)""" + """Get all sagas for an execution.""" result = await self._repo.get_sagas_by_execution(execution_id) return result.sagas async def cancel_saga(self, saga_id: str) -> bool: - """Cancel a running saga and trigger compensation. - - Args: - saga_id: The ID of the saga to cancel - - Returns: - True if cancelled successfully, False otherwise - """ + """Cancel a running saga and trigger compensation.""" try: - # Get saga instance saga_instance = await self.get_saga_status(saga_id) if not saga_instance: - self.logger.error("Saga not found", extra={"saga_id": saga_id}) + self._logger.error("Saga not found", extra={"saga_id": saga_id}) return False - # Check if saga can be cancelled if saga_instance.state not in [SagaState.RUNNING, SagaState.CREATED]: - self.logger.warning( - "Cannot cancel saga in current state. Only RUNNING or CREATED sagas can be cancelled.", + self._logger.warning( + "Cannot cancel saga in current state", extra={"saga_id": saga_id, "state": saga_instance.state}, ) return False - # Update state to CANCELLED saga_instance.state = SagaState.CANCELLED saga_instance.error_message = "Saga cancelled by user request" saga_instance.completed_at = datetime.now(UTC) - # Log cancellation with user context if available user_id = saga_instance.context_data.get("user_id") - self.logger.info( + self._logger.info( "Saga cancellation initiated", extra={ "saga_id": saga_id, @@ -512,38 +383,28 @@ async def cancel_saga(self, saga_id: str) -> bool: }, ) - # Save state await self._save_saga(saga_instance) - # Remove from running instances - self._running_instances.pop(saga_id, None) - - # Publish cancellation event - if self._producer and self.config.store_events: + if self._config.store_events: await self._publish_saga_cancelled_event(saga_instance) - # Trigger compensation if saga was running and has completed steps - if saga_instance.completed_steps and self.config.enable_compensation: - # Get saga class + if saga_instance.completed_steps and self._config.enable_compensation: saga_class = self._sagas.get(saga_instance.saga_name) if saga_class: - # Create saga instance and context saga = saga_class() try: saga.bind_dependencies( producer=self._producer, alloc_repo=self._alloc_repo, - publish_commands=bool(getattr(self.config, "publish_commands", False)), + publish_commands=bool(getattr(self._config, "publish_commands", False)), ) except Exception: pass - context = SagaContext(saga_instance.saga_id, saga_instance.execution_id) - # Restore context data + context = SagaContext(saga_instance.saga_id, saga_instance.execution_id) for key, value in saga_instance.context_data.items(): context.set(key, value) - # Get steps and build compensation list steps = saga.get_steps() for step in steps: if step.name in saga_instance.completed_steps: @@ -551,19 +412,18 @@ async def cancel_saga(self, saga_id: str) -> bool: if compensation: context.add_compensation(compensation) - # Execute compensation await self._compensate_saga(saga_instance, context) else: - self.logger.error( + self._logger.error( "Saga class not found for compensation", extra={"saga_name": saga_instance.saga_name, "saga_id": saga_id}, ) - self.logger.info("Saga cancelled successfully", extra={"saga_id": saga_id}) + self._logger.info("Saga cancelled successfully", extra={"saga_id": saga_id}) return True except Exception as e: - self.logger.error( + self._logger.error( "Error cancelling saga", extra={"saga_id": saga_id, "error": str(e)}, exc_info=True, @@ -571,11 +431,7 @@ async def cancel_saga(self, saga_id: str) -> bool: return False async def _publish_saga_cancelled_event(self, saga_instance: Saga) -> None: - """Publish saga cancelled event. - - Args: - saga_instance: The cancelled saga instance - """ + """Publish saga cancelled event.""" try: cancelled_by = saga_instance.context_data.get("user_id") if saga_instance.context_data else None metadata = EventMetadata( @@ -596,53 +452,8 @@ async def _publish_saga_cancelled_event(self, saga_instance: Saga) -> None: metadata=metadata, ) - if self._producer: - await self._producer.produce(event_to_produce=event, key=saga_instance.execution_id) - - self.logger.info(f"Published cancellation event for saga {saga_instance.saga_id}") + await self._producer.produce(event_to_produce=event, key=saga_instance.execution_id) + self._logger.info(f"Published cancellation event for saga {saga_instance.saga_id}") except Exception as e: - self.logger.error(f"Failed to publish saga cancellation event: {e}") - - -def create_saga_orchestrator( - saga_repository: SagaRepository, - producer: UnifiedProducer, - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, - resource_allocation_repository: ResourceAllocationRepository, - config: SagaConfig, - logger: logging.Logger, - event_metrics: EventMetrics, -) -> SagaOrchestrator: - """Factory function to create a saga orchestrator. - - Args: - saga_repository: Repository for saga persistence - producer: Kafka producer instance - schema_registry_manager: Schema registry manager for event serialization - settings: Application settings - event_store: Event store instance for event sourcing - idempotency_manager: Manager for idempotent event processing - resource_allocation_repository: Repository for resource allocations - config: Saga configuration - logger: Logger instance - event_metrics: Event metrics for tracking Kafka consumption - - Returns: - A new saga orchestrator instance - """ - return SagaOrchestrator( - config, - saga_repository=saga_repository, - producer=producer, - schema_registry_manager=schema_registry_manager, - settings=settings, - event_store=event_store, - idempotency_manager=idempotency_manager, - resource_allocation_repository=resource_allocation_repository, - logger=logger, - event_metrics=event_metrics, - ) + self._logger.error(f"Failed to publish saga cancellation event: {e}") diff --git a/backend/app/services/sse/kafka_redis_bridge.py b/backend/app/services/sse/kafka_redis_bridge.py index 0a4eb780..83d7604e 100644 --- a/backend/app/services/sse/kafka_redis_bridge.py +++ b/backend/app/services/sse/kafka_redis_bridge.py @@ -1,151 +1,84 @@ +"""SSE Kafka Redis Bridge - stateless event handler. + +Bridges Kafka events to Redis channels for SSE delivery. +No lifecycle management - worker entrypoint handles the consume loop. +""" + from __future__ import annotations -import asyncio import logging -from app.core.lifecycle import LifecycleEnabled -from app.core.metrics import EventMetrics from app.domain.enums.events import EventType -from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId from app.domain.events.typed import DomainEvent -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer -from app.events.schema.schema_registry import SchemaRegistryManager from app.services.sse.redis_bus import SSERedisBus -from app.settings import Settings - -class SSEKafkaRedisBridge(LifecycleEnabled): - """ - Bridges Kafka events to Redis channels for SSE delivery. - - - Consumes relevant Kafka topics using a small consumer pool - - Deserializes events and publishes them to Redis via SSERedisBus - - Keeps no in-process buffers; delivery to clients is via Redis only +# Event types relevant for SSE streaming +RELEVANT_EVENT_TYPES: set[EventType] = { + EventType.EXECUTION_REQUESTED, + EventType.EXECUTION_QUEUED, + EventType.EXECUTION_STARTED, + EventType.EXECUTION_RUNNING, + EventType.EXECUTION_COMPLETED, + EventType.EXECUTION_FAILED, + EventType.EXECUTION_TIMEOUT, + EventType.EXECUTION_CANCELLED, + EventType.RESULT_STORED, + EventType.POD_CREATED, + EventType.POD_SCHEDULED, + EventType.POD_RUNNING, + EventType.POD_SUCCEEDED, + EventType.POD_FAILED, + EventType.POD_TERMINATED, + EventType.POD_DELETED, +} + + +class SSEKafkaRedisBridge: + """Stateless SSE bridge - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + Worker entrypoint handles the consume loop. """ def __init__( - self, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_metrics: EventMetrics, - sse_bus: SSERedisBus, - logger: logging.Logger, + self, + sse_bus: SSERedisBus, + logger: logging.Logger, ) -> None: - super().__init__() - self.schema_registry = schema_registry - self.settings = settings - self.event_metrics = event_metrics - self.sse_bus = sse_bus - self.logger = logger - - self.num_consumers = settings.SSE_CONSUMER_POOL_SIZE - self.consumers: list[UnifiedConsumer] = [] - - async def _on_start(self) -> None: - """Start the SSE Kafka→Redis bridge.""" - self.logger.info(f"Starting SSE Kafka→Redis bridge with {self.num_consumers} consumers") - - # Phase 1: Build all consumers and track them immediately (no I/O) - self.consumers = [self._build_consumer(i) for i in range(self.num_consumers)] - - # Phase 2: Start all in parallel - already tracked in self.consumers for cleanup - topics = list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.WEBSOCKET_GATEWAY]) - await asyncio.gather(*[c.start(topics) for c in self.consumers]) - - self.logger.info("SSE Kafka→Redis bridge started successfully") - - async def _on_stop(self) -> None: - """Stop the SSE Kafka→Redis bridge.""" - self.logger.info("Stopping SSE Kafka→Redis bridge") - await asyncio.gather(*[c.stop() for c in self.consumers], return_exceptions=True) - self.consumers.clear() - self.logger.info("SSE Kafka→Redis bridge stopped") - - def _build_consumer(self, consumer_index: int) -> UnifiedConsumer: - """Build a consumer instance without starting it.""" - config = ConsumerConfig( - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - group_id="sse-bridge-pool", - client_id=f"sse-bridge-{consumer_index}", - enable_auto_commit=True, - auto_offset_reset="latest", - max_poll_interval_ms=self.settings.KAFKA_MAX_POLL_INTERVAL_MS, - session_timeout_ms=self.settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self.settings.KAFKA_HEARTBEAT_INTERVAL_MS, - request_timeout_ms=self.settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - dispatcher = EventDispatcher(logger=self.logger) - self._register_routing_handlers(dispatcher) - - return UnifiedConsumer( - config=config, - event_dispatcher=dispatcher, - schema_registry=self.schema_registry, - settings=self.settings, - logger=self.logger, - event_metrics=self.event_metrics, - ) - - def _register_routing_handlers(self, dispatcher: EventDispatcher) -> None: - """Publish relevant events to Redis channels keyed by execution_id.""" - relevant_events = [ - EventType.EXECUTION_REQUESTED, - EventType.EXECUTION_QUEUED, - EventType.EXECUTION_STARTED, - EventType.EXECUTION_RUNNING, - EventType.EXECUTION_COMPLETED, - EventType.EXECUTION_FAILED, - EventType.EXECUTION_TIMEOUT, - EventType.EXECUTION_CANCELLED, - EventType.RESULT_STORED, - EventType.POD_CREATED, - EventType.POD_SCHEDULED, - EventType.POD_RUNNING, - EventType.POD_SUCCEEDED, - EventType.POD_FAILED, - EventType.POD_TERMINATED, - EventType.POD_DELETED, - ] - - async def route_event(event: DomainEvent) -> None: - data = event.model_dump() - execution_id = data.get("execution_id") - if not execution_id: - self.logger.debug(f"Event {event.event_type} has no execution_id") - return - try: - await self.sse_bus.publish_event(execution_id, event) - self.logger.info(f"Published {event.event_type} to Redis for {execution_id}") - except Exception as e: - self.logger.error( - f"Failed to publish {event.event_type} to Redis for {execution_id}: {e}", - exc_info=True, - ) - - for et in relevant_events: - dispatcher.register_handler(et, route_event) - - def get_stats(self) -> dict[str, int | bool]: + self._sse_bus = sse_bus + self._logger = logger + + @staticmethod + def get_relevant_event_types() -> set[EventType]: + """Get event types that should be routed to SSE. + + Helper for worker entrypoint to know which topics to subscribe to. + """ + return RELEVANT_EVENT_TYPES + + async def handle_event(self, event: DomainEvent) -> None: + """Handle an event and route to SSE bus. + + Called by worker entrypoint for each event from consume loop. + """ + data = event.model_dump() + execution_id = data.get("execution_id") + + if not execution_id: + self._logger.debug(f"Event {event.event_type} has no execution_id") + return + + try: + await self._sse_bus.publish_event(execution_id, event) + self._logger.debug(f"Published {event.event_type} to Redis for {execution_id}") + except Exception as e: + self._logger.error( + f"Failed to publish {event.event_type} to Redis for {execution_id}: {e}", + exc_info=True, + ) + + async def get_status(self) -> dict[str, list[str]]: + """Get bridge status.""" return { - "num_consumers": len(self.consumers), - "active_executions": 0, - "total_buffers": 0, - "is_running": self.is_running, + "relevant_event_types": [str(et) for et in RELEVANT_EVENT_TYPES], } - - -def create_sse_kafka_redis_bridge( - schema_registry: SchemaRegistryManager, - settings: Settings, - event_metrics: EventMetrics, - sse_bus: SSERedisBus, - logger: logging.Logger, -) -> SSEKafkaRedisBridge: - return SSEKafkaRedisBridge( - schema_registry=schema_registry, - settings=settings, - event_metrics=event_metrics, - sse_bus=sse_bus, - logger=logger, - ) diff --git a/backend/app/services/sse/sse_service.py b/backend/app/services/sse/sse_service.py index 9cdb13ee..6afb1f90 100644 --- a/backend/app/services/sse/sse_service.py +++ b/backend/app/services/sse/sse_service.py @@ -2,13 +2,12 @@ import logging from collections.abc import AsyncGenerator from datetime import datetime, timezone -from typing import Any +from typing import Any, Dict from app.core.metrics import ConnectionMetrics from app.db.repositories.sse_repository import SSERepository from app.domain.enums.events import EventType -from app.domain.enums.sse import SSEControlEvent, SSEHealthStatus, SSENotificationEvent -from app.domain.sse import SSEHealthDomain +from app.domain.enums.sse import SSEControlEvent, SSENotificationEvent from app.schemas_pydantic.execution import ExecutionResult from app.schemas_pydantic.sse import ( RedisNotificationMessage, @@ -50,7 +49,7 @@ def __init__( self.metrics = connection_metrics self.heartbeat_interval = getattr(settings, "SSE_HEARTBEAT_INTERVAL", 30) - async def create_execution_stream(self, execution_id: str, user_id: str) -> AsyncGenerator[dict[str, Any], None]: + async def create_execution_stream(self, execution_id: str, user_id: str) -> AsyncGenerator[Dict[str, Any], None]: connection_id = f"sse_{execution_id}_{datetime.now(timezone.utc).timestamp()}" shutdown_event = await self.shutdown_manager.register_connection(execution_id, connection_id) @@ -125,7 +124,7 @@ async def _stream_events_redis( subscription: Any, shutdown_event: asyncio.Event, include_heartbeat: bool = True, - ) -> AsyncGenerator[dict[str, Any], None]: + ) -> AsyncGenerator[Dict[str, Any], None]: last_heartbeat = datetime.now(timezone.utc) while True: if shutdown_event.is_set(): @@ -195,7 +194,7 @@ async def _build_sse_event_from_redis(self, execution_id: str, msg: RedisSSEMess } ) - async def create_notification_stream(self, user_id: str) -> AsyncGenerator[dict[str, Any], None]: + async def create_notification_stream(self, user_id: str) -> AsyncGenerator[Dict[str, Any], None]: subscription = None try: @@ -258,23 +257,10 @@ async def create_notification_stream(self, user_id: str) -> AsyncGenerator[dict[ if subscription is not None: await asyncio.shield(subscription.close()) - async def get_health_status(self) -> SSEHealthDomain: - router_stats = self.router.get_stats() - return SSEHealthDomain( - status=SSEHealthStatus.DRAINING if self.shutdown_manager.is_shutting_down() else SSEHealthStatus.HEALTHY, - kafka_enabled=True, - active_connections=router_stats["active_executions"], - active_executions=router_stats["active_executions"], - active_consumers=router_stats["num_consumers"], - max_connections_per_user=5, - shutdown=self.shutdown_manager.get_shutdown_status(), - timestamp=datetime.now(timezone.utc), - ) - - def _format_sse_event(self, event: SSEExecutionEventData) -> dict[str, Any]: + def _format_sse_event(self, event: SSEExecutionEventData) -> Dict[str, Any]: """Format typed SSE event for sse-starlette.""" return {"data": event.model_dump_json(exclude_none=True)} - def _format_notification_event(self, event: SSENotificationEventData) -> dict[str, Any]: + def _format_notification_event(self, event: SSENotificationEventData) -> Dict[str, Any]: """Format typed notification SSE event for sse-starlette.""" return {"data": event.model_dump_json(exclude_none=True)} diff --git a/backend/app/services/sse/sse_shutdown_manager.py b/backend/app/services/sse/sse_shutdown_manager.py index c30ee855..dc799d10 100644 --- a/backend/app/services/sse/sse_shutdown_manager.py +++ b/backend/app/services/sse/sse_shutdown_manager.py @@ -3,7 +3,6 @@ import time from enum import Enum -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import ConnectionMetrics from app.domain.sse import ShutdownStatus @@ -56,9 +55,6 @@ def __init__( self._connection_callbacks: dict[str, asyncio.Event] = {} # connection_id -> shutdown event self._draining_connections: set[str] = set() - # Router reference (set during initialization) - self._router: LifecycleEnabled | None = None - # Synchronization self._lock = asyncio.Lock() self._shutdown_event = asyncio.Event() @@ -73,10 +69,6 @@ def __init__( extra={"drain_timeout": drain_timeout, "notification_timeout": notification_timeout}, ) - def set_router(self, router: LifecycleEnabled) -> None: - """Set the router reference for shutdown coordination.""" - self._router = router - async def register_connection(self, execution_id: str, connection_id: str) -> asyncio.Event | None: """ Register a new SSE connection. @@ -259,10 +251,6 @@ async def _force_close_connections(self) -> None: self._connection_callbacks.clear() self._draining_connections.clear() - # If we have a router, tell it to stop accepting new subscriptions - if self._router: - await self._router.aclose() - self.metrics.update_sse_draining_connections(0) self.logger.info("Force close phase complete") @@ -305,31 +293,3 @@ async def _wait_for_complete(self) -> None: """Wait for shutdown to complete""" while not self._shutdown_complete: await asyncio.sleep(0.1) - - -def create_sse_shutdown_manager( - logger: logging.Logger, - connection_metrics: ConnectionMetrics, - drain_timeout: float = 30.0, - notification_timeout: float = 5.0, - force_close_timeout: float = 10.0, -) -> SSEShutdownManager: - """Factory function to create an SSE shutdown manager. - - Args: - logger: Logger instance - connection_metrics: Connection metrics for tracking SSE connections - drain_timeout: Time to wait for connections to close gracefully - notification_timeout: Time to wait for shutdown notifications to be sent - force_close_timeout: Time before force closing connections - - Returns: - A new SSE shutdown manager instance - """ - return SSEShutdownManager( - logger=logger, - connection_metrics=connection_metrics, - drain_timeout=drain_timeout, - notification_timeout=notification_timeout, - force_close_timeout=force_close_timeout, - ) diff --git a/backend/app/services/user_settings_service.py b/backend/app/services/user_settings_service.py index fc9964d0..981f0164 100644 --- a/backend/app/services/user_settings_service.py +++ b/backend/app/services/user_settings_service.py @@ -8,7 +8,6 @@ from app.db.repositories.user_settings_repository import UserSettingsRepository from app.domain.enums import Theme from app.domain.enums.events import EventType -from app.domain.events.typed import DomainEvent, EventMetadata, UserSettingsUpdatedEvent from app.domain.user import ( DomainEditorSettings, DomainNotificationSettings, @@ -17,9 +16,8 @@ DomainUserSettingsChangedEvent, DomainUserSettingsUpdate, ) -from app.services.event_bus import EventBusManager +from app.services.event_bus import EventBus, EventBusEvent from app.services.kafka_event_service import KafkaEventService -from app.settings import Settings _settings_adapter = TypeAdapter(DomainUserSettings) _update_adapter = TypeAdapter(DomainUserSettingsUpdate) @@ -30,12 +28,12 @@ def __init__( self, repository: UserSettingsRepository, event_service: KafkaEventService, - settings: Settings, + event_bus: EventBus, logger: logging.Logger, ) -> None: self.repository = repository self.event_service = event_service - self.settings = settings + self._event_bus = event_bus self.logger = logger self._cache_ttl = timedelta(minutes=5) self._max_cache_size = 1000 @@ -43,8 +41,6 @@ def __init__( maxsize=self._max_cache_size, ttl=self._cache_ttl.total_seconds(), ) - self._event_bus_manager: EventBusManager | None = None - self._subscription_id: str | None = None self.logger.info( "UserSettingsService initialized", @@ -60,20 +56,19 @@ async def get_user_settings(self, user_id: str) -> DomainUserSettings: return await self.get_user_settings_fresh(user_id) - async def initialize(self, event_bus_manager: EventBusManager) -> None: + async def setup_event_subscription(self) -> None: """Subscribe to settings update events for cross-instance cache invalidation. Note: EventBus filters out self-published messages, so this handler only - runs for events from OTHER instances. + runs for events from OTHER instances. Called by DI provider after construction. """ - self._event_bus_manager = event_bus_manager - bus = await event_bus_manager.get_event_bus() - async def _handle(evt: DomainEvent) -> None: - if isinstance(evt, UserSettingsUpdatedEvent): - await self.invalidate_cache(evt.user_id) + async def _handle(evt: EventBusEvent) -> None: + uid = evt.payload.get("user_id") + if uid: + await self.invalidate_cache(str(uid)) - self._subscription_id = await bus.subscribe(f"{EventType.USER_SETTINGS_UPDATED}*", _handle) + await self._event_bus.subscribe("user.settings.updated*", _handle) async def get_user_settings_fresh(self, user_id: str) -> DomainUserSettings: """Bypass cache and rebuild settings from snapshot + events.""" @@ -114,20 +109,7 @@ async def update_user_settings( changes_json = _update_adapter.dump_python(updates, exclude_none=True, mode="json") await self._publish_settings_event(user_id, changes_json, reason) - if self._event_bus_manager is not None: - bus = await self._event_bus_manager.get_event_bus() - await bus.publish( - UserSettingsUpdatedEvent( - user_id=user_id, - changed_fields=list(changes_json.keys()), - reason=reason, - metadata=EventMetadata( - service_name=self.settings.SERVICE_NAME, - service_version=self.settings.SERVICE_VERSION, - user_id=user_id, - ), - ) - ) + await self._event_bus.publish("user.settings.updated", {"user_id": user_id}) self._add_to_cache(user_id, new_settings) if (await self.repository.count_events_since_snapshot(user_id)) >= 10: diff --git a/backend/di_lifecycle_refactor_plan.md b/backend/di_lifecycle_refactor_plan.md new file mode 100644 index 00000000..e19e65f6 --- /dev/null +++ b/backend/di_lifecycle_refactor_plan.md @@ -0,0 +1,64 @@ +# Backend DI Lifecycle Refactor Plan + +## Goals +- Push all service lifecycle management into Dishka providers; no ad-hoc threads, `is_running` flags, or manual `__aenter__/__aexit__`, `_on_start` / `_on_stop` hooks. +- Keep **zero lifecycle helper files or task-group utilities** in app code; Dishka provider primitives alone manage ownership, and container close is the only shutdown signal. +- Simplify worker entrypoints and FastAPI startup so shutting down the DI container is the only teardown required. + +## Current Pain Points (code refs) +- `app/core/lifecycle.py` mixes lifecycle concerns into every long-running service and leaks `is_running` flags across the codebase. +- Kafka-facing services (`events/core/consumer.py`, `events/core/producer.py`, `events/event_store_consumer.py`, `services/result_processor/processor.py`, `services/k8s_worker/worker.py`, `services/coordinator/coordinator.py`, `services/sse/kafka_redis_bridge.py`, `services/notification_service.py`, `dlq/manager.py`) start background loops via `asyncio.create_task` and manage stop signals manually. +- FastAPI lifespan (`app/core/dishka_lifespan.py`) manually enters/starts multiple services and stacks callbacks; workers use `while service.is_running` loops (e.g., `workers/run_saga_orchestrator.py`). +- `app/core/adaptive_sampling.py` spins a raw thread for periodic work. +- `EventBusManager` caches an `EventBus` started via `__aenter__`, duplicating lifecycle logic. + +## Target Architecture (Dishka-centric) +- Use Dishka `@provide` async-generator providers directly—no extra lifecycle helper modules. Providers must **not** call `start/stop/__aenter__/__aexit__`; objects are usable immediately after construction and simply released when the container closes. +- Keep services as pure orchestrators/handlers that assume dependencies are already constructed; no lifecycle methods (`start`, `stop`, `aclose`, `__aenter__`, `__aexit__`, `is_running`). +- FastAPI lifecycle only needs to create/close the container; bootstrap work (schema registry init, Beanie init, rate-limit seeding) runs inside an `APP`-scoped provider without explicit start/stop calls. +- Worker entrypoints resolve services (already constructed by providers), then block on a shutdown event; container.close() just releases references—no service teardown calls. + +## Step-by-Step Refactor +1) **Inline provider construction (no helper files, no start/stop)** + - Keep everything inside existing provider classes; do not add `core/di/*` helper modules. + - Providers construct dependencies once and yield them; **no start/stop or context-manager calls** inside providers. Objects must be drop-safe when the container releases them. + +2) **Retire `LifecycleEnabled`** + - Remove the base class and delete `is_running` state from all services. + - Convert services to plain classes whose constructors accept already-started collaborators (producer, dispatcher, repositories, etc.). + - Where a class only wrapped start/stop (e.g., `UnifiedProducer`, `UnifiedConsumer`), replace with lightweight functions or data classes plus provider-managed runners. + +3) **Kafka-facing services → passive, no lifecycle** + - Event ingestors (`EventStoreConsumer`, `ResultProcessor`, `NotificationService`, `SSEKafkaRedisBridge`, `DLQManager`, `ExecutionCoordinator`, `KubernetesWorker`, `SagaOrchestrator`, `PodMonitor`) become passive components: construction wires handlers/clients, but there is **no start/stop**. Message handling is invoked explicitly by callers (per-request or explicit trigger), not via background loops. + - Delete `_batch_task`, `_scheduling_task`, `_process_task`, and any `asyncio.create_task` usage. No runners, no background scheduling, no threads/processes. + +4) **FastAPI bootstrap simplification** + - Remove custom lifespan entirely; rely on FastAPI default lifecycle. + - Perform one-time bootstrap (schemas, Beanie, rate limits, tracing/metrics wiring) directly in `main.py` before constructing the app object and DI container. No dedicated provider or lifespan hook. + - Wiring stays declarative in `main.py`; providers stay free of bootstrap side-effects. + +5) **Worker entrypoints overhaul** + - Use signal-driven shutdown only: install handlers, wait on shutdown event, then rely on Dishka to close the container (no explicit teardown). Avoid polling loops, task groups, runners, lifecycle files. + - Providers should be sync where possible; prefer simple constructors over async generators so container cleanup is automatic and implicit. + - Each worker script builds settings → container, resolves the needed service (already constructed by its provider), logs readiness, then waits on the shutdown event. + +6) **Adaptive sampling & other threads** + - Replace `AdaptiveSampler` thread with on-demand, stateless computation (pure function or cached calculator). No background loop, no thread, no task group. + - Audit and remove any `threading.Thread` or `multiprocessing` usage; prefer synchronous or explicitly awaited calls executed by the caller. + +7) **Testing & migration** + - Update unit tests to drop assertions around `is_running` and context-manager behavior; add tests that closing the DI container cancels consumer loops and flushes Kafka commits. + - Add a narrow integration test that spins an APP-scoped container with a fake consumer to verify provider-managed shutdown. + - Keep `uv run pytest` as the execution path; prefer `PYTEST_ADDOPTS=` override to disable xdist when debugging lifecycle issues locally. + +## Risks / Open Questions +- Kafka libraries typically expect explicit `start/stop`; shifting to construct-and-drop may require swapping implementations (e.g., per-call producer/consumer) to avoid leaks. +- Some startup routines (Kubernetes config load) are blocking; may still need threadpool execution even without explicit lifecycle. +- Need to ensure Dishka container close is called in all entrypoints so provider objects are released. + +## Definition of Done +- No class in `app/` inherits `LifecycleEnabled`; the file is removed. +- No service exposes `is_running` or `__aenter__/__aexit__`; lifecycle lives exclusively in providers. +- FastAPI and worker entrypoints use container close as the sole shutdown hook. +- No background loops/tasks/threads/processes; services do work only when explicitly invoked. +- Unit and targeted integration tests pass under `uv run pytest` with minimal/no external dependencies. diff --git a/backend/tests/e2e/core/test_dishka_lifespan.py b/backend/tests/e2e/core/test_dishka_lifespan.py index 39aada74..25d4de31 100644 --- a/backend/tests/e2e/core/test_dishka_lifespan.py +++ b/backend/tests/e2e/core/test_dishka_lifespan.py @@ -89,11 +89,11 @@ async def test_sse_bridge_available(self, scope: AsyncContainer) -> None: assert bridge is not None @pytest.mark.asyncio - async def test_event_store_consumer_available( + async def test_event_store_available( self, scope: AsyncContainer ) -> None: - """Event store consumer is available after lifespan.""" - from app.events.event_store_consumer import EventStoreConsumer + """Event store is available after lifespan.""" + from app.events.event_store import EventStore - consumer = await scope.get(EventStoreConsumer) - assert consumer is not None + store = await scope.get(EventStore) + assert store is not None diff --git a/backend/tests/e2e/dlq/test_dlq_manager.py b/backend/tests/e2e/dlq/test_dlq_manager.py index 381f90e2..d8888138 100644 --- a/backend/tests/e2e/dlq/test_dlq_manager.py +++ b/backend/tests/e2e/dlq/test_dlq_manager.py @@ -6,8 +6,7 @@ import pytest from aiokafka import AIOKafkaConsumer, AIOKafkaProducer -from app.core.metrics import DLQMetrics -from app.dlq.manager import create_dlq_manager +from app.dlq.manager import DLQManager from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DLQMessageReceivedEvent @@ -28,9 +27,8 @@ @pytest.mark.asyncio async def test_dlq_manager_persists_and_emits_event(scope: AsyncContainer, test_settings: Settings) -> None: """Test that DLQ manager persists messages and emits DLQMessageReceivedEvent.""" - schema_registry = SchemaRegistryManager(test_settings, _test_logger) - dlq_metrics: DLQMetrics = await scope.get(DLQMetrics) - manager = create_dlq_manager(settings=test_settings, schema_registry=schema_registry, logger=_test_logger, dlq_metrics=dlq_metrics) + schema_registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) + manager: DLQManager = await scope.get(DLQManager) prefix = test_settings.KAFKA_TOPIC_PREFIX ev = make_execution_requested_event(execution_id=f"exec-dlq-persist-{uuid.uuid4().hex[:8]}") @@ -72,31 +70,24 @@ async def consume_dlq_events() -> None: "producer_id": "tests", } - # Produce to DLQ topic BEFORE starting consumers (auto_offset_reset="earliest") - producer = AIOKafkaProducer(bootstrap_servers=test_settings.KAFKA_BOOTSTRAP_SERVERS) - await producer.start() - try: - await producer.send_and_wait( - topic=f"{prefix}{str(KafkaTopic.DEAD_LETTER_QUEUE)}", - key=ev.event_id.encode(), - value=json.dumps(payload).encode(), - ) - finally: - await producer.stop() - - # Start consumer for DLQ events + # Start consumer for DLQ events before producing await consumer.start() consume_task = asyncio.create_task(consume_dlq_events()) try: - # Start manager - it will consume from DLQ, persist, and emit DLQMessageReceivedEvent - async with manager: - # Await the DLQMessageReceivedEvent - true async, no polling - received = await asyncio.wait_for(received_future, timeout=15.0) - assert received.dlq_event_id == ev.event_id - assert received.event_type == EventType.DLQ_MESSAGE_RECEIVED - assert received.original_event_type == str(EventType.EXECUTION_REQUESTED) - assert received.error == "handler failed" + # Now produce to DLQ topic and call manager.handle_dlq_message directly + raw_message = json.dumps(payload).encode() + headers: dict[str, str] = {} + + # Manager handles the message (stateless handler) + await manager.handle_dlq_message(raw_message, headers) + + # Await the DLQMessageReceivedEvent - true async, no polling + received = await asyncio.wait_for(received_future, timeout=15.0) + assert received.dlq_event_id == ev.event_id + assert received.event_type == EventType.DLQ_MESSAGE_RECEIVED + assert received.original_event_type == str(EventType.EXECUTION_REQUESTED) + assert received.error == "handler failed" finally: consume_task.cancel() try: diff --git a/backend/tests/e2e/events/test_consume_roundtrip.py b/backend/tests/e2e/events/test_consume_roundtrip.py index 3b7d969b..3c64f706 100644 --- a/backend/tests/e2e/events/test_consume_roundtrip.py +++ b/backend/tests/e2e/events/test_consume_roundtrip.py @@ -3,13 +3,13 @@ import uuid import pytest +from aiokafka import AIOKafkaConsumer from app.core.metrics import EventMetrics from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent from app.events.core import UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher -from app.events.core.types import ConsumerConfig from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.settings import Settings from dishka import AsyncContainer @@ -43,22 +43,25 @@ async def _handle(_event: DomainEvent) -> None: received.set() group_id = f"test-consumer.{uuid.uuid4().hex[:6]}" - config = ConsumerConfig( + + # Create AIOKafkaConsumer directly for test + topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EXECUTION_EVENTS}" + kafka_consumer = AIOKafkaConsumer( + topic, bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, group_id=group_id, enable_auto_commit=True, auto_offset_reset="earliest", ) + await kafka_consumer.start() - consumer = UnifiedConsumer( - config, - dispatcher, + handler = UnifiedConsumer( + event_dispatcher=dispatcher, schema_registry=registry, - settings=settings, logger=_test_logger, event_metrics=event_metrics, + group_id=group_id, ) - await consumer.start([KafkaTopic.EXECUTION_EVENTS]) try: # Produce a request event @@ -67,6 +70,13 @@ async def _handle(_event: DomainEvent) -> None: await producer.produce(evt, key=execution_id) # Wait for the handler to be called - await asyncio.wait_for(received.wait(), timeout=10.0) + async def consume_until_received() -> None: + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + if received.is_set(): + break + + await asyncio.wait_for(consume_until_received(), timeout=10.0) finally: - await consumer.stop() + await kafka_consumer.stop() diff --git a/backend/tests/e2e/events/test_consumer_lifecycle.py b/backend/tests/e2e/events/test_consumer_lifecycle.py index 98c53a08..a7e102d9 100644 --- a/backend/tests/e2e/events/test_consumer_lifecycle.py +++ b/backend/tests/e2e/events/test_consumer_lifecycle.py @@ -2,9 +2,10 @@ from uuid import uuid4 import pytest +from aiokafka import AIOKafkaConsumer from app.core.metrics import EventMetrics from app.domain.enums.kafka import KafkaTopic -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer +from app.events.core import EventDispatcher, UnifiedConsumer from app.events.schema.schema_registry import SchemaRegistryManager from app.settings import Settings from dishka import AsyncContainer @@ -17,30 +18,29 @@ @pytest.mark.asyncio -async def test_consumer_start_status_seek_and_stop(scope: AsyncContainer) -> None: +async def test_consumer_seek_operations(scope: AsyncContainer) -> None: + """Test AIOKafkaConsumer seek operations work correctly.""" registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) settings: Settings = await scope.get(Settings) - event_metrics: EventMetrics = await scope.get(EventMetrics) - cfg = ConsumerConfig( + + group_id = f"test-consumer-{uuid4().hex[:6]}" + + # Create AIOKafkaConsumer directly + topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EXECUTION_EVENTS}" + kafka_consumer = AIOKafkaConsumer( + topic, bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"test-consumer-{uuid4().hex[:6]}", + group_id=group_id, + enable_auto_commit=True, + auto_offset_reset="earliest", ) - disp = EventDispatcher(logger=_test_logger) - c = UnifiedConsumer( - cfg, - event_dispatcher=disp, - schema_registry=registry, - settings=settings, - logger=_test_logger, - event_metrics=event_metrics, - ) - await c.start([KafkaTopic.EXECUTION_EVENTS]) + await kafka_consumer.start() + try: - st = c.get_status() - assert st.state == "running" and st.is_running is True - # Exercise seek functions; don't force specific partition offsets - await c.seek_to_beginning() - await c.seek_to_end() - # No need to sleep; just ensure we can call seek APIs while running + # Exercise seek functions on AIOKafkaConsumer directly + assignment = kafka_consumer.assignment() + if assignment: + await kafka_consumer.seek_to_beginning(*assignment) + await kafka_consumer.seek_to_end(*assignment) finally: - await c.stop() + await kafka_consumer.stop() diff --git a/backend/tests/e2e/events/test_event_dispatcher.py b/backend/tests/e2e/events/test_event_dispatcher.py index 2ead3aa3..126bdf04 100644 --- a/backend/tests/e2e/events/test_event_dispatcher.py +++ b/backend/tests/e2e/events/test_event_dispatcher.py @@ -3,13 +3,13 @@ import uuid import pytest +from aiokafka import AIOKafkaConsumer from app.core.metrics import EventMetrics from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent from app.events.core import UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher -from app.events.core.types import ConsumerConfig from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.settings import Settings from dishka import AsyncContainer @@ -44,22 +44,26 @@ async def h1(_e: DomainEvent) -> None: async def h2(_e: DomainEvent) -> None: h2_called.set() - # Real consumer against execution-events - cfg = ConsumerConfig( + group_id = f"dispatcher-it.{uuid.uuid4().hex[:6]}" + + # Create AIOKafkaConsumer directly for test + topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EXECUTION_EVENTS}" + kafka_consumer = AIOKafkaConsumer( + topic, bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"dispatcher-it.{uuid.uuid4().hex[:6]}", + group_id=group_id, enable_auto_commit=True, auto_offset_reset="earliest", ) - consumer = UnifiedConsumer( - cfg, - dispatcher, + await kafka_consumer.start() + + handler = UnifiedConsumer( + event_dispatcher=dispatcher, schema_registry=registry, - settings=settings, logger=_test_logger, event_metrics=event_metrics, + group_id=group_id, ) - await consumer.start([KafkaTopic.EXECUTION_EVENTS]) # Produce a request event via DI producer: UnifiedProducer = await scope.get(UnifiedProducer) @@ -67,6 +71,13 @@ async def h2(_e: DomainEvent) -> None: await producer.produce(evt, key="k") try: - await asyncio.wait_for(asyncio.gather(h1_called.wait(), h2_called.wait()), timeout=10.0) + async def consume_until_handled() -> None: + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + if h1_called.is_set() and h2_called.is_set(): + break + + await asyncio.wait_for(consume_until_handled(), timeout=10.0) finally: - await consumer.stop() + await kafka_consumer.stop() diff --git a/backend/tests/e2e/events/test_producer_roundtrip.py b/backend/tests/e2e/events/test_producer_roundtrip.py index 8340610b..ed1a4cdb 100644 --- a/backend/tests/e2e/events/test_producer_roundtrip.py +++ b/backend/tests/e2e/events/test_producer_roundtrip.py @@ -2,11 +2,8 @@ from uuid import uuid4 import pytest -from app.core.metrics import EventMetrics -from app.events.core import UnifiedProducer -from app.events.schema.schema_registry import SchemaRegistryManager +from app.events.core import ProducerMetrics, UnifiedProducer from app.infrastructure.kafka.mappings import get_topic_for_event -from app.settings import Settings from dishka import AsyncContainer from tests.conftest import make_execution_requested_event @@ -17,25 +14,20 @@ @pytest.mark.asyncio -async def test_unified_producer_start_produce_send_to_dlq_stop( - scope: AsyncContainer, test_settings: Settings -) -> None: - schema: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - event_metrics: EventMetrics = await scope.get(EventMetrics) - prod = UnifiedProducer( - schema, - logger=_test_logger, - settings=test_settings, - event_metrics=event_metrics, - ) - - async with prod: - ev = make_execution_requested_event(execution_id=f"exec-{uuid4().hex[:8]}") - await prod.produce(ev) - - # Exercise send_to_dlq path - topic = str(get_topic_for_event(ev.event_type)) - await prod.send_to_dlq(ev, original_topic=topic, error=RuntimeError("forced"), retry_count=1) - - st = prod.get_status() - assert st["running"] is True and st["state"] == "running" +async def test_unified_producer_produce_and_send_to_dlq(scope: AsyncContainer) -> None: + # Get producer and metrics from DI + prod: UnifiedProducer = await scope.get(UnifiedProducer) + metrics: ProducerMetrics = await scope.get(ProducerMetrics) + + initial_sent = metrics.messages_sent + + # Produce an event + ev = make_execution_requested_event(execution_id=f"exec-{uuid4().hex[:8]}") + await prod.produce(ev) + + # Exercise send_to_dlq path + topic = str(get_topic_for_event(ev.event_type)) + await prod.send_to_dlq(ev, original_topic=topic, error=RuntimeError("forced"), retry_count=1) + + # Verify metrics are being tracked + assert metrics.messages_sent >= initial_sent + 2 diff --git a/backend/tests/e2e/idempotency/test_consumer_idempotent.py b/backend/tests/e2e/idempotency/test_consumer_idempotent.py index 2ffae6ae..749a0ea3 100644 --- a/backend/tests/e2e/idempotency/test_consumer_idempotent.py +++ b/backend/tests/e2e/idempotency/test_consumer_idempotent.py @@ -3,11 +3,12 @@ import uuid import pytest +from aiokafka import AIOKafkaConsumer from app.core.metrics import EventMetrics from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer +from app.events.core import EventDispatcher, UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher as Disp from app.events.schema.schema_registry import SchemaRegistryManager from app.domain.idempotency import KeyStrategy @@ -57,23 +58,28 @@ async def handle(_ev: DomainEvent) -> None: await producer.produce(ev, key=execution_id) await producer.produce(ev, key=execution_id) - # Real consumer with idempotent wrapper - cfg = ConsumerConfig( + group_id = f"test-idem-consumer.{uuid.uuid4().hex[:6]}" + + # Create AIOKafkaConsumer directly + topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EXECUTION_EVENTS}" + kafka_consumer = AIOKafkaConsumer( + topic, bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"test-idem-consumer.{uuid.uuid4().hex[:6]}", + group_id=group_id, enable_auto_commit=True, auto_offset_reset="earliest", ) - base = UnifiedConsumer( - cfg, + await kafka_consumer.start() + + handler = UnifiedConsumer( event_dispatcher=disp, schema_registry=registry, - settings=settings, logger=_test_logger, event_metrics=event_metrics, + group_id=group_id, ) wrapper = IdempotentConsumerWrapper( - consumer=base, + consumer=handler, idempotency_manager=idm, dispatcher=disp, default_key_strategy=KeyStrategy.EVENT_BASED, @@ -81,10 +87,16 @@ async def handle(_ev: DomainEvent) -> None: logger=_test_logger, ) - await wrapper.start([KafkaTopic.EXECUTION_EVENTS]) try: - # Await the future directly - true async, no polling - await asyncio.wait_for(handled_future, timeout=10.0) + # Consume until handler is called + async def consume_until_handled() -> None: + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + if handled_future.done(): + break + + await asyncio.wait_for(consume_until_handled(), timeout=10.0) assert seen["n"] >= 1 finally: - await wrapper.stop() + await kafka_consumer.stop() diff --git a/backend/tests/e2e/result_processor/test_result_processor.py b/backend/tests/e2e/result_processor/test_result_processor.py deleted file mode 100644 index 4f5b11f3..00000000 --- a/backend/tests/e2e/result_processor/test_result_processor.py +++ /dev/null @@ -1,134 +0,0 @@ -import asyncio -import logging -import uuid - -import pytest -from app.core.database_context import Database -from app.core.metrics import EventMetrics, ExecutionMetrics -from app.db.repositories.execution_repository import ExecutionRepository -from app.domain.enums.events import EventType -from app.domain.enums.execution import ExecutionStatus -from app.domain.enums.kafka import KafkaTopic -from app.domain.events.typed import ( - EventMetadata, - ExecutionCompletedEvent, - ResourceUsageDomain, - ResultStoredEvent, -) -from app.domain.execution import DomainExecutionCreate -from app.events.core import UnifiedConsumer, UnifiedProducer -from app.events.core.dispatcher import EventDispatcher -from app.events.core.types import ConsumerConfig -from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.idempotency import IdempotencyManager -from app.services.result_processor.processor import ResultProcessor -from app.settings import Settings -from dishka import AsyncContainer - -# xdist_group: Kafka consumer creation can crash librdkafka when multiple workers -# instantiate Consumer() objects simultaneously. Serial execution prevents this. -pytestmark = [ - pytest.mark.e2e, - pytest.mark.kafka, - pytest.mark.mongodb, - pytest.mark.xdist_group("kafka_consumers"), -] - -_test_logger = logging.getLogger("test.result_processor.processor") - - -@pytest.mark.asyncio -async def test_result_processor_persists_and_emits(scope: AsyncContainer) -> None: - # Ensure schemas - registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - settings: Settings = await scope.get(Settings) - event_metrics: EventMetrics = await scope.get(EventMetrics) - execution_metrics: ExecutionMetrics = await scope.get(ExecutionMetrics) - await initialize_event_schemas(registry) - - # Dependencies - db: Database = await scope.get(Database) - repo: ExecutionRepository = await scope.get(ExecutionRepository) - producer: UnifiedProducer = await scope.get(UnifiedProducer) - idem: IdempotencyManager = await scope.get(IdempotencyManager) - - # Create a base execution to satisfy ResultProcessor lookup - created = await repo.create_execution(DomainExecutionCreate( - script="print('x')", - user_id="u1", - lang="python", - lang_version="3.11", - status=ExecutionStatus.RUNNING, - )) - execution_id = created.execution_id - - # Build and start the processor - processor = ResultProcessor( - execution_repo=repo, - producer=producer, - schema_registry=registry, - settings=settings, - idempotency_manager=idem, - logger=_test_logger, - execution_metrics=execution_metrics, - event_metrics=event_metrics, - ) - - # Setup a small consumer to capture ResultStoredEvent - dispatcher = EventDispatcher(logger=_test_logger) - stored_received = asyncio.Event() - - @dispatcher.register(EventType.RESULT_STORED) - async def _stored(event: ResultStoredEvent) -> None: - if event.execution_id == execution_id: - stored_received.set() - - group_id = f"rp-test.{uuid.uuid4().hex[:6]}" - cconf = ConsumerConfig( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=group_id, - enable_auto_commit=True, - auto_offset_reset="earliest", - ) - stored_consumer = UnifiedConsumer( - cconf, - dispatcher, - schema_registry=registry, - settings=settings, - logger=_test_logger, - event_metrics=event_metrics, - ) - - # Produce the event BEFORE starting consumers (auto_offset_reset="earliest" will read it) - usage = ResourceUsageDomain( - execution_time_wall_seconds=0.5, - cpu_time_jiffies=100, - clk_tck_hertz=100, - peak_memory_kb=1024, - ) - evt = ExecutionCompletedEvent( - execution_id=execution_id, - exit_code=0, - stdout="hello", - stderr="", - resource_usage=usage, - metadata=EventMetadata(service_name="tests", service_version="1.0.0"), - ) - await producer.produce(evt, key=execution_id) - - # Start consumers after producing - await stored_consumer.start([KafkaTopic.EXECUTION_RESULTS]) - - try: - async with processor: - # Await the ResultStoredEvent - signals that processing is complete - await asyncio.wait_for(stored_received.wait(), timeout=12.0) - - # Now verify DB persistence - should be done since event was emitted - doc = await db.get_collection("executions").find_one({"execution_id": execution_id}) - assert doc is not None, f"Execution {execution_id} not found in DB after ResultStoredEvent" - assert doc.get("status") == ExecutionStatus.COMPLETED, ( - f"Expected COMPLETED status, got {doc.get('status')}" - ) - finally: - await stored_consumer.stop() diff --git a/backend/tests/e2e/services/coordinator/test_execution_coordinator.py b/backend/tests/e2e/services/coordinator/test_execution_coordinator.py index 5406c7b4..472ebc0e 100644 --- a/backend/tests/e2e/services/coordinator/test_execution_coordinator.py +++ b/backend/tests/e2e/services/coordinator/test_execution_coordinator.py @@ -3,148 +3,36 @@ from dishka import AsyncContainer from tests.conftest import make_execution_requested_event -pytestmark = [pytest.mark.e2e, pytest.mark.kafka] +pytestmark = [pytest.mark.e2e, pytest.mark.kafka, pytest.mark.redis] -class TestHandleExecutionRequested: - """Tests for _handle_execution_requested method.""" +class TestExecutionCoordinator: + """Tests for ExecutionCoordinator handler methods.""" @pytest.mark.asyncio - async def test_handle_requested_schedules_execution( - self, scope: AsyncContainer - ) -> None: - """Handler schedules execution immediately.""" + async def test_handle_requested_does_not_raise(self, scope: AsyncContainer) -> None: + """Handler processes execution request without error.""" coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - ev = make_execution_requested_event(execution_id="e-sched-1") + ev = make_execution_requested_event(execution_id="e-test-1") - await coord._handle_execution_requested(ev) # noqa: SLF001 - - assert "e-sched-1" in coord._active_executions # noqa: SLF001 + # Should not raise + await coord.handle_execution_requested(ev) @pytest.mark.asyncio - async def test_handle_requested_with_priority( - self, scope: AsyncContainer - ) -> None: + async def test_handle_requested_with_priority(self, scope: AsyncContainer) -> None: """Handler respects execution priority.""" coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - ev = make_execution_requested_event( - execution_id="e-priority-1", - priority=10, # High priority - ) - - await coord._handle_execution_requested(ev) # noqa: SLF001 - - assert "e-priority-1" in coord._active_executions # noqa: SLF001 - - @pytest.mark.asyncio - async def test_handle_requested_unique_executions( - self, scope: AsyncContainer - ) -> None: - """Each execution gets unique tracking.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - - ev1 = make_execution_requested_event(execution_id="e-unique-1") - ev2 = make_execution_requested_event(execution_id="e-unique-2") - - await coord._handle_execution_requested(ev1) # noqa: SLF001 - await coord._handle_execution_requested(ev2) # noqa: SLF001 - - assert "e-unique-1" in coord._active_executions # noqa: SLF001 - assert "e-unique-2" in coord._active_executions # noqa: SLF001 - - -class TestGetStatus: - """Tests for get_status method.""" - - @pytest.mark.asyncio - async def test_get_status_returns_dict(self, scope: AsyncContainer) -> None: - """Get status returns dictionary with coordinator info.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - - status = await coord.get_status() - - assert isinstance(status, dict) - assert "running" in status - assert "active_executions" in status - assert "queue_stats" in status - assert "resource_stats" in status - - @pytest.mark.asyncio - async def test_get_status_tracks_active_executions( - self, scope: AsyncContainer - ) -> None: - """Status tracks number of active executions.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - - initial_status = await coord.get_status() - initial_active = initial_status.get("active_executions", 0) - - # Add execution - ev = make_execution_requested_event(execution_id="e-status-track-1") - await coord._handle_execution_requested(ev) # noqa: SLF001 - - new_status = await coord.get_status() - new_active = new_status.get("active_executions", 0) - - assert new_active == initial_active + 1, ( - f"Expected exactly one more active execution: {initial_active} -> {new_active}" - ) - - -class TestQueueManager: - """Tests for queue manager integration.""" - - @pytest.mark.asyncio - async def test_queue_manager_initialized(self, scope: AsyncContainer) -> None: - """Queue manager is properly initialized.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - - assert coord.queue_manager is not None - assert hasattr(coord.queue_manager, "add_execution") - assert hasattr(coord.queue_manager, "get_next_execution") - - -class TestResourceManager: - """Tests for resource manager integration.""" - - @pytest.mark.asyncio - async def test_resource_manager_initialized( - self, scope: AsyncContainer - ) -> None: - """Resource manager is properly initialized.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - - assert coord.resource_manager is not None - assert hasattr(coord.resource_manager, "request_allocation") - assert hasattr(coord.resource_manager, "release_allocation") - - @pytest.mark.asyncio - async def test_resource_manager_has_pool( - self, scope: AsyncContainer - ) -> None: - """Resource manager has resource pool configured.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - - # Check resource manager has pool with capacity - assert coord.resource_manager.pool is not None - assert coord.resource_manager.pool.total_cpu_cores > 0 - assert coord.resource_manager.pool.total_memory_mb > 0 - - -class TestCoordinatorLifecycle: - """Tests for coordinator lifecycle.""" - - @pytest.mark.asyncio - async def test_coordinator_has_consumer(self, scope: AsyncContainer) -> None: - """Coordinator has Kafka consumer configured.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) + ev = make_execution_requested_event(execution_id="e-priority-1", priority=10) - # Consumer is set up during start, may be None before - assert hasattr(coord, "consumer") + await coord.handle_execution_requested(ev) @pytest.mark.asyncio - async def test_coordinator_has_producer(self, scope: AsyncContainer) -> None: - """Coordinator has Kafka producer configured.""" + async def test_coordinator_resolves_from_di(self, scope: AsyncContainer) -> None: + """Coordinator can be resolved from DI container.""" coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - assert coord.producer is not None + assert coord is not None + assert hasattr(coord, "handle_execution_requested") + assert hasattr(coord, "handle_execution_completed") + assert hasattr(coord, "handle_execution_failed") + assert hasattr(coord, "handle_execution_cancelled") diff --git a/backend/tests/e2e/services/events/test_event_bus.py b/backend/tests/e2e/services/events/test_event_bus.py index 5d87b290..f5f083eb 100644 --- a/backend/tests/e2e/services/events/test_event_bus.py +++ b/backend/tests/e2e/services/events/test_event_bus.py @@ -1,11 +1,11 @@ import asyncio +from datetime import datetime, timezone +from uuid import uuid4 import pytest from aiokafka import AIOKafkaProducer -from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic -from app.domain.events.typed import DomainEvent, EventMetadata, UserSettingsUpdatedEvent -from app.services.event_bus import EventBusManager +from app.services.event_bus import EventBus, EventBusEvent from app.settings import Settings from dishka import AsyncContainer @@ -15,24 +15,23 @@ @pytest.mark.asyncio async def test_event_bus_publish_subscribe(scope: AsyncContainer, test_settings: Settings) -> None: """Test EventBus receives events from other instances (cross-instance communication).""" - manager: EventBusManager = await scope.get(EventBusManager) - bus = await manager.get_event_bus() + bus: EventBus = await scope.get(EventBus) # Future resolves when handler receives the event - no polling needed - received_future: asyncio.Future[DomainEvent] = asyncio.get_running_loop().create_future() + received_future: asyncio.Future[EventBusEvent] = asyncio.get_running_loop().create_future() - async def handler(event: DomainEvent) -> None: + async def handler(event: EventBusEvent) -> None: if not received_future.done(): received_future.set_result(event) - await bus.subscribe(f"{EventType.USER_SETTINGS_UPDATED}*", handler) + await bus.subscribe("test.*", handler) # Simulate message from another instance by producing directly to Kafka - event = UserSettingsUpdatedEvent( - user_id="test-user", - changed_fields=["theme"], - reason="test", - metadata=EventMetadata(service_name="test", service_version="1.0"), + event = EventBusEvent( + id=str(uuid4()), + event_type="test.created", + timestamp=datetime.now(timezone.utc), + payload={"x": 1}, ) topic = f"{test_settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EVENT_BUS_STREAM}" @@ -42,7 +41,7 @@ async def handler(event: DomainEvent) -> None: await producer.send_and_wait( topic=topic, value=event.model_dump_json().encode("utf-8"), - key=EventType.USER_SETTINGS_UPDATED.encode("utf-8"), + key=b"test.created", headers=[("source_instance", b"other-instance")], ) finally: @@ -50,6 +49,4 @@ async def handler(event: DomainEvent) -> None: # Await the future directly - true async, no polling received = await asyncio.wait_for(received_future, timeout=10.0) - assert received.event_type == EventType.USER_SETTINGS_UPDATED - assert isinstance(received, UserSettingsUpdatedEvent) - assert received.user_id == "test-user" + assert received.event_type == "test.created" diff --git a/backend/tests/e2e/services/sse/test_partitioned_event_router.py b/backend/tests/e2e/services/sse/test_partitioned_event_router.py deleted file mode 100644 index 6bb6b71f..00000000 --- a/backend/tests/e2e/services/sse/test_partitioned_event_router.py +++ /dev/null @@ -1,81 +0,0 @@ -import asyncio -import logging -from uuid import uuid4 - -import pytest -import redis.asyncio as redis -from app.core.metrics import EventMetrics -from app.events.core import EventDispatcher -from app.events.schema.schema_registry import SchemaRegistryManager -from app.schemas_pydantic.sse import RedisSSEMessage -from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge -from app.services.sse.redis_bus import SSERedisBus -from app.settings import Settings - -from tests.conftest import make_execution_requested_event - -pytestmark = [pytest.mark.e2e, pytest.mark.redis] - -_test_logger = logging.getLogger("test.services.sse.partitioned_event_router_integration") - - -@pytest.mark.asyncio -async def test_router_bridges_to_redis(redis_client: redis.Redis, test_settings: Settings) -> None: - suffix = uuid4().hex[:6] - bus = SSERedisBus( - redis_client, - exec_prefix=f"sse:exec:{suffix}:", - notif_prefix=f"sse:notif:{suffix}:", - logger=_test_logger, - ) - router = SSEKafkaRedisBridge( - schema_registry=SchemaRegistryManager(settings=test_settings, logger=_test_logger), - settings=test_settings, - event_metrics=EventMetrics(test_settings), - sse_bus=bus, - logger=_test_logger, - ) - disp = EventDispatcher(logger=_test_logger) - router._register_routing_handlers(disp) - - # Open Redis subscription for our execution id - execution_id = f"e-{uuid4().hex[:8]}" - subscription = await bus.open_subscription(execution_id) - - ev = make_execution_requested_event(execution_id=execution_id) - handler = disp.get_handlers(ev.event_type)[0] - await handler(ev) - - # Await the subscription directly - true async, no polling - msg = await asyncio.wait_for(subscription.get(RedisSSEMessage), timeout=2.0) - assert msg is not None - assert str(msg.event_type) == str(ev.event_type) - - -@pytest.mark.asyncio -async def test_router_start_and_stop(redis_client: redis.Redis, test_settings: Settings) -> None: - test_settings.SSE_CONSUMER_POOL_SIZE = 1 - suffix = uuid4().hex[:6] - router = SSEKafkaRedisBridge( - schema_registry=SchemaRegistryManager(settings=test_settings, logger=_test_logger), - settings=test_settings, - event_metrics=EventMetrics(test_settings), - sse_bus=SSERedisBus( - redis_client, - exec_prefix=f"sse:exec:{suffix}:", - notif_prefix=f"sse:notif:{suffix}:", - logger=_test_logger, - ), - logger=_test_logger, - ) - - await router.__aenter__() - stats = router.get_stats() - assert stats["num_consumers"] == 1 - await router.aclose() - assert router.get_stats()["num_consumers"] == 0 - # idempotent start/stop - await router.__aenter__() - await router.__aenter__() - await router.aclose() - await router.aclose() diff --git a/backend/tests/e2e/test_k8s_worker_create_pod.py b/backend/tests/e2e/test_k8s_worker_create_pod.py index c43bb2e5..d1efcf80 100644 --- a/backend/tests/e2e/test_k8s_worker_create_pod.py +++ b/backend/tests/e2e/test_k8s_worker_create_pod.py @@ -2,13 +2,7 @@ import uuid import pytest -from app.core.metrics import EventMetrics from app.domain.events.typed import CreatePodCommandEvent, EventMetadata -from app.events.core import UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency import IdempotencyManager -from app.services.k8s_worker.config import K8sWorkerConfig from app.services.k8s_worker.worker import KubernetesWorker from app.settings import Settings from dishka import AsyncContainer @@ -25,27 +19,10 @@ async def test_worker_creates_configmap_and_pod( ) -> None: ns = test_settings.K8S_NAMESPACE - schema: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - store: EventStore = await scope.get(EventStore) - producer: UnifiedProducer = await scope.get(UnifiedProducer) - idem: IdempotencyManager = await scope.get(IdempotencyManager) - event_metrics: EventMetrics = await scope.get(EventMetrics) + # Get worker from DI (already configured with dependencies) + worker: KubernetesWorker = await scope.get(KubernetesWorker) - cfg = K8sWorkerConfig(namespace=ns, max_concurrent_pods=1) - worker = KubernetesWorker( - config=cfg, - producer=producer, - schema_registry_manager=schema, - settings=test_settings, - event_store=store, - idempotency_manager=idem, - logger=_test_logger, - event_metrics=event_metrics, - ) - - # Initialize k8s clients using worker's own method - worker._initialize_kubernetes_client() # noqa: SLF001 - if worker.v1 is None: + if worker._v1 is None: # noqa: SLF001 pytest.skip("Kubernetes cluster not available") exec_id = uuid.uuid4().hex[:8] @@ -68,7 +45,7 @@ async def test_worker_creates_configmap_and_pod( ) # Build and create ConfigMap + Pod - cm = worker.pod_builder.build_config_map( + cm = worker._pod_builder.build_config_map( # noqa: SLF001 command=cmd, script_content=cmd.script, entrypoint_content=await worker._get_entrypoint_script(), # noqa: SLF001 @@ -80,15 +57,15 @@ async def test_worker_creates_configmap_and_pod( pytest.skip(f"Insufficient permissions or namespace not found: {e}") raise - pod = worker.pod_builder.build_pod_manifest(cmd) + pod = worker._pod_builder.build_pod_manifest(cmd) # noqa: SLF001 await worker._create_pod(pod) # noqa: SLF001 # Verify resources exist - got_cm = worker.v1.read_namespaced_config_map(name=f"script-{exec_id}", namespace=ns) + got_cm = worker._v1.read_namespaced_config_map(name=f"script-{exec_id}", namespace=ns) # noqa: SLF001 assert got_cm is not None - got_pod = worker.v1.read_namespaced_pod(name=f"executor-{exec_id}", namespace=ns) + got_pod = worker._v1.read_namespaced_pod(name=f"executor-{exec_id}", namespace=ns) # noqa: SLF001 assert got_pod is not None # Cleanup - worker.v1.delete_namespaced_pod(name=f"executor-{exec_id}", namespace=ns) - worker.v1.delete_namespaced_config_map(name=f"script-{exec_id}", namespace=ns) + worker._v1.delete_namespaced_pod(name=f"executor-{exec_id}", namespace=ns) # noqa: SLF001 + worker._v1.delete_namespaced_config_map(name=f"script-{exec_id}", namespace=ns) # noqa: SLF001 diff --git a/backend/tests/unit/services/coordinator/test_queue_manager.py b/backend/tests/unit/services/coordinator/test_queue_manager.py deleted file mode 100644 index 671b19a7..00000000 --- a/backend/tests/unit/services/coordinator/test_queue_manager.py +++ /dev/null @@ -1,41 +0,0 @@ -import logging - -import pytest -from app.core.metrics import CoordinatorMetrics -from app.domain.events.typed import ExecutionRequestedEvent -from app.services.coordinator.queue_manager import QueueManager, QueuePriority - -from tests.conftest import make_execution_requested_event - -_test_logger = logging.getLogger("test.services.coordinator.queue_manager") - -pytestmark = pytest.mark.unit - - -def ev(execution_id: str, priority: int = QueuePriority.NORMAL.value) -> ExecutionRequestedEvent: - return make_execution_requested_event(execution_id=execution_id, priority=priority) - - -@pytest.mark.asyncio -async def test_requeue_execution_increments_priority(coordinator_metrics: CoordinatorMetrics) -> None: - qm = QueueManager(max_queue_size=10, logger=_test_logger, coordinator_metrics=coordinator_metrics) - await qm.start() - # Use NORMAL priority which can be incremented to LOW - e = ev("x", priority=QueuePriority.NORMAL.value) - await qm.add_execution(e) - await qm.requeue_execution(e, increment_retry=True) - nxt = await qm.get_next_execution() - assert nxt is not None - await qm.stop() - - -@pytest.mark.asyncio -async def test_queue_stats_empty_and_after_add(coordinator_metrics: CoordinatorMetrics) -> None: - qm = QueueManager(max_queue_size=5, logger=_test_logger, coordinator_metrics=coordinator_metrics) - await qm.start() - stats0 = await qm.get_queue_stats() - assert stats0["total_size"] == 0 - await qm.add_execution(ev("a")) - st = await qm.get_queue_stats() - assert st["total_size"] == 1 - await qm.stop() diff --git a/backend/tests/unit/services/coordinator/test_resource_manager.py b/backend/tests/unit/services/coordinator/test_resource_manager.py deleted file mode 100644 index 3624dae6..00000000 --- a/backend/tests/unit/services/coordinator/test_resource_manager.py +++ /dev/null @@ -1,61 +0,0 @@ -import logging - -import pytest -from app.core.metrics import CoordinatorMetrics -from app.services.coordinator.resource_manager import ResourceManager - -_test_logger = logging.getLogger("test.services.coordinator.resource_manager") - - -@pytest.mark.asyncio -async def test_request_allocation_defaults_and_limits(coordinator_metrics: CoordinatorMetrics) -> None: - rm = ResourceManager(total_cpu_cores=8.0, total_memory_mb=16384, total_gpu_count=0, logger=_test_logger, coordinator_metrics=coordinator_metrics) - - # Default for python - alloc = await rm.request_allocation("e1", "python") - assert alloc is not None - - assert alloc.cpu_cores > 0 - assert alloc.memory_mb > 0 - - # Respect per-exec max cap - alloc2 = await rm.request_allocation("e2", "python", requested_cpu=100.0, requested_memory_mb=999999) - assert alloc2 is not None - assert alloc2.cpu_cores <= rm.pool.max_cpu_per_execution - assert alloc2.memory_mb <= rm.pool.max_memory_per_execution_mb - - -@pytest.mark.asyncio -async def test_release_and_can_allocate(coordinator_metrics: CoordinatorMetrics) -> None: - rm = ResourceManager(total_cpu_cores=4.0, total_memory_mb=8192, total_gpu_count=0, logger=_test_logger, coordinator_metrics=coordinator_metrics) - - a = await rm.request_allocation("e1", "python", requested_cpu=1.0, requested_memory_mb=512) - assert a is not None - - ok = await rm.release_allocation("e1") - assert ok is True - - # After release, can allocate near limits while preserving headroom. - # Use a tiny epsilon to avoid edge rounding issues in >= comparisons. - epsilon_cpu = 1e-6 - epsilon_mem = 1 - can = await rm.can_allocate(cpu_cores=rm.pool.total_cpu_cores - rm.pool.min_available_cpu_cores - epsilon_cpu, - memory_mb=rm.pool.total_memory_mb - rm.pool.min_available_memory_mb - epsilon_mem, - gpu_count=0) - assert can is True - - -@pytest.mark.asyncio -async def test_resource_stats(coordinator_metrics: CoordinatorMetrics) -> None: - rm = ResourceManager(total_cpu_cores=2.0, total_memory_mb=4096, total_gpu_count=0, logger=_test_logger, coordinator_metrics=coordinator_metrics) - # Make sure the allocation succeeds - alloc = await rm.request_allocation("e1", "python", requested_cpu=0.5, requested_memory_mb=256) - assert alloc is not None, "Allocation should have succeeded" - - stats = await rm.get_resource_stats() - - assert stats.total.cpu_cores > 0 - assert stats.available.cpu_cores >= 0 - assert stats.allocated.cpu_cores > 0 # Should be > 0 since we allocated - assert stats.utilization["cpu_percent"] >= 0 - assert stats.allocation_count >= 1 # Should be at least 1 (may have system allocations) diff --git a/backend/tests/unit/services/pod_monitor/test_monitor.py b/backend/tests/unit/services/pod_monitor/test_monitor.py index 283d428e..8916af97 100644 --- a/backend/tests/unit/services/pod_monitor/test_monitor.py +++ b/backend/tests/unit/services/pod_monitor/test_monitor.py @@ -1,3 +1,5 @@ +"""Tests for stateless PodMonitor handler.""" + import asyncio import logging import types @@ -5,37 +7,24 @@ from unittest.mock import MagicMock import pytest -from app.core import k8s_clients as k8s_clients_module -from app.core.k8s_clients import K8sClients +from kubernetes import client as k8s_client + from app.core.metrics import EventMetrics, KubernetesMetrics -from app.db.repositories.event_repository import EventRepository -from app.domain.events.typed import ( - DomainEvent, - EventMetadata, - ExecutionCompletedEvent, - ExecutionStartedEvent, - ResourceUsageDomain, -) +from app.db.repositories.pod_state_repository import PodStateRepository +from app.domain.events.typed import DomainEvent, EventMetadata, ExecutionCompletedEvent +from app.domain.execution.models import ResourceUsageDomain from app.events.core import UnifiedProducer -from app.services.kafka_event_service import KafkaEventService from app.services.pod_monitor.config import PodMonitorConfig from app.services.pod_monitor.event_mapper import PodEventMapper from app.services.pod_monitor.monitor import ( - MonitorState, PodEvent, PodMonitor, ReconciliationResult, WatchEventType, - create_pod_monitor, ) -from app.settings import Settings -from kubernetes.client.rest import ApiException from tests.unit.services.pod_monitor.conftest import ( - MockWatchStream, - Pod, make_mock_v1_api, - make_mock_watch, make_pod, ) @@ -44,55 +33,52 @@ _test_logger = logging.getLogger("test.pod_monitor") -# ===== Test doubles for KafkaEventService dependencies ===== - - -class FakeEventRepository(EventRepository): - """In-memory event repository for testing.""" - - def __init__(self) -> None: - super().__init__(_test_logger) - self.stored_events: list[DomainEvent] = [] - - async def store_event(self, event: DomainEvent) -> str: - self.stored_events.append(event) - return event.event_id - - class FakeUnifiedProducer(UnifiedProducer): """Fake producer that captures events without Kafka.""" def __init__(self) -> None: - # Don't call super().__init__ - we don't need real Kafka self.produced_events: list[tuple[DomainEvent, str | None]] = [] - self.logger = _test_logger + self._logger = _test_logger async def produce( self, event_to_produce: DomainEvent, key: str | None = None, headers: dict[str, str] | None = None ) -> None: self.produced_events.append((event_to_produce, key)) - async def aclose(self) -> None: - pass +class FakePodStateRepository: + """Fake pod state repository for testing.""" -def create_test_kafka_event_service(event_metrics: EventMetrics) -> tuple[KafkaEventService, FakeUnifiedProducer]: - """Create real KafkaEventService with fake dependencies for testing.""" - fake_producer = FakeUnifiedProducer() - fake_repo = FakeEventRepository() - settings = Settings() # Uses defaults/env vars + def __init__(self) -> None: + self._tracked: set[str] = set() + self._resource_version: str | None = None - service = KafkaEventService( - event_repository=fake_repo, - kafka_producer=fake_producer, - settings=settings, - logger=_test_logger, - event_metrics=event_metrics, - ) - return service, fake_producer + async def track_pod( + self, pod_name: str, execution_id: str, status: str, + metadata: dict[str, object] | None = None, ttl_seconds: int = 7200, + ) -> None: + self._tracked.add(pod_name) + async def untrack_pod(self, pod_name: str) -> bool: + if pod_name in self._tracked: + self._tracked.discard(pod_name) + return True + return False -# ===== Helpers to create test instances with pure DI ===== + async def is_pod_tracked(self, pod_name: str) -> bool: + return pod_name in self._tracked + + async def get_tracked_pod_names(self) -> set[str]: + return self._tracked.copy() + + async def get_tracked_pods_count(self) -> int: + return len(self._tracked) + + async def get_resource_version(self) -> str | None: + return self._resource_version + + async def set_resource_version(self, version: str) -> None: + self._resource_version = version class SpyMapper: @@ -108,43 +94,28 @@ def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: # n return [] -def make_k8s_clients_di( - events: list[dict[str, Any]] | None = None, - resource_version: str = "rv1", - pods: list[Pod] | None = None, - logs: str = "{}", -) -> K8sClients: - """Create K8sClients for DI with mocks.""" - v1 = make_mock_v1_api(logs=logs, pods=pods) - watch = make_mock_watch(events or [], resource_version) - return K8sClients( - api_client=MagicMock(), - v1=v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=watch, - ) - - def make_pod_monitor( event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, config: PodMonitorConfig | None = None, - kafka_service: KafkaEventService | None = None, - k8s_clients: K8sClients | None = None, + producer: UnifiedProducer | None = None, + pod_state_repo: PodStateRepository | None = None, + v1_client: k8s_client.CoreV1Api | None = None, event_mapper: PodEventMapper | None = None, ) -> PodMonitor: """Create PodMonitor with sensible test defaults.""" cfg = config or PodMonitorConfig() - clients = k8s_clients or make_k8s_clients_di() + prod = producer or FakeUnifiedProducer() + repo = pod_state_repo or FakePodStateRepository() + v1 = v1_client or make_mock_v1_api("{}") mapper = event_mapper or PodEventMapper(logger=_test_logger, k8s_api=make_mock_v1_api("{}")) - service = kafka_service or create_test_kafka_event_service(event_metrics)[0] return PodMonitor( config=cfg, - kafka_event_service=service, - logger=_test_logger, - k8s_clients=clients, + producer=prod, + pod_state_repo=repo, # type: ignore[arg-type] + v1_client=v1, event_mapper=mapper, + logger=_test_logger, kubernetes_metrics=kubernetes_metrics, ) @@ -153,251 +124,127 @@ def make_pod_monitor( @pytest.mark.asyncio -async def test_start_and_stop_lifecycle(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False - - spy = SpyMapper() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, event_mapper=spy) # type: ignore[arg-type] +async def test_handle_raw_event_tracks_pod(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that handle_raw_event tracks new pods.""" + fake_repo = FakePodStateRepository() + pm = make_pod_monitor(event_metrics, kubernetes_metrics, pod_state_repo=fake_repo) # type: ignore[arg-type] - # Replace _watch_pods to avoid real watch loop - async def _quick_watch() -> None: - return None + pod = make_pod(name="test-pod", phase="Running", labels={"execution-id": "e1"}, resource_version="v1") + raw_event = {"type": "ADDED", "object": pod} - pm._watch_pods = _quick_watch # type: ignore[method-assign] + await pm.handle_raw_event(raw_event) - await pm.__aenter__() - assert pm.state == MonitorState.RUNNING - - await pm.aclose() - final_state: MonitorState = pm.state - assert final_state == MonitorState.STOPPED - assert spy.cleared is True + assert "test-pod" in fake_repo._tracked @pytest.mark.asyncio -async def test_watch_pod_events_flow_and_publish(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False +async def test_handle_raw_event_untracks_deleted_pod(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that handle_raw_event untracks deleted pods.""" + fake_repo = FakePodStateRepository() + fake_repo._tracked.add("test-pod") + pm = make_pod_monitor(event_metrics, kubernetes_metrics, pod_state_repo=fake_repo) # type: ignore[arg-type] - pod = make_pod(name="p", phase="Succeeded", labels={"execution-id": "e1"}, term_exit=0, resource_version="rv1") - k8s_clients = make_k8s_clients_di(events=[{"type": "MODIFIED", "object": pod}], resource_version="rv2") + pod = make_pod(name="test-pod", phase="Succeeded", labels={"execution-id": "e1"}, resource_version="v2") + raw_event = {"type": "DELETED", "object": pod} - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) - pm._state = MonitorState.RUNNING + await pm.handle_raw_event(raw_event) - await pm._watch_pod_events() - assert pm._last_resource_version == "rv2" + assert "test-pod" not in fake_repo._tracked @pytest.mark.asyncio -async def test_process_raw_event_invalid_and_handle_watch_error(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) +async def test_handle_raw_event_updates_resource_version(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that handle_raw_event updates resource version.""" + fake_repo = FakePodStateRepository() + pm = make_pod_monitor(event_metrics, kubernetes_metrics, pod_state_repo=fake_repo) # type: ignore[arg-type] - await pm._process_raw_event({}) + pod = make_pod(name="test-pod", phase="Running", labels={"execution-id": "e1"}, resource_version="v123") + raw_event = {"type": "ADDED", "object": pod} - pm.config.watch_reconnect_delay = 0 - pm._reconnect_attempts = 0 - await pm._handle_watch_error() - await pm._handle_watch_error() - assert pm._reconnect_attempts >= 2 + await pm.handle_raw_event(raw_event) + + assert fake_repo._resource_version == "v123" @pytest.mark.asyncio -async def test_get_status(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.namespace = "test-ns" - cfg.label_selector = "app=test" - cfg.enable_state_reconciliation = True +async def test_handle_raw_event_invalid_event(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that handle_raw_event handles invalid events gracefully.""" + pm = make_pod_monitor(event_metrics, kubernetes_metrics) - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._tracked_pods = {"pod1", "pod2"} - pm._reconnect_attempts = 3 - pm._last_resource_version = "v123" + # Should not raise for empty event + await pm.handle_raw_event({}) - status = await pm.get_status() - assert "idle" in status["state"].lower() - assert status["tracked_pods"] == 2 - assert status["reconnect_attempts"] == 3 - assert status["last_resource_version"] == "v123" - assert status["config"]["namespace"] == "test-ns" - assert status["config"]["label_selector"] == "app=test" - assert status["config"]["enable_reconciliation"] is True + # Should not raise for event without object + await pm.handle_raw_event({"type": "ADDED"}) @pytest.mark.asyncio -async def test_reconciliation_loop_and_state(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_handle_raw_event_ignored_phase(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that handle_raw_event ignores configured phases.""" cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = True - cfg.reconcile_interval_seconds = 0 # sleep(0) yields control immediately - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - reconcile_called: list[bool] = [] - - async def mock_reconcile() -> ReconciliationResult: - reconcile_called.append(True) - return ReconciliationResult(missing_pods={"p1"}, extra_pods={"p2"}, duration_seconds=0.1, success=True) - - evt = asyncio.Event() - - async def wrapped_reconcile() -> ReconciliationResult: - res = await mock_reconcile() - evt.set() - return res + cfg.ignored_pod_phases = ["Unknown"] + fake_repo = FakePodStateRepository() + pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, pod_state_repo=fake_repo) # type: ignore[arg-type] - pm._reconcile_state = wrapped_reconcile # type: ignore[method-assign] + pod = make_pod(name="ignored-pod", phase="Unknown", labels={"execution-id": "e1"}, resource_version="v1") + raw_event = {"type": "ADDED", "object": pod} - task = asyncio.create_task(pm._reconciliation_loop()) - await asyncio.wait_for(evt.wait(), timeout=1.0) - pm._state = MonitorState.STOPPED - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task + await pm.handle_raw_event(raw_event) - assert len(reconcile_called) > 0 + # Pod should not be tracked due to ignored phase + assert "ignored-pod" not in fake_repo._tracked @pytest.mark.asyncio -async def test_reconcile_state_success(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_reconcile_state_finds_missing_pods(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that reconcile_state identifies missing pods.""" cfg = PodMonitorConfig() cfg.namespace = "test" cfg.label_selector = "app=test" pod1 = make_pod(name="pod1", phase="Running", resource_version="v1") pod2 = make_pod(name="pod2", phase="Running", resource_version="v1") - k8s_clients = make_k8s_clients_di(pods=[pod1, pod2]) - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) - pm._tracked_pods = {"pod2", "pod3"} - processed: list[str] = [] + mock_v1 = MagicMock() + mock_v1.list_namespaced_pod.return_value = MagicMock(items=[pod1, pod2]) - async def mock_process(event: PodEvent) -> None: - processed.append(event.pod.metadata.name) + fake_repo = FakePodStateRepository() + fake_repo._tracked.add("pod2") + fake_repo._tracked.add("pod3") # Extra pod not in K8s - pm._process_pod_event = mock_process # type: ignore[method-assign] + pm = make_pod_monitor( + event_metrics, kubernetes_metrics, config=cfg, pod_state_repo=fake_repo, v1_client=mock_v1 # type: ignore[arg-type] + ) - result = await pm._reconcile_state() + result = await pm.reconcile_state() assert result.success is True assert result.missing_pods == {"pod1"} assert result.extra_pods == {"pod3"} - assert "pod1" in processed - assert "pod3" not in pm._tracked_pods @pytest.mark.asyncio -async def test_reconcile_state_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_reconcile_state_handles_api_error(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that reconcile_state handles API errors gracefully.""" cfg = PodMonitorConfig() - fail_v1 = MagicMock() - fail_v1.list_namespaced_pod.side_effect = RuntimeError("API error") + mock_v1 = MagicMock() + mock_v1.list_namespaced_pod.side_effect = RuntimeError("API error") - k8s_clients = K8sClients( - api_client=MagicMock(), - v1=fail_v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=make_mock_watch([]), - ) + pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, v1_client=mock_v1) - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) + result = await pm.reconcile_state() - result = await pm._reconcile_state() assert result.success is False assert result.error is not None assert "API error" in result.error @pytest.mark.asyncio -async def test_process_pod_event_full_flow(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.ignored_pod_phases = ["Unknown"] - - class MockMapper: - def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: # noqa: ARG002 - class Event: - event_type = types.SimpleNamespace(value="test_event") - metadata = types.SimpleNamespace(correlation_id=None) - aggregate_id = "agg1" - - return [Event()] - - def clear_cache(self) -> None: - pass - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, event_mapper=MockMapper()) # type: ignore[arg-type] - - published: list[Any] = [] - - async def mock_publish(event: Any, pod: Any) -> None: # noqa: ARG001 - published.append(event) - - pm._publish_event = mock_publish # type: ignore[method-assign] - - event = PodEvent( - event_type=WatchEventType.ADDED, - pod=make_pod(name="test-pod", phase="Running"), - resource_version="v1", - ) - - await pm._process_pod_event(event) - assert "test-pod" in pm._tracked_pods - assert pm._last_resource_version == "v1" - assert len(published) == 1 - - event_del = PodEvent( - event_type=WatchEventType.DELETED, - pod=make_pod(name="test-pod", phase="Succeeded"), - resource_version="v2", - ) - - await pm._process_pod_event(event_del) - assert "test-pod" not in pm._tracked_pods - assert pm._last_resource_version == "v2" - - event_ignored = PodEvent( - event_type=WatchEventType.ADDED, - pod=make_pod(name="ignored-pod", phase="Unknown"), - resource_version="v3", - ) - - published.clear() - await pm._process_pod_event(event_ignored) - assert len(published) == 0 - - -@pytest.mark.asyncio -async def test_process_pod_event_exception_handling(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - - class FailMapper: - def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: - raise RuntimeError("Mapping failed") - - def clear_cache(self) -> None: - pass - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, event_mapper=FailMapper()) # type: ignore[arg-type] - - event = PodEvent( - event_type=WatchEventType.ADDED, - pod=make_pod(name="fail-pod", phase="Pending"), - resource_version=None, - ) - - # Should not raise - errors are caught and logged - await pm._process_pod_event(event) - - -@pytest.mark.asyncio -async def test_publish_event_full_flow(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - service, fake_producer = create_test_kafka_event_service(event_metrics) - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, kafka_service=service) +async def test_publish_event(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that events are published correctly.""" + fake_producer = FakeUnifiedProducer() + pm = make_pod_monitor(event_metrics, kubernetes_metrics, producer=fake_producer) event = ExecutionCompletedEvent( execution_id="exec1", @@ -415,387 +262,58 @@ async def test_publish_event_full_flow(event_metrics: EventMetrics, kubernetes_m @pytest.mark.asyncio -async def test_publish_event_exception_handling(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - - class FailingProducer(FakeUnifiedProducer): - async def produce( - self, event_to_produce: DomainEvent, key: str | None = None, headers: dict[str, str] | None = None - ) -> None: - raise RuntimeError("Publish failed") - - # Create service with failing producer - failing_producer = FailingProducer() - fake_repo = FakeEventRepository() - failing_service = KafkaEventService( - event_repository=fake_repo, - kafka_producer=failing_producer, - settings=Settings(), - logger=_test_logger, - event_metrics=event_metrics, - ) - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, kafka_service=failing_service) - - event = ExecutionStartedEvent( - execution_id="exec1", - pod_name="test-pod", - metadata=EventMetadata(service_name="test", service_version="1.0"), - ) - - # Use pod with no metadata to exercise edge case - pod = make_pod(name="no-meta-pod", phase="Pending") - pod.metadata = None # type: ignore[assignment] - - # Should not raise - errors are caught and logged - await pm._publish_event(event, pod) - - -@pytest.mark.asyncio -async def test_handle_watch_error_max_attempts(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.max_reconnect_attempts = 2 - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - pm._reconnect_attempts = 2 - - await pm._handle_watch_error() - - assert pm._state == MonitorState.STOPPING - - -@pytest.mark.asyncio -async def test_watch_pods_main_loop(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - watch_count: list[int] = [] - - async def mock_watch() -> None: - watch_count.append(1) - if len(watch_count) > 2: - pm._state = MonitorState.STOPPED - - async def mock_handle_error() -> None: - pass - - pm._watch_pod_events = mock_watch # type: ignore[method-assign] - pm._handle_watch_error = mock_handle_error # type: ignore[method-assign] - - await pm._watch_pods() - assert len(watch_count) > 2 - - -@pytest.mark.asyncio -async def test_watch_pods_api_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - async def mock_watch() -> None: - raise ApiException(status=410) - - error_handled: list[bool] = [] - - async def mock_handle() -> None: - error_handled.append(True) - pm._state = MonitorState.STOPPED - - pm._watch_pod_events = mock_watch # type: ignore[method-assign] - pm._handle_watch_error = mock_handle # type: ignore[method-assign] - - await pm._watch_pods() - - assert pm._last_resource_version is None - assert len(error_handled) > 0 - - -@pytest.mark.asyncio -async def test_watch_pods_generic_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - async def mock_watch() -> None: - raise RuntimeError("Unexpected error") - - error_handled: list[bool] = [] - - async def mock_handle() -> None: - error_handled.append(True) - pm._state = MonitorState.STOPPED - - pm._watch_pod_events = mock_watch # type: ignore[method-assign] - pm._handle_watch_error = mock_handle # type: ignore[method-assign] - - await pm._watch_pods() - assert len(error_handled) > 0 - - -@pytest.mark.asyncio -async def test_create_pod_monitor_context_manager(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, monkeypatch: pytest.MonkeyPatch) -> None: - """Test create_pod_monitor factory with auto-created dependencies.""" - # Mock create_k8s_clients to avoid real K8s connection - mock_v1 = make_mock_v1_api() - mock_watch = make_mock_watch([]) - mock_clients = K8sClients( - api_client=MagicMock(), - v1=mock_v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=mock_watch, - ) - - def mock_create_clients( - logger: logging.Logger, # noqa: ARG001 - kubeconfig_path: str | None = None, # noqa: ARG001 - in_cluster: bool | None = None, # noqa: ARG001 - ) -> K8sClients: - return mock_clients - - monkeypatch.setattr(k8s_clients_module, "create_k8s_clients", mock_create_clients) - monkeypatch.setattr("app.services.pod_monitor.monitor.create_k8s_clients", mock_create_clients) - - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False - - service, _ = create_test_kafka_event_service(event_metrics) - - # Use the actual create_pod_monitor which will use our mocked create_k8s_clients - async with create_pod_monitor(cfg, service, _test_logger, kubernetes_metrics=kubernetes_metrics) as monitor: - assert monitor.state == MonitorState.RUNNING - - final_state: MonitorState = monitor.state - assert final_state == MonitorState.STOPPED - - -@pytest.mark.asyncio -async def test_create_pod_monitor_with_injected_k8s_clients(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - """Test create_pod_monitor with injected K8sClients (DI path).""" - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False - - service, _ = create_test_kafka_event_service(event_metrics) - - mock_v1 = make_mock_v1_api() - mock_watch = make_mock_watch([]) - mock_k8s_clients = K8sClients( - api_client=MagicMock(), - v1=mock_v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=mock_watch, - ) - - async with create_pod_monitor( - cfg, service, _test_logger, k8s_clients=mock_k8s_clients, kubernetes_metrics=kubernetes_metrics - ) as monitor: - assert monitor.state == MonitorState.RUNNING - assert monitor._clients is mock_k8s_clients - assert monitor._v1 is mock_v1 - - final_state: MonitorState = monitor.state - assert final_state == MonitorState.STOPPED - - -@pytest.mark.asyncio -async def test_start_already_running(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - """Test idempotent start via __aenter__.""" - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - - # Simulate already started state - pm._lifecycle_started = True - pm._state = MonitorState.RUNNING - - # Should be idempotent - just return self - await pm.__aenter__() - - -@pytest.mark.asyncio -async def test_stop_already_stopped(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - """Test idempotent stop via aclose().""" - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.STOPPED - # Not started, so aclose should be a no-op - - await pm.aclose() - - -@pytest.mark.asyncio -async def test_stop_with_tasks(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - """Test cleanup of tasks on aclose().""" - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - pm._lifecycle_started = True - - async def dummy_task() -> None: - await asyncio.Event().wait() - - pm._watch_task = asyncio.create_task(dummy_task()) - pm._reconcile_task = asyncio.create_task(dummy_task()) - pm._tracked_pods = {"pod1"} - - await pm.aclose() - - assert pm._state == MonitorState.STOPPED - assert len(pm._tracked_pods) == 0 - - -def test_update_resource_version(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - - class Stream: - _stop_event = types.SimpleNamespace(resource_version="v123") - - pm._update_resource_version(Stream()) - assert pm._last_resource_version == "v123" - - class BadStream: - pass - - pm._update_resource_version(BadStream()) - - -@pytest.mark.asyncio -async def test_process_raw_event_with_metadata(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - - processed: list[PodEvent] = [] - - async def mock_process(event: PodEvent) -> None: - processed.append(event) - - pm._process_pod_event = mock_process # type: ignore[method-assign] - - raw_event = { - "type": "ADDED", - "object": types.SimpleNamespace(metadata=types.SimpleNamespace(resource_version="v1")), - } - - await pm._process_raw_event(raw_event) - assert len(processed) == 1 - assert processed[0].resource_version == "v1" - - raw_event_no_meta = {"type": "MODIFIED", "object": types.SimpleNamespace(metadata=None)} - - await pm._process_raw_event(raw_event_no_meta) - assert len(processed) == 2 - assert processed[1].resource_version is None - - -@pytest.mark.asyncio -async def test_watch_pods_api_exception_other_status(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - async def mock_watch() -> None: - raise ApiException(status=500) - - error_handled: list[bool] = [] - - async def mock_handle() -> None: - error_handled.append(True) - pm._state = MonitorState.STOPPED - - pm._watch_pod_events = mock_watch # type: ignore[method-assign] - pm._handle_watch_error = mock_handle # type: ignore[method-assign] - - await pm._watch_pods() - assert len(error_handled) > 0 - - -@pytest.mark.asyncio -async def test_watch_pod_events_with_field_selector(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.field_selector = "status.phase=Running" - cfg.enable_state_reconciliation = False - - watch_kwargs: list[dict[str, Any]] = [] - - tracking_v1 = MagicMock() - - def track_list(namespace: str, label_selector: str) -> None: - watch_kwargs.append({"namespace": namespace, "label_selector": label_selector}) - return None - - tracking_v1.list_namespaced_pod.side_effect = track_list - - tracking_watch = MagicMock() +async def test_process_pod_event_publishes_mapped_events(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that _process_pod_event publishes events from mapper.""" + fake_producer = FakeUnifiedProducer() + fake_repo = FakePodStateRepository() - def track_stream(func: Any, **kwargs: Any) -> MockWatchStream: # noqa: ARG001 - watch_kwargs.append(kwargs) - return MockWatchStream([], "rv1") + class MockMapper: + def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: # noqa: ARG002 + return [ + ExecutionCompletedEvent( + execution_id="e1", + aggregate_id="e1", + exit_code=0, + resource_usage=ResourceUsageDomain(), + metadata=EventMetadata(service_name="test", service_version="1.0"), + ) + ] - tracking_watch.stream.side_effect = track_stream - tracking_watch.stop.return_value = None + def clear_cache(self) -> None: + pass - k8s_clients = K8sClients( - api_client=MagicMock(), - v1=tracking_v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=tracking_watch, + pm = make_pod_monitor( + event_metrics, + kubernetes_metrics, + producer=fake_producer, + pod_state_repo=fake_repo, # type: ignore[arg-type] + event_mapper=MockMapper(), # type: ignore[arg-type] ) - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) - pm._state = MonitorState.RUNNING - - await pm._watch_pod_events() - - assert any("field_selector" in kw for kw in watch_kwargs) - - -@pytest.mark.asyncio -async def test_reconciliation_loop_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = True - cfg.reconcile_interval_seconds = 0 # sleep(0) yields control immediately - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - hit = asyncio.Event() + pod = make_pod(name="test-pod", phase="Running", labels={"execution-id": "e1"}) + event = PodEvent(event_type=WatchEventType.ADDED, pod=pod, resource_version="v1") - async def raising() -> ReconciliationResult: - hit.set() - raise RuntimeError("Reconcile error") - - pm._reconcile_state = raising # type: ignore[method-assign] + await pm._process_pod_event(event) - task = asyncio.create_task(pm._reconciliation_loop()) - await asyncio.wait_for(hit.wait(), timeout=1.0) - pm._state = MonitorState.STOPPED - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task + assert len(fake_producer.produced_events) == 1 + assert "test-pod" in fake_repo._tracked @pytest.mark.asyncio -async def test_start_with_reconciliation(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = True - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) +async def test_process_pod_event_handles_mapper_error(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that _process_pod_event handles mapper errors gracefully.""" - async def mock_watch() -> None: - return None + class FailMapper: + def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: + raise RuntimeError("Mapping failed") - async def mock_reconcile() -> None: - return None + def clear_cache(self) -> None: + pass - pm._watch_pods = mock_watch # type: ignore[method-assign] - pm._reconciliation_loop = mock_reconcile # type: ignore[method-assign] + pm = make_pod_monitor(event_metrics, kubernetes_metrics, event_mapper=FailMapper()) # type: ignore[arg-type] - await pm.__aenter__() - assert pm._watch_task is not None - assert pm._reconcile_task is not None + pod = make_pod(name="fail-pod", phase="Pending") + event = PodEvent(event_type=WatchEventType.ADDED, pod=pod, resource_version=None) - await pm.aclose() + # Should not raise - errors are caught and logged + await pm._process_pod_event(event) diff --git a/backend/tests/unit/services/result_processor/test_processor.py b/backend/tests/unit/services/result_processor/test_processor.py index c13fe0ab..90f12556 100644 --- a/backend/tests/unit/services/result_processor/test_processor.py +++ b/backend/tests/unit/services/result_processor/test_processor.py @@ -1,16 +1,9 @@ -import logging -from unittest.mock import MagicMock - import pytest -from app.core.metrics import EventMetrics, ExecutionMetrics -from app.domain.enums.events import EventType from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId, KafkaTopic -from app.services.result_processor.processor import ResultProcessor, ResultProcessorConfig +from app.services.result_processor.processor import ResultProcessorConfig pytestmark = pytest.mark.unit -_test_logger = logging.getLogger("test.services.result_processor.processor") - class TestResultProcessorConfig: def test_default_values(self) -> None: @@ -27,24 +20,3 @@ def test_custom_values(self) -> None: config = ResultProcessorConfig(batch_size=20, processing_timeout=600) assert config.batch_size == 20 assert config.processing_timeout == 600 - - -def test_create_dispatcher_registers_handlers( - execution_metrics: ExecutionMetrics, event_metrics: EventMetrics -) -> None: - rp = ResultProcessor( - execution_repo=MagicMock(), - producer=MagicMock(), - schema_registry=MagicMock(), - settings=MagicMock(), - idempotency_manager=MagicMock(), - logger=_test_logger, - execution_metrics=execution_metrics, - event_metrics=event_metrics, - ) - dispatcher = rp._create_dispatcher() - assert dispatcher is not None - assert EventType.EXECUTION_COMPLETED in dispatcher._handlers - assert EventType.EXECUTION_FAILED in dispatcher._handlers - assert EventType.EXECUTION_TIMEOUT in dispatcher._handlers - diff --git a/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py b/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py index 8f2b35f9..c0bc6628 100644 --- a/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py +++ b/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py @@ -10,13 +10,9 @@ from app.domain.events.typed import DomainEvent, ExecutionRequestedEvent from app.domain.saga.models import Saga, SagaConfig from app.events.core import UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency.idempotency_manager import IdempotencyManager from app.services.saga.base_saga import BaseSaga from app.services.saga.saga_orchestrator import SagaOrchestrator from app.services.saga.saga_step import CompensationStep, SagaContext, SagaStep -from app.settings import Settings from tests.conftest import make_execution_requested_event @@ -52,23 +48,6 @@ async def produce( return None -class _FakeIdem(IdempotencyManager): - """Fake IdempotencyManager for testing.""" - - def __init__(self) -> None: - pass # Skip parent __init__ - - async def close(self) -> None: - return None - - -class _FakeStore(EventStore): - """Fake EventStore for testing.""" - - def __init__(self) -> None: - pass # Skip parent __init__ - - class _FakeAlloc(ResourceAllocationRepository): """Fake ResourceAllocationRepository for testing.""" @@ -105,10 +84,6 @@ def _orch(event_metrics: EventMetrics) -> SagaOrchestrator: config=SagaConfig(name="t", enable_compensation=True, store_events=True, publish_commands=False), saga_repository=_FakeRepo(), producer=_FakeProd(), - schema_registry_manager=MagicMock(spec=SchemaRegistryManager), - settings=MagicMock(spec=Settings), - event_store=_FakeStore(), - idempotency_manager=_FakeIdem(), resource_allocation_repository=_FakeAlloc(), logger=_test_logger, event_metrics=event_metrics, @@ -119,11 +94,9 @@ def _orch(event_metrics: EventMetrics) -> SagaOrchestrator: async def test_min_success_flow(event_metrics: EventMetrics) -> None: orch = _orch(event_metrics) orch.register_saga(_Saga) - # Set orchestrator running state via lifecycle property - orch._lifecycle_started = True - await orch._handle_event(make_execution_requested_event(execution_id="e")) - # basic sanity; deep behavior covered by integration - assert orch.is_running is True + # Stateless orchestrator - just call handle_event directly + await orch.handle_event(make_execution_requested_event(execution_id="e")) + # Basic sanity - no exception means success; deep behavior covered by integration @pytest.mark.asyncio @@ -133,10 +106,6 @@ async def test_should_trigger_and_existing_short_circuit(event_metrics: EventMet config=SagaConfig(name="t", enable_compensation=True, store_events=True, publish_commands=False), saga_repository=fake_repo, producer=_FakeProd(), - schema_registry_manager=MagicMock(spec=SchemaRegistryManager), - settings=MagicMock(spec=Settings), - event_store=_FakeStore(), - idempotency_manager=_FakeIdem(), resource_allocation_repository=_FakeAlloc(), logger=_test_logger, event_metrics=event_metrics, diff --git a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py index 6fa5d1ef..d2fd5ebd 100644 --- a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py +++ b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py @@ -1,15 +1,10 @@ import logging -from unittest.mock import MagicMock import pytest -from app.core.metrics import EventMetrics from app.domain.enums.events import EventType from app.domain.events.typed import DomainEvent, EventMetadata, ExecutionStartedEvent -from app.events.core import EventDispatcher -from app.events.schema.schema_registry import SchemaRegistryManager from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge from app.services.sse.redis_bus import SSERedisBus -from app.settings import Settings pytestmark = pytest.mark.unit @@ -31,34 +26,42 @@ def _make_metadata() -> EventMetadata: @pytest.mark.asyncio -async def test_register_and_route_events_without_kafka() -> None: - # Build the bridge but don't call start(); directly test routing handlers +async def test_handle_event_routes_to_redis_bus() -> None: + """Test that handle_event routes events to Redis bus.""" fake_bus = _FakeBus() - mock_settings = MagicMock(spec=Settings) - mock_settings.KAFKA_BOOTSTRAP_SERVERS = "kafka:9092" - mock_settings.SSE_CONSUMER_POOL_SIZE = 1 bridge = SSEKafkaRedisBridge( - schema_registry=MagicMock(spec=SchemaRegistryManager), - settings=mock_settings, - event_metrics=MagicMock(spec=EventMetrics), sse_bus=fake_bus, logger=_test_logger, ) - disp = EventDispatcher(_test_logger) - bridge._register_routing_handlers(disp) - handlers = disp.get_handlers(EventType.EXECUTION_STARTED) - assert len(handlers) > 0 - # Event with empty execution_id is ignored - h = handlers[0] - await h(ExecutionStartedEvent(execution_id="", pod_name="p", metadata=_make_metadata())) + await bridge.handle_event( + ExecutionStartedEvent(execution_id="", pod_name="p", metadata=_make_metadata()) + ) assert fake_bus.published == [] # Proper event is published - await h(ExecutionStartedEvent(execution_id="exec-123", pod_name="p", metadata=_make_metadata())) + await bridge.handle_event( + ExecutionStartedEvent(execution_id="exec-123", pod_name="p", metadata=_make_metadata()) + ) assert fake_bus.published and fake_bus.published[-1][0] == "exec-123" - s = bridge.get_stats() - assert s["num_consumers"] == 0 and s["is_running"] is False + +@pytest.mark.asyncio +async def test_get_status_returns_relevant_event_types() -> None: + """Test that get_status returns relevant event types.""" + fake_bus = _FakeBus() + bridge = SSEKafkaRedisBridge(sse_bus=fake_bus, logger=_test_logger) + + status = await bridge.get_status() + assert "relevant_event_types" in status + assert len(status["relevant_event_types"]) > 0 + + +def test_get_relevant_event_types() -> None: + """Test static method returns relevant event types.""" + event_types = SSEKafkaRedisBridge.get_relevant_event_types() + assert EventType.EXECUTION_STARTED in event_types + assert EventType.EXECUTION_COMPLETED in event_types + assert EventType.RESULT_STORED in event_types diff --git a/backend/tests/unit/services/sse/test_shutdown_manager.py b/backend/tests/unit/services/sse/test_shutdown_manager.py index 05f6e023..1dfc07cc 100644 --- a/backend/tests/unit/services/sse/test_shutdown_manager.py +++ b/backend/tests/unit/services/sse/test_shutdown_manager.py @@ -2,28 +2,22 @@ import logging import pytest -from app.core.lifecycle import LifecycleEnabled + from app.core.metrics import ConnectionMetrics from app.services.sse.sse_shutdown_manager import SSEShutdownManager _test_logger = logging.getLogger("test.services.sse.shutdown_manager") -class _FakeRouter(LifecycleEnabled): - """Fake router that tracks whether aclose was called.""" - - def __init__(self) -> None: - super().__init__() - self.stopped = False - self._lifecycle_started = True # Simulate already-started router - - async def _on_stop(self) -> None: - self.stopped = True - - @pytest.mark.asyncio async def test_shutdown_graceful_notify_and_drain(connection_metrics: ConnectionMetrics) -> None: - mgr = SSEShutdownManager(drain_timeout=1.0, notification_timeout=0.01, force_close_timeout=0.1, logger=_test_logger, connection_metrics=connection_metrics) + mgr = SSEShutdownManager( + drain_timeout=1.0, + notification_timeout=0.01, + force_close_timeout=0.1, + logger=_test_logger, + connection_metrics=connection_metrics, + ) # Register two connections and arrange that they unregister when notified ev1 = await mgr.register_connection("e1", "c1") @@ -46,12 +40,14 @@ async def on_shutdown(event: asyncio.Event, cid: str) -> None: @pytest.mark.asyncio -async def test_shutdown_force_close_calls_router_stop_and_rejects_new(connection_metrics: ConnectionMetrics) -> None: +async def test_shutdown_force_close_and_rejects_new(connection_metrics: ConnectionMetrics) -> None: mgr = SSEShutdownManager( - drain_timeout=0.01, notification_timeout=0.01, force_close_timeout=0.01, logger=_test_logger, connection_metrics=connection_metrics + drain_timeout=0.01, + notification_timeout=0.01, + force_close_timeout=0.01, + logger=_test_logger, + connection_metrics=connection_metrics, ) - router = _FakeRouter() - mgr.set_router(router) # Register a connection but never unregister -> force close path ev = await mgr.register_connection("e1", "c1") @@ -59,7 +55,6 @@ async def test_shutdown_force_close_calls_router_stop_and_rejects_new(connection # Initiate shutdown await mgr.initiate_shutdown() - assert router.stopped is True assert mgr.is_shutting_down() is True status = mgr.get_shutdown_status() assert status.draining_connections == 0 @@ -71,7 +66,13 @@ async def test_shutdown_force_close_calls_router_stop_and_rejects_new(connection @pytest.mark.asyncio async def test_get_shutdown_status_transitions(connection_metrics: ConnectionMetrics) -> None: - m = SSEShutdownManager(drain_timeout=0.01, notification_timeout=0.0, force_close_timeout=0.0, logger=_test_logger, connection_metrics=connection_metrics) + m = SSEShutdownManager( + drain_timeout=0.01, + notification_timeout=0.0, + force_close_timeout=0.0, + logger=_test_logger, + connection_metrics=connection_metrics, + ) st0 = m.get_shutdown_status() assert st0.phase == "ready" await m.initiate_shutdown() diff --git a/backend/tests/unit/services/sse/test_sse_service.py b/backend/tests/unit/services/sse/test_sse_service.py index 3c86a15a..c33298ce 100644 --- a/backend/tests/unit/services/sse/test_sse_service.py +++ b/backend/tests/unit/services/sse/test_sse_service.py @@ -10,9 +10,8 @@ from app.db.repositories.sse_repository import SSERepository from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus -from app.domain.events import ResourceUsageDomain -from app.domain.execution import DomainExecution -from app.domain.sse import ShutdownStatus, SSEExecutionStatusDomain, SSEHealthDomain +from app.domain.execution import DomainExecution, ResourceUsageDomain +from app.domain.sse import ShutdownStatus, SSEExecutionStatusDomain from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge from app.services.sse.redis_bus import SSERedisBus, SSERedisSubscription from app.services.sse.sse_service import SSEService @@ -240,12 +239,3 @@ async def test_notification_stream_connected_and_heartbeat_and_message(connectio # Give the generator a chance to observe the flag and finish with pytest.raises(StopAsyncIteration): await asyncio.wait_for(agen.__anext__(), timeout=0.2) - - -@pytest.mark.asyncio -async def test_health_status_shape(connection_metrics: ConnectionMetrics) -> None: - svc = SSEService(repository=_FakeRepo(), router=_FakeRouter(), sse_bus=_FakeBus(), shutdown_manager=_FakeShutdown(), - settings=_make_fake_settings(), logger=_test_logger, connection_metrics=connection_metrics) - h = await svc.get_health_status() - assert isinstance(h, SSEHealthDomain) - assert h.active_consumers == 3 and h.active_executions == 2 diff --git a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py b/backend/tests/unit/services/sse/test_sse_shutdown_manager.py index fc7ffb3b..e15c427a 100644 --- a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py +++ b/backend/tests/unit/services/sse/test_sse_shutdown_manager.py @@ -2,7 +2,7 @@ import logging import pytest -from app.core.lifecycle import LifecycleEnabled + from app.core.metrics import ConnectionMetrics from app.services.sse.sse_shutdown_manager import SSEShutdownManager @@ -11,22 +11,15 @@ _test_logger = logging.getLogger("test.services.sse.sse_shutdown_manager") -class _FakeRouter(LifecycleEnabled): - """Fake router that tracks whether aclose was called.""" - - def __init__(self) -> None: - super().__init__() - self.stopped = False - self._lifecycle_started = True # Simulate already-started router - - async def _on_stop(self) -> None: - self.stopped = True - - @pytest.mark.asyncio async def test_register_unregister_and_shutdown_flow(connection_metrics: ConnectionMetrics) -> None: - mgr = SSEShutdownManager(drain_timeout=0.5, notification_timeout=0.1, force_close_timeout=0.1, logger=_test_logger, connection_metrics=connection_metrics) - mgr.set_router(_FakeRouter()) + mgr = SSEShutdownManager( + drain_timeout=0.5, + notification_timeout=0.1, + force_close_timeout=0.1, + logger=_test_logger, + connection_metrics=connection_metrics, + ) # Register two connections e1 = await mgr.register_connection("exec-1", "c1") @@ -52,8 +45,13 @@ async def test_register_unregister_and_shutdown_flow(connection_metrics: Connect @pytest.mark.asyncio async def test_reject_new_connection_during_shutdown(connection_metrics: ConnectionMetrics) -> None: - mgr = SSEShutdownManager(drain_timeout=0.5, notification_timeout=0.01, force_close_timeout=0.01, - logger=_test_logger, connection_metrics=connection_metrics) + mgr = SSEShutdownManager( + drain_timeout=0.5, + notification_timeout=0.01, + force_close_timeout=0.01, + logger=_test_logger, + connection_metrics=connection_metrics, + ) # Pre-register one active connection - shutdown will block waiting for it e = await mgr.register_connection("e", "c0") assert e is not None diff --git a/backend/workers/dlq_processor.py b/backend/workers/dlq_processor.py index 97598539..47b3ec53 100644 --- a/backend/workers/dlq_processor.py +++ b/backend/workers/dlq_processor.py @@ -46,14 +46,14 @@ def _configure_retry_policies(manager: DLQManager, logger: logging.Logger) -> No topic="websocket-events", strategy=RetryStrategy.FIXED_INTERVAL, max_retries=10, base_delay_seconds=10 ), ) - manager.default_retry_policy = RetryPolicy( + manager.set_default_retry_policy(RetryPolicy( topic="default", strategy=RetryStrategy.EXPONENTIAL_BACKOFF, max_retries=4, base_delay_seconds=60, max_delay_seconds=1800, retry_multiplier=2.5, - ) + )) def _configure_filters(manager: DLQManager, testing: bool, logger: logging.Logger) -> None: diff --git a/backend/workers/run_coordinator.py b/backend/workers/run_coordinator.py index 12004bf1..969b240d 100644 --- a/backend/workers/run_coordinator.py +++ b/backend/workers/run_coordinator.py @@ -1,15 +1,21 @@ +"""Coordinator worker entrypoint - stateless event processing. + +Consumes execution events from Kafka and dispatches to ExecutionCoordinator handlers. +DI container manages all lifecycle - worker just iterates over consumer. +""" + import asyncio import logging -import signal +from aiokafka import AIOKafkaConsumer from app.core.container import create_coordinator_container from app.core.database_context import Database from app.core.logging import setup_logger from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId +from app.events.core import UnifiedConsumer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.coordinator.coordinator import ExecutionCoordinator from app.settings import Settings from beanie import init_beanie @@ -18,6 +24,7 @@ async def run_coordinator(settings: Settings) -> None: """Run the execution coordinator service.""" container = create_coordinator_container(settings) + logger = await container.get(logging.Logger) logger.info("Starting ExecutionCoordinator with DI container...") @@ -27,27 +34,18 @@ async def run_coordinator(settings: Settings) -> None: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) - # Services are already started by the DI container providers - coordinator = await container.get(ExecutionCoordinator) - - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) - - logger.info("ExecutionCoordinator started and running") - - try: - # Wait for shutdown signal or service to stop - while coordinator.is_running and not shutdown_event.is_set(): - await asyncio.sleep(60) - status = await coordinator.get_status() - logger.info(f"Coordinator status: {status}") - finally: - # Container cleanup stops everything - logger.info("Initiating graceful shutdown...") - await container.close() + kafka_consumer = await container.get(AIOKafkaConsumer) + handler = await container.get(UnifiedConsumer) + + logger.info("ExecutionCoordinator started, consuming events...") + + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + + logger.info("ExecutionCoordinator shutdown complete") + + await container.close() def main() -> None: diff --git a/backend/workers/run_event_replay.py b/backend/workers/run_event_replay.py index 95c38dad..757cb369 100644 --- a/backend/workers/run_event_replay.py +++ b/backend/workers/run_event_replay.py @@ -1,60 +1,38 @@ +"""Event replay worker entrypoint - stateless replay service. + +Provides event replay capability. DI container manages all lifecycle. +This service doesn't consume from Kafka - it's an HTTP-driven replay service. +""" + import asyncio import logging -from contextlib import AsyncExitStack from app.core.container import create_event_replay_container from app.core.database_context import Database from app.core.logging import setup_logger from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS -from app.events.core import UnifiedProducer -from app.services.event_replay.replay_service import EventReplayService from app.settings import Settings from beanie import init_beanie -async def cleanup_task(replay_service: EventReplayService, logger: logging.Logger, interval_hours: int = 6) -> None: - """Periodically clean up old replay sessions""" - while True: - try: - await asyncio.sleep(interval_hours * 3600) - removed = await replay_service.cleanup_old_sessions(older_than_hours=48) - logger.info(f"Cleaned up {removed} old replay sessions") - except Exception as e: - logger.error(f"Error during cleanup: {e}") - - async def run_replay_service(settings: Settings) -> None: - """Run the event replay service with cleanup task.""" + """Run the event replay service.""" container = create_event_replay_container(settings) + logger = await container.get(logging.Logger) logger.info("Starting EventReplayService with DI container...") db = await container.get(Database) await init_beanie(database=db, document_models=ALL_DOCUMENTS) - producer = await container.get(UnifiedProducer) - replay_service = await container.get(EventReplayService) - - logger.info("Event replay service initialized") - - async with AsyncExitStack() as stack: - stack.push_async_callback(container.close) - await stack.enter_async_context(producer) - - task = asyncio.create_task(cleanup_task(replay_service, logger)) - - async def _cancel_task() -> None: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + logger.info("Event replay service initialized and ready") - stack.push_async_callback(_cancel_task) + # Service is HTTP-driven, wait for external shutdown + await asyncio.Event().wait() - await asyncio.Event().wait() + await container.close() def main() -> None: diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index d3b857ad..9e11a25b 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -1,15 +1,21 @@ +"""Kubernetes worker entrypoint - stateless event processing. + +Consumes pod creation events from Kafka and dispatches to KubernetesWorker handlers. +DI container manages all lifecycle - worker just iterates over consumer. +""" + import asyncio import logging -import signal +from aiokafka import AIOKafkaConsumer from app.core.container import create_k8s_worker_container from app.core.database_context import Database from app.core.logging import setup_logger from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId +from app.events.core import UnifiedConsumer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.k8s_worker.worker import KubernetesWorker from app.settings import Settings from beanie import init_beanie @@ -27,27 +33,18 @@ async def run_kubernetes_worker(settings: Settings) -> None: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) - # Services are already started by the DI container providers - worker = await container.get(KubernetesWorker) - - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) - - logger.info("KubernetesWorker started and running") - - try: - # Wait for shutdown signal or service to stop - while worker.is_running and not shutdown_event.is_set(): - await asyncio.sleep(60) - status = await worker.get_status() - logger.info(f"Kubernetes worker status: {status}") - finally: - # Container cleanup stops everything - logger.info("Initiating graceful shutdown...") - await container.close() + kafka_consumer = await container.get(AIOKafkaConsumer) + handler = await container.get(UnifiedConsumer) + + logger.info("KubernetesWorker started, consuming events...") + + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + + logger.info("KubernetesWorker shutdown complete") + + await container.close() def main() -> None: diff --git a/backend/workers/run_pod_monitor.py b/backend/workers/run_pod_monitor.py index 4b4dd325..23997ed7 100644 --- a/backend/workers/run_pod_monitor.py +++ b/backend/workers/run_pod_monitor.py @@ -1,20 +1,24 @@ +"""Pod monitor worker entrypoint - consumes pod events from Kafka. + +Same pattern as other workers - pure Kafka consumer. +K8s watch is externalized to a separate component that publishes to Kafka. +""" + import asyncio import logging -import signal +from aiokafka import AIOKafkaConsumer from app.core.container import create_pod_monitor_container from app.core.database_context import Database from app.core.logging import setup_logger from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId +from app.events.core import UnifiedConsumer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.pod_monitor.monitor import MonitorState, PodMonitor from app.settings import Settings from beanie import init_beanie -RECONCILIATION_LOG_INTERVAL: int = 60 - async def run_pod_monitor(settings: Settings) -> None: """Run the pod monitor service.""" @@ -29,27 +33,18 @@ async def run_pod_monitor(settings: Settings) -> None: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) - # Services are already started by the DI container providers - monitor = await container.get(PodMonitor) - - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) - - logger.info("PodMonitor started and running") - - try: - # Wait for shutdown signal or service to stop - while monitor.state == MonitorState.RUNNING and not shutdown_event.is_set(): - await asyncio.sleep(RECONCILIATION_LOG_INTERVAL) - status = await monitor.get_status() - logger.info(f"Pod monitor status: {status}") - finally: - # Container cleanup stops everything - logger.info("Initiating graceful shutdown...") - await container.close() + kafka_consumer = await container.get(AIOKafkaConsumer) + handler = await container.get(UnifiedConsumer) + + logger.info("PodMonitor started, consuming events...") + + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + + logger.info("PodMonitor shutdown complete") + + await container.close() def main() -> None: diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index 5431b011..c7b557db 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -1,74 +1,50 @@ +"""Result processor worker entrypoint - stateless event processing. + +Consumes execution completion events from Kafka and dispatches to ResultProcessor handlers. +DI container manages all lifecycle - worker just iterates over consumer. +""" + import asyncio import logging -import signal -from contextlib import AsyncExitStack +from aiokafka import AIOKafkaConsumer from app.core.container import create_result_processor_container +from app.core.database_context import Database from app.core.logging import setup_logger -from app.core.metrics import EventMetrics, ExecutionMetrics from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS -from app.db.repositories.execution_repository import ExecutionRepository from app.domain.enums.kafka import GroupId -from app.events.core import UnifiedProducer -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency import IdempotencyManager -from app.services.result_processor.processor import ProcessingState, ResultProcessor +from app.events.core import UnifiedConsumer +from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.settings import Settings from beanie import init_beanie -from pymongo.asynchronous.mongo_client import AsyncMongoClient async def run_result_processor(settings: Settings) -> None: - - db_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( - settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 - ) - await init_beanie(database=db_client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) + """Run the result processor service.""" container = create_result_processor_container(settings) - producer = await container.get(UnifiedProducer) - schema_registry = await container.get(SchemaRegistryManager) - idempotency_manager = await container.get(IdempotencyManager) - execution_repo = await container.get(ExecutionRepository) - execution_metrics = await container.get(ExecutionMetrics) - event_metrics = await container.get(EventMetrics) logger = await container.get(logging.Logger) - logger.info(f"Beanie ODM initialized with {len(ALL_DOCUMENTS)} document models") - - # ResultProcessor is manually created (not from DI), so we own its lifecycle - processor = ResultProcessor( - execution_repo=execution_repo, - producer=producer, - schema_registry=schema_registry, - settings=settings, - idempotency_manager=idempotency_manager, - logger=logger, - execution_metrics=execution_metrics, - event_metrics=event_metrics, - ) - - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) - - # We own the processor, so we use async with to manage its lifecycle - async with AsyncExitStack() as stack: - stack.callback(db_client.close) - stack.push_async_callback(container.close) - await stack.enter_async_context(processor) - - logger.info("ResultProcessor started and running") - - # Wait for shutdown signal or service to stop - while processor._state == ProcessingState.PROCESSING and not shutdown_event.is_set(): - await asyncio.sleep(60) - status = await processor.get_status() - logger.info(f"ResultProcessor status: {status}") - - logger.info("Initiating graceful shutdown...") + logger.info("Starting ResultProcessor with DI container...") + + db = await container.get(Database) + await init_beanie(database=db, document_models=ALL_DOCUMENTS) + + schema_registry = await container.get(SchemaRegistryManager) + await initialize_event_schemas(schema_registry) + + kafka_consumer = await container.get(AIOKafkaConsumer) + handler = await container.get(UnifiedConsumer) + + logger.info("ResultProcessor started, consuming events...") + + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + + logger.info("ResultProcessor shutdown complete") + + await container.close() def main() -> None: diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index 7fd0c359..3a230be8 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -1,21 +1,27 @@ +"""Saga orchestrator worker entrypoint - stateless event processing. + +Consumes execution events from Kafka and dispatches to SagaOrchestrator handlers. +DI container manages all lifecycle - worker just iterates over consumer. +""" + import asyncio import logging -import signal +from aiokafka import AIOKafkaConsumer from app.core.container import create_saga_orchestrator_container from app.core.database_context import Database from app.core.logging import setup_logger from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId +from app.events.core import UnifiedConsumer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.saga import SagaOrchestrator from app.settings import Settings from beanie import init_beanie async def run_saga_orchestrator(settings: Settings) -> None: - """Run the saga orchestrator.""" + """Run the saga orchestrator service.""" container = create_saga_orchestrator_container(settings) logger = await container.get(logging.Logger) @@ -27,27 +33,18 @@ async def run_saga_orchestrator(settings: Settings) -> None: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) - # Services are already started by the DI container providers - orchestrator = await container.get(SagaOrchestrator) + kafka_consumer = await container.get(AIOKafkaConsumer) + handler = await container.get(UnifiedConsumer) - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) + logger.info("SagaOrchestrator started, consuming events...") - logger.info("Saga orchestrator started and running") + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() - try: - # Wait for shutdown signal or service to stop - while orchestrator.is_running and not shutdown_event.is_set(): - await asyncio.sleep(1) - finally: - # Container cleanup stops everything - logger.info("Initiating graceful shutdown...") - await container.close() + logger.info("SagaOrchestrator shutdown complete") - logger.warning("Saga orchestrator stopped") + await container.close() def main() -> None: From 528aaa5a3b7fb1f10cfa61749663567ce523df17 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Tue, 27 Jan 2026 16:52:08 +0100 Subject: [PATCH 2/2] mypy and failed test fix --- backend/app/services/idempotency/middleware.py | 11 ++++++----- backend/app/services/result_processor/processor.py | 6 +++--- .../tests/unit/services/pod_monitor/test_monitor.py | 2 +- backend/tests/unit/services/sse/test_sse_service.py | 3 ++- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/backend/app/services/idempotency/middleware.py b/backend/app/services/idempotency/middleware.py index 4dac5287..7fd3d1e3 100644 --- a/backend/app/services/idempotency/middleware.py +++ b/backend/app/services/idempotency/middleware.py @@ -6,6 +6,7 @@ from app.domain.enums.events import EventType from app.domain.events.typed import DomainEvent +from app.domain.idempotency import KeyStrategy from app.events.core import EventDispatcher, UnifiedConsumer from app.services.idempotency.idempotency_manager import IdempotencyManager @@ -18,7 +19,7 @@ def __init__( handler: Callable[[DomainEvent], Awaitable[None]], idempotency_manager: IdempotencyManager, logger: logging.Logger, - key_strategy: str = "event_based", + key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, custom_key_func: Callable[[DomainEvent], str] | None = None, fields: Set[str] | None = None, ttl_seconds: int | None = None, @@ -43,7 +44,7 @@ async def __call__(self, event: DomainEvent) -> None: ) # Generate custom key if function provided custom_key = None - if self.key_strategy == "custom" and self.custom_key_func: + if self.key_strategy == KeyStrategy.CUSTOM and self.custom_key_func: custom_key = self.custom_key_func(event) # Check idempotency @@ -92,7 +93,7 @@ async def __call__(self, event: DomainEvent) -> None: def idempotent_handler( idempotency_manager: IdempotencyManager, logger: logging.Logger, - key_strategy: str = "event_based", + key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, custom_key_func: Callable[[DomainEvent], str] | None = None, fields: Set[str] | None = None, ttl_seconds: int | None = None, @@ -127,7 +128,7 @@ def __init__( idempotency_manager: IdempotencyManager, dispatcher: EventDispatcher, logger: logging.Logger, - default_key_strategy: str = "event_based", + default_key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, default_ttl_seconds: int = 3600, enable_for_all_handlers: bool = True, ): @@ -171,7 +172,7 @@ def subscribe_idempotent_handler( self, event_type: str, handler: Callable[[DomainEvent], Awaitable[None]], - key_strategy: str | None = None, + key_strategy: KeyStrategy | None = None, custom_key_func: Callable[[DomainEvent], str] | None = None, fields: Set[str] | None = None, ttl_seconds: int | None = None, diff --git a/backend/app/services/result_processor/processor.py b/backend/app/services/result_processor/processor.py index c7d5f1c7..a71a4e4b 100644 --- a/backend/app/services/result_processor/processor.py +++ b/backend/app/services/result_processor/processor.py @@ -100,7 +100,7 @@ async def handle_execution_completed(self, event: ExecutionCompletedEvent) -> No stdout=event.stdout, stderr=event.stderr, resource_usage=event.resource_usage, - metadata=event.metadata.model_dump(), + metadata=event.metadata, ) try: @@ -127,7 +127,7 @@ async def handle_execution_failed(self, event: ExecutionFailedEvent) -> None: stdout=event.stdout, stderr=event.stderr, resource_usage=event.resource_usage, - metadata=event.metadata.model_dump(), + metadata=event.metadata, error_type=event.error_type, ) @@ -158,7 +158,7 @@ async def handle_execution_timeout(self, event: ExecutionTimeoutEvent) -> None: stdout=event.stdout, stderr=event.stderr, resource_usage=event.resource_usage, - metadata=event.metadata.model_dump(), + metadata=event.metadata, error_type=ExecutionErrorType.TIMEOUT, ) diff --git a/backend/tests/unit/services/pod_monitor/test_monitor.py b/backend/tests/unit/services/pod_monitor/test_monitor.py index 8916af97..5a233fbd 100644 --- a/backend/tests/unit/services/pod_monitor/test_monitor.py +++ b/backend/tests/unit/services/pod_monitor/test_monitor.py @@ -12,7 +12,7 @@ from app.core.metrics import EventMetrics, KubernetesMetrics from app.db.repositories.pod_state_repository import PodStateRepository from app.domain.events.typed import DomainEvent, EventMetadata, ExecutionCompletedEvent -from app.domain.execution.models import ResourceUsageDomain +from app.domain.events.typed import ResourceUsageDomain from app.events.core import UnifiedProducer from app.services.pod_monitor.config import PodMonitorConfig from app.services.pod_monitor.event_mapper import PodEventMapper diff --git a/backend/tests/unit/services/sse/test_sse_service.py b/backend/tests/unit/services/sse/test_sse_service.py index c33298ce..310907c6 100644 --- a/backend/tests/unit/services/sse/test_sse_service.py +++ b/backend/tests/unit/services/sse/test_sse_service.py @@ -10,7 +10,8 @@ from app.db.repositories.sse_repository import SSERepository from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus -from app.domain.execution import DomainExecution, ResourceUsageDomain +from app.domain.events.typed import ResourceUsageDomain +from app.domain.execution import DomainExecution from app.domain.sse import ShutdownStatus, SSEExecutionStatusDomain from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge from app.services.sse.redis_bus import SSERedisBus, SSERedisSubscription