Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions marimo/_server/scratchpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

from marimo._ai._tools.types import CodeExecutionResult
from marimo._messaging.cell_output import CellChannel
from marimo._messaging.notification import CellNotification
from marimo._messaging.notification import (
CellNotification,
CompletedRunNotification,
)
from marimo._messaging.serde import deserialize_kernel_message
from marimo._runtime.scratch import SCRATCH_CELL_ID
from marimo._session.extensions.types import EventAwareExtension
Expand Down Expand Up @@ -99,13 +102,17 @@ def on_notification_sent(
) -> None:
del session
msg = deserialize_kernel_message(notification)

# Check for completion sentinel FIRST (waits for downstream errors)
if isinstance(msg, CompletedRunNotification):
self._queue.put_nowait(None) # sentinel
return

if not isinstance(msg, CellNotification):
return

if msg.cell_id == SCRATCH_CELL_ID:
self._queue.put_nowait(msg)
if msg.status == "idle":
self._queue.put_nowait(None) # sentinel
else:
if msg.console is not None:
# Stream console output from cells run by _code_mode
Expand Down
2 changes: 1 addition & 1 deletion tests/_convert/markdown/snapshots/dataflow.md.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ A marimo notebook is a directed acyclic graph in which nodes represent
cells and edges represent data dependencies. marimo creates this graph by
analyzing each cell (without running it) to determine its

- references ("refs*), the global variables it reads but doesn't define;
- references ("refs"), the global variables it reads but doesn't define;
- definitions ("defs"), the global variables it defines.

There is an edge from one cell to another if the latter cell references any
Expand Down
180 changes: 179 additions & 1 deletion tests/_server/test_scratchpad.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2026 Marimo. All rights reserved.
from __future__ import annotations

import asyncio
import json
from unittest.mock import MagicMock

Expand All @@ -13,7 +14,10 @@
MarimoExceptionRaisedError,
MarimoSyntaxError,
)
from marimo._messaging.notification import CellNotification
from marimo._messaging.notification import (
CellNotification,
CompletedRunNotification,
)
from marimo._runtime.scratch import SCRATCH_CELL_ID
from marimo._server.scratchpad import (
ScratchCellListener,
Expand Down Expand Up @@ -416,6 +420,11 @@ async def test_stream_basic(self) -> None:
session, serialize_kernel_message(notif)
)

# Send CompletedRunNotification to trigger sentinel
listener.on_notification_sent(
session, serialize_kernel_message(CompletedRunNotification())
)

events: list[str] = []
async for event in listener.stream():
events.append(event)
Expand Down Expand Up @@ -467,6 +476,11 @@ async def test_captures_other_cell_console(self) -> None:
idle = CellNotification(cell_id=SCRATCH_CELL_ID, status="idle")
listener.on_notification_sent(session, serialize_kernel_message(idle))

# Send CompletedRunNotification to trigger sentinel
listener.on_notification_sent(
session, serialize_kernel_message(CompletedRunNotification())
)

events: list[str] = []
async for event in listener.stream():
events.append(event)
Expand Down Expand Up @@ -516,3 +530,167 @@ async def consume() -> None:
name, payload = _parse_sse(events[0])
assert name == "stdout"
assert payload["data"] == "partial\n"

@pytest.mark.asyncio
async def test_state_setter_cascade_error_captured(
self,
) -> None:
"""Downstream error from mo.state setter cascade must be captured.

Bug scenario 1: Scratchpad calls set_x(0), which triggers a downstream
cell (result = 1 / get_x()) to re-run and error. The error must be
captured BEFORE CompletedRunNotification is broadcast.
"""
from marimo._messaging.serde import serialize_kernel_message

listener = ScratchCellListener()
event_bus = MagicMock()
session = MagicMock()
listener.on_attach(session, event_bus)

# Send scratch cell running
listener.on_notification_sent(
session,
serialize_kernel_message(
CellNotification(cell_id=SCRATCH_CELL_ID, status="running")
),
)

