Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions bec_lib/bec_lib/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,54 @@ class MessageEndpoints:
Class for message endpoints.
"""

@staticmethod
def shared_memory_info():
"""
Endpoint for shared memory information. This endpoint is used to publish the shared memory information using
a messages.SharedMemAllocationInfo message.

Returns:
EndpointInfo: Endpoint for shared memory information.
"""
endpoint = f"{EndpointType.INFO.value}/shared_memory/info/"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.SharedMemAllocationInfo,
message_op=MessageOp.SET_PUBLISH,
)

@staticmethod
def shared_memory_allocate():
"""
Endpoint for shared memory allocation. This endpoint is used to request the allocation of a shared memory object using
a messages.SharedMemAllocationRequest message.

Returns:
EndpointInfo: Endpoint for shared memory allocation.
"""
endpoint = f"{EndpointType.INFO.value}/shared_memory/allocate"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.SharedMemAllocationRequest,
message_op=MessageOp.STREAM,
)

@staticmethod
def shared_memory_deallocate():
"""
Endpoint for shared memory deallocation. This endpoint is used to request the deallocation of a shared memory object using
a messages.SharedMemDeallocationRequest message.

Returns:
EndpointInfo: Endpoint for shared memory deallocation.
"""
endpoint = f"{EndpointType.INFO.value}/shared_memory/deallocate"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.SharedMemDeallocationRequest,
message_op=MessageOp.STREAM,
)

# devices feedback
@staticmethod
def device_status(device: str):
Expand Down
36 changes: 36 additions & 0 deletions bec_lib/bec_lib/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

from bec_lib.metadata_schema import get_metadata_schema_for_scan

# TODO remove bec_server depencency..
from bec_server.shared_memory.models import PayloadDescriptor, SharedMemInfo


class ProcedureWorkerStatus(Enum):
RUNNING = auto()
Expand Down Expand Up @@ -94,6 +97,39 @@ def __hash__(self) -> int:
return self.model_dump_json().__hash__()


class SharedMemAllocationInfo(BECMessage):
"""
This message is published by the shared memory manager and contains a list of all currently allocated shared memory objects.
Once shared memory objects are created or destroyed, this message will publish the updated list of shared memory objects.
"""

msg_type: ClassVar[str] = "shared_mem_allocation_info"

# Consider structure with dict[str, SharedMemInfo]. signal dotted name as key, which allows to identify this directly
# Alternatively, dict[str, dict[str, SharedMemInfo]] with device name as key, and then signal name as nested key
info: dict[str, dict[str, SharedMemInfo]]


class SharedMemAllocationRequest(BECMessage):
"""Message to request information about a shared memory object."""

msg_type: ClassVar[str] = "shared_mem_allocation_request"

client_id: str
slots: int
payload_desc: PayloadDescriptor
signal: str | None = None


class SharedMemDeallocationRequest(BECMessage):
"""Message to request deallocation of a shared memory object."""

msg_type: ClassVar[str] = "shared_mem_deallocation_request"

client_id: str
shared_mem_info: SharedMemInfo


class BundleMessage(BECMessage):
"""Message type to send a bundle of BECMessages.

Expand Down
Empty file.
Empty file.
36 changes: 36 additions & 0 deletions bec_server/bec_server/shared_memory/cli/launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Description: Launch the shared memory manager server.
# This script is the entry point for the Shared Memory Manager Server. It is called either
# by the bec-shared-mem-manager entry point or directly from the command line.
import threading

from bec_lib.bec_service import parse_cmdline_args
from bec_lib.logger import bec_logger
from bec_lib.redis_connector import RedisConnector
from bec_server.shared_memory.manager import SharedMemoryManager

logger = bec_logger.logger
bec_logger.level = bec_logger.LOGLEVEL.INFO


def main():
"""
Launch the shared memory manager server.
"""
_, _, config = parse_cmdline_args()

bec_server = SharedMemoryManager(config=config, connector_cls=RedisConnector)
bec_server.start()

try:
event = threading.Event()
logger.success(
f"Started Shared Memory Manager server (id: {bec_server._service_id}). Press Ctrl+C to stop."
)
event.wait()
except KeyboardInterrupt:
bec_server.shutdown()
event.set()


if __name__ == "__main__":
main()
139 changes: 139 additions & 0 deletions bec_server/bec_server/shared_memory/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from bec_server.shared_memory.models import PayloadDescriptor
from bec_server.shared_memory.ring_buffer import RingBufferView

