Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions marimo/_runtime/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@
import functools
import sys
import textwrap
import threading
import types
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from marimo import _loggers
from marimo._ast.parse import ast_parse
from marimo._dependencies.dependencies import DependencyManager
from marimo._runtime import marimo_browser, marimo_pdb
from marimo._utils.platform import is_pyodide

LOGGER = _loggers.marimo_logger()

Unpatch = Callable[[], None]

if TYPE_CHECKING:
Expand Down Expand Up @@ -60,6 +64,58 @@ def patch_sys_module(module: types.ModuleType) -> None:
sys.modules[module.__name__] = module


# Refcount-based save/restore for the process's ``sys.modules['__main__']``.
# Run-mode (thread) kernels share sys.modules with the host; pairing every
# ``patch_main_module`` with save/restore keeps the host's main module from
# leaking the kernel's synthetic module after the session ends.
_main_save_lock = threading.Lock()
_main_save_count = 0
_main_original: types.ModuleType | None = None


def save_main_module() -> None:
"""Pair with ``restore_main_module`` to scope a ``patch_main_module`` call.

The first call in a process captures the current ``sys.modules['__main__']``;
subsequent overlapping calls share that capture and the last paired
``restore_main_module`` puts the original back.

Callers must balance save/restore. An unbalanced save increments the
refcount permanently for the process lifetime and prevents the original
``__main__`` from ever being restored; an unbalanced restore is a no-op.
"""
global _main_save_count, _main_original
with _main_save_lock:
if _main_save_count == 0:
_main_original = sys.modules.get("__main__")
LOGGER.debug(
"save_main_module captured original __main__=%r",
_main_original,
)
_main_save_count += 1


def restore_main_module() -> None:
"""Release a reference acquired by ``save_main_module``.

Restores the captured original ``sys.modules['__main__']`` when the last
outstanding reference is released. Calling with no outstanding reference
is a no-op.
"""
global _main_save_count, _main_original
with _main_save_lock:
if _main_save_count == 0:
return
_main_save_count -= 1
if _main_save_count == 0 and _main_original is not None:
sys.modules["__main__"] = _main_original
LOGGER.debug(
"restore_main_module restored __main__ to %r",
_main_original,
)
_main_original = None


def patch_pyodide_networking() -> None:
import pyodide_http # type: ignore

Expand Down
151 changes: 114 additions & 37 deletions marimo/_session/managers/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@

LOGGER = _loggers.marimo_logger()

# Seconds to wait for the RUN-mode kernel thread to exit after we send
# StopKernelCommand. Long enough that a cooperative kernel winds down
# cleanly, short enough that a stuck kernel does not block a server.
_THREAD_KERNEL_JOIN_TIMEOUT_S = 5.0


class KernelManagerImpl(KernelManager):
"""Kernel manager using multiprocessing Process or threading Thread.
Expand Down Expand Up @@ -63,6 +68,13 @@ def __init__(
self._read_conn: TypedConnection[KernelMessage] | None = None
self._virtual_file_storage = virtual_file_storage

# Tracks an outstanding save_main_module awaiting its paired
# restore. Guarded by _main_save_lock so concurrent close_kernel
# or start_kernel calls on the same instance cannot double-release
# the __main__ refcount.
self._main_save_lock: threading.Lock = threading.Lock()
self._main_save_outstanding: bool = False

def start_kernel(self) -> None:
# We use a process in edit mode so that we can interrupt the app
# with a SIGINT; we don't mind the additional memory consumption,
Expand Down Expand Up @@ -121,36 +133,44 @@ def launch_kernel_with_cleanup(

install_thread_local_proxies()

assert self.queue_manager.stream_queue is not None
# Make threads daemons so killing the server immediately brings
# down all client sessions
self.kernel_task = threading.Thread(
target=launch_kernel_with_cleanup,
args=(
self.queue_manager.control_queue,
self.queue_manager.set_ui_element_queue,
self.queue_manager.completion_queue,
self.queue_manager.input_queue,
self.queue_manager.stream_queue,
# IPC not used in run mode
None,
is_edit_mode,
self.configs,
self.app_metadata,
self.config_manager.get_config(hide_secrets=False),
self._virtual_file_storage,
self.redirect_console_to_browser,
# win32 interrupt queue
None,
# profile path
None,
# log level
GLOBAL_SETTINGS.LOG_LEVEL,
),
# daemon threads can create child processes, unlike
# daemon processes
daemon=True,
)
self._save_host_main_module()
try:
assert self.queue_manager.stream_queue is not None
# daemon threads can create child processes (unlike daemon
# processes); daemon=True so a server shutdown tears all
# sessions down immediately.
self.kernel_task = threading.Thread(
target=launch_kernel_with_cleanup,
args=(
self.queue_manager.control_queue,
self.queue_manager.set_ui_element_queue,
self.queue_manager.completion_queue,
self.queue_manager.input_queue,
self.queue_manager.stream_queue,
# IPC not used in run mode
None,
is_edit_mode,
self.configs,
self.app_metadata,
self.config_manager.get_config(hide_secrets=False),
self._virtual_file_storage,
self.redirect_console_to_browser,
# win32 interrupt queue
None,
# profile path
None,
# log level
GLOBAL_SETTINGS.LOG_LEVEL,
),
daemon=True,
)
self.kernel_task.start()
except BaseException:
# The thread never started; release the save we just made so
# the __main__ refcount does not leak.
self._restore_host_main_module()
raise
return

self.kernel_task.start() # type: ignore
if listener is not None:
Expand Down Expand Up @@ -238,17 +258,74 @@ def interrupt_kernel(self) -> None:
LOGGER.debug("Sending SIGINT to kernel")
os.kill(self.kernel_task.pid, signal.SIGINT)

def _save_host_main_module(self) -> None:
"""Pair ``patch_main_module`` in the run-mode kernel with a restore.

