1- import contextlib
21import unittest
32import os
43import textwrap
4+ import contextlib
55import importlib
66import sys
77import socket
@@ -216,33 +216,13 @@ def requires_subinterpreters(meth):
216216# Simple wrapper functions for RemoteUnwinder
217217# ============================================================================
218218
219- # Errors that can occur transiently when reading process memory without synchronization
220- RETRIABLE_ERRORS = (
221- "Task list appears corrupted",
222- "Invalid linked list structure reading remote memory",
223- "Unknown error reading memory",
224- "Unhandled frame owner",
225- "Failed to parse initial frame",
226- "Failed to process frame chain",
227- "Failed to unwind stack",
228- )
229-
230-
231- def _is_retriable_error(exc):
232- """Check if an exception is a transient error that should be retried."""
233- msg = str(exc)
234- return any(msg.startswith(err) or err in msg for err in RETRIABLE_ERRORS)
235-
236-
237219def get_stack_trace(pid):
238220 for _ in busy_retry(SHORT_TIMEOUT):
239221 try:
240222 unwinder = RemoteUnwinder(pid, all_threads=True, debug=True)
241223 return unwinder.get_stack_trace()
242224 except RuntimeError as e:
243- if _is_retriable_error(e):
244- continue
245- raise
225+ continue
246226 raise RuntimeError("Failed to get stack trace after retries")
247227
248228
@@ -252,9 +232,7 @@ def get_async_stack_trace(pid):
252232 unwinder = RemoteUnwinder(pid, debug=True)
253233 return unwinder.get_async_stack_trace()
254234 except RuntimeError as e:
255- if _is_retriable_error(e):
256- continue
257- raise
235+ continue
258236 raise RuntimeError("Failed to get async stack trace after retries")
259237
260238
@@ -264,9 +242,7 @@ def get_all_awaited_by(pid):
264242 unwinder = RemoteUnwinder(pid, debug=True)
265243 return unwinder.get_all_awaited_by()
266244 except RuntimeError as e:
267- if _is_retriable_error(e):
268- continue
269- raise
245+ continue
270246 raise RuntimeError("Failed to get all awaited_by after retries")
271247
272248
@@ -2268,18 +2244,13 @@ def make_unwinder(cache_frames=True):
22682244 def _get_frames_with_retry(self, unwinder, required_funcs):
22692245 """Get frames containing required_funcs, with retry for transient errors."""
22702246 for _ in range(MAX_TRIES):
2271- try :
2247+ with contextlib.suppress(OSError, RuntimeError) :
22722248 traces = unwinder.get_stack_trace()
22732249 for interp in traces:
22742250 for thread in interp.threads:
22752251 funcs = {f.funcname for f in thread.frame_info}
22762252 if required_funcs.issubset(funcs):
22772253 return thread.frame_info
2278- except RuntimeError as e:
2279- if _is_retriable_error(e):
2280- pass
2281- else:
2282- raise
22832254 time.sleep(0.1)
22842255 return None
22852256
@@ -2802,70 +2773,39 @@ def foo2():
28022773 make_unwinder,
28032774 ):
28042775 unwinder = make_unwinder(cache_frames=True)
2805- buffer = b""
2806-
2807- def recv_msg():
2808- """Receive a single message from socket."""
2809- nonlocal buffer
2810- while b"\n" not in buffer:
2811- chunk = client_socket.recv(256)
2812- if not chunk:
2813- return None
2814- buffer += chunk
2815- msg, buffer = buffer.split(b"\n", 1)
2816- return msg
2817-
2818- def get_thread_frames(target_funcs):
2819- """Get frames for thread matching target functions."""
2820- retries = 0
2821- for _ in busy_retry(SHORT_TIMEOUT):
2822- if retries >= 5:
2823- break
2824- retries += 1
2825- # On Windows, ReadProcessMemory can fail with OSError
2826- # (WinError 299) when frame pointers are in flux
2827- with contextlib.suppress(RuntimeError, OSError):
2828- traces = unwinder.get_stack_trace()
2829- for interp in traces:
2830- for thread in interp.threads:
2831- funcs = [f.funcname for f in thread.frame_info]
2832- if any(f in funcs for f in target_funcs):
2833- return funcs
2834- return None
2776+
2777+ # Message dispatch table: signal -> required functions for that thread
2778+ dispatch = {
2779+ b"t1:baz1": {"baz1", "bar1", "foo1"},
2780+ b"t2:baz2": {"baz2", "bar2", "foo2"},
2781+ b"t1:blech1": {"blech1", "foo1"},
2782+ b"t2:blech2": {"blech2", "foo2"},
2783+ }
28352784
28362785 # Track results for each sync point
28372786 results = {}
28382787
2839- # Process 4 sync points: baz1, baz2, blech1, blech2
2840- # With the lock, threads are serialized - handle one at a time
2841- for _ in range(4):
2842- msg = recv_msg()
2843- self.assertIsNotNone(msg, "Expected message from subprocess")
2844-
2845- # Determine which thread/function and take snapshot
2846- if msg == b"t1:baz1":
2847- funcs = get_thread_frames(["baz1", "bar1", "foo1"])
2848- self.assertIsNotNone(funcs, "Thread 1 not found at baz1")
2849- results["t1:baz1"] = funcs
2850- elif msg == b"t2:baz2":
2851- funcs = get_thread_frames(["baz2", "bar2", "foo2"])
2852- self.assertIsNotNone(funcs, "Thread 2 not found at baz2")
2853- results["t2:baz2"] = funcs
2854- elif msg == b"t1:blech1":
2855- funcs = get_thread_frames(["blech1", "foo1"])
2856- self.assertIsNotNone(funcs, "Thread 1 not found at blech1")
2857- results["t1:blech1"] = funcs
2858- elif msg == b"t2:blech2":
2859- funcs = get_thread_frames(["blech2", "foo2"])
2860- self.assertIsNotNone(funcs, "Thread 2 not found at blech2")
2861- results["t2:blech2"] = funcs
2862-
2863- # Release thread to continue
2788+ # Process 4 sync points (order depends on thread scheduling)
2789+ buffer = _wait_for_signal(client_socket, b"\n")
2790+ for i in range(4):
2791+ # Extract first message from buffer
2792+ msg, sep, buffer = buffer.partition(b"\n")
2793+ self.assertIn(msg, dispatch, f"Unexpected message: {msg!r}")
2794+
2795+ # Sample frames for the thread at this sync point
2796+ required_funcs = dispatch[msg]
2797+ frames = self._get_frames_with_retry(unwinder, required_funcs)
2798+ self.assertIsNotNone(frames, f"Thread not found for {msg!r}")
2799+ results[msg] = [f.funcname for f in frames]
2800+
2801+ # Release thread and wait for next message (if not last)
28642802 client_socket.sendall(b"k")
2803+ if i < 3:
2804+ buffer += _wait_for_signal(client_socket, b"\n")
28652805
28662806 # Validate Phase 1: baz snapshots
2867- t1_baz = results.get("t1:baz1")
2868- t2_baz = results.get("t2:baz2")
2807+ t1_baz = results.get(b "t1:baz1")
2808+ t2_baz = results.get(b "t2:baz2")
28692809 self.assertIsNotNone(t1_baz, "Missing t1:baz1 snapshot")
28702810 self.assertIsNotNone(t2_baz, "Missing t2:baz2 snapshot")
28712811
@@ -2890,8 +2830,8 @@ def get_thread_frames(target_funcs):
28902830 self.assertNotIn("foo1", t2_baz)
28912831
28922832 # Validate Phase 2: blech snapshots (cache invalidation test)
2893- t1_blech = results.get("t1:blech1")
2894- t2_blech = results.get("t2:blech2")
2833+ t1_blech = results.get(b "t1:blech1")
2834+ t2_blech = results.get(b "t2:blech2")
28952835 self.assertIsNotNone(t1_blech, "Missing t1:blech1 snapshot")
28962836 self.assertIsNotNone(t2_blech, "Missing t2:blech2 snapshot")
28972837
0 commit comments