diff --git a/claude/lightcone/hooks.json b/claude/lightcone/hooks.json index 46eccfd9..1560e86a 100644 --- a/claude/lightcone/hooks.json +++ b/claude/lightcone/hooks.json @@ -15,6 +15,17 @@ ] } ], + "SessionEnd": [ + { + "hooks": [ + { + "type": "command", + "command": "bash ${CLAUDE_PROJECT_DIR}/.claude/scripts/session-end.sh", + "timeout": 10 + } + ] + } + ], "PostToolUse": [ { "matcher": "Write|Edit", diff --git a/claude/lightcone/scripts/session-end.sh b/claude/lightcone/scripts/session-end.sh new file mode 100755 index 00000000..2c0a44d7 --- /dev/null +++ b/claude/lightcone/scripts/session-end.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# SessionEnd hook: stop the project's session-scoped Dask scheduler. +# +# Best-effort and silent. The scheduler self-shuts on idle-timeout +# (see lightcone.engine.dask_daemon) so failure to fire here only +# delays cleanup; it does not leak resources indefinitely. + +input=$(cat) +cwd=$(echo "$input" | jq -r '.cwd // empty') + +[ -z "$cwd" ] && exit 0 +cd "$cwd" 2>/dev/null || exit 0 +[ -f "astra.yaml" ] || exit 0 +command -v lc &>/dev/null || exit 0 + +lc dask stop >/dev/null 2>&1 || true +exit 0 diff --git a/src/lightcone/cli/commands.py b/src/lightcone/cli/commands.py index 1de8bd78..d994d57a 100644 --- a/src/lightcone/cli/commands.py +++ b/src/lightcone/cli/commands.py @@ -373,9 +373,10 @@ def run( ) -> None: """Materialize outputs declared in astra.yaml. - Always dispatches through a Dask cluster: a ``LocalCluster`` on a - workstation, srun-launched workers inside a SLURM allocation, or an - existing scheduler if ``DASK_SCHEDULER_ADDRESS`` is set. + Always dispatches through a Dask cluster: the session-scoped + scheduler (spawned on first run, reused thereafter), or an existing + one if ``DASK_SCHEDULER_ADDRESS`` is set. See + :mod:`lightcone.engine.dask_daemon` for the lifecycle. """ _abort_on_perlmutter_login() @@ -474,9 +475,7 @@ def run( except RunLockBusyError as e: raise click.ClickException(str(e)) - with cluster_for_run( - verbose=verbose, local_directory=str(rundirs.dask_local) - ) as scheduler_addr: + with cluster_for_run(project_path=project, verbose=verbose) as scheduler_addr: env = { **os.environ, "DASK_SCHEDULER_ADDRESS": scheduler_addr, @@ -817,6 +816,32 @@ def _ensure_images(project: Path, *, runtime: str, force: bool = False) -> None: raise click.ClickException(str(e)) +# ============================================================================= +# lc dask +# ============================================================================= + + +@main.group() +def dask() -> None: + """Manage the session-scoped Dask scheduler.""" + + +@dask.command("stop") +def dask_stop() -> None: + """Shut down the session-scoped Dask scheduler for this project. + + Best-effort: silent when no scheduler is running. Wired to the + SessionEnd Claude Code hook so closing a session frees the + scheduler's resources promptly; otherwise the scheduler self-shuts + after its idle timeout. + """ + from lightcone.engine.dask_daemon import stop + + project = _project_root() + if stop(project): + console.print("[dim]Sent SIGTERM to Dask scheduler.[/dim]") + + # Register eval subgroup (requires optional 'eval' extra) try: from lightcone.eval.cli import eval_group diff --git a/src/lightcone/engine/dask_cluster.py b/src/lightcone/engine/dask_cluster.py index 03049e5b..047c3539 100644 --- a/src/lightcone/engine/dask_cluster.py +++ b/src/lightcone/engine/dask_cluster.py @@ -1,35 +1,33 @@ # mypy: disable-error-code="no-untyped-call" -"""Cluster lifecycle for ``lc run``. +"""Cluster connection point for ``lc run``. -One context manager, three branches: +Two branches: -- ``DASK_SCHEDULER_ADDRESS`` is already set → yield it as-is. We don't own - the cluster, so we don't tear it down. -- ``SLURM_JOB_ID`` is set → start an in-process scheduler via - ``LocalCluster(n_workers=0)``, then ``srun`` one ``dask worker`` per node - across the allocation. Workers advertise the node's full resources; - per-rule ``threads`` / ``mem_mb`` / ``gpus`` map to per-task constraints. -- Neither → ``LocalCluster()`` sized to the local machine. +- ``DASK_SCHEDULER_ADDRESS`` set → use it as-is. We don't own the + cluster, so we don't tear it down. Kept as the escape hatch / CI + override / "user has their own scheduler" path. +- Otherwise → :func:`lightcone.engine.dask_daemon.ensure_scheduler` + returns the address of a session-scoped scheduler (spawning the + daemon if needed). The daemon outlives any single ``lc run`` so + successive runs in the same Claude session reuse it. -The scheduler is always in-process (driven by ``lc run`` itself) so its -lifetime equals the run's lifetime — no service to manage, no orphaned -schedulers if the driver crashes. +The node-shape helpers and resource-key constants live here because +both the daemon (when it spins up a cluster) and the executor plugin +(when it requests per-task resources) consume them. """ from __future__ import annotations -import logging import os -import shutil -import socket -import subprocess from collections.abc import Iterator from contextlib import contextmanager from dataclasses import dataclass +from pathlib import Path # Resource keys advertised by workers and requested per-task. These strings -# form a contract between the worker bootstrap (here) and the executor plugin -# (snakemake_executor_plugin_dask.executor). Dask matches by string equality. +# form a contract between the worker bootstrap (in :mod:`dask_daemon`) and +# the executor plugin (snakemake_executor_plugin_dask.executor). Dask +# matches by string equality. RESOURCE_CPUS = "cpus" RESOURCE_MEMORY = "memory" RESOURCE_GPUS = "gpus" @@ -67,9 +65,9 @@ def _resource_dict(shape: _NodeShape) -> dict[str, float]: """Resource keys advertised by a worker for this node shape. Single source of truth for which keys workers expose — both the - in-process LocalCluster and the srun-launched ``dask worker``s - advertise the same set so the executor's per-task requests resolve - on either path. + laptop ``LocalCluster`` and the srun-launched ``dask worker``s + advertise the same set so the executor's per-task requests + resolve on either path. """ res: dict[str, float] = {RESOURCE_CPUS: float(shape.cpus)} if shape.mem_bytes: @@ -80,23 +78,21 @@ def _resource_dict(shape: _NodeShape) -> dict[str, float]: def _resources_arg(shape: _NodeShape) -> str: - """Format `--resources` for `dask worker`.""" + """Format ``--resources`` for ``dask worker``.""" return " ".join(f"{k}={int(v)}" for k, v in _resource_dict(shape).items()) @contextmanager def cluster_for_run( *, + project_path: Path, verbose: bool = False, - local_directory: str | None = None, ) -> Iterator[str]: - """Yield a Dask scheduler address valid for the duration of `lc run`. + """Yield a Dask scheduler address for the duration of one ``lc run``. - *local_directory*, when given, is where dask workers stage their - spilled task data and internal state files. ``lc run`` resolves it - to a path under :mod:`lightcone.engine.scratch` so on NERSC the - spill lands on Lustre instead of DVS-mounted home/CFS (where small- - file I/O is slow and can pressure the gateway nodes). + The scheduler outlives the run — see :mod:`lightcone.engine.dask_daemon`. + No teardown happens here; the daemon self-shuts on idle-timeout, or + on SIGTERM from the SessionEnd hook. """ if addr := os.environ.get("DASK_SCHEDULER_ADDRESS"): if verbose: @@ -104,149 +100,9 @@ def cluster_for_run( yield addr return - if "SLURM_JOB_ID" in os.environ: - with _slurm_backed_cluster( - verbose=verbose, local_directory=local_directory - ) as addr: - yield addr - return - - with _local_cluster( - verbose=verbose, local_directory=local_directory - ) as addr: - yield addr - - -@contextmanager -def _local_cluster( - *, verbose: bool, local_directory: str | None -) -> Iterator[str]: - from dask.distributed import LocalCluster - - shape = _detect_node_shape() - # Workers must advertise every key the executor may request — Dask - # matches by exact key presence — or rules with ``mem_mb`` / - # ``gpus_per_task`` would never schedule on a workstation. - cluster = LocalCluster( - n_workers=1, - threads_per_worker=shape.cpus, - resources=_resource_dict(shape), - dashboard_address=":0", - local_directory=local_directory, - silence_logs=logging.INFO if verbose else logging.WARNING, - ) - if verbose: - print( - f"→ Local Dask cluster ({shape.cpus} threads); " - f"scheduler at {cluster.scheduler_address}" - ) - try: - yield cluster.scheduler_address - finally: - cluster.close() - - -@contextmanager -def _slurm_backed_cluster( - *, verbose: bool, local_directory: str | None -) -> Iterator[str]: - from dask.distributed import LocalCluster - - if shutil.which("dask") is None: - raise RuntimeError( - "`dask` CLI is not on PATH inside the SLURM allocation. " - "Install lightcone-cli (and its `distributed` dep) into the " - "environment activated by your sbatch/salloc." - ) - - shape = _detect_node_shape() - nnodes = int(os.environ.get("SLURM_NNODES") or 1) - - # Default LocalCluster binds the scheduler to 127.0.0.1, which workers - # on remote nodes cannot reach. Bind to the driver's hostname so srun- - # launched workers across the allocation can connect. SLURMD_NODENAME - # is the SLURM-canonical name; gethostname() is a sane fallback. - scheduler_host = os.environ.get("SLURMD_NODENAME") or socket.gethostname() - cluster = LocalCluster( - n_workers=0, - host=scheduler_host, - dashboard_address=":0", - local_directory=local_directory, - silence_logs=logging.INFO if verbose else logging.WARNING, - ) - addr = cluster.scheduler_address + from lightcone.engine.dask_daemon import ensure_scheduler + addr = ensure_scheduler(project_path) if verbose: - print( - f"→ SLURM allocation detected ({nnodes} node(s), " - f"{shape.cpus} cpu/node, {shape.gpus} gpu/node); " - f"launching workers via srun. Scheduler: {addr}" - ) - - worker_cmd = [ - "srun", - f"--ntasks={nnodes}", - "--ntasks-per-node=1", - "dask", - "worker", - addr, - "--nthreads", - str(shape.cpus), - "--nworkers", - "1", - "--resources", - _resources_arg(shape), - "--no-dashboard", - # Each srun task is a single run-scoped worker; an auto-restart - # nanny adds no value (srun won't relaunch the task either) and - # logs "Worker process died unexpectedly" when retire_workers - # asks the worker to exit on shutdown. - "--no-nanny", - ] - if local_directory: - worker_cmd.extend(["--local-directory", local_directory]) - # Hide the worker's INFO-level connection chatter (Nanny start, - # scheduler registration, etc.) — useful only when debugging the - # cluster itself. WARNING+ still surface real issues. The newer - # `dask worker` CLI dropped `--silence-logs`, so we drive it via - # Dask's config env var instead; srun inherits env by default. - worker_env = dict(os.environ) - if not verbose: - worker_env.setdefault("DASK_LOGGING__DISTRIBUTED", "warning") - workers = subprocess.Popen(worker_cmd, env=worker_env) - - try: - from dask.distributed import Client - - client = Client(addr) - try: - client.wait_for_workers(n_workers=nnodes, timeout=120) - if verbose: - print(f"→ {nnodes} dask worker(s) registered.") - finally: - client.close() - yield addr - finally: - # Graceful shutdown: ask the scheduler to retire workers so each - # `dask worker` process exits on its own. srun then sees its task - # exit with code 0 and terminates silently. SIGTERM-ing srun - # directly (the prior path) prints "srun: forcing job - # termination" / "task 0: Killed" to stderr on every clean run. - try: - client = Client(addr, timeout="10s") - try: - client.retire_workers(close_workers=True, remove=True) - finally: - client.close() - except Exception: - pass - try: - workers.wait(timeout=20) - except subprocess.TimeoutExpired: - workers.terminate() - try: - workers.wait(timeout=10) - except subprocess.TimeoutExpired: - workers.kill() - workers.wait() - cluster.close() + print(f"→ Reusing session scheduler at {addr}") + yield addr diff --git a/src/lightcone/engine/dask_daemon.py b/src/lightcone/engine/dask_daemon.py new file mode 100644 index 00000000..474f9ed0 --- /dev/null +++ b/src/lightcone/engine/dask_daemon.py @@ -0,0 +1,462 @@ +# mypy: disable-error-code="no-untyped-call" +"""Session-scoped Dask scheduler. + +One scheduler per *execution context*. On a workstation that's the +project; inside a SLURM allocation it's the allocation (so workers +spawned via ``srun`` are reused across every ``lc run`` in the +allocation, not respawned each time). The key: + + slurm- if SLURM_JOB_ID is set + otherwise + +Storage is per-key under the resolved scratch root:: + + /.lightcone/dask-scheduler// + ├── owner.lock # flock'd by the daemon for its lifetime + ├── spawn.lock # serializes concurrent ensure() racers + ├── scheduler.json # Dask's native scheduler-file (address, …) + ├── meta.json # {pid, host, started_at, mode, …} + ├── scheduler.log # daemon stdout+stderr (detached) + └── spill/ # worker spill / local-directory + +Crash safety rests on a single primitive: ``flock`` is released by the +kernel when the holding process dies (clean exit, crash, or SIGKILL). +Liveness is therefore probed by trying to acquire the lock — never by +PID file or heartbeat. + +If everything else fails, the scheduler self-shuts after ``IDLE_TIMEOUT`` +of inactivity (Dask's built-in ``Scheduler.idle_timeout``); a stale +``scheduler.json`` from a SIGKILL'd daemon is detected by ``ensure``'s +TCP probe and replaced. The SessionEnd hook calls :func:`stop` for +prompt cleanup; idle-timeout is the safety net. +""" +from __future__ import annotations + +import argparse +import fcntl +import json +import logging +import os +import shutil +import signal +import socket +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +from lightcone.engine.scratch import project_hash, resolve_scratch_root + +#: How long the scheduler tolerates being idle before self-shutting. +#: Tuned to span "user steps away mid-conversation" without lingering +#: forever on an abandoned project. +IDLE_TIMEOUT = "30 minutes" + +#: Cap on how long ``ensure_scheduler`` waits for the daemon to come up +#: and write ``scheduler.json``. Local boot is sub-second; SLURM boot +#: includes ``srun`` worker registration which dominates. +SPAWN_WAIT_SECONDS = 60 + +#: Probe budget for an existing scheduler. Two seconds is generous for +#: TCP connect on localhost or an HPC fabric, and short enough that a +#: dead-but-still-listed scheduler doesn't stall ``lc run``. +PROBE_TIMEOUT_SECONDS = 2.0 + + +@dataclass(frozen=True) +class SchedulerDirs: + """Per-key paths for a session-scoped scheduler.""" + + root: Path + owner_lock: Path + spawn_lock: Path + scheduler_file: Path + meta_file: Path + log_file: Path + spill: Path + + +def scheduler_key(project_path: Path) -> str: + """Identity for the scheduler's lifecycle scope. + + Inside a SLURM allocation every ``lc run`` for any project shares + one scheduler keyed by ``SLURM_JOB_ID`` — workers spawned via + ``srun`` are tied to the allocation and outlive a single run. On a + laptop the natural unit is the project. + """ + if jid := os.environ.get("SLURM_JOB_ID"): + return f"slurm-{jid}" + return project_hash(project_path) + + +def scheduler_dirs(project_path: Path) -> SchedulerDirs: + """Resolve and create the per-key scheduler directory.""" + root = ( + resolve_scratch_root(project_path) + / ".lightcone" + / "dask-scheduler" + / scheduler_key(project_path) + ) + spill = root / "spill" + for d in (root, spill): + d.mkdir(parents=True, exist_ok=True) + return SchedulerDirs( + root=root, + owner_lock=root / "owner.lock", + spawn_lock=root / "spawn.lock", + scheduler_file=root / "scheduler.json", + meta_file=root / "meta.json", + log_file=root / "scheduler.log", + spill=spill, + ) + + +# --------------------------------------------------------------------------- +# Public API: ensure / stop +# --------------------------------------------------------------------------- + + +def ensure_scheduler(project_path: Path) -> str: + """Return the address of a live session-scoped scheduler. + + Connects to an existing one if present, otherwise spawns a detached + daemon and waits for it to come up. Idempotent and concurrent-safe: + multiple callers race through ``spawn.lock`` and converge on the + same scheduler. + """ + dirs = scheduler_dirs(project_path) + + if (addr := _read_address(dirs.scheduler_file)) and _probe(addr): + return addr + + # Slow path. Serialize spawn races on a separate flock so a + # second ensure() doesn't double-spawn while the first is still + # waiting for scheduler.json. + spawn_fd = os.open(dirs.spawn_lock, os.O_RDWR | os.O_CREAT, 0o644) + try: + fcntl.flock(spawn_fd, fcntl.LOCK_EX) + + # Re-probe: another caller may have spawned while we waited. + if (addr := _read_address(dirs.scheduler_file)) and _probe(addr): + return addr + + # Stale scheduler.json (daemon died) — clean before spawning so + # a partial-state read can't return a dead address. + for f in (dirs.scheduler_file, dirs.meta_file): + f.unlink(missing_ok=True) + + _spawn_daemon(project_path, dirs) + + deadline = time.monotonic() + SPAWN_WAIT_SECONDS + while time.monotonic() < deadline: + if (addr := _read_address(dirs.scheduler_file)) and _probe(addr): + return addr + time.sleep(0.2) + + raise RuntimeError( + f"Dask scheduler did not come up within {SPAWN_WAIT_SECONDS}s. " + f"See {dirs.log_file} for daemon output." + ) + finally: + os.close(spawn_fd) + + +def stop(project_path: Path) -> bool: + """Best-effort SIGTERM the running scheduler. Returns True if signalled. + + Quiet on every "nothing to stop" path: no meta file, malformed + meta, dead PID, foreign PID. The SessionEnd hook calls this without + caring about the result. + """ + dirs = scheduler_dirs(project_path) + try: + meta = json.loads(dirs.meta_file.read_text()) + pid = int(meta["pid"]) + except (FileNotFoundError, json.JSONDecodeError, KeyError, ValueError): + return False + try: + os.kill(pid, signal.SIGTERM) + except (ProcessLookupError, PermissionError): + return False + return True + + +# --------------------------------------------------------------------------- +# Internals +# --------------------------------------------------------------------------- + + +def _read_address(scheduler_file: Path) -> str | None: + try: + return str(json.loads(scheduler_file.read_text())["address"]) + except (FileNotFoundError, json.JSONDecodeError, KeyError, OSError): + return None + + +def _probe(addr: str) -> bool: + """Cheap liveness probe: TCP connect to scheduler host:port. + + Avoids ``Client(addr)`` since that brings up an event loop for what + should be a sub-second decision. A dead scheduler whose + ``scheduler.json`` survived a crash will fail to connect; a live + one accepts immediately. + """ + try: + u = urlparse(addr) + if not u.hostname or not u.port: + return False + with socket.create_connection( + (u.hostname, u.port), timeout=PROBE_TIMEOUT_SECONDS + ): + return True + except (OSError, ValueError): + return False + + +def _spawn_daemon(project_path: Path, dirs: SchedulerDirs) -> None: + """Detach a daemon process running ``python -m`` this module.""" + log = open(dirs.log_file, "ab", buffering=0) + try: + subprocess.Popen( + [ + sys.executable, + "-m", + "lightcone.engine.dask_daemon", + "--project", + str(project_path), + ], + stdin=subprocess.DEVNULL, + stdout=log, + stderr=log, + start_new_session=True, + close_fds=True, + ) + finally: + log.close() + + +# --------------------------------------------------------------------------- +# Daemon entrypoint (``python -m lightcone.engine.dask_daemon``) +# --------------------------------------------------------------------------- + + +def _serve(project_path: Path) -> int: + """Run the long-lived scheduler. Exit 0 when it's gone.""" + dirs = scheduler_dirs(project_path) + + # Dedup: if another daemon already holds the lock, exit silently. + # This makes Popen-twice safe — only one daemon ever runs per key. + fd = os.open(dirs.owner_lock, os.O_RDWR | os.O_CREAT, 0o644) + try: + fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError: + os.close(fd) + return 0 + + in_slurm = "SLURM_JOB_ID" in os.environ + + # Write meta.json *before* starting the cluster so that as soon as + # ``ensure`` sees ``scheduler.json`` (which Dask writes during + # cluster construction), ``stop`` can already find the PID. Without + # this ordering there's a window where ``ensure`` returns but + # ``stop`` is silently a no-op. The address lives in + # ``scheduler.json`` — no need to duplicate it here. + dirs.meta_file.write_text( + json.dumps( + { + "pid": os.getpid(), + "host": socket.gethostname(), + "started_at": time.time(), + "key": scheduler_key(project_path), + "mode": "slurm" if in_slurm else "local", + } + ) + ) + + cluster, workers = ( + _start_slurm_cluster(dirs) if in_slurm else _start_local_cluster(dirs) + ) + + try: + _block_until_done(cluster) + finally: + _shutdown(cluster, workers, dirs) + + return 0 + + +def _start_local_cluster( + dirs: SchedulerDirs, +) -> tuple[Any, subprocess.Popen[bytes] | None]: + """LocalCluster sized to the host machine.""" + from dask.distributed import LocalCluster + + from lightcone.engine.dask_cluster import _detect_node_shape, _resource_dict + + shape = _detect_node_shape() + cluster = LocalCluster( + n_workers=1, + threads_per_worker=shape.cpus, + resources=_resource_dict(shape), + dashboard_address=":0", + local_directory=str(dirs.spill), + scheduler_kwargs={ + "idle_timeout": IDLE_TIMEOUT, + "scheduler_file": str(dirs.scheduler_file), + }, + silence_logs=logging.WARNING, + ) + return cluster, None + + +def _start_slurm_cluster( + dirs: SchedulerDirs, +) -> tuple[Any, subprocess.Popen[bytes] | None]: + """In-process scheduler + one ``srun``-launched worker per node.""" + from dask.distributed import Client, LocalCluster + + from lightcone.engine.dask_cluster import ( + _detect_node_shape, + _resources_arg, + ) + + if shutil.which("dask") is None: + raise RuntimeError( + "`dask` CLI is not on PATH inside the SLURM allocation. " + "Install lightcone-cli (and its `distributed` dep) into " + "the environment activated by your sbatch/salloc." + ) + + shape = _detect_node_shape() + nnodes = int(os.environ.get("SLURM_NNODES") or 1) + + # Bind the scheduler to a hostname workers on remote nodes can + # reach. Default 127.0.0.1 silently fails wait_for_workers. + scheduler_host = os.environ.get("SLURMD_NODENAME") or socket.gethostname() + cluster = LocalCluster( + n_workers=0, + host=scheduler_host, + dashboard_address=":0", + local_directory=str(dirs.spill), + scheduler_kwargs={ + "idle_timeout": IDLE_TIMEOUT, + "scheduler_file": str(dirs.scheduler_file), + }, + silence_logs=logging.WARNING, + ) + addr = cluster.scheduler_address + + worker_cmd = [ + "srun", + f"--ntasks={nnodes}", + "--ntasks-per-node=1", + "dask", + "worker", + addr, + "--nthreads", + str(shape.cpus), + "--nworkers", + "1", + "--resources", + _resources_arg(shape), + "--no-dashboard", + "--no-nanny", + "--local-directory", + str(dirs.spill), + ] + worker_env = dict(os.environ) + worker_env.setdefault("DASK_LOGGING__DISTRIBUTED", "warning") + workers = subprocess.Popen(worker_cmd, env=worker_env) + + client = Client(addr) + try: + client.wait_for_workers(n_workers=nnodes, timeout=120) + finally: + client.close() + return cluster, workers + + +def _block_until_done(cluster: Any) -> None: + """Sleep until the cluster shuts itself down or we're SIGTERM'd.""" + stopping = {"flag": False} + + def _on_term(signum: int, frame: object) -> None: + stopping["flag"] = True + + signal.signal(signal.SIGTERM, _on_term) + signal.signal(signal.SIGINT, _on_term) + + # Idle-timeout closes the scheduler from inside Dask; we observe it + # via ``cluster.status``. Polling at 2s is invisible in human-time + # and bounded against a 30-minute idle window. + while not stopping["flag"]: + status = getattr(cluster, "status", None) + if status is not None and str(status).rsplit(".", maxsplit=1)[-1] != "running": + break + time.sleep(2) + + +def _shutdown( + cluster: Any, + workers: subprocess.Popen[bytes] | None, + dirs: SchedulerDirs, +) -> None: + """Gracefully retire workers, close the cluster, clean up files.""" + from dask.distributed import Client + + address = getattr(cluster, "scheduler_address", None) + if workers is not None and address: + # Ask the scheduler to retire workers so each ``dask worker`` + # exits cleanly. SIGTERM-ing srun directly prints noisy + # "forcing job termination" lines on every clean shutdown. + try: + client = Client(address, timeout="10s") + try: + client.retire_workers(close_workers=True, remove=True) + finally: + client.close() + except Exception: + pass + try: + workers.wait(timeout=20) + except subprocess.TimeoutExpired: + workers.terminate() + try: + workers.wait(timeout=10) + except subprocess.TimeoutExpired: + workers.kill() + workers.wait() + + try: + cluster.close() + except Exception: + pass + + # Best-effort cleanup. Stale files don't affect correctness — the + # next ensure() probes liveness and rewrites — but keeping the dir + # tidy is a UX nicety. + for f in (dirs.scheduler_file, dirs.meta_file): + try: + f.unlink(missing_ok=True) + except OSError: + pass + + +def _main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + prog="python -m lightcone.engine.dask_daemon", + description="Run a session-scoped Dask scheduler. Internal entrypoint; " + "users invoke this indirectly via `lc run` and `lc dask stop`.", + ) + parser.add_argument( + "--project", required=True, type=Path, help="Project root path." + ) + args = parser.parse_args(argv) + return _serve(args.project) + + +if __name__ == "__main__": + raise SystemExit(_main()) diff --git a/src/lightcone/engine/scratch.py b/src/lightcone/engine/scratch.py index f61bc73d..af398b27 100644 --- a/src/lightcone/engine/scratch.py +++ b/src/lightcone/engine/scratch.py @@ -46,7 +46,6 @@ class RunDirs: root: Path # ``/.lightcone`` snakemake_state: Path # ``/.lightcone/snakemake//.snakemake`` - dask_local: Path # ``/.lightcone/dask/`` lock_path: Path # ``/.lightcone/locks/.lock`` # Project-level sentinel for the run-exclusion flock. Held for the # duration of one ``lc run`` to prevent concurrent invocations on @@ -98,19 +97,21 @@ def prepare_run_dirs(project_path: Path, *, run_id: str | None = None) -> RunDir """Create and return per-run scratch sub-directories. *run_id* defaults to the current PID — unique per ``lc run`` - invocation, easily mappable to a process for debugging. Lock and - dask-local dirs are run-scoped (cleaned per invocation); snakemake - state is project-scoped (persistent across invocations). + invocation, easily mappable to a process for debugging. The + cross-node stdout lock is run-scoped (cleaned per invocation); + snakemake state and the project run-lock are project-scoped + (persistent across invocations). Dask-cluster spill lives under + :mod:`lightcone.engine.dask_daemon`'s scheduler-keyed dir, which + outlives any single run. """ scratch = resolve_scratch_root(project_path) root = scratch / ".lightcone" rid = run_id or str(os.getpid()) pkey = project_hash(project_path) snakemake_state = root / "snakemake" / pkey / ".snakemake" - dask_local = root / "dask" / rid lock_path = root / "locks" / f"{rid}.lock" run_lock_path = root / "locks" / f"{pkey}.run-lock" - for d in (root, snakemake_state.parent, dask_local, lock_path.parent): + for d in (root, snakemake_state.parent, lock_path.parent): d.mkdir(parents=True, exist_ok=True) # Touch lockfiles so workers can ``flock`` them without racing on # ``O_CREAT``. Empty file is fine — flock is independent of contents. @@ -119,7 +120,6 @@ def prepare_run_dirs(project_path: Path, *, run_id: str | None = None) -> RunDir return RunDirs( root=root, snakemake_state=snakemake_state, - dask_local=dask_local, lock_path=lock_path, run_lock_path=run_lock_path, ) diff --git a/tests/test_dask_cluster.py b/tests/test_dask_cluster.py index 4bb6b473..8f6a3c7c 100644 --- a/tests/test_dask_cluster.py +++ b/tests/test_dask_cluster.py @@ -1,15 +1,18 @@ -"""Unit tests for the cluster bootstrap. +"""Tests for the cluster-connection helpers. -We test the routing decision (which branch fires given env vars) and the -node-shape detection. The actual `LocalCluster` spin-up is exercised in a -single smoke test; the `srun`-backed path is mocked because real -multi-node testing requires SLURM. +Two surfaces: + +- node-shape detection and resource-key formatting (pure functions + consumed by both the daemon and the executor plugin), +- ``cluster_for_run``'s routing: explicit ``DASK_SCHEDULER_ADDRESS`` + vs. session-scoped scheduler via :mod:`dask_daemon`. + +The daemon's own behavior is exercised in ``test_dask_daemon.py``. """ from __future__ import annotations -from contextlib import contextmanager -from unittest.mock import patch +from pathlib import Path import pytest @@ -19,6 +22,7 @@ RESOURCE_MEMORY, _detect_node_shape, _NodeShape, + _resource_dict, _resources_arg, cluster_for_run, ) @@ -37,6 +41,9 @@ def _clean_env(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv(var, raising=False) +# ---- node shape ---------------------------------------------------------- + + def test_detect_shape_falls_back_to_os(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr("os.cpu_count", lambda: 8) shape = _detect_node_shape() @@ -54,6 +61,16 @@ def test_detect_shape_reads_slurm_env(monkeypatch: pytest.MonkeyPatch) -> None: assert shape.gpus == 4 +def test_resource_dict_minimal() -> None: + res = _resource_dict(_NodeShape(cpus=8, mem_bytes=0, gpus=0)) + assert res == {RESOURCE_CPUS: 8.0} + + +def test_resource_dict_full() -> None: + res = _resource_dict(_NodeShape(cpus=64, mem_bytes=256_000_000_000, gpus=4)) + assert set(res.keys()) == {RESOURCE_CPUS, RESOURCE_MEMORY, RESOURCE_GPUS} + + def test_resources_arg_minimal() -> None: arg = _resources_arg(_NodeShape(cpus=8, mem_bytes=0, gpus=0)) assert arg == "cpus=8" @@ -64,228 +81,43 @@ def test_resources_arg_full() -> None: assert arg == "cpus=64 memory=256000000000 gpus=4" -def test_existing_scheduler_address_yields_unchanged( - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setenv("DASK_SCHEDULER_ADDRESS", "tcp://example:8786") - - with cluster_for_run() as addr: - assert addr == "tcp://example:8786" - - -def test_no_env_uses_local_cluster() -> None: - """The local-cluster branch should actually start a (tiny) cluster.""" - sentinel: dict[str, str] = {} - - @contextmanager - def _fake_local(*, verbose: bool, local_directory: str | None = None): - sentinel["called"] = "local" - yield "tcp://stub:9999" - - with patch("lightcone.engine.dask_cluster._local_cluster", _fake_local): - with cluster_for_run() as addr: - assert addr == "tcp://stub:9999" - assert sentinel["called"] == "local" - +# ---- cluster_for_run routing -------------------------------------------- -def test_slurm_env_takes_slurm_path(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("SLURM_JOB_ID", "12345") - sentinel: dict[str, str] = {} - @contextmanager - def _fake_slurm(*, verbose: bool, local_directory: str | None = None): - sentinel["called"] = "slurm" - yield "tcp://stub:9999" - - with patch("lightcone.engine.dask_cluster._slurm_backed_cluster", _fake_slurm): - with cluster_for_run() as addr: - assert addr == "tcp://stub:9999" - assert sentinel["called"] == "slurm" - - -def test_existing_scheduler_address_wins_over_slurm( - monkeypatch: pytest.MonkeyPatch, +def test_existing_scheduler_address_yields_unchanged( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: - """If both are set, the explicit address takes precedence.""" - monkeypatch.setenv("DASK_SCHEDULER_ADDRESS", "tcp://existing:8786") - monkeypatch.setenv("SLURM_JOB_ID", "12345") - - @contextmanager - def _should_not_run(*, verbose: bool, local_directory: str | None = None): - raise AssertionError("slurm path should not have been taken") - yield # pragma: no cover - - with patch("lightcone.engine.dask_cluster._slurm_backed_cluster", _should_not_run): - with cluster_for_run() as addr: - assert addr == "tcp://existing:8786" + """When the user (or CI) supplies an address, we use it verbatim and + never reach into the daemon — that's the escape hatch the env var is + for, and going through ``ensure`` would fight the user's setup.""" + monkeypatch.setenv("DASK_SCHEDULER_ADDRESS", "tcp://example:8786") + def _should_not_be_called(_: Path) -> str: + raise AssertionError("ensure_scheduler must not run when env is set") -def test_slurm_backed_cluster_binds_to_routable_host( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Multi-node SLURM allocations need the scheduler bound to a hostname - workers on other nodes can reach. The default LocalCluster host of - 127.0.0.1 fails silently with `wait_for_workers` timeouts. - """ - monkeypatch.setenv("SLURM_JOB_ID", "12345") - monkeypatch.setenv("SLURM_NNODES", "2") - monkeypatch.setenv("SLURMD_NODENAME", "nid001234") monkeypatch.setattr( - "lightcone.engine.dask_cluster.shutil.which", lambda _: "/usr/bin/dask" + "lightcone.engine.dask_daemon.ensure_scheduler", _should_not_be_called ) - captured: dict[str, object] = {} - - class _FakeCluster: - def __init__(self, **kwargs: object) -> None: - captured.update(kwargs) - self.scheduler_address = "tcp://nid001234:8786" - - def close(self) -> None: - pass - - class _FakeClient: - def __init__(self, addr: str) -> None: - captured["client_addr"] = addr - - def wait_for_workers(self, n_workers: int, timeout: int) -> None: - pass - - def close(self) -> None: - pass - - class _FakePopen: - def __init__(self, cmd: list[str], **kwargs: object) -> None: - captured["worker_cmd"] = cmd - captured["worker_kwargs"] = kwargs - - def terminate(self) -> None: - pass - - def wait(self, timeout: int | None = None) -> int: - return 0 - - def kill(self) -> None: - pass - - monkeypatch.setattr("dask.distributed.LocalCluster", _FakeCluster) - monkeypatch.setattr("dask.distributed.Client", _FakeClient) - monkeypatch.setattr("subprocess.Popen", _FakePopen) - - from lightcone.engine.dask_cluster import _slurm_backed_cluster - - with _slurm_backed_cluster(verbose=False, local_directory=None) as addr: - assert addr == "tcp://nid001234:8786" - - assert captured.get("host") == "nid001234", ( - f"LocalCluster must be told to bind to the SLURM nodename so remote " - f"workers can connect; got host={captured.get('host')!r}" - ) + with cluster_for_run(project_path=tmp_path) as addr: + assert addr == "tcp://example:8786" -def test_slurm_backed_cluster_falls_back_to_gethostname( - monkeypatch: pytest.MonkeyPatch, +def test_no_env_calls_ensure_scheduler( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: - """Without SLURMD_NODENAME, fall back to socket.gethostname().""" - monkeypatch.setenv("SLURM_JOB_ID", "12345") - monkeypatch.setenv("SLURM_NNODES", "1") - monkeypatch.delenv("SLURMD_NODENAME", raising=False) - monkeypatch.setattr( - "lightcone.engine.dask_cluster.shutil.which", lambda _: "/usr/bin/dask" - ) - monkeypatch.setattr( - "lightcone.engine.dask_cluster.socket.gethostname", lambda: "host-fallback" - ) - - captured: dict[str, object] = {} - - class _FakeCluster: - def __init__(self, **kwargs: object) -> None: - captured.update(kwargs) - self.scheduler_address = "tcp://host-fallback:8786" - - def close(self) -> None: - pass - - class _FakeClient: - def __init__(self, addr: str) -> None: - pass - - def wait_for_workers(self, n_workers: int, timeout: int) -> None: - pass - - def close(self) -> None: - pass - - class _FakePopen: - def __init__(self, cmd: list[str], **kwargs: object) -> None: - pass - - def terminate(self) -> None: - pass + """Default path: cluster_for_run delegates to ensure_scheduler with + the project path so the daemon picks the right scratch dir.""" + seen: dict[str, Path] = {} - def wait(self, timeout: int | None = None) -> int: - return 0 + def _fake_ensure(project: Path) -> str: + seen["project"] = project + return "tcp://stub:9999" - def kill(self) -> None: - pass - - monkeypatch.setattr("dask.distributed.LocalCluster", _FakeCluster) - monkeypatch.setattr("dask.distributed.Client", _FakeClient) - monkeypatch.setattr("subprocess.Popen", _FakePopen) - - from lightcone.engine.dask_cluster import _slurm_backed_cluster - - with _slurm_backed_cluster(verbose=False, local_directory=None): - pass - - assert captured.get("host") == "host-fallback" - - -def test_local_cluster_advertises_memory_and_gpus( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """Dask only schedules a task on a worker that advertises every - requested resource key — so the local worker must expose mem and - gpus too, otherwise rules with ``mem_mb``/``gpus_per_task`` hang. - """ monkeypatch.setattr( - "lightcone.engine.dask_cluster._detect_node_shape", - lambda: _NodeShape(cpus=4, mem_bytes=16_000_000_000, gpus=2), + "lightcone.engine.dask_daemon.ensure_scheduler", _fake_ensure ) - captured: dict[str, object] = {} - - class _FakeCluster: - def __init__(self, **kwargs: object) -> None: - captured.update(kwargs) - self.scheduler_address = "tcp://stub:0" - - def close(self) -> None: - pass - - monkeypatch.setattr("dask.distributed.LocalCluster", _FakeCluster) - - from lightcone.engine.dask_cluster import _local_cluster - - with _local_cluster(verbose=False, local_directory=None): - pass - - resources = captured.get("resources") - assert isinstance(resources, dict) - assert set(resources.keys()) == {RESOURCE_CPUS, RESOURCE_MEMORY, RESOURCE_GPUS} - - -@pytest.mark.slow -def test_local_cluster_smoke() -> None: - """End-to-end: a real LocalCluster spins up, accepts a task, tears down.""" - from dask.distributed import Client - - from lightcone.engine.dask_cluster import _local_cluster - - with _local_cluster(verbose=False, local_directory=None) as addr: - client = Client(addr) - try: - assert client.submit(lambda x: x + 1, 41).result() == 42 - finally: - client.close() + with cluster_for_run(project_path=tmp_path) as addr: + assert addr == "tcp://stub:9999" + assert seen["project"] == tmp_path diff --git a/tests/test_dask_daemon.py b/tests/test_dask_daemon.py new file mode 100644 index 00000000..ac977984 --- /dev/null +++ b/tests/test_dask_daemon.py @@ -0,0 +1,226 @@ +"""Tests for the session-scoped Dask scheduler daemon. + +What we cover here: + +- ``scheduler_key`` switches on ``SLURM_JOB_ID``, +- ``scheduler_dirs`` lays out the per-key directory under scratch, +- ``ensure_scheduler`` reuses a live scheduler (no spawn), +- ``ensure_scheduler`` spawns a daemon when nothing is alive, cleans + stale state, and times out cleanly, +- ``stop`` is silently a no-op when there's no scheduler, and signals + the recorded PID otherwise. + +We mock the ``Popen`` of the daemon and the TCP probe — actually +spinning up Dask is reserved for the smoke test in test_dask_cluster. +""" + +from __future__ import annotations + +import json +import os +import signal +from pathlib import Path + +import pytest + +from lightcone.engine.dask_daemon import ( + SPAWN_WAIT_SECONDS, + ensure_scheduler, + scheduler_dirs, + scheduler_key, + stop, +) +from lightcone.engine.scratch import LIGHTCONE_SCRATCH_ENV, project_hash + + +@pytest.fixture +def project(tmp_path: Path) -> Path: + p = tmp_path / "proj" + p.mkdir() + (p / "astra.yaml").write_text("outputs: []\n") + return p + + +@pytest.fixture(autouse=True) +def _clean_env(monkeypatch: pytest.MonkeyPatch) -> None: + for var in ("SLURM_JOB_ID", "DASK_SCHEDULER_ADDRESS"): + monkeypatch.delenv(var, raising=False) + import socket + + monkeypatch.setattr(socket, "gethostname", lambda: "unknown-host-x") + + +@pytest.fixture +def scratch(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path: + s = tmp_path / "scratch" + monkeypatch.setenv(LIGHTCONE_SCRATCH_ENV, str(s)) + return s + + +# ---- key + dirs ---------------------------------------------------------- + + +def test_scheduler_key_defaults_to_project_hash(project: Path) -> None: + assert scheduler_key(project) == project_hash(project) + + +def test_scheduler_key_uses_slurm_job_id( + monkeypatch: pytest.MonkeyPatch, project: Path +) -> None: + """Inside an allocation the natural lifecycle scope is the + allocation, not the project — workers spawned via srun belong to + SLURM_JOB_ID, and a switch of project mid-allocation should reuse + the same workers.""" + monkeypatch.setenv("SLURM_JOB_ID", "12345") + assert scheduler_key(project) == "slurm-12345" + + +def test_scheduler_dirs_layout(scratch: Path, project: Path) -> None: + d = scheduler_dirs(project) + assert d.root == scratch / ".lightcone" / "dask-scheduler" / project_hash(project) + assert d.root.is_dir() + assert d.spill.is_dir() + # The lock-file paths must be under root so a daemon flock'ing + # owner.lock from inside the scheduler keeps the kernel-level + # liveness invariant we rely on. + for f in (d.owner_lock, d.spawn_lock, d.scheduler_file, d.meta_file, d.log_file): + assert f.parent == d.root + + +# ---- ensure_scheduler ---------------------------------------------------- + + +def test_ensure_reuses_live_scheduler( + monkeypatch: pytest.MonkeyPatch, scratch: Path, project: Path +) -> None: + """Fast path: scheduler.json points to an address that responds — + return it without going near the spawn lock.""" + d = scheduler_dirs(project) + d.scheduler_file.write_text(json.dumps({"address": "tcp://live:8786"})) + + monkeypatch.setattr( + "lightcone.engine.dask_daemon._probe", lambda addr: addr == "tcp://live:8786" + ) + + def _no_spawn(*a: object, **kw: object) -> None: + raise AssertionError("must not spawn when an existing scheduler is live") + + monkeypatch.setattr("lightcone.engine.dask_daemon._spawn_daemon", _no_spawn) + + assert ensure_scheduler(project) == "tcp://live:8786" + + +def test_ensure_spawns_when_nothing_running( + monkeypatch: pytest.MonkeyPatch, scratch: Path, project: Path +) -> None: + """No scheduler.json on disk → spawn the daemon and wait for it to + write the address. We simulate the daemon by writing scheduler.json + from inside the spawn stub.""" + d = scheduler_dirs(project) + spawned: dict[str, bool] = {} + + def _fake_spawn(_proj: Path, dirs: object) -> None: + spawned["yes"] = True + d.scheduler_file.write_text(json.dumps({"address": "tcp://fresh:7000"})) + + monkeypatch.setattr("lightcone.engine.dask_daemon._spawn_daemon", _fake_spawn) + monkeypatch.setattr("lightcone.engine.dask_daemon._probe", lambda addr: True) + + assert ensure_scheduler(project) == "tcp://fresh:7000" + assert spawned["yes"] is True + + +def test_ensure_clears_stale_address_before_spawn( + monkeypatch: pytest.MonkeyPatch, scratch: Path, project: Path +) -> None: + """A scheduler.json from a SIGKILL'd daemon points at a dead + address — it must be cleared before spawn, otherwise a partial-read + race between the new daemon writing the file and our caller reading + it could yield the stale address.""" + d = scheduler_dirs(project) + d.scheduler_file.write_text(json.dumps({"address": "tcp://dead:1"})) + d.meta_file.write_text(json.dumps({"pid": 999999, "address": "tcp://dead:1"})) + + probe_calls = {"n": 0} + + def _probe(addr: str) -> bool: + probe_calls["n"] += 1 + # First call (fast path) probes the dead address: fail. + # Subsequent calls (post-spawn) succeed. + return addr == "tcp://fresh:8000" + + def _fake_spawn(_proj: Path, dirs: object) -> None: + # By the time we're called, stale files must be gone — that's + # the actual invariant under test. + assert not d.scheduler_file.exists() + assert not d.meta_file.exists() + d.scheduler_file.write_text(json.dumps({"address": "tcp://fresh:8000"})) + + monkeypatch.setattr("lightcone.engine.dask_daemon._probe", _probe) + monkeypatch.setattr("lightcone.engine.dask_daemon._spawn_daemon", _fake_spawn) + + assert ensure_scheduler(project) == "tcp://fresh:8000" + assert probe_calls["n"] >= 2 # at least the fast-path probe + post-spawn probe + + +def test_ensure_times_out_with_clear_error( + monkeypatch: pytest.MonkeyPatch, scratch: Path, project: Path +) -> None: + """If the daemon never writes scheduler.json, ensure must surface + that with a pointer to the daemon's log — silent hang here would + look like a wedged ``lc run`` to the user.""" + monkeypatch.setattr("lightcone.engine.dask_daemon._probe", lambda addr: False) + monkeypatch.setattr("lightcone.engine.dask_daemon._spawn_daemon", lambda *a, **k: None) + monkeypatch.setattr( + "lightcone.engine.dask_daemon.SPAWN_WAIT_SECONDS", 0.3 + ) + + with pytest.raises(RuntimeError, match=r"did not come up"): + ensure_scheduler(project) + # Sanity: the constant we mocked is what controls the budget. + assert SPAWN_WAIT_SECONDS # original (untouched) constant still imports + + +# ---- stop ---------------------------------------------------------------- + + +def test_stop_is_silent_when_nothing_running(scratch: Path, project: Path) -> None: + assert stop(project) is False + + +def test_stop_handles_corrupt_meta(scratch: Path, project: Path) -> None: + d = scheduler_dirs(project) + d.meta_file.write_text("not json") + assert stop(project) is False + + +def test_stop_signals_recorded_pid( + monkeypatch: pytest.MonkeyPatch, scratch: Path, project: Path +) -> None: + d = scheduler_dirs(project) + d.meta_file.write_text(json.dumps({"pid": 42})) + + sent: dict[str, tuple[int, int]] = {} + + def _kill(pid: int, sig: int) -> None: + sent["call"] = (pid, sig) + + monkeypatch.setattr(os, "kill", _kill) + assert stop(project) is True + assert sent["call"] == (42, signal.SIGTERM) + + +def test_stop_silent_on_dead_pid( + monkeypatch: pytest.MonkeyPatch, scratch: Path, project: Path +) -> None: + """A PID from an old run that no longer exists is the common case + after SIGKILL — stop must swallow ProcessLookupError and report + ``False`` so the SessionEnd hook stays quiet.""" + d = scheduler_dirs(project) + d.meta_file.write_text(json.dumps({"pid": 1234567})) + + def _kill(pid: int, sig: int) -> None: + raise ProcessLookupError + + monkeypatch.setattr(os, "kill", _kill) + assert stop(project) is False diff --git a/tests/test_scratch.py b/tests/test_scratch.py index aa5684f8..7a72ed6d 100644 --- a/tests/test_scratch.py +++ b/tests/test_scratch.py @@ -122,12 +122,10 @@ def test_prepare_run_dirs_creates_layout( monkeypatch.setenv(LIGHTCONE_SCRATCH_ENV, str(tmp_path / "scratch")) rd = prepare_run_dirs(project, run_id="42") assert rd.root == tmp_path / "scratch" / ".lightcone" - assert rd.dask_local == rd.root / "dask" / "42" assert rd.lock_path == rd.root / "locks" / "42.lock" assert rd.snakemake_state.parent.parent == rd.root / "snakemake" # Every path that callers rely on must exist on return. assert rd.root.is_dir() - assert rd.dask_local.is_dir() assert rd.lock_path.is_file() assert rd.snakemake_state.parent.is_dir()