From 05db25003a2e37935cb6ecd36980e509bc661909 Mon Sep 17 00:00:00 2001 From: jacobcbeaudin <51803634+jacobcbeaudin@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:51:21 -0700 Subject: [PATCH 1/2] fix(runtime): restore sys.modules['__main__'] after run-mode kernel --- marimo/_runtime/patches.py | 56 +++++++++ marimo/_session/managers/kernel.py | 151 ++++++++++++++++------ tests/_runtime/_patches_spawn_target.py | 17 +++ tests/_runtime/test_patches.py | 115 ++++++++++++++++- tests/_session/managers/test_kernel.py | 160 ++++++++++++++++++++++++ 5 files changed, 461 insertions(+), 38 deletions(-) create mode 100644 tests/_runtime/_patches_spawn_target.py create mode 100644 tests/_session/managers/test_kernel.py diff --git a/marimo/_runtime/patches.py b/marimo/_runtime/patches.py index e37c5ccfe72..1a84bd85064 100644 --- a/marimo/_runtime/patches.py +++ b/marimo/_runtime/patches.py @@ -6,13 +6,17 @@ import functools import sys import textwrap +import threading import types from collections.abc import Callable from typing import TYPE_CHECKING, Any +from marimo import _loggers from marimo._ast.parse import ast_parse from marimo._dependencies.dependencies import DependencyManager from marimo._runtime import marimo_browser, marimo_pdb + +LOGGER = _loggers.marimo_logger() from marimo._utils.platform import is_pyodide Unpatch = Callable[[], None] @@ -60,6 +64,58 @@ def patch_sys_module(module: types.ModuleType) -> None: sys.modules[module.__name__] = module +# Refcount-based save/restore for the process's ``sys.modules['__main__']``. +# Run-mode (thread) kernels share sys.modules with the host; pairing every +# ``patch_main_module`` with save/restore keeps the host's main module from +# leaking the kernel's synthetic module after the session ends. +_main_save_lock = threading.Lock() +_main_save_count = 0 +_main_original: types.ModuleType | None = None + + +def save_main_module() -> None: + """Pair with ``restore_main_module`` to scope a ``patch_main_module`` call. + + The first call in a process captures the current ``sys.modules['__main__']``; + subsequent overlapping calls share that capture and the last paired + ``restore_main_module`` puts the original back. + + Callers must balance save/restore. An unbalanced save increments the + refcount permanently for the process lifetime and prevents the original + ``__main__`` from ever being restored; an unbalanced restore is a no-op. + """ + global _main_save_count, _main_original + with _main_save_lock: + if _main_save_count == 0: + _main_original = sys.modules.get("__main__") + LOGGER.debug( + "save_main_module captured original __main__=%r", + _main_original, + ) + _main_save_count += 1 + + +def restore_main_module() -> None: + """Release a reference acquired by ``save_main_module``. + + Restores the captured original ``sys.modules['__main__']`` when the last + outstanding reference is released. Calling with no outstanding reference + is a no-op. + """ + global _main_save_count, _main_original + with _main_save_lock: + if _main_save_count == 0: + return + _main_save_count -= 1 + if _main_save_count == 0 and _main_original is not None: + sys.modules["__main__"] = _main_original + LOGGER.debug( + "restore_main_module restored __main__ to %r", + _main_original, + ) + _main_original = None + + def patch_pyodide_networking() -> None: import pyodide_http # type: ignore diff --git a/marimo/_session/managers/kernel.py b/marimo/_session/managers/kernel.py index cdc8beb7440..7450e8fdfa7 100644 --- a/marimo/_session/managers/kernel.py +++ b/marimo/_session/managers/kernel.py @@ -32,6 +32,11 @@ LOGGER = _loggers.marimo_logger() +# Seconds to wait for the RUN-mode kernel thread to exit after we send +# StopKernelCommand. Long enough that a cooperative kernel winds down +# cleanly, short enough that a stuck kernel does not block a server. +_THREAD_KERNEL_JOIN_TIMEOUT_S = 5.0 + class KernelManagerImpl(KernelManager): """Kernel manager using multiprocessing Process or threading Thread. @@ -63,6 +68,13 @@ def __init__( self._read_conn: TypedConnection[KernelMessage] | None = None self._virtual_file_storage = virtual_file_storage + # Tracks an outstanding save_main_module awaiting its paired + # restore. Guarded by _main_save_lock so concurrent close_kernel + # or start_kernel calls on the same instance cannot double-release + # the __main__ refcount. + self._main_save_lock: threading.Lock = threading.Lock() + self._main_save_outstanding: bool = False + def start_kernel(self) -> None: # We use a process in edit mode so that we can interrupt the app # with a SIGINT; we don't mind the additional memory consumption, @@ -121,36 +133,44 @@ def launch_kernel_with_cleanup( install_thread_local_proxies() - assert self.queue_manager.stream_queue is not None - # Make threads daemons so killing the server immediately brings - # down all client sessions - self.kernel_task = threading.Thread( - target=launch_kernel_with_cleanup, - args=( - self.queue_manager.control_queue, - self.queue_manager.set_ui_element_queue, - self.queue_manager.completion_queue, - self.queue_manager.input_queue, - self.queue_manager.stream_queue, - # IPC not used in run mode - None, - is_edit_mode, - self.configs, - self.app_metadata, - self.config_manager.get_config(hide_secrets=False), - self._virtual_file_storage, - self.redirect_console_to_browser, - # win32 interrupt queue - None, - # profile path - None, - # log level - GLOBAL_SETTINGS.LOG_LEVEL, - ), - # daemon threads can create child processes, unlike - # daemon processes - daemon=True, - ) + self._save_host_main_module() + try: + assert self.queue_manager.stream_queue is not None + # daemon threads can create child processes (unlike daemon + # processes); daemon=True so a server shutdown tears all + # sessions down immediately. + self.kernel_task = threading.Thread( + target=launch_kernel_with_cleanup, + args=( + self.queue_manager.control_queue, + self.queue_manager.set_ui_element_queue, + self.queue_manager.completion_queue, + self.queue_manager.input_queue, + self.queue_manager.stream_queue, + # IPC not used in run mode + None, + is_edit_mode, + self.configs, + self.app_metadata, + self.config_manager.get_config(hide_secrets=False), + self._virtual_file_storage, + self.redirect_console_to_browser, + # win32 interrupt queue + None, + # profile path + None, + # log level + GLOBAL_SETTINGS.LOG_LEVEL, + ), + daemon=True, + ) + self.kernel_task.start() + except BaseException: + # The thread never started; release the save we just made so + # the __main__ refcount does not leak. + self._restore_host_main_module() + raise + return self.kernel_task.start() # type: ignore if listener is not None: @@ -238,17 +258,74 @@ def interrupt_kernel(self) -> None: LOGGER.debug("Sending SIGINT to kernel") os.kill(self.kernel_task.pid, signal.SIGINT) + def _save_host_main_module(self) -> None: + """Pair ``patch_main_module`` in the run-mode kernel with a restore. + + Idempotent: a second call after a save-without-restore does not + double-increment the process-wide refcount, so a caller that + accidentally invokes ``start_kernel`` twice will not leak. + + Lazy import so the session managers package does not circularly + depend on the runtime's patches module at type-checking time. + """ + with self._main_save_lock: + if self._main_save_outstanding: + return + from marimo._runtime.patches import save_main_module + + save_main_module() + self._main_save_outstanding = True + + def _restore_host_main_module(self) -> None: + """Release the save made by ``_save_host_main_module``. + + Idempotent: safe to call more than once and safe against + concurrent callers. Only the first call after a save decrements + the process-wide refcount, so repeated or overlapping + ``close_kernel`` invocations do not corrupt other sessions. + """ + with self._main_save_lock: + if not self._main_save_outstanding: + return + from marimo._runtime.patches import restore_main_module + + restore_main_module() + self._main_save_outstanding = False + + def _stop_and_join_run_mode_kernel(self) -> bool: + """Stop the run-mode kernel thread and wait for it to exit. + + Returns ``True`` if the thread is known to be idle (either it was + already finished or it exited within the bounded join window) and + the caller may run post-session cleanup such as restoring host + ``__main__``. Returns ``False`` if the thread is still alive after + the join timeout; the caller must skip cleanup because modifying + process state under a live cell corrupts in-flight execution. + """ + assert isinstance(self.kernel_task, threading.Thread) + if not self.kernel_task.is_alive(): + return True + self.queue_manager.put_control_request(commands.StopKernelCommand()) + # The kernel has already received the stop command, so a short + # upper bound is enough for a cooperative wind-down. A server that + # closes a session never pauses longer than this even if the kernel + # is stuck. + self.kernel_task.join(timeout=_THREAD_KERNEL_JOIN_TIMEOUT_S) + if self.kernel_task.is_alive(): + LOGGER.warning( + "RUN-mode kernel thread did not exit within %.1fs of " + "StopKernelCommand; skipping host __main__ restore.", + _THREAD_KERNEL_JOIN_TIMEOUT_S, + ) + return False + return True + def close_kernel(self) -> None: assert self.kernel_task is not None, "kernel not started" if isinstance(self.kernel_task, threading.Thread): - # in run mode - if self.kernel_task.is_alive(): - # We don't join the kernel thread because we don't want to server - # to block on it finishing - self.queue_manager.put_control_request( - commands.StopKernelCommand() - ) + if self._stop_and_join_run_mode_kernel(): + self._restore_host_main_module() else: # otherwise we have something that is `ProcessLike` if self.profile_path is not None and self.kernel_task.is_alive(): diff --git a/tests/_runtime/_patches_spawn_target.py b/tests/_runtime/_patches_spawn_target.py new file mode 100644 index 00000000000..05c135994a1 --- /dev/null +++ b/tests/_runtime/_patches_spawn_target.py @@ -0,0 +1,17 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Spawn-pickleable target for test_patches.py. + +Defined in a real importable module so spawn workers can resolve it via +``_fixup_main_from_path`` without relying on the test module's __main__. +""" + +from __future__ import annotations + + +def noop_target() -> None: + """No-op target for multiprocessing.Process. + + The point of the test is whether the spawn worker can be set up at all, + not what it does. + """ + return diff --git a/tests/_runtime/test_patches.py b/tests/_runtime/test_patches.py index 8065d724c00..76b7da0d9ce 100644 --- a/tests/_runtime/test_patches.py +++ b/tests/_runtime/test_patches.py @@ -2,7 +2,9 @@ from __future__ import annotations import io +import multiprocessing import sys +import threading from typing import TYPE_CHECKING from unittest.mock import Mock, patch @@ -10,10 +12,16 @@ from marimo._dependencies.dependencies import DependencyManager from marimo._runtime.capture import capture_stderr -from marimo._runtime.patches import patch_polars_write_json +from marimo._runtime.patches import ( + patch_main_module, + patch_polars_write_json, + restore_main_module, + save_main_module, +) from marimo._runtime.runtime import Kernel from marimo._utils.platform import is_pyodide from tests._messaging.mocks import MockStream +from tests._runtime._patches_spawn_target import noop_target from tests.conftest import ExecReqProvider if TYPE_CHECKING: @@ -254,3 +262,108 @@ def test_polars_write_json_patch(tmp_path: Path): # Test it fails again with pytest.raises(ValueError, match="Test error"): df.write_json(file_path) + + +class TestSaveRestoreMainModule: + """save_main_module / restore_main_module refcount helpers. + + patch_main_module mutates sys.modules['__main__'] without a restore + path, leaking a synthetic module into any host that shares sys.modules + with the kernel (i.e. RUN-mode thread kernels). The helpers let + callers explicitly scope the mutation. + """ + + def test_restore_returns_original_main(self) -> None: + original = sys.modules["__main__"] + save_main_module() + try: + patch_main_module( + file=None, input_override=None, print_override=None + ) + assert sys.modules["__main__"] is not original + finally: + restore_main_module() + assert sys.modules["__main__"] is original + + def test_refcounted_multiple_savers_share_original(self) -> None: + original = sys.modules["__main__"] + # Two overlapping sessions each call save. + save_main_module() + save_main_module() + try: + patch_main_module( + file=None, input_override=None, print_override=None + ) + mutated = sys.modules["__main__"] + assert mutated is not original + # First release: refcount still > 0, no restore yet. + restore_main_module() + assert sys.modules["__main__"] is mutated + finally: + # Last release: restore to the originally captured value. + restore_main_module() + assert sys.modules["__main__"] is original + + def test_over_release_is_noop(self) -> None: + original = sys.modules["__main__"] + # No outstanding save; restore should do nothing. + restore_main_module() + assert sys.modules["__main__"] is original + + def test_concurrent_save_restore_cycles_preserve_main(self) -> None: + """Concurrent save/restore cycles converge back to the original main.""" + original = sys.modules["__main__"] + n_threads = 8 + iterations_per_thread = 25 + + def worker() -> None: + for _ in range(iterations_per_thread): + save_main_module() + patch_main_module( + file=None, input_override=None, print_override=None + ) + restore_main_module() + + threads = [threading.Thread(target=worker) for _ in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert sys.modules["__main__"] is original + + def test_save_when_main_is_absent_is_safe(self) -> None: + """Saving when sys.modules has no __main__ skips the restore write.""" + saved_main = sys.modules.pop("__main__", None) + try: + save_main_module() + # A save that captured None must not install a None __main__. + restore_main_module() + assert "__main__" not in sys.modules + finally: + if saved_main is not None: + sys.modules["__main__"] = saved_main + + def test_host_spawn_works_after_save_and_restore(self) -> None: + """End-to-end: restoring __main__ lets the host spawn subprocesses. + + Matches the reproduction in the bug report. Without the restore, + ``multiprocessing.get_context('spawn').Process(...).start()`` fails + because the synthetic __main__ has no handle to module-level + targets. With the restore, it succeeds. + """ + save_main_module() + try: + patch_main_module( + file=None, input_override=None, print_override=None + ) + finally: + restore_main_module() + + proc = multiprocessing.get_context("spawn").Process(target=noop_target) + proc.start() + proc.join(timeout=10) + assert proc.exitcode == 0, ( + "spawn subprocess failed after save/patch/restore cycle; " + "restore did not recover host's __main__" + ) diff --git a/tests/_session/managers/test_kernel.py b/tests/_session/managers/test_kernel.py new file mode 100644 index 00000000000..ac599b7849c --- /dev/null +++ b/tests/_session/managers/test_kernel.py @@ -0,0 +1,160 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Tests for KernelManagerImpl's host __main__ save/restore methods. + +These methods are the application-layer glue between ``KernelManagerImpl`` +and the refcounted save/restore primitives in ``marimo._runtime.patches``. +They own the per-instance flag and lock that make the wrappers idempotent +against double ``close_kernel`` calls and concurrent callers. +""" + +from __future__ import annotations + +import sys +import threading + +from marimo._runtime import patches +from marimo._session.managers.kernel import KernelManagerImpl + + +def _make_bare_kernel_manager() -> KernelManagerImpl: + """Return an uninitialized KernelManagerImpl with only the state that + ``_save_host_main_module`` and ``_restore_host_main_module`` touch. + + Bypasses ``__init__`` so the test does not need to construct a real + queue manager, session config, etc. + """ + km = object.__new__(KernelManagerImpl) + km._main_save_lock = threading.Lock() + km._main_save_outstanding = False + return km + + +def _expected_refcount() -> int: + """Current value of the module-level save refcount for assertions.""" + return patches._main_save_count + + +class TestSaveHostMainModule: + def test_save_sets_outstanding_flag(self) -> None: + km = _make_bare_kernel_manager() + assert km._main_save_outstanding is False + km._save_host_main_module() + try: + assert km._main_save_outstanding is True + finally: + km._restore_host_main_module() + + def test_save_is_idempotent(self) -> None: + """A second save call without a restore must not double-increment.""" + km = _make_bare_kernel_manager() + before = _expected_refcount() + km._save_host_main_module() + try: + after_first = _expected_refcount() + km._save_host_main_module() + after_second = _expected_refcount() + assert after_first - before == 1 + assert after_second == after_first + finally: + km._restore_host_main_module() + assert _expected_refcount() == before + + def test_save_without_restore_leaves_refcount_held(self) -> None: + """Documents the leak contract: an unbalanced save holds the + process refcount and the per-instance flag until a paired restore + releases them. A caller that never calls ``close_kernel`` (e.g. + process exits abnormally) leaves the host ``__main__`` polluted; + this is the expected degraded behavior. + """ + km = _make_bare_kernel_manager() + before = _expected_refcount() + km._save_host_main_module() + try: + assert km._main_save_outstanding is True + assert _expected_refcount() == before + 1 + finally: + # Release so the leak does not cascade into other tests. + km._restore_host_main_module() + assert km._main_save_outstanding is False + assert _expected_refcount() == before + + +class TestRestoreHostMainModule: + def test_restore_without_save_is_noop(self) -> None: + km = _make_bare_kernel_manager() + before = _expected_refcount() + km._restore_host_main_module() + assert km._main_save_outstanding is False + assert _expected_refcount() == before + + def test_restore_is_idempotent(self) -> None: + """Blocker repro: double close_kernel must not over-release. + + Two ``KernelManagerImpl`` instances share the process refcount; + if one calls ``_restore_host_main_module`` twice, the second call + must not decrement the refcount and prematurely restore ``__main__`` + while the other instance is still holding a save. + """ + other = _make_bare_kernel_manager() + me = _make_bare_kernel_manager() + # ``other`` mimics a second session still holding a save. + other._save_host_main_module() + try: + me._save_host_main_module() + with_both = _expected_refcount() + + me._restore_host_main_module() + after_first_release = _expected_refcount() + assert after_first_release == with_both - 1 + assert me._main_save_outstanding is False + + # Second call from the same instance: no effect on refcount. + me._restore_host_main_module() + assert _expected_refcount() == after_first_release + assert me._main_save_outstanding is False + finally: + other._restore_host_main_module() + + def test_concurrent_restores_do_not_double_release(self) -> None: + """Many threads racing on restore release the refcount exactly once.""" + other = _make_bare_kernel_manager() + me = _make_bare_kernel_manager() + other._save_host_main_module() + try: + me._save_host_main_module() + before = _expected_refcount() + + barrier = threading.Barrier(8) + + def worker() -> None: + barrier.wait() + me._restore_host_main_module() + + threads = [threading.Thread(target=worker) for _ in range(8)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert me._main_save_outstanding is False + assert _expected_refcount() == before - 1 + finally: + other._restore_host_main_module() + + +class TestSaveRestoreIntegration: + def test_cycle_returns_main_to_host_original(self) -> None: + """Full save-patch-restore via the wrapper methods leaves main intact.""" + km = _make_bare_kernel_manager() + original = sys.modules["__main__"] + + km._save_host_main_module() + try: + patches.patch_main_module( + file=None, input_override=None, print_override=None + ) + assert sys.modules["__main__"] is not original + finally: + km._restore_host_main_module() + + assert sys.modules["__main__"] is original From f74bee9f108f3de9e6a6aff187c123d9258edce5 Mon Sep 17 00:00:00 2001 From: jacobcbeaudin <51803634+jacobcbeaudin@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:58:09 -0700 Subject: [PATCH 2/2] chore: ruff formatting --- marimo/_runtime/patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/marimo/_runtime/patches.py b/marimo/_runtime/patches.py index 1a84bd85064..6ede331c665 100644 --- a/marimo/_runtime/patches.py +++ b/marimo/_runtime/patches.py @@ -15,9 +15,9 @@ from marimo._ast.parse import ast_parse from marimo._dependencies.dependencies import DependencyManager from marimo._runtime import marimo_browser, marimo_pdb +from marimo._utils.platform import is_pyodide LOGGER = _loggers.marimo_logger() -from marimo._utils.platform import is_pyodide Unpatch = Callable[[], None]