Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 76 additions & 2 deletions marimo/_ast/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import sys
import threading
import weakref
from collections.abc import (
Callable,
Iterable,
Expand Down Expand Up @@ -83,14 +84,66 @@


class _Namespace(Mapping[str, object]):
"""Thin read-only mapping returned by ``app.run()``.

Performs additional ref-counting to prevent memory leaks, and correctly
clear module scope.
"""

def __init__(
self, dictionary: dict[str, object], owner: Cell | App
self,
dictionary: dict[str, object],
owner: Cell | App,
_module_dict: dict[str, Any] | None = None,
_required_refs: dict[str, set[str]] | None = None,
) -> None:
self._dict = dictionary
self._owner = owner
self._module_dict = _module_dict
# Closures capture tracked/module_dict, NOT self, to avoid
# preventing the _Namespace from being collected.
self._tracked: set[str] = set()
if _module_dict is not None:
tracked = self._tracked
required_refs = _required_refs or {}

def _on_namespace_collected() -> None:
# Compute names needed by any tracked variable
needed: set[str] = set()
for key in tracked:
needed |= required_refs.get(key, set())
# Only remove tracked keys whose deps are satisfied
removable = tracked - needed
for key in removable:
_module_dict.pop(key, None)
# Narrow tracked to removable — only those will
# fire finalizers when external refs drop.
tracked.difference_update(needed)
if not tracked:
_module_dict.clear()

weakref.finalize(self, _on_namespace_collected)

def __getitem__(self, item: str) -> object:
return self._dict[item]
val = self._dict[item]
if self._module_dict is not None and item not in self._tracked:
tracked = self._tracked
module_dict = self._module_dict
try:

def _on_value_collected() -> None:
tracked.discard(item)
if not tracked:
module_dict.clear()

weakref.finalize(val, _on_value_collected)
tracked.add(item)
except TypeError:
pass # not weakref-able (int, str, …)
return val

def __contains__(self, item: object) -> bool:
return item in self._dict

def __iter__(self) -> Iterator[str]:
return iter(self._dict)
Expand Down Expand Up @@ -618,12 +671,33 @@ def _flatten_outputs(self, outputs: dict[CellId_t, Any]) -> Sequence[Any]:
if not self._graph.is_disabled(cid) and cid in outputs
)

@staticmethod
def _build_required_refs(
graph: dataflow.DirectedGraph,
) -> dict[str, set[str]]:
"""Map each defined name to the set of names it depends on.

Used by _Namespace to avoid removing values that tracked
functions still need from the module dict.
"""
required_refs: dict[str, set[str]] = {}
for cell in graph.cells.values():
for name, vdata_list in cell.variable_data.items():
refs: set[str] = set()
for vdata in vdata_list:
refs |= vdata.required_refs
if refs:
required_refs[name] = refs
return required_refs

def _globals_to_defs(self, glbls: dict[str, Any]) -> _Namespace:
return _Namespace(
dictionary={
name: glbls[name] for name in self._defs if name in glbls
},
owner=self,
_module_dict=glbls,
_required_refs=self._build_required_refs(self._graph),
)

def run(
Expand Down
136 changes: 136 additions & 0 deletions tests/_ast/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,3 +1844,139 @@ def test_get_runner(self, k: Kernel) -> None:
registry.remove_runner(app)
registry.remove_runner(other)
assert not registry._runners


class TestAppRunMemoryLeak:
"""Ensure memory released after app.run()"""

@staticmethod
def test_scoped_call_frees_memory() -> None:
import weakref

app = App()

@app.cell
def _():
class _Payload:
pass

payload = _Payload()
result = 42
return result, payload

def run_and_extract() -> tuple[int, weakref.ref[object]]:
_, defs = app.run()
ref = weakref.ref(defs["payload"])
return int(defs["result"]), ref

value, ref = run_and_extract()

assert value == 42
assert ref() is None, (
"Cell objects not freed after scoped app.run() returned"
)

@staticmethod
def test_accessed_values_freed_when_not_held() -> None:
"""Values accessed via defs[key] are tracked and removed from
the module dict on namespace collection. If no strong reference
is held externally, they are freed immediately."""
import weakref

app = App()

@app.cell
def _():
class _Tracker:
pass

tracker = _Tracker()
result = 42
return result, tracker

def run_and_extract() -> tuple[int, weakref.ref[object]]:
_, defs = app.run()
ref = weakref.ref(defs["tracker"])
return int(defs["result"]), ref

value, ref = run_and_extract()
assert value == 42
assert ref() is None

@staticmethod
def test_scoped_call_function_behavior() -> None:
"""
Check that scoped variables are capable of working until all references
are cleaned up.
"""
import weakref

app = App()

@app.cell
def _():
class _Tracker:
pass

tracked = _Tracker()
implicitly_tracked = _Tracker()

unreferenced = _Tracker()

def pure():
return 1

def uses_global():
return len([tracked, implicitly_tracked])

return pure, uses_global, tracker

def scope_values():
def run_and_extract() -> (
tuple[object, object, weakref.ref[object]]
):
_, defs = app.run()
ref = weakref.ref(defs["tracked"])
unref = weakref.ref(defs["unreferenced"])
assert defs["pure"]() == 1
assert defs["uses_global"]() == 2
return defs["pure"], defs["uses_global"], unref, ref

pure, uses_global, unref, ref = run_and_extract()

# tracker was accessed (tracked) — freed since no strong
# ref held externally.
assert unref() is None
assert ref() is not None
assert pure() == 1
# Known limitation: uses_global's dependency was tracked
# and freed, so __globals__ lookup fails.
assert uses_global() == 2
return weakref.ref(uses_global)

# No strong references exist at this point, so we should clean up.
uses_global_ref = scope_values()
# All strong refs dropped — function collected.
assert uses_global_ref() is None

@staticmethod
def test_repeated_runs_dont_accumulate() -> None:
"""Multiple app.run() calls should not accumulate live objects."""
import weakref

app = App()

@app.cell
def _():
class _Ephemeral:
pass

obj = _Ephemeral()
return (obj,)

prev = weakref.ref(type("_Dead", (), {})())
for i in range(5):
_, defs = app.run()
assert prev() is None, f"Run {i}: previous object not freed"
prev = weakref.ref(defs["obj"])
assert prev() is not None
Loading