diff --git a/conftest.py b/conftest.py index 5939d7578..7019b788e 100644 --- a/conftest.py +++ b/conftest.py @@ -178,13 +178,53 @@ def pytest_addoption(parser): def _install_session_timeout(timeout_s: int) -> None: def _handler(signum, frame): + from simpler_setup import parallel_scheduler as _ps # noqa: PLC0415 + print( - f"\n{'=' * 40}\n" - f"[pytest] TIMEOUT: session exceeded {timeout_s}s " - f"({timeout_s // 60}min) limit, aborting\n" - f"{'=' * 40}", + f"\n{'=' * 40}\n[pytest] TIMEOUT: session exceeded {timeout_s}s ({timeout_s // 60}min) limit\n{'=' * 40}", flush=True, ) + + # If the dispatcher is mid-flight, surface every stuck child: + # 1. SIGUSR1 each pid so its faulthandler dumps all-thread tracebacks + # (Python + C frames) into its own stdout — pumped into output_lines. + # 2. Briefly let pumps drain those bytes before we tear everything down. + # 3. Print each in-flight job's tail buffer in a HUNG group so the log + # contains the actual cause, not just the timeout banner. + # 4. SIGTERM/SIGKILL the children so they don't outlive us as orphans + # holding NPU device state. + state = _ps._active_state + if state is not None and state.running: + for p in list(state.running): + try: + if hasattr(signal, "SIGUSR1"): + p.send_signal(signal.SIGUSR1) + except (ProcessLookupError, OSError): + pass + + time.sleep(2.0) + + now = time.monotonic() + for p, rj in list(state.running.items()): + elapsed = now - rj.start_time + tail = "".join(rj.output_lines[-200:]) + print( + f"::group::HUNG {rj.job.label} pid={p.pid} devices={rj.device_ids} elapsed={elapsed:.1f}s", + flush=True, + ) + if tail: + print(tail, end="" if tail.endswith("\n") else "\n", flush=True) + print("::endgroup::", flush=True) + print( + f"*** HUNG: {rj.job.label} (devices={rj.device_ids}) — expand group above ***", + flush=True, + ) + + try: + _ps._terminate_all(state) + except Exception: # noqa: BLE001 + pass + os._exit(TIMEOUT_EXIT_CODE) # signal.alarm / SIGALRM are Unix-only; skip silently on platforms without @@ -194,6 +234,30 @@ def _handler(signum, frame): signal.alarm(timeout_s) +def _install_child_faulthandler() -> None: + """In dispatched child pytest processes, let SIGUSR1 dump all-thread stacks. + + The parent dispatcher's session-timeout handler sends SIGUSR1 to every + in-flight child before tearing the run down. ``faulthandler.register`` + runs in the C signal handler, so it works even when the main thread is + blocked inside a native call that doesn't release the GIL (NPU runtime, + nanobind into C++) — exactly the case Python-level watchdogs miss. + + Always-on ``faulthandler.enable()`` also gives us a stack on real crashes + (SIGSEGV/SIGABRT) instead of a silent exit. + """ + import faulthandler # noqa: PLC0415 + + faulthandler.enable() + if hasattr(signal, "SIGUSR1"): + try: + faulthandler.register(signal.SIGUSR1, chain=False, all_threads=True) + except (ValueError, RuntimeError): + # Fails when stdout/stderr can't be duped (rare in child subprocs); + # leave faulthandler.enable() in place and continue. + pass + + def pytest_configure(config): """Register custom markers and apply global config.""" config.addinivalue_line("markers", "platforms(list): supported platforms for standalone ST functions") @@ -240,6 +304,13 @@ def pytest_configure(config): if timeout and timeout > 0: _install_session_timeout(timeout) + # Always register SIGUSR1 → faulthandler. In dispatched child pytest + # processes this is what the parent's session-timeout handler relies on + # to extract a stack from a hung run. In the parent dispatcher itself + # it's harmless and lets a developer query "what is this process doing?" + # interactively with `kill -USR1 `. + _install_child_faulthandler() + # xdist worker: bind this process to a single device id from the --device range. # The dispatcher (or the user) supplies --device 0-7; xdist spawns N workers # labelled gw0..gwN-1. We slice device_ids[worker_index] so each worker owns diff --git a/simpler_setup/parallel_scheduler.py b/simpler_setup/parallel_scheduler.py index 8fa28b258..fcbdb8393 100644 --- a/simpler_setup/parallel_scheduler.py +++ b/simpler_setup/parallel_scheduler.py @@ -79,6 +79,13 @@ class _RunState: cancelled: bool = False +# Module-global handle to the active _RunState while run_jobs is in flight. +# Read by conftest's session-timeout SIGALRM handler so it can SIGUSR1 every +# stuck child and dump their captured buffers before tearing the parent down. +# At most one run_jobs is active at a time (the dispatcher is single-threaded). +_active_state: _RunState | None = None + + def _device_range_str(ids: list[int]) -> str: """Format a device-id list as a CLI-friendly range or comma list. @@ -176,7 +183,12 @@ def run_jobs( f"{len(device_ids)}; widen --device range or shrink the case's device_count" ) + # Module-global is intentional: the conftest SIGALRM handler runs in a + # signal context with no reference to this function's locals, so it has + # to find the in-flight state through a known module attribute. + global _active_state # noqa: PLW0603 state = _RunState(free_devices=list(device_ids)) + _active_state = state queue = list(jobs) def _pump_stdout(p: subprocess.Popen, sink: list[str]) -> None: @@ -231,6 +243,13 @@ def _try_launch_head() -> bool: output_lines=output_lines, pump_thread=pump, ) + # Emit at launch (not just at completion) so a hung child is locatable: + # the last START without a matching PASS/FAIL line in _emit_group output + # is the case that's stuck. + print( + f"[scheduler] START {head.label} pid={p.pid} devices={allocated}", + flush=True, + ) return True def _reap_one() -> JobResult | None: @@ -311,6 +330,7 @@ def _reap_one() -> JobResult | None: duration_s=duration, ) ) + _active_state = None return state.results