diff --git a/marimo/_server/scratchpad.py b/marimo/_server/scratchpad.py index f395f0df28c..1508e1ab8d6 100644 --- a/marimo/_server/scratchpad.py +++ b/marimo/_server/scratchpad.py @@ -15,7 +15,10 @@ from marimo._ai._tools.types import CodeExecutionResult from marimo._messaging.cell_output import CellChannel -from marimo._messaging.notification import CellNotification +from marimo._messaging.notification import ( + CellNotification, + CompletedRunNotification, +) from marimo._messaging.serde import deserialize_kernel_message from marimo._runtime.scratch import SCRATCH_CELL_ID from marimo._session.extensions.types import EventAwareExtension @@ -99,13 +102,17 @@ def on_notification_sent( ) -> None: del session msg = deserialize_kernel_message(notification) + + # Check for completion sentinel FIRST (waits for downstream errors) + if isinstance(msg, CompletedRunNotification): + self._queue.put_nowait(None) # sentinel + return + if not isinstance(msg, CellNotification): return if msg.cell_id == SCRATCH_CELL_ID: self._queue.put_nowait(msg) - if msg.status == "idle": - self._queue.put_nowait(None) # sentinel else: if msg.console is not None: # Stream console output from cells run by _code_mode diff --git a/tests/_convert/markdown/snapshots/dataflow.md.txt b/tests/_convert/markdown/snapshots/dataflow.md.txt index 59c192ea354..985bfaf879e 100644 --- a/tests/_convert/markdown/snapshots/dataflow.md.txt +++ b/tests/_convert/markdown/snapshots/dataflow.md.txt @@ -38,7 +38,7 @@ A marimo notebook is a directed acyclic graph in which nodes represent cells and edges represent data dependencies. marimo creates this graph by analyzing each cell (without running it) to determine its -- references ("refs*), the global variables it reads but doesn't define; +- references ("refs"), the global variables it reads but doesn't define; - definitions ("defs"), the global variables it defines. There is an edge from one cell to another if the latter cell references any diff --git a/tests/_server/test_scratchpad.py b/tests/_server/test_scratchpad.py index 565eed36e69..25e33ab4dec 100644 --- a/tests/_server/test_scratchpad.py +++ b/tests/_server/test_scratchpad.py @@ -1,6 +1,7 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations +import asyncio import json from unittest.mock import MagicMock @@ -13,7 +14,10 @@ MarimoExceptionRaisedError, MarimoSyntaxError, ) -from marimo._messaging.notification import CellNotification +from marimo._messaging.notification import ( + CellNotification, + CompletedRunNotification, +) from marimo._runtime.scratch import SCRATCH_CELL_ID from marimo._server.scratchpad import ( ScratchCellListener, @@ -416,6 +420,11 @@ async def test_stream_basic(self) -> None: session, serialize_kernel_message(notif) ) + # Send CompletedRunNotification to trigger sentinel + listener.on_notification_sent( + session, serialize_kernel_message(CompletedRunNotification()) + ) + events: list[str] = [] async for event in listener.stream(): events.append(event) @@ -467,6 +476,11 @@ async def test_captures_other_cell_console(self) -> None: idle = CellNotification(cell_id=SCRATCH_CELL_ID, status="idle") listener.on_notification_sent(session, serialize_kernel_message(idle)) + # Send CompletedRunNotification to trigger sentinel + listener.on_notification_sent( + session, serialize_kernel_message(CompletedRunNotification()) + ) + events: list[str] = [] async for event in listener.stream(): events.append(event) @@ -516,3 +530,167 @@ async def consume() -> None: name, payload = _parse_sse(events[0]) assert name == "stdout" assert payload["data"] == "partial\n" + + @pytest.mark.asyncio + async def test_state_setter_cascade_error_captured( + self, + ) -> None: + """Downstream error from mo.state setter cascade must be captured. + + Bug scenario 1: Scratchpad calls set_x(0), which triggers a downstream + cell (result = 1 / get_x()) to re-run and error. The error must be + captured BEFORE CompletedRunNotification is broadcast. + """ + from marimo._messaging.serde import serialize_kernel_message + + listener = ScratchCellListener() + event_bus = MagicMock() + session = MagicMock() + listener.on_attach(session, event_bus) + + # Send scratch cell running + listener.on_notification_sent( + session, + serialize_kernel_message( + CellNotification(cell_id=SCRATCH_CELL_ID, status="running") + ), + ) + + # Send downstream error (e.g., result = 1 / get_x()) + err = MarimoExceptionRaisedError( + msg="ZeroDivisionError: division by zero", + exception_type="ZeroDivisionError", + raising_cell=None, + ) + listener.on_notification_sent( + session, + serialize_kernel_message( + CellNotification( + cell_id="downstream_cell", # type: ignore[arg-type] + output=CellOutput.errors([err]), + status="idle", + ) + ), + ) + + # Send scratch cell idle (OLD sentinel - should be ignored) + listener.on_notification_sent( + session, + serialize_kernel_message( + CellNotification(cell_id=SCRATCH_CELL_ID, status="idle") + ), + ) + + # Send CompletedRunNotification (the real sentinel, sent after state_updates flush) + listener.on_notification_sent( + session, + serialize_kernel_message(CompletedRunNotification()), + ) + + await listener.wait() + + # Downstream error must be captured + assert listener.child_error_summaries == [ + "cell 'downstream_cell' raised ZeroDivisionError" + ] + + @pytest.mark.asyncio + async def test_run_cell_cascade_error_captured( + self, + ) -> None: + """Downstream error from ctx.run_cell() cascade must be captured. + + Bug scenario 2: Scratchpad runs a notebook cell via ctx.run_cell(cid), + which triggers downstream reactive execution. The downstream cell error must be + captured BEFORE CompletedRunNotification is broadcast. + """ + from marimo._messaging.serde import serialize_kernel_message + + listener = ScratchCellListener() + event_bus = MagicMock() + session = MagicMock() + listener.on_attach(session, event_bus) + + # Send scratch cell running + listener.on_notification_sent( + session, + serialize_kernel_message( + CellNotification(cell_id=SCRATCH_CELL_ID, status="running") + ), + ) + + # Send downstream reactive cell error (e.g., cell ran by ctx.run_cell()) + err = MarimoExceptionRaisedError( + msg="ValueError: invalid value", + exception_type="ValueError", + raising_cell=None, + ) + listener.on_notification_sent( + session, + serialize_kernel_message( + CellNotification( + cell_id="reactive_cell", # type: ignore[arg-type] + output=CellOutput.errors([err]), + status="idle", + ) + ), + ) + + # Send scratch cell idle (OLD sentinel - should be ignored) + listener.on_notification_sent( + session, + serialize_kernel_message( + CellNotification(cell_id=SCRATCH_CELL_ID, status="idle") + ), + ) + + # Send CompletedRunNotification (the real sentinel) + listener.on_notification_sent( + session, + serialize_kernel_message(CompletedRunNotification()), + ) + + await listener.wait() + + # Downstream error must be captured + assert listener.child_error_summaries == [ + "cell 'reactive_cell' raised ValueError" + ] + + @pytest.mark.asyncio + async def test_scratch_cell_idle_does_not_trigger_sentinel( + self, + ) -> None: + """Scratch cell idle alone should NOT trigger sentinel. + + Guards against regression: we must wait for CompletedRunNotification + even if scratch cell goes idle first. + """ + from marimo._messaging.serde import serialize_kernel_message + + listener = ScratchCellListener() + event_bus = MagicMock() + session = MagicMock() + listener.on_attach(session, event_bus) + + # Send scratch cell running + listener.on_notification_sent( + session, + serialize_kernel_message( + CellNotification(cell_id=SCRATCH_CELL_ID, status="running") + ), + ) + + # Send scratch cell idle WITHOUT CompletedRunNotification + listener.on_notification_sent( + session, + serialize_kernel_message( + CellNotification(cell_id=SCRATCH_CELL_ID, status="idle") + ), + ) + + # Wait a short time - listener should NOT have returned yet + await asyncio.wait_for(listener.wait(timeout=0.1), timeout=0.2) + + # Listener should have timed out, not returned + assert listener.timed_out is True