diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index 930e3bbb0..3a5033713 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -83,6 +83,54 @@ class MessageEndpoints: Class for message endpoints. """ + @staticmethod + def shared_memory_info(): + """ + Endpoint for shared memory information. This endpoint is used to publish the shared memory information using + a messages.SharedMemAllocationInfo message. + + Returns: + EndpointInfo: Endpoint for shared memory information. + """ + endpoint = f"{EndpointType.INFO.value}/shared_memory/info/" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.SharedMemAllocationInfo, + message_op=MessageOp.SET_PUBLISH, + ) + + @staticmethod + def shared_memory_allocate(): + """ + Endpoint for shared memory allocation. This endpoint is used to request the allocation of a shared memory object using + a messages.SharedMemAllocationRequest message. + + Returns: + EndpointInfo: Endpoint for shared memory allocation. + """ + endpoint = f"{EndpointType.INFO.value}/shared_memory/allocate" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.SharedMemAllocationRequest, + message_op=MessageOp.STREAM, + ) + + @staticmethod + def shared_memory_deallocate(): + """ + Endpoint for shared memory deallocation. This endpoint is used to request the deallocation of a shared memory object using + a messages.SharedMemDeallocationRequest message. + + Returns: + EndpointInfo: Endpoint for shared memory deallocation. + """ + endpoint = f"{EndpointType.INFO.value}/shared_memory/deallocate" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.SharedMemDeallocationRequest, + message_op=MessageOp.STREAM, + ) + # devices feedback @staticmethod def device_status(device: str): diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index b94fc6938..d242ff470 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -17,6 +17,9 @@ from bec_lib.metadata_schema import get_metadata_schema_for_scan +# TODO remove bec_server depencency.. +from bec_server.shared_memory.models import PayloadDescriptor, SharedMemInfo + class ProcedureWorkerStatus(Enum): RUNNING = auto() @@ -94,6 +97,39 @@ def __hash__(self) -> int: return self.model_dump_json().__hash__() +class SharedMemAllocationInfo(BECMessage): + """ + This message is published by the shared memory manager and contains a list of all currently allocated shared memory objects. + Once shared memory objects are created or destroyed, this message will publish the updated list of shared memory objects. + """ + + msg_type: ClassVar[str] = "shared_mem_allocation_info" + + # Consider structure with dict[str, SharedMemInfo]. signal dotted name as key, which allows to identify this directly + # Alternatively, dict[str, dict[str, SharedMemInfo]] with device name as key, and then signal name as nested key + info: dict[str, dict[str, SharedMemInfo]] + + +class SharedMemAllocationRequest(BECMessage): + """Message to request information about a shared memory object.""" + + msg_type: ClassVar[str] = "shared_mem_allocation_request" + + client_id: str + slots: int + payload_desc: PayloadDescriptor + signal: str | None = None + + +class SharedMemDeallocationRequest(BECMessage): + """Message to request deallocation of a shared memory object.""" + + msg_type: ClassVar[str] = "shared_mem_deallocation_request" + + client_id: str + shared_mem_info: SharedMemInfo + + class BundleMessage(BECMessage): """Message type to send a bundle of BECMessages. diff --git a/bec_server/bec_server/shared_memory/__init__.py b/bec_server/bec_server/shared_memory/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bec_server/bec_server/shared_memory/cli/__init__.py b/bec_server/bec_server/shared_memory/cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bec_server/bec_server/shared_memory/cli/launch.py b/bec_server/bec_server/shared_memory/cli/launch.py new file mode 100644 index 000000000..28bfd5f5a --- /dev/null +++ b/bec_server/bec_server/shared_memory/cli/launch.py @@ -0,0 +1,36 @@ +# Description: Launch the shared memory manager server. +# This script is the entry point for the Shared Memory Manager Server. It is called either +# by the bec-shared-mem-manager entry point or directly from the command line. +import threading + +from bec_lib.bec_service import parse_cmdline_args +from bec_lib.logger import bec_logger +from bec_lib.redis_connector import RedisConnector +from bec_server.shared_memory.manager import SharedMemoryManager + +logger = bec_logger.logger +bec_logger.level = bec_logger.LOGLEVEL.INFO + + +def main(): + """ + Launch the shared memory manager server. + """ + _, _, config = parse_cmdline_args() + + bec_server = SharedMemoryManager(config=config, connector_cls=RedisConnector) + bec_server.start() + + try: + event = threading.Event() + logger.success( + f"Started Shared Memory Manager server (id: {bec_server._service_id}). Press Ctrl+C to stop." + ) + event.wait() + except KeyboardInterrupt: + bec_server.shutdown() + event.set() + + +if __name__ == "__main__": + main() diff --git a/bec_server/bec_server/shared_memory/client.py b/bec_server/bec_server/shared_memory/client.py new file mode 100644 index 000000000..674d86ba6 --- /dev/null +++ b/bec_server/bec_server/shared_memory/client.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from bec_lib.endpoints import MessageEndpoints +from bec_lib.logger import bec_logger +from bec_server.shared_memory.models import PayloadDescriptor +from bec_server.shared_memory.ring_buffer import RingBufferView + +if TYPE_CHECKING: + import numpy as np + + from bec_lib.connector import MessageObject + from bec_lib.messages import SharedMemAllocationInfo + from bec_lib.redis_connector import RedisConnector + +logger = bec_logger.logger + + +# TODO one per service, or N per service. +class SharedMemoryClient: + """Client for interacting with shared memory objects managed by the SharedMemoryManager.""" + + def __init__(self, name: str, connector: RedisConnector): + self.name = name + self.connector = connector + # signal name to ring buffer view mapping + self._ring_buffer_views: dict[str, RingBufferView] = {} + self.start() + + def start(self): + """Start the client by subscribing to the shared memory object.""" + self.connector.register(MessageEndpoints.shared_memory_info(), cb=self._handle_info_update) + + def _handle_info_update(self, info: MessageObject) -> None: + """Handle updates to the shared memory information.""" + info: SharedMemAllocationInfo = info.value + # Any info update can potentially contain relevant information for creating or deleting ring buffer views. + info_updates = [] + client_info = info.info.get(self.name, {}) + + for signal, buff_info in client_info.items(): + info_updates.append(signal) + if signal not in self._ring_buffer_views: # + self._ring_buffer_views[signal] = RingBufferView(descriptor=buff_info.buffer_desc) + else: + logger.error( + f"Ring buffer view for signal {signal} already exists, should not happend. Received info update: {buff_info}" + ) + if len(client_info) < len(self._ring_buffer_views): + # Some shared memory objects have been deallocated. Remove them from the local dictionary. + to_be_removed = set(self._ring_buffer_views.keys()) - set(info_updates) + for name in to_be_removed: + view = self._ring_buffer_views.pop(name) + view.close() + + def request_allocation( + self, signal_name: str, slots: int, payload_desc: PayloadDescriptor | dict + ) -> None: + """Request the allocation of a shared memory object.""" + if isinstance(payload_desc, dict): + payload_desc = PayloadDescriptor.model_validate(payload_desc) + + self.connector.xadd( + MessageEndpoints.shared_memory_allocate(), + { + "client_id": self.name, + "slots": slots, + "payload_desc": payload_desc, + "signal": signal_name, + }, + max_size=1000, # Keep history of 1000 allocation requests + ) + + def request_deallocation(self, signal_name: str) -> None: + """Request the deallocation of a shared memory object.""" + self.connector.xadd( + MessageEndpoints.shared_memory_deallocate(), + {"client_id": self.name, "signal": signal_name}, + max_size=1000, # Keep history of 1000 deallocation requests + ) + + def read_from_buffer( + self, signal_name: str, index: int, timeout: float | None = None + ) -> np.ndarray: + """ + Read data from the shared memory buffer associated with the given signal name. + If timeout is provided, the method will wait for the specified time and raise a TimeoutError if it cannot + read the data within that time frame. Please be aware, this is meant to block during write/read operations. + """ + # TODO add option to wait receiving an update on a specific signal in the buffer + # Also block until there is an update on the specific index in the buffer. + # Should there be a consume logic??? + buff = self._ring_buffer_views.get(signal_name) + if buff is None: + raise ValueError(f"No buffer found for signal name: {signal_name}") + return buff.copy_data(index, timeout) + + def write_to_buffer( + self, signal_name: str, data: np.ndarray, timeout: float | None = None + ) -> int: + """ + Write data to the next ring position associated with the given signal name. + If timeout is provided, the method will wait for the specified time and raise a TimeoutError if it cannot + write the data within that time frame. Please be aware, this is meant to block during write/read operations. + + Returns: + int: The slot index containing the newly written payload. + """ + buff = self._ring_buffer_views.get(signal_name) + if buff is None: + raise ValueError(f"No buffer found for signal name: {signal_name}") + return buff.write_data(data=data, acquire_timeout=timeout) + + def shutdown(self) -> None: + """Clean up resources and all shared memory views.""" + for view in self._ring_buffer_views.values(): + view.close() + self._ring_buffer_views.clear() + self.connector.unregister( + MessageEndpoints.shared_memory_info(), cb=self._handle_info_update + ) + + +if __name__ == "__main__": + import time + + import numpy as np + + from bec_lib.redis_connector import RedisConnector + + array = np.random.rand(5, 5) + connector = RedisConnector(bootstrap="localhost:6379") + client = SharedMemoryClient(name="test_client", connector=connector) + client.request_allocation( + signal_name="test_signal", slots=10, payload_desc=PayloadDescriptor.from_numpy(array) + ) + time.sleep(1) # Wait for the allocation to be processed + print(client._ring_buffer_views) diff --git a/bec_server/bec_server/shared_memory/manager.py b/bec_server/bec_server/shared_memory/manager.py new file mode 100644 index 000000000..fad2412dc --- /dev/null +++ b/bec_server/bec_server/shared_memory/manager.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import threading +from collections import defaultdict +from typing import TYPE_CHECKING, Literal, Tuple + +from bec_lib import messages +from bec_lib.bec_service import BECService +from bec_lib.endpoints import MessageEndpoints +from bec_lib.logger import bec_logger +from bec_server.shared_memory.models import SharedMemInfo +from bec_server.shared_memory.ring_buffer import RingBuffer + +SUPPORTED_DATATYPES = Literal["str", "float", "byte", "np.array", "list", "dict"] + +if TYPE_CHECKING: + from bec_lib.redis_connector import MessageObject, RedisConnector + +logger = bec_logger.logger + + +class SharedMemoryManager(BECService): + """ + Service to manage shared memory objects. It keeps track of all allocated shared memory objects and their descriptors. + It also handles the creation and destruction of shared memory objects, and publishes the updated list of shared memory objects + whenever a new shared memory object is created or destroyed. + """ + + def __init__(self, config, connector_cls: type[RedisConnector]) -> None: + super().__init__(config, connector_cls, unique_service=True) + # Shared memory objects are stored in a dictionary with the client_id and signal name tuple as key + # and the RingBuffer instance as value + self._shared_memory_objects: dict[Tuple[str, str], RingBuffer] = {} + self._shared_memory_info: dict[str, dict[str, SharedMemInfo]] = defaultdict( + dict + ) # Nested dict with client_id as key, and dict with signal name and ShareMemInfo as value + self.lock = threading.RLock() + + def _allocate_memory(self, request: messages.SharedMemAllocationRequest) -> None: + """Callback function to handle shared memory allocation requests.""" + if isinstance(request, dict): + request = messages.SharedMemAllocationRequest.model_validate(request) + if (request.client_id, request.signal) in self._shared_memory_objects: + logger.error( + f"Shared memory object for client {request.client_id} and signal {request.signal} already exists. Overwriting." + ) + # TODO should this republish the info? + self._publish_allocation_info(self._shared_memory_info) + return + + buff = RingBuffer( + slots=request.slots, payload=request.payload_desc, name_suffix=request.signal + ) + with self.lock: + self._shared_memory_objects[(request.client_id, request.signal)] = buff + self._shared_memory_info[request.client_id][request.signal] = SharedMemInfo( + client_id=request.client_id, buffer_desc=buff.descriptor, signal=request.signal + ) + self._publish_allocation_info(self._shared_memory_info) + + def _deallocate_memory(self, request: messages.SharedMemDeallocationRequest) -> None: + """Callback function to handle shared memory deallocation requests.""" + if isinstance(request, dict): + request = messages.SharedMemDeallocationRequest.model_validate(request) + if (request.client_id, request.signal) not in self._shared_memory_objects: + logger.error( + f"Shared memory object for client {request.client_id} and signal {request.signal} does not exist. Cannot deallocate." + ) + # TODO should this republish the info? + self._publish_allocation_info(self._shared_memory_info) + return + + with self.lock: + buff = self._shared_memory_objects.pop((request.client_id, request.signal)) + buff.destroy() + self._shared_memory_info[request.client_id].pop(request.signal, None) + self._publish_allocation_info(self._shared_memory_info) + + def _publish_allocation_info(self, info: dict[str, dict[str, SharedMemInfo]]) -> None: + """Publish the updated list of allocated shared memory objects.""" + self.connector.set_and_publish( + MessageEndpoints.shared_memory_info(), messages.SharedMemAllocationInfo(info=info) + ) + + def start(self) -> None: + """start the shared memory manager server""" + self.connector.register(MessageEndpoints.shared_memory_allocate(), cb=self._allocate_memory) + self.connector.register( + MessageEndpoints.shared_memory_deallocate(), cb=self._deallocate_memory + ) + + def stop(self) -> None: + with self.lock: + for buff in self._shared_memory_objects.values(): + buff.destroy() + self._shared_memory_objects.clear() + self._shared_memory_info.clear() + self._publish_allocation_info({}) + # Cleanup bec service related resources + + def shutdown(self) -> None: + """Shutdown the shared memory manager server and destroy all shared memory objects.""" + self.stop() + super().shutdown() diff --git a/bec_server/bec_server/shared_memory/models.py b/bec_server/bec_server/shared_memory/models.py new file mode 100644 index 000000000..038afe78e --- /dev/null +++ b/bec_server/bec_server/shared_memory/models.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import sys +from typing import Literal, Tuple + +import numpy as np +from pydantic import BaseModel, ConfigDict + + +class SharedMemInfo(BaseModel): + """ + Store information about the shared memory object. This message has the client_id, the buffer descriptor and + the potentially a list of devices for which this shared memory object is relevant. + """ + + model_config = ConfigDict(validate_assignment=True) + client_id: str + buffer_desc: RingBufferDescriptor + signal: str | None = None # dotted signal name, e.g. "eiger.preview" + + +class DTypeDescriptor(BaseModel): + kind: Literal["uint", "int", "float", "bool"] + itemsize: int + byte_order: Literal["little", "big"] = "little" + + @classmethod + def from_numpy(cls, dtype: np.dtype) -> DTypeDescriptor: + """Class method to create DTypeDescriptor from numpy dtype.""" + dtype = np.dtype(dtype) + kind_map = {"u": "uint", "i": "int", "f": "float", "b": "bool"} + if dtype.kind not in kind_map: + raise ValueError(f"Unsupported dtype kind: {dtype.kind!r}") + + byte_order = dtype.byteorder + if byte_order in ("=", "|"): + byte_order = sys.byteorder + elif byte_order == "<": + byte_order = "little" + elif byte_order == ">": + byte_order = "big" + else: + raise ValueError(f"Unsupported byte order: {dtype.byteorder!r}") + + return cls(kind=kind_map[dtype.kind], itemsize=dtype.itemsize, byte_order=byte_order) + + @property + def numpy_dtype(self) -> np.dtype: + """Return the corresponding numpy dtype for this DTypeDescriptor.""" + byte_order_char = {"little": "<", "big": ">"}[self.byte_order] + kind_char = {"uint": "u", "int": "i", "float": "f", "bool": "b"}[self.kind] + dtype_str = f"{byte_order_char}{kind_char}{self.itemsize}" + return np.dtype(dtype_str) + + +class PayloadDescriptor(BaseModel): + """Descriptor for the data payload stored in each slot of the ring buffer.""" + + nbytes: int + shape: Tuple[int, ...] + dtype: DTypeDescriptor + layout: Literal["C"] = "C" + + @classmethod + def from_numpy(cls, array: np.ndarray) -> PayloadDescriptor: + """Class method to create PayloadDescriptor from a numpy array.""" + return cls( + nbytes=array.nbytes, + shape=array.shape, + dtype=DTypeDescriptor.from_numpy(array.dtype), + layout="C" if array.flags.c_contiguous else "C", + ) + + +class RingBufferDescriptor(BaseModel): + """Information required to attach to a shared ring buffer.""" + + name: str + metadata_lock_id: str + slot_lock_ids: Tuple[str, ...] + slots: int + payload: PayloadDescriptor + + +# class AvailableDataAnalysisMethods(messages.BECMessage): +# """Message published by the DAP server on which analysis methods are available.""" + +# methods: list[str] + + +# TODO maybe not needed to warm up, could automatically start a DAP worker once a shared memory object is created, +# Then DataAnalysisRegisterRequest is designed to register analysis methods for the shared memory object, and +# DataAnalysisTrigger is designed to trigger the analysis of the shared memory object. +# DataAnalysisResponse is designed to send the results back to the client. +# class DataAnalysisRequestWarmup(BECMessage): +# """Message to request a data analysis""" + +# shared_mem: SharedMemDescriptor + + +# class DataAnalysisRegisterRequest(BECMessage): +# """Message to request processing of a shared memory object.""" + +# shared_mem: SharedMemDescriptor +# methods: list[str] +# client_id: str +# device: str | None = None + + +# class DataAnalysisTrigger(BECMessage): +# """Message to request processing of a shared memory object.""" + +# shared_mem: SharedMemDescriptor +# index: int + + +# class DataAnalysisResponse(BECMessage): +# """Message to request processing of a shared memory object.""" + +# shared_mem: SharedMemDescriptor +# index: int +# results: dict +# client_id: str +# device: str | None = None diff --git a/bec_server/bec_server/shared_memory/ring_buffer.py b/bec_server/bec_server/shared_memory/ring_buffer.py new file mode 100644 index 000000000..6d0072d3d --- /dev/null +++ b/bec_server/bec_server/shared_memory/ring_buffer.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +import struct +from contextlib import contextmanager +from functools import wraps +from multiprocessing import shared_memory +from threading import RLock +from typing import Any, Callable, Iterator +from uuid import uuid4 + +import numpy as np +import posix_ipc + +from bec_server.shared_memory.models import DTypeDescriptor, PayloadDescriptor, RingBufferDescriptor + +# pylint: disable=c-extension-no-member + +MAGIC = b"BEC_RING" +MAX_NDIM = 8 +_HEADER = struct.Struct("<8sIQIBBBB8Q") +HEADER_SIZE = _HEADER.size + +_KIND_TO_CODE = {"uint": 1, "int": 2, "float": 3, "bool": 4} +_CODE_TO_KIND = {code: kind for kind, code in _KIND_TO_CODE.items()} +_BYTE_ORDER_TO_CODE = {"little": 1, "big": 2} +_CODE_TO_BYTE_ORDER = {code: order for order, code in _BYTE_ORDER_TO_CODE.items()} + + +def not_destroyed(method: Callable[..., Any]) -> Callable[..., Any]: + """Check that a shared-memory handle is still open before accessing it.""" + + @wraps(method) + def wrapper(self: RingBufferView, *args: Any, **kwargs: Any) -> Any: + if self.destroyed: + raise RuntimeError( + f"Cannot perform operation on a destroyed {self.__class__.__name__} object with name {self.name!r}." + ) + return method(self, *args, **kwargs) + + return wrapper + + +class RingBufferView: + """Attached handle for accessing a ring buffer without owning its resources.""" + + def __init__( + self, + descriptor: RingBufferDescriptor, + shm: shared_memory.SharedMemory | None = None, + *, + owns_memory: bool = False, + ): + if len(descriptor.slot_lock_ids) != descriptor.slots: + raise ValueError("Ring buffer descriptor must provide exactly one lock per slot.") + self._descriptor = descriptor + self._shm = shm if shm is not None else shared_memory.SharedMemory(name=descriptor.name) + self._owns_memory = owns_memory + self._metadata_lock = posix_ipc.Semaphore(descriptor.metadata_lock_id, flags=0) + self._slot_locks = [ + posix_ipc.Semaphore(lock_id, flags=0) for lock_id in descriptor.slot_lock_ids + ] + self.__destroyed = False + self._lifecycle_lock = RLock() + try: + self._slots, self._payload_descriptor = self._read_header() + if self._slots != descriptor.slots or self._payload_descriptor != descriptor.payload: + raise ValueError("Ring buffer descriptor does not match shared-memory metadata.") + except Exception: + self._close_handles() + raise + + @staticmethod + def _encode_header( + slots: int, payload: PayloadDescriptor, next_write_position: int = 0 + ) -> bytes: + if not 0 < slots: + raise ValueError("Ring buffer must contain at least one slot.") + if len(payload.shape) > MAX_NDIM: + raise ValueError(f"Ring buffer payload supports at most {MAX_NDIM} dimensions.") + dimensions = (*payload.shape, *((0,) * (MAX_NDIM - len(payload.shape)))) + return _HEADER.pack( + MAGIC, + slots, + payload.nbytes, + next_write_position, + _KIND_TO_CODE[payload.dtype.kind], + payload.dtype.itemsize, + _BYTE_ORDER_TO_CODE[payload.dtype.byte_order], + len(payload.shape), + *dimensions, + ) + + def _read_header(self) -> tuple[int, PayloadDescriptor]: + ( + magic, + slots, + bytes_per_slot, + _next_write_position, + kind_code, + itemsize, + byte_order_code, + ndim, + *dimensions, + ) = _HEADER.unpack_from(self._shm.buf) + if magic != MAGIC: + raise ValueError("Shared memory does not contain a BEC ring buffer header.") + if ndim > MAX_NDIM: + raise ValueError("Shared memory contains an invalid ring buffer payload shape.") + try: + dtype = DTypeDescriptor( + kind=_CODE_TO_KIND[kind_code], + itemsize=itemsize, + byte_order=_CODE_TO_BYTE_ORDER[byte_order_code], + ) + except KeyError as exc: + raise ValueError( + "Shared memory contains an invalid ring buffer payload dtype." + ) from exc + payload = PayloadDescriptor( + nbytes=bytes_per_slot, shape=tuple(dimensions[:ndim]), dtype=dtype, layout="C" + ) + if ( + payload.nbytes + != int(np.prod(payload.shape, dtype=np.int64)) * dtype.numpy_dtype.itemsize + ): + raise ValueError("Shared memory contains inconsistent ring buffer payload metadata.") + return slots, payload + + def _write_next_position(self, position: int) -> None: + struct.pack_into(" int: + return struct.unpack_from(" Iterator[None]: + acquired = False + try: + semaphore.acquire(timeout=None if timeout in (None, 0) else timeout) + acquired = True + yield + except posix_ipc.BusyError: + raise TimeoutError( + f"Could not acquire lock for {operation} buffer {self.name!r} within {timeout} seconds." + ) from None + finally: + if acquired: + semaphore.release() + + def _validate_index(self, index: int) -> None: + if index < 0 or index >= self.slots: + raise IndexError( + f"Index {index} is out of bounds for ring buffer with {self.slots} slots." + ) + + def _validate_payload(self, data: np.ndarray) -> None: + descriptor = PayloadDescriptor.from_numpy(data) + if descriptor != self.payload_descriptor: + raise ValueError( + f"Data shape/dtype {descriptor.shape}/{descriptor.dtype} does not match expected " + f"shape/dtype {self.payload_descriptor.shape}/{self.payload_descriptor.dtype}" + ) + + def _array_for_slot(self, index: int) -> np.ndarray: + return np.ndarray( + shape=self.payload_descriptor.shape, + dtype=self.payload_descriptor.dtype.numpy_dtype, + buffer=self._shm.buf, + offset=HEADER_SIZE + index * self.bytes_per_slot, + ) + + def _claim_next_write_position(self, acquire_timeout: float | None = 0) -> int: + with self._acquire(self._metadata_lock, acquire_timeout, "selecting a write slot in"): + index = self._read_next_position() + self._write_next_position((index + 1) % self.slots) + return index + + @not_destroyed + def copy_data(self, index: int, acquire_timeout: float | None = 0) -> np.ndarray: + """Copy one identified payload slot under that slot's semaphore.""" + self._validate_index(index) + with self._acquire(self._slot_locks[index], acquire_timeout, "reading from"): + return self._array_for_slot(index).copy() + + @not_destroyed + def write_data(self, data: np.ndarray, acquire_timeout: float | None = 0) -> int: + """Write to the next circular slot and return the written slot index.""" + self._validate_payload(data) + index = self._claim_next_write_position(acquire_timeout) + self._write_data_at(index, data, acquire_timeout) + return index + + @not_destroyed + def write_data_at( + self, index: int, data: np.ndarray, acquire_timeout: float | None = 0 + ) -> None: + """Write directly to an identified slot without advancing the write cursor.""" + self._validate_index(index) + self._validate_payload(data) + self._write_data_at(index, data, acquire_timeout) + + def _write_data_at( + self, index: int, data: np.ndarray, acquire_timeout: float | None = 0 + ) -> None: + with self._acquire(self._slot_locks[index], acquire_timeout, "writing to"): + np.copyto(self._array_for_slot(index), data) + + @property + def descriptor(self) -> RingBufferDescriptor: + return self._descriptor + + @property + def destroyed(self) -> bool: + return self.__destroyed + + @property + def name(self) -> str: + return self._descriptor.name + + @property + def slots(self) -> int: + return self._slots + + @property + def bytes_per_slot(self) -> int: + return self._payload_descriptor.nbytes + + @property + def payload_descriptor(self) -> PayloadDescriptor: + return self._payload_descriptor + + @property + @not_destroyed + def next_write_position(self) -> int: + with self._acquire(self._metadata_lock, None, "reading metadata from"): + return self._read_next_position() + + def _close_handles(self) -> None: + for lock in self._slot_locks: + lock.close() + self._metadata_lock.close() + self._shm.close() + + def close(self) -> None: + """Close local handles without unlinking owner-managed resources.""" + if self.destroyed: + return + with self._lifecycle_lock: + if self.destroyed: + return + self._close_handles() + self.__destroyed = True + + def destroy(self) -> None: + """Compatibility alias for attached clients; attached handles only close resources.""" + self.close() + + +class RingBuffer(RingBufferView): + """Owner of a shared ring buffer and its semaphore resources.""" + + def __init__(self, slots: int, payload: PayloadDescriptor, name_suffix: str = ""): + name = f"bec_psm_{uuid4().hex[:6]}" + header = self._encode_header(slots, payload) + shm = shared_memory.SharedMemory( + create=True, size=HEADER_SIZE + slots * payload.nbytes, name=name + ) + shm.buf[:HEADER_SIZE] = header + metadata_lock_name = f"{name}_md_lock" + slot_lock_names = tuple(f"{name}_slot__lock_{index}" for index in range(slots)) + created_locks: list[posix_ipc.Semaphore] = [] + try: + created_locks.append( + posix_ipc.Semaphore( + metadata_lock_name, flags=posix_ipc.O_CREAT | posix_ipc.O_EXCL, initial_value=1 + ) + ) + created_locks.extend( + posix_ipc.Semaphore( + lock_name, flags=posix_ipc.O_CREAT | posix_ipc.O_EXCL, initial_value=1 + ) + for lock_name in slot_lock_names + ) + for lock in created_locks: + lock.close() + descriptor = RingBufferDescriptor( + name=shm.name, + metadata_lock_id=metadata_lock_name, + slot_lock_ids=slot_lock_names, + slots=slots, + payload=payload, + ) + super().__init__(descriptor=descriptor, shm=shm, owns_memory=True) + except Exception: + for lock in created_locks: + try: + lock.close() + except OSError: + pass + try: + posix_ipc.unlink_semaphore(lock.name) + except posix_ipc.ExistentialError: + pass + shm.close() + shm.unlink() + raise + + def destroy(self) -> None: + """Close and unlink all resources created for this owned ring buffer.""" + if self.destroyed: + return + descriptor = self.descriptor + self.close() + self._shm.unlink() + for lock_id in (descriptor.metadata_lock_id, *descriptor.slot_lock_ids): + try: + posix_ipc.unlink_semaphore(lock_id) + except posix_ipc.ExistentialError: + pass + + @classmethod + def _name_suffix(cls, name: str, suffix: str, max_length: int = 63) -> str: + if suffix: + name = f"{name}_{suffix}" + return name[:max_length] diff --git a/bec_server/pyproject.toml b/bec_server/pyproject.toml index 2598be6ef..f0cfa79b8 100644 --- a/bec_server/pyproject.toml +++ b/bec_server/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "pyyaml~=6.0", "python-dotenv~=1.0", "rich>=13.7,<16.0", + "posix_ipc~=1.0", ] [project.optional-dependencies] diff --git a/bec_server/tests/tests_shared_memory/test_ring_buffer.py b/bec_server/tests/tests_shared_memory/test_ring_buffer.py new file mode 100644 index 000000000..29edc778e --- /dev/null +++ b/bec_server/tests/tests_shared_memory/test_ring_buffer.py @@ -0,0 +1,185 @@ +from multiprocessing import shared_memory + +import numpy as np +import posix_ipc +import pytest + +from bec_server.shared_memory.models import PayloadDescriptor +from bec_server.shared_memory.ring_buffer import HEADER_SIZE, RingBuffer, RingBufferView + + +@pytest.fixture +def payload() -> PayloadDescriptor: + return PayloadDescriptor.from_numpy(np.zeros((4,), dtype=np.float64)) + + +@pytest.fixture +def ring_buffer(payload: PayloadDescriptor): + buffer = RingBuffer(slots=2, payload=payload) + yield buffer + buffer.destroy() + + +def test_descriptor_exposes_attachment_resources_and_payload( + ring_buffer: RingBuffer, payload: PayloadDescriptor +): + assert ring_buffer.descriptor.name == ring_buffer.name + assert ring_buffer.descriptor.slots == 2 + assert ring_buffer.descriptor.payload == payload + assert ring_buffer.descriptor.metadata_lock_id + assert len(ring_buffer.descriptor.slot_lock_ids) == 2 + assert len(set(ring_buffer.descriptor.slot_lock_ids)) == 2 + + +def test_header_reconstructs_payload_description_for_attached_view( + ring_buffer: RingBuffer, payload: PayloadDescriptor +): + view = RingBufferView(ring_buffer.descriptor) + try: + assert view.slots == ring_buffer.descriptor.slots + assert view.bytes_per_slot == payload.nbytes + assert view.payload_descriptor == payload + finally: + view.close() + + +def test_attached_view_rejects_descriptor_payload_mismatch(ring_buffer: RingBuffer): + invalid_payload = PayloadDescriptor.from_numpy(np.zeros((2,), dtype=np.float64)) + descriptor = ring_buffer.descriptor.model_copy(update={"payload": invalid_payload}) + + with pytest.raises(ValueError, match="does not match shared-memory metadata"): + RingBufferView(descriptor) + + +def test_write_data_uses_circular_position_and_returns_written_slot(ring_buffer: RingBuffer): + first = np.array([1, 2, 3, 4], dtype=np.float64) + second = np.array([5, 6, 7, 8], dtype=np.float64) + third = np.array([9, 10, 11, 12], dtype=np.float64) + + assert ring_buffer.next_write_position == 0 + assert ring_buffer.write_data(first) == 0 + assert ring_buffer.next_write_position == 1 + assert ring_buffer.write_data(second) == 1 + assert ring_buffer.next_write_position == 0 + assert ring_buffer.write_data(third) == 0 + assert ring_buffer.next_write_position == 1 + np.testing.assert_array_equal(ring_buffer.copy_data(0), third) + np.testing.assert_array_equal(ring_buffer.copy_data(1), second) + + +def test_explicit_write_uses_header_payload_offset_without_advancing_cursor( + ring_buffer: RingBuffer, payload: PayloadDescriptor +): + data = np.arange(4, dtype=np.float64) + + ring_buffer.write_data_at(1, data) + + raw_payload = np.ndarray( + payload.shape, + dtype=payload.dtype.numpy_dtype, + buffer=ring_buffer._shm.buf, + offset=HEADER_SIZE + payload.nbytes, + ) + np.testing.assert_array_equal(raw_payload, data) + assert ring_buffer.next_write_position == 0 + + +def test_attached_view_shares_cursor_and_payload_storage(ring_buffer: RingBuffer): + view = RingBufferView(ring_buffer.descriptor) + try: + written_from_view = np.array([1, 2, 3, 4], dtype=np.float64) + written_from_owner = np.array([5, 6, 7, 8], dtype=np.float64) + + assert view.write_data(written_from_view) == 0 + np.testing.assert_array_equal(ring_buffer.copy_data(0), written_from_view) + + assert ring_buffer.write_data(written_from_owner) == 1 + np.testing.assert_array_equal(view.copy_data(1), written_from_owner) + assert view.next_write_position == 0 + finally: + view.close() + + +def test_each_buffer_has_distinct_shared_memory_and_semaphore_names(payload: PayloadDescriptor): + first = RingBuffer(slots=2, payload=payload) + second = RingBuffer(slots=2, payload=payload) + try: + assert first.name != second.name + assert first.descriptor.metadata_lock_id != second.descriptor.metadata_lock_id + assert first.descriptor.slot_lock_ids != second.descriptor.slot_lock_ids + finally: + first.destroy() + second.destroy() + + +def test_slot_locks_do_not_block_access_to_other_slots(ring_buffer: RingBuffer): + first_slot_lock = posix_ipc.Semaphore(ring_buffer.descriptor.slot_lock_ids[0]) + try: + first_slot_lock.acquire() + ring_buffer.write_data_at(1, np.arange(4, dtype=np.float64), acquire_timeout=0.01) + finally: + first_slot_lock.release() + first_slot_lock.close() + + +@pytest.mark.parametrize("index", [-1, 2]) +def test_copy_data_rejects_indices_outside_slots(ring_buffer: RingBuffer, index: int): + with pytest.raises(IndexError, match="out of bounds"): + ring_buffer.copy_data(index) + + +@pytest.mark.parametrize("index", [-1, 2]) +def test_write_data_at_rejects_indices_outside_slots(ring_buffer: RingBuffer, index: int): + with pytest.raises(IndexError, match="out of bounds"): + ring_buffer.write_data_at(index, np.zeros((4,), dtype=np.float64)) + + +@pytest.mark.parametrize( + "data", [np.zeros((2,), dtype=np.float64), np.zeros((4,), dtype=np.float32)] +) +def test_write_data_rejects_payload_shape_or_dtype_mismatch( + ring_buffer: RingBuffer, data: np.ndarray +): + with pytest.raises(ValueError, match="does not match expected"): + ring_buffer.write_data(data) + + +def test_destroy_is_idempotent_and_rejects_further_operations( + ring_buffer: RingBuffer, payload: PayloadDescriptor +): + ring_buffer.destroy() + ring_buffer.destroy() + + with pytest.raises(RuntimeError, match="destroyed"): + ring_buffer.write_data(np.zeros(payload.shape, dtype=payload.dtype.numpy_dtype)) + + +def test_only_creator_owns_shared_memory_resources(ring_buffer: RingBuffer): + view = RingBufferView(ring_buffer.descriptor) + try: + assert ring_buffer._owns_memory is True + assert view._owns_memory is False + finally: + view.close() + + +def test_closing_view_keeps_owner_resources_attachable(ring_buffer: RingBuffer): + view = RingBufferView(ring_buffer.descriptor) + view.close() + + attached = RingBufferView(ring_buffer.descriptor) + attached.close() + assert ring_buffer.next_write_position == 0 + + +def test_destroying_owner_unlinks_shared_memory_and_semaphores(ring_buffer: RingBuffer): + descriptor = ring_buffer.descriptor + + ring_buffer.destroy() + + with pytest.raises(FileNotFoundError): + shared_memory.SharedMemory(name=descriptor.name) + with pytest.raises(posix_ipc.ExistentialError): + posix_ipc.Semaphore(descriptor.metadata_lock_id) + with pytest.raises(posix_ipc.ExistentialError): + posix_ipc.Semaphore(descriptor.slot_lock_ids[0])