if TYPE_CHECKING:
import numpy as np

from bec_lib.connector import MessageObject
from bec_lib.messages import SharedMemAllocationInfo
from bec_lib.redis_connector import RedisConnector

logger = bec_logger.logger


# TODO one per service, or N per service.
class SharedMemoryClient:
"""Client for interacting with shared memory objects managed by the SharedMemoryManager."""

def __init__(self, name: str, connector: RedisConnector):
self.name = name
self.connector = connector
# signal name to ring buffer view mapping
self._ring_buffer_views: dict[str, RingBufferView] = {}
self.start()

def start(self):
"""Start the client by subscribing to the shared memory object."""
self.connector.register(MessageEndpoints.shared_memory_info(), cb=self._handle_info_update)

def _handle_info_update(self, info: MessageObject) -> None:
"""Handle updates to the shared memory information."""
info: SharedMemAllocationInfo = info.value
# Any info update can potentially contain relevant information for creating or deleting ring buffer views.
info_updates = []
client_info = info.info.get(self.name, {})

for signal, buff_info in client_info.items():
info_updates.append(signal)
if signal not in self._ring_buffer_views: #
self._ring_buffer_views[signal] = RingBufferView(descriptor=buff_info.buffer_desc)
else:
logger.error(
f"Ring buffer view for signal {signal} already exists, should not happend. Received info update: {buff_info}"
)
if len(client_info) < len(self._ring_buffer_views):
# Some shared memory objects have been deallocated. Remove them from the local dictionary.
to_be_removed = set(self._ring_buffer_views.keys()) - set(info_updates)
for name in to_be_removed:
view = self._ring_buffer_views.pop(name)
view.close()

def request_allocation(
self, signal_name: str, slots: int, payload_desc: PayloadDescriptor | dict
) -> None:
"""Request the allocation of a shared memory object."""
if isinstance(payload_desc, dict):
payload_desc = PayloadDescriptor.model_validate(payload_desc)

self.connector.xadd(
MessageEndpoints.shared_memory_allocate(),
{
"client_id": self.name,
"slots": slots,
"payload_desc": payload_desc,
"signal": signal_name,
},
max_size=1000, # Keep history of 1000 allocation requests
)

def request_deallocation(self, signal_name: str) -> None:
"""Request the deallocation of a shared memory object."""
self.connector.xadd(
MessageEndpoints.shared_memory_deallocate(),
{"client_id": self.name, "signal": signal_name},
max_size=1000, # Keep history of 1000 deallocation requests
)

def read_from_buffer(
self, signal_name: str, index: int, timeout: float | None = None
) -> np.ndarray:
"""
Read data from the shared memory buffer associated with the given signal name.
If timeout is provided, the method will wait for the specified time and raise a TimeoutError if it cannot
read the data within that time frame. Please be aware, this is meant to block during write/read operations.
"""
# TODO add option to wait receiving an update on a specific signal in the buffer
# Also block until there is an update on the specific index in the buffer.
# Should there be a consume logic???
buff = self._ring_buffer_views.get(signal_name)
if buff is None:
raise ValueError(f"No buffer found for signal name: {signal_name}")
return buff.copy_data(index, timeout)

def write_to_buffer(
self, signal_name: str, data: np.ndarray, timeout: float | None = None
) -> int:
"""
Write data to the next ring position associated with the given signal name.
If timeout is provided, the method will wait for the specified time and raise a TimeoutError if it cannot
write the data within that time frame. Please be aware, this is meant to block during write/read operations.

Returns:
int: The slot index containing the newly written payload.
"""
buff = self._ring_buffer_views.get(signal_name)
if buff is None:
raise ValueError(f"No buffer found for signal name: {signal_name}")
return buff.write_data(data=data, acquire_timeout=timeout)

def shutdown(self) -> None:
"""Clean up resources and all shared memory views."""
for view in self._ring_buffer_views.values():
view.close()
self._ring_buffer_views.clear()
self.connector.unregister(
MessageEndpoints.shared_memory_info(), cb=self._handle_info_update
)


if __name__ == "__main__":
import time

import numpy as np

from bec_lib.redis_connector import RedisConnector

