Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions dimos/core/coordination/_test_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2026 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from reactivex.disposable import Disposable

from dimos.core.core import rpc
from dimos.core.module import Module
from dimos.core.stream import In, Out


class AliceModule(Module):
greetings: In[str]
response: Out[str]

@rpc
def start(self) -> None:
super().start()
self._disposables.add(Disposable(self.greetings.subscribe(self._on_greetings)))

@rpc
def stop(self) -> None:
super().stop()

def _on_greetings(self, greeting: str) -> None:
self.response.publish(f"Hello {greeting} from Alice")
175 changes: 166 additions & 9 deletions dimos/core/coordination/module_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from collections import defaultdict
from collections.abc import Mapping
import importlib
import shutil
import sys
import threading
Expand All @@ -34,7 +35,7 @@
from dimos.utils.safe_thread_map import safe_thread_map

if TYPE_CHECKING:
from dimos.core.coordination.blueprints import Blueprint
from dimos.core.coordination.blueprints import Blueprint, _BlueprintAtom
from dimos.core.rpc_client import ModuleProxy, ModuleProxyProtocol

logger = setup_logger()
Expand All @@ -55,7 +56,11 @@ def __init__(
cls.deployment_identifier: cls(g=g) for cls in manager_types
}
self._deployed_modules = {}
self._deployed_atoms: dict[type[ModuleBase], _BlueprintAtom] = {}
self._resolved_module_refs: dict[tuple[type[ModuleBase], str], type[ModuleBase]] = {}
self._transport_registry: dict[tuple[str, type], PubSubTransport[Any]] = {}
self._class_aliases: dict[type[ModuleBase], type[ModuleBase]] = {}
self._module_transports: dict[type[ModuleBase], dict[str, PubSubTransport[Any]]] = {}
self._started = False

def start(self) -> None:
Expand Down Expand Up @@ -168,13 +173,20 @@ def start_all_modules(self) -> None:

safe_thread_map(modules, lambda m: m.start())

self._send_on_system_modules()

def _resolve_class(self, cls: type[ModuleBase]) -> type[ModuleBase]:
return self._class_aliases.get(cls, cls)

def get_instance(self, module: type[ModuleBase]) -> ModuleProxy:
return self._deployed_modules.get(self._resolve_class(module)) # type: ignore[return-value, no-any-return]

def _send_on_system_modules(self) -> None:
modules = list(self._deployed_modules.values())
for module in modules:
if hasattr(module, "on_system_modules"):
module.on_system_modules(modules)

def get_instance(self, module: type[ModuleBase]) -> ModuleProxy:
return self._deployed_modules.get(module) # type: ignore[return-value, no-any-return]

def _connect_streams(self, blueprint: Blueprint) -> None:
streams: dict[tuple[str, type], list[tuple[type, str]]] = defaultdict(list)

