diff --git a/bq/app.py b/bq/app.py index 5e44aa7..a765404 100644 --- a/bq/app.py +++ b/bq/app.py @@ -1,3 +1,4 @@ +import socketserver import functools import importlib import json @@ -13,6 +14,7 @@ from importlib.metadata import version from wsgiref.simple_server import make_server from wsgiref.simple_server import WSGIRequestHandler +from wsgiref.simple_server import WSGIServer import venusian from sqlalchemy import func @@ -53,6 +55,10 @@ def log_message(self, format, *args): ) + +class ThreadingWSGIServer(socketserver.ThreadingMixIn, WSGIServer): + daemon_threads = True + class BeanQueue: def __init__( self, @@ -70,6 +76,8 @@ def __init__( self._worker_update_shutdown_event: threading.Event = threading.Event() # noop if metrics thread is not started yet, shutdown if it is started self._metrics_server_shutdown: typing.Callable[[], None] = lambda: None + self._health_ok: bool = False + self._health_info: dict = {} def create_default_engine(self): # Use thread-safe connection pool when thread pool executor is enabled @@ -195,6 +203,10 @@ def update_workers( db.commit() if current_worker.state != models.WorkerState.RUNNING: + self._health_ok = False + self._health_info = { + "state": str(current_worker.state), + } # This probably means we are somehow very slow to update the heartbeat in time, or the timeout window # is set too short. It could also be the administrator update the worker state to something else than # RUNNING. Regardless the reason, let's stop processing. @@ -203,6 +215,7 @@ def update_workers( current_worker.id, current_worker.state, ) + self._health_ok = False sys.exit(0) do_shutdown = self._worker_update_shutdown_event.wait( @@ -214,51 +227,34 @@ def update_workers( current_worker.last_heartbeat = func.now() db.add(current_worker) db.commit() + self._health_ok = (current_worker.state == models.WorkerState.RUNNING) + self._health_info = { + "state": str(current_worker.state), + } def _serve_http_request( self, worker_id: typing.Any, environ: dict, start_response: typing.Callable ) -> list[bytes]: path = environ["PATH_INFO"] if path == "/healthz": - db = self.make_session() - worker_service = self._make_worker_service(db) - worker = worker_service.get_worker(worker_id) - if worker is not None and worker.state == models.WorkerState.RUNNING: - start_response( - "200 OK", - [ - ("Content-Type", "application/json"), - ], - ) - return [ - json.dumps(dict(status="ok", worker_id=str(worker_id))).encode( - "utf8" - ) - ] + if self._health_ok: + start_response("200 OK", [("Content-Type", "application/json")]) + return [json.dumps(dict( + status="ok", + worker_id=str(worker_id), + **self._health_info, + )).encode("utf8")] else: - logger.warning("Bad worker %s state %s", worker_id, worker.state) start_response( "500 Internal Server Error", - [ - ("Content-Type", "application/json"), - ], + [("Content-Type", "application/json")], ) - return [ - json.dumps( - dict( - status="internal error", - worker_id=str(worker_id), - state=str(worker.state), - ) - ).encode("utf8") - ] - # TODO: add other metrics endpoints - start_response( - "404 NOT FOUND", - [ - ("Content-Type", "application/json"), - ], - ) + return [json.dumps(dict( + status="error", + worker_id=str(worker_id), + **self._health_info, + )).encode("utf8")] + start_response("404 NOT FOUND", [("Content-Type", "application/json")]) return [json.dumps(dict(status="not found")).encode("utf8")] def run_metrics_http_server(self, worker_id: typing.Any): @@ -269,6 +265,7 @@ def run_metrics_http_server(self, worker_id: typing.Any): port, functools.partial(self._serve_http_request, worker_id), handler_class=WSGIRequestHandlerWithLogger, + server_class=ThreadingWSGIServer, ) as httpd: # expose graceful shutdown to the main thread self._metrics_server_shutdown = httpd.shutdown @@ -475,6 +472,8 @@ def process_tasks( db.add(worker) dispatch_service.listen(channels) db.commit() + self._health_ok = True + self._health_info = {"state": "RUNNING"} metrics_server_thread = None if self.config.METRICS_HTTP_SERVER_ENABLED: diff --git a/tests/unit/test_healthcheck.py b/tests/unit/test_healthcheck.py new file mode 100644 index 0000000..ffe4eda --- /dev/null +++ b/tests/unit/test_healthcheck.py @@ -0,0 +1,112 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from bq.app import BeanQueue + + +def _make_environ(path: str) -> dict: + """Build a minimal WSGI environ dict.""" + return {"PATH_INFO": path, "REQUEST_METHOD": "GET"} + + +@pytest.fixture +def bq(): + """Create a BeanQueue with stubbed config (no real DB needed).""" + with patch("bq.app.Config") as MockConfig: + config = MockConfig.return_value + config.DATABASE_URL = "postgresql://test@localhost/test" + config.METRICS_HTTP_SERVER_INTERFACE = "127.0.0.1" + config.METRICS_HTTP_SERVER_PORT = 0 + config.METRICS_HTTP_SERVER_ENABLED = False + config.METRICS_HTTP_SERVER_LOG_LEVEL = "WARNING" + config.WORKER_HEARTBEAT_PERIOD = 30 + config.WORKER_HEARTBEAT_TIMEOUT = 60 + config.MAX_WORKER_THREADS = 1 + instance = BeanQueue(config=config) + yield instance + + +class TestHealthzEndpoint: + """Tests for the /healthz HTTP handler.""" + + def test_healthz_returns_200_when_healthy(self, bq): + bq._health_ok = True + bq._health_info = {"state": "RUNNING"} + + start_response = MagicMock() + result = bq._serve_http_request("42", _make_environ("/healthz"), start_response) + + start_response.assert_called_once_with( + "200 OK", [("Content-Type", "application/json")] + ) + body = json.loads(result[0]) + assert body["status"] == "ok" + assert body["worker_id"] == "42" + assert body["state"] == "RUNNING" + + def test_healthz_returns_500_when_unhealthy(self, bq): + bq._health_ok = False + bq._health_info = {"state": "SHUTDOWN"} + + start_response = MagicMock() + result = bq._serve_http_request("42", _make_environ("/healthz"), start_response) + + start_response.assert_called_once_with( + "500 Internal Server Error", + [("Content-Type", "application/json")], + ) + body = json.loads(result[0]) + assert body["status"] == "error" + assert body["worker_id"] == "42" + assert body["state"] == "SHUTDOWN" + + def test_healthz_returns_500_before_worker_initialized(self, bq): + """Before process_tasks runs, _health_ok is False and _health_info is empty.""" + start_response = MagicMock() + result = bq._serve_http_request("1", _make_environ("/healthz"), start_response) + + start_response.assert_called_once_with( + "500 Internal Server Error", + [("Content-Type", "application/json")], + ) + body = json.loads(result[0]) + assert body["status"] == "error" + assert body["worker_id"] == "1" + + def test_healthz_does_not_create_db_session(self, bq): + """The critical fix: /healthz must never touch the DB.""" + bq._health_ok = True + bq._health_info = {"state": "RUNNING"} + + bq.make_session = MagicMock() + start_response = MagicMock() + bq._serve_http_request("42", _make_environ("/healthz"), start_response) + + bq.make_session.assert_not_called() + + def test_unknown_path_returns_404(self, bq): + start_response = MagicMock() + result = bq._serve_http_request("42", _make_environ("/unknown"), start_response) + + start_response.assert_called_once_with( + "404 NOT FOUND", [("Content-Type", "application/json")] + ) + body = json.loads(result[0]) + assert body["status"] == "not found" + + def test_404_does_not_create_db_session(self, bq): + bq.make_session = MagicMock() + start_response = MagicMock() + bq._serve_http_request("42", _make_environ("/anything"), start_response) + + bq.make_session.assert_not_called() + + +class TestHealthStateInitialization: + """Tests that _health_ok defaults correctly.""" + + def test_defaults_to_unhealthy(self, bq): + assert bq._health_ok is False + assert bq._health_info == {}