array = np.random.rand(5, 5)
connector = RedisConnector(bootstrap="localhost:6379")
client = SharedMemoryClient(name="test_client", connector=connector)
client.request_allocation(
signal_name="test_signal", slots=10, payload_desc=PayloadDescriptor.from_numpy(array)
)
time.sleep(1) # Wait for the allocation to be processed
print(client._ring_buffer_views)
104 changes: 104 additions & 0 deletions bec_server/bec_server/shared_memory/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

import threading
from collections import defaultdict
from typing import TYPE_CHECKING, Literal, Tuple

from bec_lib import messages
from bec_lib.bec_service import BECService
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from bec_server.shared_memory.models import SharedMemInfo
from bec_server.shared_memory.ring_buffer import RingBuffer

SUPPORTED_DATATYPES = Literal["str", "float", "byte", "np.array", "list", "dict"]

if TYPE_CHECKING:
from bec_lib.redis_connector import MessageObject, RedisConnector

logger = bec_logger.logger


class SharedMemoryManager(BECService):
"""
Service to manage shared memory objects. It keeps track of all allocated shared memory objects and their descriptors.
It also handles the creation and destruction of shared memory objects, and publishes the updated list of shared memory objects
whenever a new shared memory object is created or destroyed.
"""

def __init__(self, config, connector_cls: type[RedisConnector]) -> None:
super().__init__(config, connector_cls, unique_service=True)
# Shared memory objects are stored in a dictionary with the client_id and signal name tuple as key
# and the RingBuffer instance as value
self._shared_memory_objects: dict[Tuple[str, str], RingBuffer] = {}
self._shared_memory_info: dict[str, dict[str, SharedMemInfo]] = defaultdict(
dict
) # Nested dict with client_id as key, and dict with signal name and ShareMemInfo as value
self.lock = threading.RLock()

def _allocate_memory(self, request: messages.SharedMemAllocationRequest) -> None:
"""Callback function to handle shared memory allocation requests."""
if isinstance(request, dict):
request = messages.SharedMemAllocationRequest.model_validate(request)
if (request.client_id, request.signal) in self._shared_memory_objects:
logger.error(
f"Shared memory object for client {request.client_id} and signal {request.signal} already exists. Overwriting."
)
# TODO should this republish the info?
self._publish_allocation_info(self._shared_memory_info)
return

buff = RingBuffer(
slots=request.slots, payload=request.payload_desc, name_suffix=request.signal
)
with self.lock:
self._shared_memory_objects[(request.client_id, request.signal)] = buff
self._shared_memory_info[request.client_id][request.signal] = SharedMemInfo(
client_id=request.client_id, buffer_desc=buff.descriptor, signal=request.signal
)
self._publish_allocation_info(self._shared_memory_info)

def _deallocate_memory(self, request: messages.SharedMemDeallocationRequest) -> None:
"""Callback function to handle shared memory deallocation requests."""
if isinstance(request, dict):
request = messages.SharedMemDeallocationRequest.model_validate(request)
if (request.client_id, request.signal) not in self._shared_memory_objects:
logger.error(
f"Shared memory object for client {request.client_id} and signal {request.signal} does not exist. Cannot deallocate."
)
# TODO should this republish the info?
self._publish_allocation_info(self._shared_memory_info)
return

with self.lock:
buff = self._shared_memory_objects.pop((request.client_id, request.signal))
buff.destroy()
self._shared_memory_info[request.client_id].pop(request.signal, None)
self._publish_allocation_info(self._shared_memory_info)

def _publish_allocation_info(self, info: dict[str, dict[str, SharedMemInfo]]) -> None:
"""Publish the updated list of allocated shared memory objects."""
self.connector.set_and_publish(
MessageEndpoints.shared_memory_info(), messages.SharedMemAllocationInfo(info=info)
)

def start(self) -> None:
"""start the shared memory manager server"""
self.connector.register(MessageEndpoints.shared_memory_allocate(), cb=self._allocate_memory)
self.connector.register(
MessageEndpoints.shared_memory_deallocate(), cb=self._deallocate_memory
)

def stop(self) -> None:
with self.lock:
for buff in self._shared_memory_objects.values():
buff.destroy()
self._shared_memory_objects.clear()
self._shared_memory_info.clear()
self._publish_allocation_info({})
# Cleanup bec service related resources

def shutdown(self) -> None:
"""Shutdown the shared memory manager server and destroy all shared memory objects."""
self.stop()
super().shutdown()
Loading
Loading