Idempotent: a second call after a save-without-restore does not
double-increment the process-wide refcount, so a caller that
accidentally invokes ``start_kernel`` twice will not leak.

Lazy import so the session managers package does not circularly
depend on the runtime's patches module at type-checking time.
"""
with self._main_save_lock:
if self._main_save_outstanding:
return
from marimo._runtime.patches import save_main_module

save_main_module()
self._main_save_outstanding = True

def _restore_host_main_module(self) -> None:
"""Release the save made by ``_save_host_main_module``.

Idempotent: safe to call more than once and safe against
concurrent callers. Only the first call after a save decrements
the process-wide refcount, so repeated or overlapping
``close_kernel`` invocations do not corrupt other sessions.
"""
with self._main_save_lock:
if not self._main_save_outstanding:
return
from marimo._runtime.patches import restore_main_module

restore_main_module()
self._main_save_outstanding = False

def _stop_and_join_run_mode_kernel(self) -> bool:
"""Stop the run-mode kernel thread and wait for it to exit.

Returns ``True`` if the thread is known to be idle (either it was
already finished or it exited within the bounded join window) and
the caller may run post-session cleanup such as restoring host
``__main__``. Returns ``False`` if the thread is still alive after
the join timeout; the caller must skip cleanup because modifying
process state under a live cell corrupts in-flight execution.
"""
assert isinstance(self.kernel_task, threading.Thread)
if not self.kernel_task.is_alive():
return True
self.queue_manager.put_control_request(commands.StopKernelCommand())
# The kernel has already received the stop command, so a short
# upper bound is enough for a cooperative wind-down. A server that
# closes a session never pauses longer than this even if the kernel
# is stuck.
self.kernel_task.join(timeout=_THREAD_KERNEL_JOIN_TIMEOUT_S)
if self.kernel_task.is_alive():
LOGGER.warning(
"RUN-mode kernel thread did not exit within %.1fs of "
"StopKernelCommand; skipping host __main__ restore.",
_THREAD_KERNEL_JOIN_TIMEOUT_S,
)
return False
return True

def close_kernel(self) -> None:
assert self.kernel_task is not None, "kernel not started"

if isinstance(self.kernel_task, threading.Thread):
# in run mode
if self.kernel_task.is_alive():
# We don't join the kernel thread because we don't want to server
# to block on it finishing
self.queue_manager.put_control_request(
commands.StopKernelCommand()
)
if self._stop_and_join_run_mode_kernel():
self._restore_host_main_module()
else:
# otherwise we have something that is `ProcessLike`
if self.profile_path is not None and self.kernel_task.is_alive():
Expand Down
17 changes: 17 additions & 0 deletions tests/_runtime/_patches_spawn_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2026 Marimo. All rights reserved.
"""Spawn-pickleable target for test_patches.py.

Defined in a real importable module so spawn workers can resolve it via
``_fixup_main_from_path`` without relying on the test module's __main__.
"""

from __future__ import annotations


def noop_target() -> None:
"""No-op target for multiprocessing.Process.

The point of the test is whether the spawn worker can be set up at all,
not what it does.
"""
return
Loading
Loading