# Send downstream error (e.g., result = 1 / get_x())
err = MarimoExceptionRaisedError(
msg="ZeroDivisionError: division by zero",
exception_type="ZeroDivisionError",
raising_cell=None,
)
listener.on_notification_sent(
session,
serialize_kernel_message(
CellNotification(
cell_id="downstream_cell", # type: ignore[arg-type]
output=CellOutput.errors([err]),
status="idle",
)
),
)

# Send scratch cell idle (OLD sentinel - should be ignored)
listener.on_notification_sent(
session,
serialize_kernel_message(
CellNotification(cell_id=SCRATCH_CELL_ID, status="idle")
),
)

# Send CompletedRunNotification (the real sentinel, sent after state_updates flush)
listener.on_notification_sent(
session,
serialize_kernel_message(CompletedRunNotification()),
)

await listener.wait()

# Downstream error must be captured
assert listener.child_error_summaries == [
"cell 'downstream_cell' raised ZeroDivisionError"
]

@pytest.mark.asyncio
async def test_run_cell_cascade_error_captured(
self,
) -> None:
"""Downstream error from ctx.run_cell() cascade must be captured.

Bug scenario 2: Scratchpad runs a notebook cell via ctx.run_cell(cid),
which triggers downstream reactive execution. The downstream cell error must be
captured BEFORE CompletedRunNotification is broadcast.
"""
from marimo._messaging.serde import serialize_kernel_message

listener = ScratchCellListener()
event_bus = MagicMock()
session = MagicMock()
listener.on_attach(session, event_bus)

# Send scratch cell running
listener.on_notification_sent(
session,
serialize_kernel_message(
CellNotification(cell_id=SCRATCH_CELL_ID, status="running")
),
)

# Send downstream reactive cell error (e.g., cell ran by ctx.run_cell())
err = MarimoExceptionRaisedError(
msg="ValueError: invalid value",
exception_type="ValueError",
raising_cell=None,
)
listener.on_notification_sent(
session,
serialize_kernel_message(
CellNotification(
cell_id="reactive_cell", # type: ignore[arg-type]
output=CellOutput.errors([err]),
status="idle",
)
),
)

# Send scratch cell idle (OLD sentinel - should be ignored)
listener.on_notification_sent(
session,
serialize_kernel_message(
CellNotification(cell_id=SCRATCH_CELL_ID, status="idle")
),
)

# Send CompletedRunNotification (the real sentinel)
listener.on_notification_sent(
session,
serialize_kernel_message(CompletedRunNotification()),
)

await listener.wait()

# Downstream error must be captured
assert listener.child_error_summaries == [
"cell 'reactive_cell' raised ValueError"
]

@pytest.mark.asyncio
async def test_scratch_cell_idle_does_not_trigger_sentinel(
self,
) -> None:
"""Scratch cell idle alone should NOT trigger sentinel.

Guards against regression: we must wait for CompletedRunNotification
even if scratch cell goes idle first.
"""
from marimo._messaging.serde import serialize_kernel_message

listener = ScratchCellListener()
event_bus = MagicMock()
session = MagicMock()
listener.on_attach(session, event_bus)

# Send scratch cell running
listener.on_notification_sent(
session,
serialize_kernel_message(
CellNotification(cell_id=SCRATCH_CELL_ID, status="running")
),
)

# Send scratch cell idle WITHOUT CompletedRunNotification
listener.on_notification_sent(
session,
serialize_kernel_message(
CellNotification(cell_id=SCRATCH_CELL_ID, status="idle")
),
)

# Wait a short time - listener should NOT have returned yet
await asyncio.wait_for(listener.wait(timeout=0.1), timeout=0.2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this wait_for can be removed.

Suggested change
await asyncio.wait_for(listener.wait(timeout=0.1), timeout=0.2)
await listener.wait(timeout=0.1)


# Listener should have timed out, not returned
assert listener.timed_out is True
Loading