diff --git a/backend/routers/rail.py b/backend/routers/rail.py index f7efc15..e94c012 100644 --- a/backend/routers/rail.py +++ b/backend/routers/rail.py @@ -22,6 +22,9 @@ _mem: dict = {"data": None, "ts": 0.0} _fetch_lock = asyncio.Lock() +# In-process cache for poller-written GTFS route shapes +_gtfs_mem: dict = {"data": None, "ts": 0.0} + def _build_query() -> str: return ( @@ -119,3 +122,30 @@ async def get_rail_tracks(): logger.warning("[rail] Redis write failed (cache not persisted): %s", exc) return geojson + + +@router.get("/rail/gtfs-shapes") +async def get_gtfs_shapes(): + """TriMet GTFS route shapes as GeoJSON (written by poller into Redis, 24-hour TTL). + + Returns an empty FeatureCollection when the poller hasn't populated the cache yet + (e.g., TRIMET_GTFS_ENABLED=false or first boot before the static GTFS has been fetched). + """ + redis_key = "cache:gtfs:trimet:shapes" + now = time.monotonic() + + if _gtfs_mem["data"] is not None and (now - _gtfs_mem["ts"]) < _CACHE_TTL_S: + return _gtfs_mem["data"] + + try: + r = get_redis() + cached_raw = await r.get(redis_key) + if cached_raw: + geojson = json.loads(cached_raw) + _gtfs_mem["data"] = geojson + _gtfs_mem["ts"] = now + return geojson + except Exception as exc: + logger.warning("[rail] Redis GTFS shapes read failed: %s", exc) + + return {"type": "FeatureCollection", "features": []} diff --git a/frontend/src/components/layers/RailLayer.tsx b/frontend/src/components/layers/RailLayer.tsx index 9995532..0e0a308 100644 --- a/frontend/src/components/layers/RailLayer.tsx +++ b/frontend/src/components/layers/RailLayer.tsx @@ -8,86 +8,111 @@ interface Props { map: maplibregl.Map } -const SRC_ID = 'rail-tracks-src' -const LINE_ID = 'rail-tracks-line' +const OSM_SRC_ID = 'rail-tracks-src' +const OSM_LINE_ID = 'rail-tracks-line' +const GTFS_SRC_ID = 'rail-gtfs-src' +const GTFS_LINE_ID = 'rail-gtfs-line' export function RailLayer({ map }: Props) { const railTracksVisible = useCivicStore(s => s.railTracksVisible) - const loadedRef = useRef(false) + const osmLoadedRef = useRef(false) + const gtfsLoadedRef = useRef(false) useEffect(() => { if (!map || typeof map.getSource !== 'function') return - if (loadedRef.current) return - const load = async () => { + // OSM tracks — basemap-style rail geometry for all mainline/freight/Amtrak + const loadOsm = async () => { + if (osmLoadedRef.current) return try { const res = await fetch(`${API_BASE}/rail/tracks`, { headers: authHeaders() }) if (!res.ok) return const geojson = await res.json() + if (map.getSource(OSM_SRC_ID)) return - if (map.getSource(SRC_ID)) return // already added (strict-mode double-effect guard) + map.addSource(OSM_SRC_ID, { type: 'geojson', data: geojson }) + // Insert OSM below the GTFS layer if it already loaded; otherwise append + const beforeId = map.getLayer(GTFS_LINE_ID) ? GTFS_LINE_ID : undefined + map.addLayer( + { + id: OSM_LINE_ID, + type: 'line', + source: OSM_SRC_ID, + layout: { 'line-join': 'round', 'line-cap': 'round' }, + paint: { + 'line-color': [ + 'match', ['get', 'railway'], + 'light_rail', '#a78bfa', + '#b45309', + ], + 'line-width': ['interpolate', ['linear'], ['zoom'], 6, 1, 10, 1.5, 14, 2.5], + 'line-opacity': 0.45, + }, + }, + beforeId, + ) + osmLoadedRef.current = true + } catch { /* retry via interval */ } + } - map.addSource(SRC_ID, { type: 'geojson', data: geojson }) + // GTFS shapes — official TriMet route geometry with per-route brand colors + const loadGtfs = async () => { + if (gtfsLoadedRef.current) return + try { + const res = await fetch(`${API_BASE}/rail/gtfs-shapes`, { headers: authHeaders() }) + if (!res.ok) return + const geojson = await res.json() + // Empty means the poller hasn't run yet; retry later + if (!geojson.features?.length) return + if (map.getSource(GTFS_SRC_ID)) return + + map.addSource(GTFS_SRC_ID, { type: 'geojson', data: geojson }) map.addLayer({ - id: LINE_ID, + id: GTFS_LINE_ID, type: 'line', - source: SRC_ID, - layout: { - 'line-join': 'round', - 'line-cap': 'round', - }, + source: GTFS_SRC_ID, + layout: { 'line-join': 'round', 'line-cap': 'round' }, paint: { - 'line-color': [ - 'match', - ['get', 'railway'], - 'light_rail', '#a78bfa', // violet for light rail / MAX - '#b45309', // amber-700 rust for mainline freight/Amtrak - ], - 'line-width': [ - 'interpolate', ['linear'], ['zoom'], - 6, 1, - 10, 2, - 14, 3, - ], - 'line-opacity': 0.7, + 'line-color': ['get', 'route_color'], + 'line-width': ['interpolate', ['linear'], ['zoom'], 6, 2, 10, 3.5, 14, 6], + 'line-opacity': 0.9, }, }) - - loadedRef.current = true - } catch { - // Overpass API may be slow on first load; retry will fire via the interval below - } + gtfsLoadedRef.current = true + } catch { /* retry via interval */ } } - // Initial attempt + const loadAll = () => { loadOsm(); loadGtfs() } + if (map.isStyleLoaded()) { - load() + loadAll() } else { - map.once('load', load) + map.once('load', loadAll) } - // Retry every 30 s until the backend Overpass cache warms up after a rebuild + // Retry every 30 s until both sources are loaded (Overpass and poller cache warm-up) const retryInterval = setInterval(() => { - if (!loadedRef.current) load() + if (!osmLoadedRef.current || !gtfsLoadedRef.current) loadAll() }, 30_000) return () => clearInterval(retryInterval) }, [map]) - // Toggle layer visibility when railTracksVisible changes useEffect(() => { if (!map || typeof map.getLayer !== 'function') return - if (!map.getLayer(LINE_ID)) return - map.setLayoutProperty(LINE_ID, 'visibility', railTracksVisible ? 'visible' : 'none') + const vis = railTracksVisible ? 'visible' : 'none' + if (map.getLayer(OSM_LINE_ID)) map.setLayoutProperty(OSM_LINE_ID, 'visibility', vis) + if (map.getLayer(GTFS_LINE_ID)) map.setLayoutProperty(GTFS_LINE_ID, 'visibility', vis) }, [map, railTracksVisible]) - // Cleanup on unmount useEffect(() => { return () => { if (!map || typeof map.getLayer !== 'function') return try { - if (map.getLayer(LINE_ID)) map.removeLayer(LINE_ID) - if (map.getSource(SRC_ID)) map.removeSource(SRC_ID) + if (map.getLayer(GTFS_LINE_ID)) map.removeLayer(GTFS_LINE_ID) + if (map.getSource(GTFS_SRC_ID)) map.removeSource(GTFS_SRC_ID) + if (map.getLayer(OSM_LINE_ID)) map.removeLayer(OSM_LINE_ID) + if (map.getSource(OSM_SRC_ID)) map.removeSource(OSM_SRC_ID) } catch { /* ignore */ } } }, [map]) diff --git a/poller/pollers/gtfs_rt.py b/poller/pollers/gtfs_rt.py index ef85f86..31ebd45 100644 --- a/poller/pollers/gtfs_rt.py +++ b/poller/pollers/gtfs_rt.py @@ -1,6 +1,7 @@ import asyncio import csv import io +import json import logging import time import zipfile @@ -8,7 +9,7 @@ import httpx -from bus import publish_entity +from bus import get_bus, publish_entity from config import settings from .base import BasePoller @@ -161,6 +162,8 @@ async def _ensure_route_map(self, state: _FeedState) -> dict[str, dict]: "type": int(row.get("route_type", -1)), "short_name": row.get("route_short_name", "").strip(), "long_name": row.get("route_long_name", "").strip(), + "color": row.get("route_color", "").strip(), + "text_color": row.get("route_text_color", "FFFFFF").strip(), } state.route_map = route_map @@ -168,11 +171,102 @@ async def _ensure_route_map(self, state: _FeedState) -> dict[str, dict]: logger.info( "[gtfs_rt:%s] loaded %d routes from static GTFS", feed.name, len(route_map) ) + + # Build and cache route shape GeoJSON without blocking vehicle position polling + asyncio.create_task( + self._build_and_cache_shapes(zf, route_map, set(feed.route_types), feed.name) + ) except Exception as exc: logger.warning("[gtfs_rt:%s] static GTFS fetch failed: %s", feed.name, exc) return state.route_map + # ── GTFS shape cache builder ─────────────────────────────────────────────── + + async def _build_and_cache_shapes( + self, + zf: zipfile.ZipFile, + route_map: dict[str, dict], + allowed_types: set[int], + feed_name: str, + ) -> None: + """Parse shapes.txt + trips.txt from the static GTFS zip and write one + MultiLineString GeoJSON feature per route to Redis.""" + try: + namelist = zf.namelist() + + # trips.txt: first shape_id seen per route_id wins + shape_to_route: dict[str, str] = {} + if "trips.txt" in namelist: + with zf.open("trips.txt") as f: + for row in csv.DictReader(io.TextIOWrapper(f, "utf-8")): + sid = row.get("shape_id", "").strip() + rid = row.get("route_id", "").strip() + if sid and rid and sid not in shape_to_route: + shape_to_route[sid] = rid + + # shapes.txt: collect (seq, lon, lat) tuples per shape_id + shape_pts: dict[str, list[tuple[int, float, float]]] = {} + if "shapes.txt" in namelist: + with zf.open("shapes.txt") as f: + for row in csv.DictReader(io.TextIOWrapper(f, "utf-8")): + sid = row.get("shape_id", "").strip() + if not sid: + continue + try: + lat = float(row["shape_pt_lat"]) + lon = float(row["shape_pt_lon"]) + seq = int(row.get("shape_pt_sequence", 0)) + except (KeyError, ValueError): + continue + shape_pts.setdefault(sid, []).append((seq, lon, lat)) + + # Sort each shape by sequence and flatten to [lon, lat] pairs + sorted_shapes: dict[str, list[list[float]]] = {} + for sid, pts in shape_pts.items(): + pts.sort(key=lambda x: x[0]) + sorted_shapes[sid] = [[lon, lat] for _, lon, lat in pts] + + # Group shapes into one MultiLineString per route (rail types only) + route_lines: dict[str, list[list[list[float]]]] = {} + for shape_id, route_id in shape_to_route.items(): + info = route_map.get(route_id) + if info is None or info["type"] not in allowed_types: + continue + coords = sorted_shapes.get(shape_id) + if not coords or len(coords) < 2: + continue + route_lines.setdefault(route_id, []).append(coords) + + features = [] + for route_id, lines in route_lines.items(): + info = route_map[route_id] + raw_color = info.get("color", "").strip() + raw_text = info.get("text_color", "FFFFFF").strip() + features.append({ + "type": "Feature", + "geometry": {"type": "MultiLineString", "coordinates": lines}, + "properties": { + "route_id": route_id, + "route_short_name": info["short_name"], + "route_long_name": info["long_name"], + "route_type": info["type"], + "route_color": f"#{raw_color}" if raw_color else "#a78bfa", + "route_text_color": f"#{raw_text}" if raw_text else "#FFFFFF", + }, + }) + + geojson = {"type": "FeatureCollection", "features": features} + redis_key = f"cache:gtfs:{feed_name}:shapes" + r = await get_bus() + await r.set(redis_key, json.dumps(geojson), ex=int(_STATIC_CACHE_TTL) + 3600) + logger.info( + "[gtfs_rt:%s] cached %d route shapes to Redis (%s)", + feed_name, len(features), redis_key, + ) + except Exception as exc: + logger.warning("[gtfs_rt:%s] shape cache build failed: %s", feed_name, exc) + # ── GTFS-RT fetch + parse ───────────────────────────────────────────────── async def _poll_once(self, state: _FeedState):