diff --git a/marimo/_ast/app.py b/marimo/_ast/app.py index 2b3df508f1e..b47043316e5 100644 --- a/marimo/_ast/app.py +++ b/marimo/_ast/app.py @@ -6,6 +6,7 @@ import os import sys import threading +import weakref from collections.abc import ( Callable, Iterable, @@ -83,14 +84,66 @@ class _Namespace(Mapping[str, object]): + """Thin read-only mapping returned by ``app.run()``. + + Performs additional ref-counting to prevent memory leaks, and correctly + clear module scope. + """ + def __init__( - self, dictionary: dict[str, object], owner: Cell | App + self, + dictionary: dict[str, object], + owner: Cell | App, + _module_dict: dict[str, Any] | None = None, + _required_refs: dict[str, set[str]] | None = None, ) -> None: self._dict = dictionary self._owner = owner + self._module_dict = _module_dict + # Closures capture tracked/module_dict, NOT self, to avoid + # preventing the _Namespace from being collected. + self._tracked: set[str] = set() + if _module_dict is not None: + tracked = self._tracked + required_refs = _required_refs or {} + + def _on_namespace_collected() -> None: + # Compute names needed by any tracked variable + needed: set[str] = set() + for key in tracked: + needed |= required_refs.get(key, set()) + # Only remove tracked keys whose deps are satisfied + removable = tracked - needed + for key in removable: + _module_dict.pop(key, None) + # Narrow tracked to removable — only those will + # fire finalizers when external refs drop. + tracked.difference_update(needed) + if not tracked: + _module_dict.clear() + + weakref.finalize(self, _on_namespace_collected) def __getitem__(self, item: str) -> object: - return self._dict[item] + val = self._dict[item] + if self._module_dict is not None and item not in self._tracked: + tracked = self._tracked + module_dict = self._module_dict + try: + + def _on_value_collected() -> None: + tracked.discard(item) + if not tracked: + module_dict.clear() + + weakref.finalize(val, _on_value_collected) + tracked.add(item) + except TypeError: + pass # not weakref-able (int, str, …) + return val + + def __contains__(self, item: object) -> bool: + return item in self._dict def __iter__(self) -> Iterator[str]: return iter(self._dict) @@ -618,12 +671,33 @@ def _flatten_outputs(self, outputs: dict[CellId_t, Any]) -> Sequence[Any]: if not self._graph.is_disabled(cid) and cid in outputs ) + @staticmethod + def _build_required_refs( + graph: dataflow.DirectedGraph, + ) -> dict[str, set[str]]: + """Map each defined name to the set of names it depends on. + + Used by _Namespace to avoid removing values that tracked + functions still need from the module dict. + """ + required_refs: dict[str, set[str]] = {} + for cell in graph.cells.values(): + for name, vdata_list in cell.variable_data.items(): + refs: set[str] = set() + for vdata in vdata_list: + refs |= vdata.required_refs + if refs: + required_refs[name] = refs + return required_refs + def _globals_to_defs(self, glbls: dict[str, Any]) -> _Namespace: return _Namespace( dictionary={ name: glbls[name] for name in self._defs if name in glbls }, owner=self, + _module_dict=glbls, + _required_refs=self._build_required_refs(self._graph), ) def run( diff --git a/tests/_ast/test_app.py b/tests/_ast/test_app.py index a03cae79838..3788de8c3d4 100644 --- a/tests/_ast/test_app.py +++ b/tests/_ast/test_app.py @@ -1844,3 +1844,139 @@ def test_get_runner(self, k: Kernel) -> None: registry.remove_runner(app) registry.remove_runner(other) assert not registry._runners + + +class TestAppRunMemoryLeak: + """Ensure memory released after app.run()""" + + @staticmethod + def test_scoped_call_frees_memory() -> None: + import weakref + + app = App() + + @app.cell + def _(): + class _Payload: + pass + + payload = _Payload() + result = 42 + return result, payload + + def run_and_extract() -> tuple[int, weakref.ref[object]]: + _, defs = app.run() + ref = weakref.ref(defs["payload"]) + return int(defs["result"]), ref + + value, ref = run_and_extract() + + assert value == 42 + assert ref() is None, ( + "Cell objects not freed after scoped app.run() returned" + ) + + @staticmethod + def test_accessed_values_freed_when_not_held() -> None: + """Values accessed via defs[key] are tracked and removed from + the module dict on namespace collection. If no strong reference + is held externally, they are freed immediately.""" + import weakref + + app = App() + + @app.cell + def _(): + class _Tracker: + pass + + tracker = _Tracker() + result = 42 + return result, tracker + + def run_and_extract() -> tuple[int, weakref.ref[object]]: + _, defs = app.run() + ref = weakref.ref(defs["tracker"]) + return int(defs["result"]), ref + + value, ref = run_and_extract() + assert value == 42 + assert ref() is None + + @staticmethod + def test_scoped_call_function_behavior() -> None: + """ + Check that scoped variables are capable of working until all references + are cleaned up. + """ + import weakref + + app = App() + + @app.cell + def _(): + class _Tracker: + pass + + tracked = _Tracker() + implicitly_tracked = _Tracker() + + unreferenced = _Tracker() + + def pure(): + return 1 + + def uses_global(): + return len([tracked, implicitly_tracked]) + + return pure, uses_global, tracker + + def scope_values(): + def run_and_extract() -> ( + tuple[object, object, weakref.ref[object]] + ): + _, defs = app.run() + ref = weakref.ref(defs["tracked"]) + unref = weakref.ref(defs["unreferenced"]) + assert defs["pure"]() == 1 + assert defs["uses_global"]() == 2 + return defs["pure"], defs["uses_global"], unref, ref + + pure, uses_global, unref, ref = run_and_extract() + + # tracker was accessed (tracked) — freed since no strong + # ref held externally. + assert unref() is None + assert ref() is not None + assert pure() == 1 + # Known limitation: uses_global's dependency was tracked + # and freed, so __globals__ lookup fails. + assert uses_global() == 2 + return weakref.ref(uses_global) + + # No strong references exist at this point, so we should clean up. + uses_global_ref = scope_values() + # All strong refs dropped — function collected. + assert uses_global_ref() is None + + @staticmethod + def test_repeated_runs_dont_accumulate() -> None: + """Multiple app.run() calls should not accumulate live objects.""" + import weakref + + app = App() + + @app.cell + def _(): + class _Ephemeral: + pass + + obj = _Ephemeral() + return (obj,) + + prev = weakref.ref(type("_Dead", (), {})()) + for i in range(5): + _, defs = app.run() + assert prev() is None, f"Run {i}: previous object not freed" + prev = weakref.ref(defs["obj"]) + assert prev() is not None