diff --git a/README.md b/README.md index 62a7c82..d77924f 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,8 @@ hivewatch.init( - `runs/.jsonl` for the complete event history - `runs/.map.json` for map-ready metadata that can be loaded directly later -Serve the dashboard separately: +Serve the dashboard separately. Without `--run-id`, the server watches the run +directory and streams new events as they arrive: ```bash hivewatch map run --runs-dir runs --port 7070 @@ -78,7 +79,10 @@ Open one specific saved run in static mode: hivewatch map run --runs-dir runs --run-id run-abc123 ``` -The bundled `examples/hivewatch_map.html` viewer loads map metadata first and falls back to the JSONL-derived event history for older runs. This keeps local development and later replay workflows compatible with the same viewer. +The bundled viewer loads map metadata first and falls back to the JSONL-derived +event history for older runs. It displays non-geographic client metadata in the +sidebar automatically. Prefix a client metadata key with `_` when the value +should remain in the saved artifact but stay hidden from the map card. ### Package layout @@ -190,6 +194,11 @@ hivewatch.init(emitters=[MyEmitter()]) | `lat` / `lng` / `country` | float/str | Client location metadata for map visualization | | `base_round` | int | For asynchronous FL, staleness is `round - base_round` | +Unknown client fields are preserved in `ClientUpdate` objects and the local +JSONL/`.map.json` artifacts. The map dashboard displays visible non-geographic +fields automatically. Use a leading underscore for values that should be stored +but hidden from the bundled map viewer, for example `_debug_score`. + ## Logged Metrics ### Weights & Biases @@ -236,7 +245,8 @@ FL Server ▼ hivewatch ├── WandbEmitter → wandb.ai dashboard - └── MLflowEmitter → MLflow UI (localhost:5000) + ├── MLflowEmitter → MLflow UI (localhost:5000) + └── SSEEmitter → local JSONL/.map.json artifacts and map dashboard ``` `hivewatch` does not depend on a specific transport layer or FL framework. Applications bridge their training framework to `hivewatch` in the same way they would bridge it to another experiment tracking backend. diff --git a/docs/appfl.rst b/docs/appfl.rst index 87a66ad..9967a0d 100644 --- a/docs/appfl.rst +++ b/docs/appfl.rst @@ -51,4 +51,6 @@ HiveWatch understands common FL metadata such as: * ``lat``, ``lng``, ``city``, and ``country`` for map views Unknown keys are preserved in the client payload, so you can attach -framework-specific metadata without breaking the core schema. +framework-specific metadata without breaking the core schema. The bundled map +viewer displays non-geographic client metadata automatically. Prefix a key with +``_`` when it should be saved but hidden from the map card. diff --git a/docs/emitters.rst b/docs/emitters.rst index 60b3be1..8a846c0 100644 --- a/docs/emitters.rst +++ b/docs/emitters.rst @@ -26,6 +26,21 @@ This emitter writes: * ``runs/.jsonl`` with the full event stream * ``runs/.map.json`` with map-ready metadata +Set ``serve_map=False`` when training and dashboard serving should be separate. +This is useful for batch jobs, shared run directories, or environments where a +long-running dashboard process already exists. + +Server metadata, such as the server location or host label, can be recorded +with the runtime API: + +.. code-block:: python + + hw.set_server_metadata(city="Chicago", country="US", lat=41.7, lng=-87.9) + +Unknown client metadata is preserved in both artifacts. Fields prefixed with +``_`` remain available in the saved data but are hidden by the bundled map +viewer. + Weights & Biases Emitter ======================== diff --git a/docs/index.rst b/docs/index.rst index 9d9d174..f80d788 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,7 +10,7 @@ forcing you into a single training framework. With HiveWatch you can: * log client and round metrics from custom training loops, -* stream and replay runs locally with the built-in map dashboard, +* stream, inspect, and replay runs locally with the built-in map dashboard, * send metrics to Weights & Biases, MLflow, or both, and * integrate observability into APPFL workflows with minimal glue code. @@ -55,8 +55,8 @@ HiveWatch is organized around a few clear technical components. Map Dashboard ^^^^^^^^^^^^^ - Visualizes client geography and run progress from saved ``.jsonl`` and - ``.map.json`` artifacts. + Visualizes client geography, training metadata, and replay progress + from saved ``.jsonl`` and ``.map.json`` artifacts. .. grid-item-card:: diff --git a/docs/map.rst b/docs/map.rst index 731360e..9705c86 100644 --- a/docs/map.rst +++ b/docs/map.rst @@ -3,7 +3,8 @@ Map Dashboard ============= HiveWatch ships with a local dashboard flow based on saved run artifacts and a -small HTTP server. +small HTTP server. The same viewer supports live monitoring and replay of +completed runs. Generate run artifacts ---------------------- @@ -27,16 +28,19 @@ Serve the dashboard hivewatch map run --runs-dir runs --port 7070 +Without ``--run-id``, the server watches the run directory and publishes new +JSONL events as they arrive. + Useful flags ------------ * ``--host`` changes the bind address. +* ``--run-id`` opens one saved run in static replay mode. * ``--map-path`` points at a custom HTML viewer. * ``--poll-interval`` controls how often the runs directory is rescanned. -The bundled viewer in the ``examples/`` directory loads map metadata first and -falls back to the raw event history for older runs. That means the same viewer -works for both live monitoring and deferred replay. +The bundled viewer loads ``.map.json`` metadata first and falls back to the raw +event history for older runs. Common map commands ------------------- @@ -46,6 +50,9 @@ Common map commands # Watch a runs directory for live updates hivewatch map run --runs-dir runs --port 7070 + # Open one completed run in static replay mode + hivewatch map run --runs-dir runs --run-id run-abc123 + # Serve a custom viewer HTML file hivewatch map run --runs-dir runs --map-path /path/to/viewer.html @@ -58,3 +65,27 @@ Deferred viewing Because HiveWatch persists map metadata separately from the raw JSONL log, you can train first and inspect later. This keeps a live workflow and a replay workflow compatible with the same dashboard interface. + +Client metadata in the map +-------------------------- + +The sidebar shows client fields from the run artifacts. Geo and system identity +fields such as ``lat``, ``lng``, ``city``, ``country``, ``ip``, and +``client_id`` are used internally or omitted from the card. Other scalar +metadata is displayed automatically. + +Use a leading underscore for metadata that should be preserved but hidden from +the bundled viewer: + +.. code-block:: python + + hw.log_client_update( + client_id="client-1", + round=round_num, + current_local_steps=200, + blocking=True, + _debug_score=0.92, + ) + +The map also includes a draggable event log. Drag it by the header to move it; +double-click the header to reset its saved position. diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 2a9ec84..c7deff3 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -27,6 +27,7 @@ Minimal example local_accuracy=0.70 + round_num * 0.03, local_loss=0.90 - round_num * 0.08, num_samples=500, + current_local_steps=200, bytes_sent=8192, lat=41.88, lng=-87.63, @@ -46,6 +47,11 @@ This creates two run artifacts under ``runs/``: * ``.jsonl`` with the full event stream. * ``.map.json`` with map-ready metadata for replay and deferred viewing. +Client metadata that is not part of the core schema is preserved. The map +dashboard displays non-geographic client fields automatically; prefix a field +with ``_`` when it should remain in the artifact but stay hidden from the map +card, for example ``_debug_score``. + Serve the dashboard ------------------- diff --git a/src/hivewatch/cli.py b/src/hivewatch/cli.py index 717d08a..0a075cc 100644 --- a/src/hivewatch/cli.py +++ b/src/hivewatch/cli.py @@ -33,8 +33,9 @@ def main(argv=None): args = parser.parse_args(argv) if args.command == "map" and args.map_command == "run": - # Disable watch mode when loading a specific run - watch_mode = args.run_id + # Static run replay reads the selected JSONL directly; the directory + # watcher is only needed for live dashboards. + watch_mode = args.run_id is None server = MapServer( host=args.host, diff --git a/src/hivewatch/emitters/sse_emitter.py b/src/hivewatch/emitters/sse_emitter.py index 237bcb2..28d8129 100644 --- a/src/hivewatch/emitters/sse_emitter.py +++ b/src/hivewatch/emitters/sse_emitter.py @@ -308,20 +308,4 @@ def _start_server(self): @staticmethod def _client_dict(c: ClientUpdate) -> dict: - return { - "client_id": c.client_id, - "round": c.round, - "lat": c.lat, - "lng": c.lng, - "city": c.city, - "country": c.country, - "local_accuracy": c.local_accuracy, - "local_loss": c.local_loss, - "num_samples": c.num_samples, - "gradient_norm": c.gradient_norm, - "bytes_sent": c.bytes_sent, - "train_time_sec": c.train_time_sec, - "cpu_pct": c.cpu_pct, - "ram_mb": c.ram_mb, - "status": c.status, - } + return c.to_dict() diff --git a/src/hivewatch/map/hivewatch_map.html b/src/hivewatch/map/hivewatch_map.html index 4df0a93..7eb6cc1 100644 --- a/src/hivewatch/map/hivewatch_map.html +++ b/src/hivewatch/map/hivewatch_map.html @@ -70,7 +70,7 @@ .round-stat { background:var(--bg); border:1px solid var(--border); border-radius:6px; padding:8px 10px; } .round-stat .val { font-size:16px; font-weight:600; color:var(--text); line-height:1; } .round-stat .lbl { font-size:9px; color:var(--muted); margin-top:3px; } - .client-list { flex:1; overflow-y:auto; padding:8px; } + .client-list { flex:1; overflow-y:auto; padding:8px; min-height:0; } .client-list::-webkit-scrollbar { width:4px; } .client-list::-webkit-scrollbar-thumb { background:var(--border); border-radius:2px; } .client-card { padding:10px 12px; border-radius:8px; border:1px solid var(--border); margin-bottom:6px; cursor:pointer; transition:all 0.2s; background:var(--surface); } @@ -82,9 +82,9 @@ .client-badge.dropped { background:var(--red); } .client-location { font-size:10px; color:var(--muted); margin-bottom:5px; } .client-metrics { display:grid; grid-template-columns:1fr 1fr; gap:4px; } - .metric-chip { background:var(--bg); border:1px solid var(--border); border-radius:4px; padding:3px 6px; font-size:10px; } - .metric-chip .val { color:var(--text); font-weight:500; } - .metric-chip .lbl { color:var(--muted); font-size:9px; } + .metric-chip { background:var(--bg); border:1px solid var(--border); border-radius:4px; padding:5px 7px; font-size:10px; min-width:0; } + .metric-chip .val { color:var(--text); font-weight:600; line-height:1.25; overflow-wrap:anywhere; } + .metric-chip .lbl { color:var(--muted); font-size:9px; line-height:1.25; overflow-wrap:anywhere; } .map-area { flex:1; position:relative; display:flex; flex-direction:column; min-height:0; } #map { flex:1; z-index:1; min-height:0; } .leaflet-tile-pane { filter: var(--tile-filter); transition: filter 0.3s; } @@ -130,9 +130,12 @@ .speed-select { background:var(--bg); border:1px solid var(--border); border-radius:6px; padding:4px 8px; font-size:10px; font-family:'JetBrains Mono',monospace; color:var(--text); cursor:pointer; } .playback-mode { font-size:10px; color:var(--muted); flex-shrink:0; } .playback-mode.replay { color:var(--orange); font-weight:600; } - .log-panel { position:absolute; bottom:16px; right:16px; width:290px; background:var(--surface); border:1px solid var(--border); border-radius:10px; z-index:999; overflow:hidden; box-shadow:0 4px 20px rgba(0,0,0,0.12); opacity:0.95; } - .log-header { padding:8px 12px; border-bottom:1px solid var(--border); font-size:9px; letter-spacing:1px; text-transform:uppercase; color:var(--muted); display:flex; justify-content:space-between; } - .log-entries { max-height:110px; overflow-y:auto; padding:6px; } + .log-panel { background:var(--surface); border:1px solid var(--border); border-radius:10px; overflow:hidden; box-shadow:0 4px 20px rgba(0,0,0,0.12); opacity:0.95; } + .log-panel.floating { position:fixed; width:290px; max-width:calc(100vw - 24px); z-index:2000; } + .log-panel.dragging { opacity:0.88; } + .sidebar-log-panel { margin:0 12px 12px; flex-shrink:0; } + .log-header { padding:8px 12px; border-bottom:1px solid var(--border); font-size:9px; letter-spacing:1px; text-transform:uppercase; color:var(--muted); display:flex; justify-content:space-between; cursor:move; user-select:none; touch-action:none; } + .log-entries { max-height:180px; overflow-y:auto; padding:6px; } .log-entries::-webkit-scrollbar { width:3px; } .log-entries::-webkit-scrollbar-thumb { background:var(--border); } .log-entry { font-size:10px; padding:3px 6px; border-radius:4px; margin-bottom:2px; display:flex; gap:8px; animation:fadeIn 0.3s ease; } @@ -197,6 +200,13 @@
Waiting for clients…
+ @@ -209,14 +219,6 @@ Events received: 0 -
-
- Event Log - 0 events -
-
-
-
@@ -233,7 +235,7 @@ - +
live
@@ -433,6 +435,8 @@ frame: null, startedAt: 0, cycleMs: state.roundDurationMs, + replay: false, + endpointTimer: null, }; function accColor(acc) { @@ -546,6 +550,23 @@ cancelAnimationFrame(networkAnim.frame); networkAnim.frame = null; } + if (networkAnim.endpointTimer) { + clearTimeout(networkAnim.endpointTimer); + networkAnim.endpointTimer = null; + } +} + +function setReplayPacketsAtFinalState() { + networkAnim.endpointTimer = null; + Object.keys(lines).forEach(id => { + const client = state.clients[id]; + if (!packets[id]) return; + const packetLL = client ? getPacketClientLL(id, client) : SERVER_LL; + packets[id].uplink.setLatLng(SERVER_LL); + packets[id].downlink + .setStyle({ opacity: 1, fillOpacity: 0.95 }) + .setLatLng(packetLL); + }); } function animateNetworkTraffic(now) { @@ -558,26 +579,42 @@ return; } - const cycle = ((now - networkAnim.startedAt) % networkAnim.cycleMs) / networkAnim.cycleMs; - const uploadProgress = pingPong(cycle); - const downloadProgress = pingPong((cycle + 0.5) % 1); + const elapsed = now - networkAnim.startedAt; + const cycle = networkAnim.replay + ? Math.min(1, elapsed / networkAnim.cycleMs) + : ((elapsed % networkAnim.cycleMs) / networkAnim.cycleMs); + const uploadPhase = 0.46; + const downloadStart = 0.56; + const uploadProgress = networkAnim.replay ? Math.min(1, cycle / uploadPhase) : pingPong(cycle); + const downloadProgress = networkAnim.replay ? Math.max(0, (cycle - downloadStart) / (1 - downloadStart)) : pingPong((cycle + 0.5) % 1); + const showDownload = !networkAnim.replay || cycle >= downloadStart; activeIds.forEach(id => { const client = state.clients[id]; const packetLL = getPacketClientLL(id, client); ensureClientPackets(id, packetLL, accColor(client.local_accuracy)); packets[id].uplink.setLatLng(lerpLatLng(packetLL, SERVER_LL, uploadProgress)); - packets[id].downlink.setLatLng(lerpLatLng(SERVER_LL, packetLL, downloadProgress)); + packets[id].downlink + .setStyle({ opacity: showDownload ? 1 : 0, fillOpacity: showDownload ? 0.95 : 0 }) + .setLatLng(lerpLatLng(SERVER_LL, packetLL, downloadProgress)); }); + if (networkAnim.replay && cycle >= 1) { + networkAnim.frame = null; + return; + } networkAnim.frame = requestAnimationFrame(animateNetworkTraffic); } -function syncNetworkAnimation(durationMs) { - networkAnim.cycleMs = Math.max(1200, durationMs || state.roundDurationMs || 4000); +function syncNetworkAnimation(durationMs, minDurationMs = 1200, replay = false) { + networkAnim.cycleMs = Math.max(minDurationMs, durationMs || state.roundDurationMs || 4000); networkAnim.startedAt = performance.now(); + networkAnim.replay = replay; stopNetworkAnimation(); networkAnim.frame = requestAnimationFrame(animateNetworkTraffic); + if (replay) { + networkAnim.endpointTimer = setTimeout(setReplayPacketsAtFinalState, networkAnim.cycleMs); + } } function upsertClient(id, data) { @@ -660,10 +697,7 @@
📍 ${c.city||"Unknown"}, ${c.country||""}
-
${fmtAcc(c.local_accuracy)}
accuracy
-
${fmtLoss(c.local_loss)}
loss
-
${c.num_samples||"—"}
samples
-
${c.gradient_norm!=null?parseFloat(c.gradient_norm).toFixed(2):"—"}
grad norm
+ ${renderClientMetrics(c)}
`).join("") || @@ -692,6 +726,123 @@ if (lastEvt) document.getElementById("dbg-last").textContent = lastEvt; document.getElementById("dbg-count").textContent = state.evtCount; } +const HIDDEN_CLIENT_METADATA_KEYS = new Set([ + "client_id", "id", "lat", "lng", "latitude", "longitude", + "city", "country", "region", "timezone", "postal", "zip", + "ip", "org", "host", "hostname", "fallback", "status", +]); + +function escapeHtml(value) { + return String(value) + .replaceAll("&", "&") + .replaceAll("<", "<") + .replaceAll(">", ">") + .replaceAll('"', """) + .replaceAll("'", "'"); +} + +function shouldShowClientMetadata(key, value) { + if (!key || key.startsWith("_") || value == null) return false; + const lower = key.toLowerCase(); + if (HIDDEN_CLIENT_METADATA_KEYS.has(lower)) return false; + if (lower.endsWith("_lat") || lower.endsWith("_lng")) return false; + if (lower.includes("latitude") || lower.includes("longitude")) return false; + if (Array.isArray(value)) return value.length > 0; + if (typeof value === "object") return Object.keys(value).length > 0; + return true; +} + +function formatClientMetadataValue(value) { + if (typeof value === "boolean") return value ? "true" : "false"; + if (typeof value === "number") { + if (!Number.isFinite(value)) return "—"; + return Number.isInteger(value) ? String(value) : value.toFixed(4); + } + if (Array.isArray(value) || typeof value === "object") return JSON.stringify(value); + return String(value); +} + +function renderClientMetrics(client) { + return Object.entries(client) + .filter(([key, value]) => shouldShowClientMetadata(key, value)) + .map(([key, value]) => ` +
+
${escapeHtml(formatClientMetadataValue(value))}
+
${escapeHtml(key)}
+
+ `).join("") || + '
No visible metadata
'; +} +function initDraggableLogPanel() { + const panel = document.querySelector(".log-panel"); + const handle = panel?.querySelector(".log-header"); + if (!panel || !handle) return; + const storageKey = "hivewatch-log-panel-position"; + let drag = null; + + function clampPosition(left, top) { + const width = panel.offsetWidth || 290; + const height = panel.offsetHeight || 160; + return { + left: Math.max(8, Math.min(left, window.innerWidth - width - 8)), + top: Math.max(8, Math.min(top, window.innerHeight - height - 8)), + }; + } + + function floatPanel(left, top) { + const pos = clampPosition(left, top); + panel.classList.add("floating"); + panel.style.left = `${pos.left}px`; + panel.style.top = `${pos.top}px`; + panel.style.right = "auto"; + panel.style.bottom = "auto"; + return pos; + } + + try { + const saved = JSON.parse(localStorage.getItem(storageKey) || "null"); + if (saved) floatPanel(saved.left, saved.top); + } catch (_) {} + + handle.addEventListener("pointerdown", (event) => { + if (event.button !== 0) return; + const rect = panel.getBoundingClientRect(); + const pos = floatPanel(rect.left, rect.top); + drag = { + dx: event.clientX - pos.left, + dy: event.clientY - pos.top, + }; + panel.classList.add("dragging"); + handle.setPointerCapture(event.pointerId); + }); + + handle.addEventListener("pointermove", (event) => { + if (!drag) return; + const pos = floatPanel(event.clientX - drag.dx, event.clientY - drag.dy); + localStorage.setItem(storageKey, JSON.stringify(pos)); + }); + + function endDrag(event) { + if (!drag) return; + drag = null; + panel.classList.remove("dragging"); + try { handle.releasePointerCapture(event.pointerId); } catch (_) {} + } + + handle.addEventListener("pointerup", endDrag); + handle.addEventListener("pointercancel", endDrag); + handle.addEventListener("dblclick", () => { + localStorage.removeItem(storageKey); + panel.classList.remove("floating", "dragging"); + panel.removeAttribute("style"); + }); + window.addEventListener("resize", () => { + if (!panel.classList.contains("floating")) return; + const rect = panel.getBoundingClientRect(); + const pos = floatPanel(rect.left, rect.top); + localStorage.setItem(storageKey, JSON.stringify(pos)); + }); +} function handleLiveEvent(msg) { if (typeof msg === "string") { try { msg = JSON.parse(msg); } catch(e) { return; } @@ -812,7 +963,7 @@ if (snapshot.duration != null) state.roundDurationMs = Math.max(1200, snapshot.duration * 1000); snapshot.clients.forEach(c => upsertClient(c.client_id, c)); - syncNetworkAnimation(state.roundDurationMs); + syncNetworkAnimation(getReplayAnimationDuration(snapshot), 90, true); addLog( `Round ${snapshot.round} — acc=${fmtAcc(snapshot.globalAcc)} loss=${fmtLoss(snapshot.globalLoss)}`, @@ -967,7 +1118,16 @@ function getPlaybackDelay(snapshot) { const base = snapshot?.duration != null ? Math.max(250, snapshot.duration * 1000) : 1000; - return Math.max(50, Math.round(base * (pb.speed / 1000))); + return Math.max(120, Math.round(base * (pb.speed / 1000))); +} + +function getReplayAnimationDuration(snapshot) { + return Math.max(90, Math.round(getPlaybackDelay(snapshot) * 0.9)); +} + +function getCurrentPlaybackSnapshot() { + if (!pb.rounds.length) return null; + return pb.rounds[Math.max(0, Math.min(pb.index - 1, pb.rounds.length - 1))]; } function startPlayback() { @@ -1039,7 +1199,18 @@ updateProgress(); } -function setSpeed(val) { pb.speed = parseInt(val); } +function setSpeed(val) { + const nextSpeed = parseInt(val, 10); + if (!Number.isFinite(nextSpeed)) return; + pb.speed = nextSpeed; + if (!pb.playing) return; + + const snapshot = getCurrentPlaybackSnapshot(); + const delay = getPlaybackDelay(snapshot); + if (pb.timer) clearTimeout(pb.timer); + syncNetworkAnimation(getReplayAnimationDuration(snapshot), 90, true); + pb.timer = setTimeout(stepPlayback, delay); +} function connect() { const es = new EventSource(SSE_URL); @@ -1119,6 +1290,7 @@ } }; } +initDraggableLogPanel(); if (METADATA_URL || EVENTS_URL) { loadRunFromSource({ metadataUrl: METADATA_URL, diff --git a/src/hivewatch/map/server.py b/src/hivewatch/map/server.py index 0860f48..0b5d50a 100644 --- a/src/hivewatch/map/server.py +++ b/src/hivewatch/map/server.py @@ -94,6 +94,9 @@ def _serve_map(self): content = candidate.read_bytes() self.send_response(200) self.send_header("Content-Type", "text/html; charset=utf-8") + self.send_header("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0") + self.send_header("Pragma", "no-cache") + self.send_header("Expires", "0") self._cors() self.end_headers() self.wfile.write(content) @@ -290,6 +293,10 @@ def stop(self): self._server.shutdown() self._server.server_close() self._server = None + if self._watch_thread is not None and self._watch_thread.is_alive(): + self._watch_thread.join(timeout=max(1.0, self.poll_interval * 2)) + self._thread = None + self._watch_thread = None def publish(self, payload: dict): if payload.get("run_id"): @@ -330,3 +337,37 @@ def read_events(self, jsonl_path: Path) -> List[dict]: except json.JSONDecodeError as exc: logger.warning("[hivewatch/map] could not parse %s: %s", jsonl_path, exc) return events + + def _watch_paths(self) -> List[Path]: + if self._fixed_run_id and self._live_run_id: + return [self.runs_dir / f"{self._live_run_id}.jsonl"] + return sorted(self.runs_dir.glob("*.jsonl")) + + def _prime_watch_offsets(self): + for jsonl_path in self._watch_paths(): + if jsonl_path.exists(): + self._seen_offsets[jsonl_path] = jsonl_path.stat().st_size + + def _watch_loop(self): + while self._server is not None: + for jsonl_path in self._watch_paths(): + if not jsonl_path.exists(): + continue + + offset = self._seen_offsets.get(jsonl_path, 0) + try: + with jsonl_path.open("r", encoding="utf-8") as handle: + handle.seek(offset) + for line in handle: + line = line.strip() + if not line: + continue + try: + self.publish(json.loads(line)) + except json.JSONDecodeError as exc: + logger.warning("[hivewatch/map] could not parse %s: %s", jsonl_path, exc) + self._seen_offsets[jsonl_path] = handle.tell() + except OSError as exc: + logger.warning("[hivewatch/map] could not watch %s: %s", jsonl_path, exc) + + time.sleep(self.poll_interval) diff --git a/tests/cli_test.py b/tests/cli_test.py new file mode 100644 index 0000000..d427ba8 --- /dev/null +++ b/tests/cli_test.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from hivewatch import cli + + +class FakeMapServer: + instances = [] + + def __init__(self, **kwargs): + self.kwargs = kwargs + self.started = False + self.stopped = False + FakeMapServer.instances.append(self) + + def start(self): + self.started = True + + def serve_forever(self): + return None + + def stop(self): + self.stopped = True + + +def test_map_run_cli_watches_live_directory_when_no_run_id(monkeypatch): + FakeMapServer.instances = [] + monkeypatch.setattr(cli, "MapServer", FakeMapServer) + + assert cli.main(["map", "run", "--runs-dir", "runs"]) == 0 + + server = FakeMapServer.instances[0] + assert server.started is True + assert server.kwargs["watch"] is True + assert server.kwargs["run_id"] is None + + +def test_map_run_cli_disables_watcher_for_static_run(monkeypatch): + FakeMapServer.instances = [] + monkeypatch.setattr(cli, "MapServer", FakeMapServer) + + assert cli.main(["map", "run", "--runs-dir", "runs", "--run-id", "run-abc"]) == 0 + + server = FakeMapServer.instances[0] + assert server.kwargs["watch"] is False + assert server.kwargs["run_id"] == "run-abc" diff --git a/tests/map_metadata_test.py b/tests/map_metadata_test.py new file mode 100644 index 0000000..32d814a --- /dev/null +++ b/tests/map_metadata_test.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from hivewatch.map.metadata import build_map_metadata_from_events, build_rounds_from_events + + +def test_build_rounds_merges_client_updates_and_status_events(): + events = [ + { + "event_type": "client_update", + "round": 2, + "clients": [ + { + "client_id": "client-a", + "local_accuracy": 0.6, + "current_local_steps": 100, + "_hidden_acc": 0.1, + } + ], + }, + { + "event_type": "round_end", + "round": 2, + "round_metrics": { + "global_accuracy": 0.75, + "global_loss": 0.4, + "round_duration_sec": 3.2, + "gradient_divergence": 0.12, + }, + "clients": [ + { + "client_id": "client-a", + "local_loss": 0.25, + "current_local_steps": None, + }, + {"client_id": "client-b", "local_accuracy": 0.7}, + ], + }, + {"event_type": "comm_failure", "round": 3, "client_id": "client-b"}, + ] + + rounds = build_rounds_from_events(events) + + assert [round_state["round"] for round_state in rounds] == [2, 3] + round_two = rounds[0] + assert round_two["globalAcc"] == 0.75 + assert round_two["globalLoss"] == 0.4 + assert round_two["duration"] == 3.2 + assert round_two["divergence"] == 0.12 + + clients = {client["client_id"]: client for client in round_two["clients"]} + assert clients["client-a"]["local_accuracy"] == 0.6 + assert clients["client-a"]["local_loss"] == 0.25 + assert clients["client-a"]["current_local_steps"] == 100 + assert clients["client-a"]["_hidden_acc"] == 0.1 + assert clients["client-b"]["local_accuracy"] == 0.7 + + assert rounds[1]["clients"] == [{"client_id": "client-b", "status": "failed"}] + + +def test_build_map_metadata_includes_run_server_rounds_and_finish_time(): + events = [ + { + "event_type": "init", + "run_id": "run-1234", + "algorithm": "FedAvg", + "config": {"epochs": 2}, + "started_at": "2026-01-01T00:00:00+00:00", + }, + { + "event_type": "server_metadata", + "server": {"lat": 41.7, "lng": -87.9, "city": "Chicago"}, + }, + {"event_type": "round_end", "round": 1, "round_metrics": {}, "clients": []}, + {"event_type": "finished", "timestamp": "2026-01-01T00:01:00+00:00"}, + ] + + metadata = build_map_metadata_from_events(events) + + assert metadata["schema_version"] == 1 + assert metadata["run_id"] == "run-1234" + assert metadata["algorithm"] == "FedAvg" + assert metadata["config"] == {"epochs": 2} + assert metadata["server"] == {"lat": 41.7, "lng": -87.9, "city": "Chicago"} + assert metadata["finished_at"] == "2026-01-01T00:01:00+00:00" + assert metadata["rounds"][0]["round"] == 1 diff --git a/tests/map_server_test.py b/tests/map_server_test.py new file mode 100644 index 0000000..b00aa1d --- /dev/null +++ b/tests/map_server_test.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import json +import queue +import urllib.error +import urllib.request + +import pytest + +from hivewatch.map.server import MapServer + + +def write_jsonl(path, events): + path.write_text( + "".join(json.dumps(event) + "\n" for event in events), + encoding="utf-8", + ) + + +def get_json(base_url: str, path: str): + with urllib.request.urlopen(base_url + path, timeout=2) as response: + return json.loads(response.read().decode("utf-8")) + + +def get_text_response(base_url: str, path: str): + response = urllib.request.urlopen(base_url + path, timeout=2) + try: + return response, response.read().decode("utf-8") + except Exception: + response.close() + raise + + +@pytest.fixture +def running_server(tmp_path): + servers = [] + + def start(**kwargs): + server = MapServer( + host="127.0.0.1", + port=0, + runs_dir=str(tmp_path), + **kwargs, + ) + server.start() + servers.append(server) + port = server._server.server_address[1] + return server, f"http://127.0.0.1:{port}" + + yield start + + for server in servers: + server.stop() + + +def test_map_server_serves_runs_events_metadata_and_map_file(tmp_path, running_server): + events = [ + { + "event_type": "init", + "run_id": "run-abc", + "algorithm": "FedAvg", + "config": {"epochs": 1}, + "started_at": "2026-01-01T00:00:00+00:00", + }, + { + "event_type": "round_end", + "round": 1, + "round_metrics": {"global_accuracy": 0.9}, + "clients": [{"client_id": "client-1", "local_accuracy": 0.8}], + }, + ] + write_jsonl(tmp_path / "run-abc.jsonl", events) + map_file = tmp_path / "viewer.html" + map_file.write_text("custom map", encoding="utf-8") + + _, base_url = running_server(map_path=str(map_file)) + + runs = get_json(base_url, "/runs") + assert runs == [ + { + "run_id": "run-abc", + "algorithm": "FedAvg", + "started_at": "2026-01-01T00:00:00+00:00", + "num_events": 2, + "file": "run-abc.jsonl", + "metadata_file": "run-abc.map.json", + "has_metadata": False, + } + ] + assert get_json(base_url, "/runs/run-abc/events") == events + metadata = get_json(base_url, "/runs/run-abc/metadata") + assert metadata["run_id"] == "run-abc" + assert metadata["rounds"][0]["clients"][0]["client_id"] == "client-1" + + response, body = get_text_response(base_url, "/map") + assert body == "custom map" + assert response.headers["Cache-Control"].startswith("no-store") + assert response.headers["Access-Control-Allow-Origin"] == "*" + + +def test_map_server_prefers_prebuilt_metadata_file(tmp_path, running_server): + write_jsonl( + tmp_path / "run-abc.jsonl", + [{"event_type": "init", "run_id": "run-abc", "algorithm": "FedAvg"}], + ) + metadata = {"schema_version": 99, "rounds": [{"round": 7}]} + (tmp_path / "run-abc.map.json").write_text(json.dumps(metadata), encoding="utf-8") + + _, base_url = running_server() + + assert get_json(base_url, "/runs/run-abc/metadata") == metadata + + +def test_map_server_returns_404_for_missing_run(tmp_path, running_server): + _, base_url = running_server() + + with pytest.raises(urllib.error.HTTPError) as excinfo: + urllib.request.urlopen(base_url + "/runs/missing/events", timeout=2) + + assert excinfo.value.code == 404 + + +def test_map_server_watch_mode_publishes_new_jsonl_events(tmp_path, running_server): + server, _ = running_server(watch=True, poll_interval=0.01) + subscriber = server._subscribe() + event = {"event_type": "init", "run_id": "run-live", "algorithm": "FedAvg"} + + write_jsonl(tmp_path / "run-live.jsonl", [event]) + + message = subscriber.get(timeout=2) + assert message == f"data: {json.dumps(event)}\n\n" + assert server._live_run_id == "run-live" diff --git a/tests/schema_test.py b/tests/schema_test.py new file mode 100644 index 0000000..61c2bdc --- /dev/null +++ b/tests/schema_test.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from hivewatch.schema import ClientUpdate + + +def test_client_update_preserves_unknown_metadata_fields(): + client = ClientUpdate.from_dict( + { + "client_id": "client-1", + "round": 4, + "local_accuracy": 0.8, + "current_local_steps": 200, + "blocking": True, + "_hidden_acc": 0.99, + } + ) + + assert client.extra == { + "current_local_steps": 200, + "blocking": True, + "_hidden_acc": 0.99, + } + assert client.to_dict()["current_local_steps"] == 200 + assert client.to_dict()["blocking"] is True + assert client.to_dict()["_hidden_acc"] == 0.99 + + +def test_client_update_staleness_is_derived_when_base_round_is_known(): + client = ClientUpdate(client_id="client-1", round=5, base_round=2) + + assert client.staleness == 3 diff --git a/tests/sse_emitter_test.py b/tests/sse_emitter_test.py new file mode 100644 index 0000000..146530e --- /dev/null +++ b/tests/sse_emitter_test.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import json + +from hivewatch.emitters.sse_emitter import SSEEmitter +from hivewatch.schema import ClientUpdate, RoundSummary + + +def read_jsonl(path): + return [json.loads(line) for line in path.read_text(encoding="utf-8").splitlines()] + + +def test_sse_emitter_persists_jsonl_and_map_metadata_with_custom_client_fields(tmp_path): + emitter = SSEEmitter(runs_dir=str(tmp_path), serve_map=False) + emitter.on_init("run-custom", "FedAvg", {"epochs": 3}) + emitter.on_server_metadata({"lat": 41.7, "lng": -87.9, "city": "Chicago"}) + + client = ClientUpdate.from_dict( + { + "client_id": "client-1", + "round": 1, + "local_accuracy": 0.71, + "local_loss": 0.33, + "num_samples": 64, + "lat": 1.3, + "lng": 103.8, + "city": "Singapore", + "country": "SG", + "current_local_steps": 200, + "blocking": True, + "_hidden_acc": 0.99, + } + ) + emitter.on_client_update(client) + emitter.on_round( + RoundSummary( + round=1, + global_accuracy=0.8, + global_loss=0.2, + num_selected=1, + num_completed=1, + round_duration_sec=2.5, + gradient_divergence=0.0, + ), + [client], + ) + emitter.finish() + + events = read_jsonl(tmp_path / "run-custom.jsonl") + assert [event["event_type"] for event in events] == [ + "init", + "server_metadata", + "client_update", + "round_end", + "finished", + ] + client_update = events[2]["clients"][0] + assert client_update["current_local_steps"] == 200 + assert client_update["blocking"] is True + assert client_update["_hidden_acc"] == 0.99 + + metadata = json.loads((tmp_path / "run-custom.map.json").read_text(encoding="utf-8")) + assert metadata["server"]["city"] == "Chicago" + assert metadata["finished_at"] == events[-1]["timestamp"] + round_state = metadata["rounds"][0] + assert round_state["globalAcc"] == 0.8 + assert round_state["duration"] == 2.5 + map_client = round_state["clients"][0] + assert map_client["current_local_steps"] == 200 + assert map_client["blocking"] is True + assert map_client["_hidden_acc"] == 0.99 + + +def test_sse_emitter_merges_dropout_and_comm_failure_into_map_metadata(tmp_path): + emitter = SSEEmitter(runs_dir=str(tmp_path), serve_map=False) + emitter.on_init("run-status", "FedAvg", {}) + + emitter.on_dropout(2, "client-drop", "timeout") + emitter.on_comm_failure(2, "client-fail", "socket closed") + emitter.finish() + + metadata = json.loads((tmp_path / "run-status.map.json").read_text(encoding="utf-8")) + clients = { + client["client_id"]: client + for client in metadata["rounds"][0]["clients"] + } + assert clients["client-drop"]["status"] == "dropped" + assert clients["client-fail"]["status"] == "failed"