Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 75 additions & 4 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,53 @@ def pytest_addoption(parser):

def _install_session_timeout(timeout_s: int) -> None:
def _handler(signum, frame):
from simpler_setup import parallel_scheduler as _ps # noqa: PLC0415

print(
f"\n{'=' * 40}\n"
f"[pytest] TIMEOUT: session exceeded {timeout_s}s "
f"({timeout_s // 60}min) limit, aborting\n"
f"{'=' * 40}",
f"\n{'=' * 40}\n[pytest] TIMEOUT: session exceeded {timeout_s}s ({timeout_s // 60}min) limit\n{'=' * 40}",
flush=True,
)

# If the dispatcher is mid-flight, surface every stuck child:
# 1. SIGUSR1 each pid so its faulthandler dumps all-thread tracebacks
# (Python + C frames) into its own stdout — pumped into output_lines.
# 2. Briefly let pumps drain those bytes before we tear everything down.
# 3. Print each in-flight job's tail buffer in a HUNG group so the log
# contains the actual cause, not just the timeout banner.
# 4. SIGTERM/SIGKILL the children so they don't outlive us as orphans
# holding NPU device state.
state = _ps._active_state
if state is not None and state.running:
for p in list(state.running):
try:
if hasattr(signal, "SIGUSR1"):
p.send_signal(signal.SIGUSR1)
except (ProcessLookupError, OSError):
pass

time.sleep(2.0)

now = time.monotonic()
for p, rj in list(state.running.items()):
elapsed = now - rj.start_time
tail = "".join(rj.output_lines[-200:])
print(
f"::group::HUNG {rj.job.label} pid={p.pid} devices={rj.device_ids} elapsed={elapsed:.1f}s",
flush=True,
)
if tail:
print(tail, end="" if tail.endswith("\n") else "\n", flush=True)
print("::endgroup::", flush=True)
print(
f"*** HUNG: {rj.job.label} (devices={rj.device_ids}) — expand group above ***",
flush=True,
)

try:
_ps._terminate_all(state)
except Exception: # noqa: BLE001
pass

os._exit(TIMEOUT_EXIT_CODE)

# signal.alarm / SIGALRM are Unix-only; skip silently on platforms without
Expand All @@ -194,6 +234,30 @@ def _handler(signum, frame):
signal.alarm(timeout_s)


def _install_child_faulthandler() -> None:
"""In dispatched child pytest processes, let SIGUSR1 dump all-thread stacks.

The parent dispatcher's session-timeout handler sends SIGUSR1 to every
in-flight child before tearing the run down. ``faulthandler.register``
runs in the C signal handler, so it works even when the main thread is
blocked inside a native call that doesn't release the GIL (NPU runtime,
nanobind into C++) — exactly the case Python-level watchdogs miss.

Always-on ``faulthandler.enable()`` also gives us a stack on real crashes
(SIGSEGV/SIGABRT) instead of a silent exit.
"""
import faulthandler # noqa: PLC0415

faulthandler.enable()
if hasattr(signal, "SIGUSR1"):
try:
faulthandler.register(signal.SIGUSR1, chain=False, all_threads=True)
except (ValueError, RuntimeError):
# Fails when stdout/stderr can't be duped (rare in child subprocs);
# leave faulthandler.enable() in place and continue.
pass


def pytest_configure(config):
"""Register custom markers and apply global config."""
config.addinivalue_line("markers", "platforms(list): supported platforms for standalone ST functions")
Expand Down Expand Up @@ -240,6 +304,13 @@ def pytest_configure(config):
if timeout and timeout > 0:
_install_session_timeout(timeout)

# Always register SIGUSR1 → faulthandler. In dispatched child pytest
# processes this is what the parent's session-timeout handler relies on
# to extract a stack from a hung run. In the parent dispatcher itself
# it's harmless and lets a developer query "what is this process doing?"
# interactively with `kill -USR1 <pid>`.
_install_child_faulthandler()

# xdist worker: bind this process to a single device id from the --device range.
# The dispatcher (or the user) supplies --device 0-7; xdist spawns N workers
# labelled gw0..gwN-1. We slice device_ids[worker_index] so each worker owns
Expand Down
20 changes: 20 additions & 0 deletions simpler_setup/parallel_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ class _RunState:
cancelled: bool = False


# Module-global handle to the active _RunState while run_jobs is in flight.
# Read by conftest's session-timeout SIGALRM handler so it can SIGUSR1 every
# stuck child and dump their captured buffers before tearing the parent down.
# At most one run_jobs is active at a time (the dispatcher is single-threaded).
_active_state: _RunState | None = None


def _device_range_str(ids: list[int]) -> str:
"""Format a device-id list as a CLI-friendly range or comma list.

Expand Down Expand Up @@ -176,7 +183,12 @@ def run_jobs(
f"{len(device_ids)}; widen --device range or shrink the case's device_count"
)

# Module-global is intentional: the conftest SIGALRM handler runs in a
# signal context with no reference to this function's locals, so it has
# to find the in-flight state through a known module attribute.
global _active_state # noqa: PLW0603
state = _RunState(free_devices=list(device_ids))
_active_state = state
queue = list(jobs)

def _pump_stdout(p: subprocess.Popen, sink: list[str]) -> None:
Expand Down Expand Up @@ -231,6 +243,13 @@ def _try_launch_head() -> bool:
output_lines=output_lines,
pump_thread=pump,
)
# Emit at launch (not just at completion) so a hung child is locatable:
# the last START without a matching PASS/FAIL line in _emit_group output
# is the case that's stuck.
print(
f"[scheduler] START {head.label} pid={p.pid} devices={allocated}",
flush=True,
)
return True

def _reap_one() -> JobResult | None:
Expand Down Expand Up @@ -311,6 +330,7 @@ def _reap_one() -> JobResult | None:
duration_s=duration,
)
)
_active_state = None

return state.results

Expand Down
Loading