Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 34 additions & 35 deletions bq/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import socketserver
import functools
import importlib
import json
Expand All @@ -13,6 +14,7 @@
from importlib.metadata import version
from wsgiref.simple_server import make_server
from wsgiref.simple_server import WSGIRequestHandler
from wsgiref.simple_server import WSGIServer

import venusian
from sqlalchemy import func
Expand Down Expand Up @@ -53,6 +55,10 @@ def log_message(self, format, *args):
)



class ThreadingWSGIServer(socketserver.ThreadingMixIn, WSGIServer):
daemon_threads = True

class BeanQueue:
def __init__(
self,
Expand All @@ -70,6 +76,8 @@ def __init__(
self._worker_update_shutdown_event: threading.Event = threading.Event()
# noop if metrics thread is not started yet, shutdown if it is started
self._metrics_server_shutdown: typing.Callable[[], None] = lambda: None
self._health_ok: bool = False
self._health_info: dict = {}

def create_default_engine(self):
# Use thread-safe connection pool when thread pool executor is enabled
Expand Down Expand Up @@ -195,6 +203,10 @@ def update_workers(
db.commit()

if current_worker.state != models.WorkerState.RUNNING:
self._health_ok = False
self._health_info = {
"state": str(current_worker.state),
}
# This probably means we are somehow very slow to update the heartbeat in time, or the timeout window
# is set too short. It could also be the administrator update the worker state to something else than
# RUNNING. Regardless the reason, let's stop processing.
Expand All @@ -203,6 +215,7 @@ def update_workers(
current_worker.id,
current_worker.state,
)
self._health_ok = False
sys.exit(0)

do_shutdown = self._worker_update_shutdown_event.wait(
Expand All @@ -214,51 +227,34 @@ def update_workers(
current_worker.last_heartbeat = func.now()
db.add(current_worker)
db.commit()
self._health_ok = (current_worker.state == models.WorkerState.RUNNING)
self._health_info = {
"state": str(current_worker.state),
}

def _serve_http_request(
self, worker_id: typing.Any, environ: dict, start_response: typing.Callable
) -> list[bytes]:
path = environ["PATH_INFO"]
if path == "/healthz":
db = self.make_session()
worker_service = self._make_worker_service(db)
worker = worker_service.get_worker(worker_id)
if worker is not None and worker.state == models.WorkerState.RUNNING:
start_response(
"200 OK",
[
("Content-Type", "application/json"),
],
)
return [
json.dumps(dict(status="ok", worker_id=str(worker_id))).encode(
"utf8"
)
]
if self._health_ok:
start_response("200 OK", [("Content-Type", "application/json")])
return [json.dumps(dict(
status="ok",
worker_id=str(worker_id),
**self._health_info,
)).encode("utf8")]
else:
logger.warning("Bad worker %s state %s", worker_id, worker.state)
start_response(
"500 Internal Server Error",
[
("Content-Type", "application/json"),
],
[("Content-Type", "application/json")],
)
return [
json.dumps(
dict(
status="internal error",
worker_id=str(worker_id),
state=str(worker.state),
)
).encode("utf8")
]
# TODO: add other metrics endpoints
start_response(
"404 NOT FOUND",
[
("Content-Type", "application/json"),
],
)
return [json.dumps(dict(
status="error",
worker_id=str(worker_id),
**self._health_info,
)).encode("utf8")]
start_response("404 NOT FOUND", [("Content-Type", "application/json")])
return [json.dumps(dict(status="not found")).encode("utf8")]

def run_metrics_http_server(self, worker_id: typing.Any):
Expand All @@ -269,6 +265,7 @@ def run_metrics_http_server(self, worker_id: typing.Any):
port,
functools.partial(self._serve_http_request, worker_id),
handler_class=WSGIRequestHandlerWithLogger,
server_class=ThreadingWSGIServer,
) as httpd:
# expose graceful shutdown to the main thread
self._metrics_server_shutdown = httpd.shutdown
Expand Down Expand Up @@ -475,6 +472,8 @@ def process_tasks(
db.add(worker)
dispatch_service.listen(channels)
db.commit()
self._health_ok = True
self._health_info = {"state": "RUNNING"}

metrics_server_thread = None
if self.config.METRICS_HTTP_SERVER_ENABLED:
Expand Down
112 changes: 112 additions & 0 deletions tests/unit/test_healthcheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import json
from unittest.mock import MagicMock, patch

import pytest

from bq.app import BeanQueue


def _make_environ(path: str) -> dict:
"""Build a minimal WSGI environ dict."""
return {"PATH_INFO": path, "REQUEST_METHOD": "GET"}


@pytest.fixture
def bq():
"""Create a BeanQueue with stubbed config (no real DB needed)."""
with patch("bq.app.Config") as MockConfig:
config = MockConfig.return_value
config.DATABASE_URL = "postgresql://test@localhost/test"
config.METRICS_HTTP_SERVER_INTERFACE = "127.0.0.1"
config.METRICS_HTTP_SERVER_PORT = 0
config.METRICS_HTTP_SERVER_ENABLED = False
config.METRICS_HTTP_SERVER_LOG_LEVEL = "WARNING"
config.WORKER_HEARTBEAT_PERIOD = 30
config.WORKER_HEARTBEAT_TIMEOUT = 60
config.MAX_WORKER_THREADS = 1
instance = BeanQueue(config=config)
yield instance


class TestHealthzEndpoint:
"""Tests for the /healthz HTTP handler."""

def test_healthz_returns_200_when_healthy(self, bq):
bq._health_ok = True
bq._health_info = {"state": "RUNNING"}

start_response = MagicMock()
result = bq._serve_http_request("42", _make_environ("/healthz"), start_response)

start_response.assert_called_once_with(
"200 OK", [("Content-Type", "application/json")]
)
body = json.loads(result[0])
assert body["status"] == "ok"
assert body["worker_id"] == "42"
assert body["state"] == "RUNNING"

def test_healthz_returns_500_when_unhealthy(self, bq):
bq._health_ok = False
bq._health_info = {"state": "SHUTDOWN"}

start_response = MagicMock()
result = bq._serve_http_request("42", _make_environ("/healthz"), start_response)

start_response.assert_called_once_with(
"500 Internal Server Error",
[("Content-Type", "application/json")],
)
body = json.loads(result[0])
assert body["status"] == "error"
assert body["worker_id"] == "42"
assert body["state"] == "SHUTDOWN"

def test_healthz_returns_500_before_worker_initialized(self, bq):
"""Before process_tasks runs, _health_ok is False and _health_info is empty."""
start_response = MagicMock()
result = bq._serve_http_request("1", _make_environ("/healthz"), start_response)

start_response.assert_called_once_with(
"500 Internal Server Error",
[("Content-Type", "application/json")],
)
body = json.loads(result[0])
assert body["status"] == "error"
assert body["worker_id"] == "1"

def test_healthz_does_not_create_db_session(self, bq):
"""The critical fix: /healthz must never touch the DB."""
bq._health_ok = True
bq._health_info = {"state": "RUNNING"}

bq.make_session = MagicMock()
start_response = MagicMock()
bq._serve_http_request("42", _make_environ("/healthz"), start_response)

bq.make_session.assert_not_called()

def test_unknown_path_returns_404(self, bq):
start_response = MagicMock()
result = bq._serve_http_request("42", _make_environ("/unknown"), start_response)

start_response.assert_called_once_with(
"404 NOT FOUND", [("Content-Type", "application/json")]
)
body = json.loads(result[0])
assert body["status"] == "not found"

def test_404_does_not_create_db_session(self, bq):
bq.make_session = MagicMock()
start_response = MagicMock()
bq._serve_http_request("42", _make_environ("/anything"), start_response)

bq.make_session.assert_not_called()


class TestHealthStateInitialization:
"""Tests that _health_ok defaults correctly."""

def test_defaults_to_unhealthy(self, bq):
assert bq._health_ok is False
assert bq._health_info == {}