Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions backend/tests/market/test_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
"""Tests for SSE streaming endpoint."""

from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from app.market.cache import PriceCache
from app.market.stream import _generate_events, create_stream_router


class TestCreateStreamRouter:
"""Tests for the create_stream_router factory."""

def test_returns_router(self):
"""Test that create_stream_router returns an APIRouter."""
from fastapi import APIRouter

cache = PriceCache()
router = create_stream_router(cache)
assert isinstance(router, APIRouter)

def test_router_has_prices_route(self):
"""Test that the router has a /prices route registered."""
cache = PriceCache()
router = create_stream_router(cache)

routes = [r.path for r in router.routes]
assert "/prices" in routes

def test_router_prefix(self):
"""Test that the router has the expected prefix."""
from app.market import stream

assert stream.router.prefix == "/api/stream"


@pytest.mark.asyncio
class TestGenerateEvents:
"""Tests for the _generate_events async generator."""

def _make_request(self, *, disconnected_after: int = 3) -> MagicMock:
"""Create a mock Request that disconnects after N calls."""
request = MagicMock()
request.client = MagicMock()
request.client.host = "127.0.0.1"

call_count = 0

async def is_disconnected():
nonlocal call_count
call_count += 1
return call_count > disconnected_after

request.is_disconnected = is_disconnected
return request

async def test_yields_retry_directive_first(self):
"""Test that the first yielded value is the retry directive."""
cache = PriceCache()
request = self._make_request(disconnected_after=0)

gen = _generate_events(cache, request, interval=0)
first = await gen.__anext__()
assert first == "retry: 1000\n\n"

async def test_yields_data_when_cache_has_prices(self):
"""Test that price data is yielded when the cache has entries."""
import json

cache = PriceCache()
cache.update("AAPL", 190.50)
request = self._make_request(disconnected_after=2)

events = []
async for event in _generate_events(cache, request, interval=0):
events.append(event)

# First event is retry directive; subsequent events are data
data_events = [e for e in events if e.startswith("data: ")]
assert len(data_events) >= 1

payload = json.loads(data_events[0][len("data: "):].strip())
assert "AAPL" in payload
assert payload["AAPL"]["price"] == 190.50
assert payload["AAPL"]["ticker"] == "AAPL"

async def test_no_data_event_for_empty_cache(self):
"""Test that no data event is sent when the cache is empty."""
cache = PriceCache()
request = self._make_request(disconnected_after=2)

events = []
async for event in _generate_events(cache, request, interval=0):
events.append(event)

data_events = [e for e in events if e.startswith("data: ")]
assert len(data_events) == 0

async def test_includes_all_tickers(self):
"""Test that all cached tickers appear in the data event."""
import json

cache = PriceCache()
cache.update("AAPL", 190.50)
cache.update("GOOGL", 175.25)
cache.update("MSFT", 420.00)
request = self._make_request(disconnected_after=2)

async for event in _generate_events(cache, request, interval=0):
if event.startswith("data: "):
payload = json.loads(event[len("data: "):].strip())
assert set(payload.keys()) == {"AAPL", "GOOGL", "MSFT"}
break

async def test_data_event_format(self):
"""Test that data events are properly SSE-formatted."""
cache = PriceCache()
cache.update("AAPL", 190.50)
request = self._make_request(disconnected_after=2)

async for event in _generate_events(cache, request, interval=0):
if event.startswith("data: "):
# SSE format: "data: <payload>\n\n"
assert event.endswith("\n\n")
break

async def test_stops_on_disconnect(self):
"""Test that the generator stops when the client disconnects."""
cache = PriceCache()
cache.update("AAPL", 190.50)
# Disconnect immediately
request = self._make_request(disconnected_after=0)

events = []
async for event in _generate_events(cache, request, interval=0):
events.append(event)

# Only the retry directive should be yielded (disconnect on first check)
assert len(events) == 1
assert events[0] == "retry: 1000\n\n"

async def test_version_change_detection(self):
"""Test that events are only sent when the cache version changes."""
import json

cache = PriceCache()
cache.update("AAPL", 190.50)

call_count = 0

async def is_disconnected():
nonlocal call_count
call_count += 1
return call_count > 5 # Allow several iterations

request = MagicMock()
request.client = MagicMock()
request.client.host = "127.0.0.1"
request.is_disconnected = is_disconnected

events = []
async for event in _generate_events(cache, request, interval=0):
events.append(event)

data_events = [e for e in events if e.startswith("data: ")]
# Version only changes on cache.update(); since we don't update again,
# we should see at most one data event (on first version change detection)
assert len(data_events) == 1

async def test_no_client_host(self):
"""Test graceful handling when request has no client info."""
cache = PriceCache()
request = MagicMock()
request.client = None # No client info

call_count = 0

async def is_disconnected():
nonlocal call_count
call_count += 1
return call_count > 1

request.is_disconnected = is_disconnected

events = []
async for event in _generate_events(cache, request, interval=0):
events.append(event)

# Should still work without client host info
assert events[0] == "retry: 1000\n\n"

async def test_price_update_fields_in_payload(self):
"""Test that each ticker payload contains all expected fields."""
import json

cache = PriceCache()
cache.update("AAPL", 190.00)
cache.update("AAPL", 191.00) # Second update so direction != flat
request = self._make_request(disconnected_after=2)

async for event in _generate_events(cache, request, interval=0):
if event.startswith("data: "):
payload = json.loads(event[len("data: "):].strip())
aapl = payload["AAPL"]
assert "ticker" in aapl
assert "price" in aapl
assert "previous_price" in aapl
assert "timestamp" in aapl
assert "change" in aapl
assert "change_percent" in aapl
assert "direction" in aapl
assert aapl["direction"] == "up"
break
Loading