Expand All @@ -194,6 +206,7 @@ def _connect_streams(self, blueprint: Blueprint) -> None:
for module, original_name in streams[key]:
instance = self.get_instance(module) # type: ignore[assignment]
instance.set_transport(original_name, transport) # type: ignore[union-attr]
self._module_transports.setdefault(module, {})[original_name] = transport
logger.info(
"Transport",
name=remapped_name,
Expand Down Expand Up @@ -285,15 +298,153 @@ def load_blueprint(
safe_thread_map(new_modules, lambda m: m.build())
safe_thread_map(new_modules, lambda m: m.start())

# Re-notify all modules about the updated module list.
all_modules = list(self._deployed_modules.values())
for module in all_modules:
if hasattr(module, "on_system_modules"):
module.on_system_modules(all_modules)
self._send_on_system_modules()

def load_module(self, module_class: type[ModuleBase[Any]], **kwargs: Any) -> None:
self.load_blueprint(module_class.blueprint(**kwargs))

def unload_module(self, module_class: type[ModuleBase]) -> None:
"""Stop and tear down a single deployed module.

Removes the module from coordinator state, stops its worker-side
instance, and shuts down the worker process if it becomes empty.
Stream transports and other modules' references are left intact —
callers that expect the module to come back (e.g. ``restart_module``)
are responsible for rewiring.
"""
module_class = self._resolve_class(module_class)
if module_class not in self._deployed_modules:
raise ValueError(f"{module_class.__name__} is not deployed")
if module_class.deployment != "python":
raise NotImplementedError(
f"unload_module only supports python deployment, got {module_class.deployment!r}"
)

proxy = self._deployed_modules[module_class]

try:
proxy.stop()
except Exception:
logger.error(
"Error stopping module during unload",
module=module_class.__name__,
exc_info=True,
)

python_wm = cast("WorkerManagerPython", self._managers["python"])
try:
python_wm.undeploy(proxy)
except Exception:
logger.error(
"Error undeploying module from worker",
module=module_class.__name__,
exc_info=True,
)

del self._deployed_modules[module_class]
self._deployed_atoms.pop(module_class, None)
self._module_transports.pop(module_class, None)
self._class_aliases = {
k: v for k, v in self._class_aliases.items() if v is not module_class
}
self._resolved_module_refs = {
key: target
for key, target in self._resolved_module_refs.items()
if key[0] is not module_class and target is not module_class
}

def restart_module(
self,
module_class: type[ModuleBase],
*,
reload_source: bool = True,
) -> ModuleProxyProtocol:
"""Restart a single deployed module in place.

Unloads *module_class*, optionally reloads its source file via
``importlib.reload`` so edited code is picked up, then redeploys it
onto a fresh worker process, reconnects its streams to the existing
transports, and re-injects the new proxy into every other module that
held a reference to it.
"""
module_class = self._resolve_class(module_class)
if module_class not in self._deployed_modules:
raise ValueError(f"{module_class.__name__} is not deployed")
if module_class.deployment != "python":
raise NotImplementedError(
f"restart_module only supports python deployment, got {module_class.deployment!r}"
)

old_atom = self._deployed_atoms[module_class]
kwargs = dict(old_atom.kwargs)
saved_transports = dict(self._module_transports.get(module_class, {}))
inbound_refs = [
(consumer, ref_name)
for (consumer, ref_name), target in self._resolved_module_refs.items()
if target is module_class
]
outbound_refs = [
(ref_name, target)
for (consumer, ref_name), target in self._resolved_module_refs.items()
if consumer is module_class
]

self.unload_module(module_class)

if reload_source:
source_mod = sys.modules.get(module_class.__module__)
if source_mod is None:
source_mod = importlib.import_module(module_class.__module__)
importlib.reload(source_mod)
new_class = cast("type[ModuleBase]", getattr(source_mod, module_class.__name__))
else:
new_class = module_class

if new_class is not module_class:
for old_cls in list(self._class_aliases):
if self._class_aliases[old_cls] is module_class:
self._class_aliases[old_cls] = new_class
self._class_aliases[module_class] = new_class

python_wm = cast("WorkerManagerPython", self._managers["python"])
new_proxy = python_wm.deploy_fresh(new_class, self._global_config, kwargs)
self._deployed_modules[new_class] = new_proxy

new_bp = new_class.blueprint(**kwargs)
new_atom = new_bp.active_blueprints[0]
self._deployed_atoms[new_class] = new_atom

for stream_ref in new_atom.streams:
transport = saved_transports.get(stream_ref.name)
if transport is not None:
new_proxy.set_transport(stream_ref.name, transport)
self._module_transports[new_class] = {
s.name: t for s in new_atom.streams if (t := saved_transports.get(s.name)) is not None
}

for consumer_class, ref_name in inbound_refs:
consumer_proxy = self._deployed_modules.get(consumer_class)
if consumer_proxy is None:
continue
setattr(consumer_proxy, ref_name, new_proxy)
consumer_proxy.set_module_ref(ref_name, new_proxy) # type: ignore[attr-defined]
self._resolved_module_refs[consumer_class, ref_name] = new_class

for ref_name, target_class in outbound_refs:
target_proxy = self._deployed_modules.get(target_class)
if target_proxy is None:
continue
setattr(new_proxy, ref_name, target_proxy)
new_proxy.set_module_ref(ref_name, target_proxy) # type: ignore[attr-defined]
self._resolved_module_refs[new_class, ref_name] = target_class

new_proxy.build()
new_proxy.start()

self._send_on_system_modules()

return new_proxy

def loop(self) -> None:
stop = threading.Event()
try:
Expand Down Expand Up @@ -433,6 +584,9 @@ def _deploy_all_modules(

module_coordinator.deploy_parallel(module_specs)

for bp in blueprint.active_blueprints:
module_coordinator._deployed_atoms[bp.module] = bp


def _ref_msg(module_name: str, ref: object, spec_name: str, detail: str) -> str:
return (
Expand Down Expand Up @@ -578,6 +732,9 @@ def _connect_module_refs(
target_instance = module_coordinator.get_instance(target_module) # type: ignore[type-var,arg-type]
setattr(base_instance, ref_name, target_instance)
base_instance.set_module_ref(ref_name, target_instance)
module_coordinator._resolved_module_refs[base_module, ref_name] = cast(
"type[ModuleBase]", target_module
)

for (base_module, ref_name), proxy in disabled_ref_proxies.items():
base_instance = module_coordinator.get_instance(base_module)
Expand Down
21 changes: 21 additions & 0 deletions dimos/core/coordination/python_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,20 @@ def deploy_module(
finally:
self._reserved = max(0, self._reserved - 1)

def undeploy_module(self, module_id: int) -> None:
"""Stop and remove a single module from the worker process."""
if self._conn is None:
raise RuntimeError("Worker process not started")

with self._lock:
self._conn.send({"type": "undeploy_module", "module_id": module_id})
response = self._conn.recv()

if response.get("error"):
raise RuntimeError(f"Failed to undeploy module: {response['error']}")

self._modules.pop(module_id, None)

def suppress_console(self) -> None:
if self._conn is None:
return
Expand Down Expand Up @@ -366,6 +380,13 @@ def _worker_loop(conn: Connection, instances: dict[int, Any], worker_id: int) ->
result = method(*request.get("args", ()), **request.get("kwargs", {}))
response["result"] = result

elif req_type == "undeploy_module":
module_id = request["module_id"]
instance = instances.pop(module_id, None)
if instance is not None:
instance.stop()
response["result"] = True

elif req_type == "suppress_console":
_suppress_console_output()
response["result"] = True
Expand Down
Loading
Loading