From de636f91d659d501f26090188c5bc159facef245 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Sat, 28 Mar 2026 12:39:29 +0000 Subject: [PATCH] Add unit tests for SSE streaming module (test_stream.py) Complete the market data test suite by adding tests for the SSE streaming endpoint in stream.py, which was the only untested module. Tests cover router creation, event generation, version-based change detection, disconnect handling, and SSE payload format. Co-authored-by: team2human --- backend/tests/market/test_stream.py | 215 ++++++++++++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 backend/tests/market/test_stream.py diff --git a/backend/tests/market/test_stream.py b/backend/tests/market/test_stream.py new file mode 100644 index 00000000..d112be19 --- /dev/null +++ b/backend/tests/market/test_stream.py @@ -0,0 +1,215 @@ +"""Tests for SSE streaming endpoint.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.market.cache import PriceCache +from app.market.stream import _generate_events, create_stream_router + + +class TestCreateStreamRouter: + """Tests for the create_stream_router factory.""" + + def test_returns_router(self): + """Test that create_stream_router returns an APIRouter.""" + from fastapi import APIRouter + + cache = PriceCache() + router = create_stream_router(cache) + assert isinstance(router, APIRouter) + + def test_router_has_prices_route(self): + """Test that the router has a /prices route registered.""" + cache = PriceCache() + router = create_stream_router(cache) + + routes = [r.path for r in router.routes] + assert "/prices" in routes + + def test_router_prefix(self): + """Test that the router has the expected prefix.""" + from app.market import stream + + assert stream.router.prefix == "/api/stream" + + +@pytest.mark.asyncio +class TestGenerateEvents: + """Tests for the _generate_events async generator.""" + + def _make_request(self, *, disconnected_after: int = 3) -> MagicMock: + """Create a mock Request that disconnects after N calls.""" + request = MagicMock() + request.client = MagicMock() + request.client.host = "127.0.0.1" + + call_count = 0 + + async def is_disconnected(): + nonlocal call_count + call_count += 1 + return call_count > disconnected_after + + request.is_disconnected = is_disconnected + return request + + async def test_yields_retry_directive_first(self): + """Test that the first yielded value is the retry directive.""" + cache = PriceCache() + request = self._make_request(disconnected_after=0) + + gen = _generate_events(cache, request, interval=0) + first = await gen.__anext__() + assert first == "retry: 1000\n\n" + + async def test_yields_data_when_cache_has_prices(self): + """Test that price data is yielded when the cache has entries.""" + import json + + cache = PriceCache() + cache.update("AAPL", 190.50) + request = self._make_request(disconnected_after=2) + + events = [] + async for event in _generate_events(cache, request, interval=0): + events.append(event) + + # First event is retry directive; subsequent events are data + data_events = [e for e in events if e.startswith("data: ")] + assert len(data_events) >= 1 + + payload = json.loads(data_events[0][len("data: "):].strip()) + assert "AAPL" in payload + assert payload["AAPL"]["price"] == 190.50 + assert payload["AAPL"]["ticker"] == "AAPL" + + async def test_no_data_event_for_empty_cache(self): + """Test that no data event is sent when the cache is empty.""" + cache = PriceCache() + request = self._make_request(disconnected_after=2) + + events = [] + async for event in _generate_events(cache, request, interval=0): + events.append(event) + + data_events = [e for e in events if e.startswith("data: ")] + assert len(data_events) == 0 + + async def test_includes_all_tickers(self): + """Test that all cached tickers appear in the data event.""" + import json + + cache = PriceCache() + cache.update("AAPL", 190.50) + cache.update("GOOGL", 175.25) + cache.update("MSFT", 420.00) + request = self._make_request(disconnected_after=2) + + async for event in _generate_events(cache, request, interval=0): + if event.startswith("data: "): + payload = json.loads(event[len("data: "):].strip()) + assert set(payload.keys()) == {"AAPL", "GOOGL", "MSFT"} + break + + async def test_data_event_format(self): + """Test that data events are properly SSE-formatted.""" + cache = PriceCache() + cache.update("AAPL", 190.50) + request = self._make_request(disconnected_after=2) + + async for event in _generate_events(cache, request, interval=0): + if event.startswith("data: "): + # SSE format: "data: \n\n" + assert event.endswith("\n\n") + break + + async def test_stops_on_disconnect(self): + """Test that the generator stops when the client disconnects.""" + cache = PriceCache() + cache.update("AAPL", 190.50) + # Disconnect immediately + request = self._make_request(disconnected_after=0) + + events = [] + async for event in _generate_events(cache, request, interval=0): + events.append(event) + + # Only the retry directive should be yielded (disconnect on first check) + assert len(events) == 1 + assert events[0] == "retry: 1000\n\n" + + async def test_version_change_detection(self): + """Test that events are only sent when the cache version changes.""" + import json + + cache = PriceCache() + cache.update("AAPL", 190.50) + + call_count = 0 + + async def is_disconnected(): + nonlocal call_count + call_count += 1 + return call_count > 5 # Allow several iterations + + request = MagicMock() + request.client = MagicMock() + request.client.host = "127.0.0.1" + request.is_disconnected = is_disconnected + + events = [] + async for event in _generate_events(cache, request, interval=0): + events.append(event) + + data_events = [e for e in events if e.startswith("data: ")] + # Version only changes on cache.update(); since we don't update again, + # we should see at most one data event (on first version change detection) + assert len(data_events) == 1 + + async def test_no_client_host(self): + """Test graceful handling when request has no client info.""" + cache = PriceCache() + request = MagicMock() + request.client = None # No client info + + call_count = 0 + + async def is_disconnected(): + nonlocal call_count + call_count += 1 + return call_count > 1 + + request.is_disconnected = is_disconnected + + events = [] + async for event in _generate_events(cache, request, interval=0): + events.append(event) + + # Should still work without client host info + assert events[0] == "retry: 1000\n\n" + + async def test_price_update_fields_in_payload(self): + """Test that each ticker payload contains all expected fields.""" + import json + + cache = PriceCache() + cache.update("AAPL", 190.00) + cache.update("AAPL", 191.00) # Second update so direction != flat + request = self._make_request(disconnected_after=2) + + async for event in _generate_events(cache, request, interval=0): + if event.startswith("data: "): + payload = json.loads(event[len("data: "):].strip()) + aapl = payload["AAPL"] + assert "ticker" in aapl + assert "price" in aapl + assert "previous_price" in aapl + assert "timestamp" in aapl + assert "change" in aapl + assert "change_percent" in aapl + assert "direction" in aapl + assert aapl["direction"] == "up" + break