From cc6dbc1dee624f40e7c88a0bf5afda10ac297eea Mon Sep 17 00:00:00 2001 From: appel_c Date: Wed, 6 May 2026 14:52:49 +0200 Subject: [PATCH 01/16] wip first draft --- .../bec_server/shared_memory/__init__.py | 0 .../bec_server/shared_memory/manager.py | 64 +++++++++++++ .../bec_server/shared_memory/ring_buffer.py | 91 +++++++++++++++++++ 3 files changed, 155 insertions(+) create mode 100644 bec_server/bec_server/shared_memory/__init__.py create mode 100644 bec_server/bec_server/shared_memory/manager.py create mode 100644 bec_server/bec_server/shared_memory/ring_buffer.py diff --git a/bec_server/bec_server/shared_memory/__init__.py b/bec_server/bec_server/shared_memory/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bec_server/bec_server/shared_memory/manager.py b/bec_server/bec_server/shared_memory/manager.py new file mode 100644 index 000000000..7c63970f6 --- /dev/null +++ b/bec_server/bec_server/shared_memory/manager.py @@ -0,0 +1,64 @@ +from typing import Literal, TypeVar + +from bec_lib.messages import BECMessage + +SUPPORTED_DATATYPES = Literal["str", "float", "byte", "np.array", "list", "dict"] + + +class SharedMemRequestAllocation(BECMessage): + """Message to send to the shared memory manager to create a new shared memory object.""" + + sender: Literal["device", "client"] + device: str | None = None + + +class SharedMemDescriptor(BECMessage): + """Message with metadata about the shared memory created in the shared memory manager.""" + + id: str + max_index: int + owner: Literal["device", "client"] + device: str | None = None + shape: tuple[int, ...] + dtype: SUPPORTED_DATATYPES + + +class AvailableDataAnalysisMethods(BECMessage): + """Message published by the DAP server on which analysis methods are available.""" + + methods: list[str] + + +class DataAnalysisRequestWarmup(BECMessage): + """Message to request a data analysis""" + + shared_mem: SharedMemDescriptor + + +class DataAnalysisRequest(BECMessage): + """Message to request processing of a shared memory object.""" + + shared_mem: SharedMemDescriptor + index: int + methods: list[str] + + +class DataAnalysisResponse(BECMessage): + """Message to request processing of a shared memory object.""" + + shared_mem: SharedMemDescriptor + index: int + methods: list[str] + results: dict + + +class SharedMemoryManager: + + def shutdown(self): + """Shutdown method, should clean up all shared memory objects.""" + + def create_shared_mem(self, msg: SharedMemRequestAllocation) -> str: + """Creates a shared memory object under a unique name.""" + + def _publish_shared_mem_info(self, msg: SharedMemDescriptor): + """Publish information about a shared memory object.""" diff --git a/bec_server/bec_server/shared_memory/ring_buffer.py b/bec_server/bec_server/shared_memory/ring_buffer.py new file mode 100644 index 000000000..736b47c00 --- /dev/null +++ b/bec_server/bec_server/shared_memory/ring_buffer.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from dataclasses import dataclass +from multiprocessing import shared_memory +from typing import Any, Literal, Tuple + +from pydantic import BaseModel + + +@dataclass(frozen=True) +class DTypeDescriptor: + kind: Literal["uint", "int", "float", "bool"] + itemsize: int + byte_order: Literal["little", "big"] = "little" + + +@dataclass(frozen=True) +class PayloadDescriptor: + nbytes: int + shape: Tuple[int, ...] + dtype: DTypeDescriptor + layout: Literal["C"] = "C" + + +class SharedRingBufferDescriptor(BaseModel): + """Descriptor for SharedRingBuffer object.""" + + name: str + max_index: int + bytes_per_index: int + payload: PayloadDescriptor + # owner: Literal["device" , "client"] # To be checked if needed + + +class SharedRingBuffer: + """Descriptor for RingBuffer Object to share memory across processes.""" + + def __init__( + self, + shm: shared_memory.SharedMemory, + max_index: int, + bytes_per_index: int, + owns_memory: bool = False, + ): + self._shm = shm + self._max_index = max_index + self._bytes_per_index = bytes_per_index + self._owns_memory = owns_memory + + @property + def name(self): + """Name of shared ring buffer""" + return self._shm.name + + @property + def max_index(self): + """Max Index of shared ring buffer""" + return self._max_index + + @property + def bytes_per_index(self): + """Bytes per index in shared ring buffer""" + return self._bytes_per_index + + @classmethod + def create(cls, max_index: int, bytes_per_index: int) -> SharedRingBuffer: + """Create a new shared memory location and SharedRingBuffer object.""" + total_size = max_index * bytes_per_index + shm = shared_memory.SharedMemory(create=True, size=total_size) + return cls(shm, max_index=max_index, bytes_per_index=bytes_per_index, owns_memory=True) + + @classmethod + def attach(cls, descriptor: SharedRingBufferDescriptor) -> SharedRingBuffer: + """Create SharedRingBuffer by attaching to an existing shared memory object by descriptor name.""" + shm = shared_memory.SharedMemory(name=descriptor.name) + return cls(shm, max_index=descriptor.max_index, bytes_per_index=descriptor.bytes_per_index) + + def data(self, index: int) -> Any: + """Get data from SharedRingBuffer from index.""" + start = index * self.bytes_per_index + stop = start + self.bytes_per_index + return self._shm[start:stop] + + def close(self): + """Close the shared memory object.""" + self._shm.close() + + def unlink(self): + if not self._owns_memory: + raise RuntimeError(f"Can't unlike memory {self.name} that is not owned by this process") + self._shm.unlink() From 6c58d9f9d7343591ce9a01ec59d8956bd9dad4dd Mon Sep 17 00:00:00 2001 From: appel_c Date: Wed, 6 May 2026 17:23:18 +0200 Subject: [PATCH 02/16] wip, some help from chat gpt, to be reviewed and cleaned up. 4096x4096 has 7.07us write on avergage -> 19GB/s --- .../bec_server/shared_memory/ring_buffer.py | 187 +++++++++++++++--- 1 file changed, 159 insertions(+), 28 deletions(-) diff --git a/bec_server/bec_server/shared_memory/ring_buffer.py b/bec_server/bec_server/shared_memory/ring_buffer.py index 736b47c00..a967104f1 100644 --- a/bec_server/bec_server/shared_memory/ring_buffer.py +++ b/bec_server/bec_server/shared_memory/ring_buffer.py @@ -1,51 +1,94 @@ from __future__ import annotations -from dataclasses import dataclass +import sys +from contextlib import contextmanager +from enum import IntEnum from multiprocessing import shared_memory -from typing import Any, Literal, Tuple +from typing import Iterator, Literal, Tuple +import numpy as np from pydantic import BaseModel -@dataclass(frozen=True) -class DTypeDescriptor: +class SlotState(IntEnum): + """State of the data at memory slot.""" + + READY_TO_WRITE = 0 + WRITING = 1 + READY_TO_READ = 2 + READING = 3 + + +class DTypeDescriptor(BaseModel): kind: Literal["uint", "int", "float", "bool"] itemsize: int byte_order: Literal["little", "big"] = "little" - -@dataclass(frozen=True) -class PayloadDescriptor: + @classmethod + def from_numpy(cls, dtype: np.dtype) -> "DTypeDescriptor": + dtype = np.dtype(dtype) + kind_map = {"u": "uint", "i": "int", "f": "float", "b": "bool"} + if dtype.kind not in kind_map: + raise ValueError(f"Unsupported dtype kind: {dtype.kind!r}") + + byte_order = dtype.byteorder + if byte_order in ("=", "|"): + byte_order = sys.byteorder + elif byte_order == "<": + byte_order = "little" + elif byte_order == ">": + byte_order = "big" + else: + raise ValueError(f"Unsupported byte order: {dtype.byteorder!r}") + + return cls(kind=kind_map[dtype.kind], itemsize=dtype.itemsize, byte_order=byte_order) + + +class PayloadDescriptor(BaseModel): nbytes: int shape: Tuple[int, ...] dtype: DTypeDescriptor layout: Literal["C"] = "C" + @classmethod + def from_numpy(cls, array: np.ndarray) -> "PayloadDescriptor": + + return cls( + nbytes=array.nbytes, + shape=array.shape, + dtype=DTypeDescriptor.from_numpy(array.dtype), + layout="C" if array.flags.c_contiguous else "C", + ) + class SharedRingBufferDescriptor(BaseModel): """Descriptor for SharedRingBuffer object.""" name: str - max_index: int - bytes_per_index: int + slots: int + bytes_per_slot: int + slot_state_bytes: int payload: PayloadDescriptor - # owner: Literal["device" , "client"] # To be checked if needed class SharedRingBuffer: """Descriptor for RingBuffer Object to share memory across processes.""" + SLOT_STATE_BYTES = 1 + def __init__( self, shm: shared_memory.SharedMemory, - max_index: int, - bytes_per_index: int, + payload: PayloadDescriptor, + slots: int, + bytes_per_slot: int, owns_memory: bool = False, ): self._shm = shm - self._max_index = max_index - self._bytes_per_index = bytes_per_index + self._slots = slots + self._bytes_per_slot = bytes_per_slot self._owns_memory = owns_memory + self._payload = payload @property def name(self): @@ -53,33 +96,119 @@ def name(self): return self._shm.name @property - def max_index(self): + def slots(self): """Max Index of shared ring buffer""" - return self._max_index + return self._slots @property - def bytes_per_index(self): + def bytes_per_slot(self): """Bytes per index in shared ring buffer""" - return self._bytes_per_index + return self._bytes_per_slot + + @property + def payload(self): + return self._payload @classmethod - def create(cls, max_index: int, bytes_per_index: int) -> SharedRingBuffer: + def create(cls, slots: int, payload: PayloadDescriptor | dict) -> SharedRingBuffer: """Create a new shared memory location and SharedRingBuffer object.""" - total_size = max_index * bytes_per_index + if isinstance(payload, dict): + payload = PayloadDescriptor.model_validate(payload) + bytes_per_slot = payload.nbytes + cls.SLOT_STATE_BYTES + total_size = slots * (bytes_per_slot) shm = shared_memory.SharedMemory(create=True, size=total_size) - return cls(shm, max_index=max_index, bytes_per_index=bytes_per_index, owns_memory=True) + ring_buffer = cls( + shm, slots=slots, bytes_per_slot=bytes_per_slot, payload=payload, owns_memory=True + ) + for slot in range(slots): + ring_buffer.set_state(slot, SlotState.READY_TO_WRITE.value) + return ring_buffer @classmethod def attach(cls, descriptor: SharedRingBufferDescriptor) -> SharedRingBuffer: """Create SharedRingBuffer by attaching to an existing shared memory object by descriptor name.""" shm = shared_memory.SharedMemory(name=descriptor.name) - return cls(shm, max_index=descriptor.max_index, bytes_per_index=descriptor.bytes_per_index) - - def data(self, index: int) -> Any: - """Get data from SharedRingBuffer from index.""" - start = index * self.bytes_per_index - stop = start + self.bytes_per_index - return self._shm[start:stop] + return cls( + shm, + slots=descriptor.slots, + bytes_per_slot=descriptor.bytes_per_slot, + payload=descriptor.payload, + ) + + def descriptor(self) -> SharedRingBufferDescriptor: + """Create a serializable descriptor for this ring buffer.""" + + return SharedRingBufferDescriptor( + name=self.name, + slots=self.slots, + bytes_per_slot=self.bytes_per_slot, + slot_state_bytes=self.SLOT_STATE_BYTES, + payload=self.payload, + ) + + def _data(self, index: int) -> memoryview: + """Return payload memory for one slot.""" + start = self._payload_offset(index) + stop = start + self.payload.nbytes + return self._shm.buf[start:stop] + + def _slot_offset(self, index: int) -> int: + self._validate_index(index) + return index * self.bytes_per_slot + + def _slot_state_offset(self, index: int) -> int: + return self._slot_offset(index) + + def _payload_offset(self, index: int) -> int: + return self._slot_offset(index) + self.SLOT_STATE_BYTES + + def _validate_index(self, index: int) -> None: + if not 0 <= index < self.slots: + raise IndexError(f"Index {index} outside valid range 0..{self.slots - 1}.") + + def state(self, index: int) -> SlotState: + """Return the current slot state.""" + offset = self._slot_state_offset(index) + return SlotState(self._shm.buf[offset]) + + def set_state(self, index: int, state: SlotState) -> None: + """Set the current slot state.""" + offset = self._slot_state_offset(index) + self._shm.buf[offset] = int(state) + + @contextmanager + def read_slot(self, index: int, force: bool = False) -> Iterator[memoryview]: + """Read from a slot and mark it writable afterwards.""" + if force: + valid_read_states = [SlotState.READY_TO_READ.value, SlotState.READY_TO_WRITE.value] + else: + valid_read_states = [SlotState.READY_TO_WRITE.value] + while not self.state(index) in valid_read_states: + ... + self.set_state(index, SlotState.READING) + try: + yield self._data(index) + finally: + self.set_state(index, SlotState.READY_TO_WRITE) + + @contextmanager + def write_slot(self, index: int, force: bool = False) -> Iterator[memoryview]: + """Write to a slot and mark it readable afterwards.""" + if force: + valid_write_states = [SlotState.READY_TO_READ.value, SlotState.READY_TO_WRITE.value] + else: + valid_write_states = [SlotState.READY_TO_WRITE.value] + + while not self.state(index) in valid_write_states: + ... + self.set_state(index, SlotState.WRITING) + try: + yield self._data(index) + except Exception as exc: + self.set_state(index, SlotState.READY_TO_WRITE) + raise exc + else: + self.set_state(index, SlotState.READY_TO_READ) def close(self): """Close the shared memory object.""" @@ -89,3 +218,5 @@ def unlink(self): if not self._owns_memory: raise RuntimeError(f"Can't unlike memory {self.name} that is not owned by this process") self._shm.unlink() + + # TODO shutdown procedure for proper clean up of resources.. From 6e6a1c46232d8e2d152a26bc936ff819a3c9d08d Mon Sep 17 00:00:00 2001 From: appel_c Date: Wed, 13 May 2026 22:53:15 +0200 Subject: [PATCH 03/16] wip --- .../bec_server/shared_memory/ring_buffer.py | 89 +++++++++---------- 1 file changed, 43 insertions(+), 46 deletions(-) diff --git a/bec_server/bec_server/shared_memory/ring_buffer.py b/bec_server/bec_server/shared_memory/ring_buffer.py index a967104f1..507dd295f 100644 --- a/bec_server/bec_server/shared_memory/ring_buffer.py +++ b/bec_server/bec_server/shared_memory/ring_buffer.py @@ -7,6 +7,7 @@ from typing import Iterator, Literal, Tuple import numpy as np +import posix_ipc from pydantic import BaseModel @@ -26,6 +27,7 @@ class DTypeDescriptor(BaseModel): @classmethod def from_numpy(cls, dtype: np.dtype) -> "DTypeDescriptor": + """Class method to create DTypeDescriptor from numpy dtype.""" dtype = np.dtype(dtype) kind_map = {"u": "uint", "i": "int", "f": "float", "b": "bool"} if dtype.kind not in kind_map: @@ -43,8 +45,18 @@ def from_numpy(cls, dtype: np.dtype) -> "DTypeDescriptor": return cls(kind=kind_map[dtype.kind], itemsize=dtype.itemsize, byte_order=byte_order) + @property + def numpy_dtype(self) -> np.dtype: + """Return the corresponding numpy dtype for this DTypeDescriptor.""" + byte_order_char = {"little": "<", "big": ">"}[self.byte_order] + kind_char = {"uint": "u", "int": "i", "float": "f", "bool": "b"}[self.kind] + dtype_str = f"{byte_order_char}{kind_char}{self.itemsize}" + return np.dtype(dtype_str) + class PayloadDescriptor(BaseModel): + """Descriptor for the data payload stored in each slot of the ring buffer.""" + nbytes: int shape: Tuple[int, ...] dtype: DTypeDescriptor @@ -52,7 +64,7 @@ class PayloadDescriptor(BaseModel): @classmethod def from_numpy(cls, array: np.ndarray) -> "PayloadDescriptor": - + """Class method to create PayloadDescriptor from a numpy array.""" return cls( nbytes=array.nbytes, shape=array.shape, @@ -65,17 +77,15 @@ class SharedRingBufferDescriptor(BaseModel): """Descriptor for SharedRingBuffer object.""" name: str + lock_id: str slots: int bytes_per_slot: int - slot_state_bytes: int payload: PayloadDescriptor class SharedRingBuffer: """Descriptor for RingBuffer Object to share memory across processes.""" - SLOT_STATE_BYTES = 1 - def __init__( self, shm: shared_memory.SharedMemory, @@ -83,12 +93,18 @@ def __init__( slots: int, bytes_per_slot: int, owns_memory: bool = False, + lock_id: str | None = None, ): self._shm = shm self._slots = slots self._bytes_per_slot = bytes_per_slot self._owns_memory = owns_memory self._payload = payload + self._semaphore_lock = ( + posix_ipc.Semaphore(shm.name + "_lock", flags=posix_ipc.O_CREAT, initial_value=1) + if lock_id is None + else posix_ipc.Semaphore(lock_id, flags=0) + ) @property def name(self): @@ -114,14 +130,12 @@ def create(cls, slots: int, payload: PayloadDescriptor | dict) -> SharedRingBuff """Create a new shared memory location and SharedRingBuffer object.""" if isinstance(payload, dict): payload = PayloadDescriptor.model_validate(payload) - bytes_per_slot = payload.nbytes + cls.SLOT_STATE_BYTES + bytes_per_slot = payload.nbytes total_size = slots * (bytes_per_slot) shm = shared_memory.SharedMemory(create=True, size=total_size) ring_buffer = cls( shm, slots=slots, bytes_per_slot=bytes_per_slot, payload=payload, owns_memory=True ) - for slot in range(slots): - ring_buffer.set_state(slot, SlotState.READY_TO_WRITE.value) return ring_buffer @classmethod @@ -133,6 +147,7 @@ def attach(cls, descriptor: SharedRingBufferDescriptor) -> SharedRingBuffer: slots=descriptor.slots, bytes_per_slot=descriptor.bytes_per_slot, payload=descriptor.payload, + lock_id=descriptor.lock_id, ) def descriptor(self) -> SharedRingBufferDescriptor: @@ -142,13 +157,13 @@ def descriptor(self) -> SharedRingBufferDescriptor: name=self.name, slots=self.slots, bytes_per_slot=self.bytes_per_slot, - slot_state_bytes=self.SLOT_STATE_BYTES, payload=self.payload, + lock_id=self._semaphore_lock.name, ) def _data(self, index: int) -> memoryview: """Return payload memory for one slot.""" - start = self._payload_offset(index) + start = self._slot_offset(index) stop = start + self.payload.nbytes return self._shm.buf[start:stop] @@ -156,59 +171,32 @@ def _slot_offset(self, index: int) -> int: self._validate_index(index) return index * self.bytes_per_slot - def _slot_state_offset(self, index: int) -> int: - return self._slot_offset(index) - - def _payload_offset(self, index: int) -> int: - return self._slot_offset(index) + self.SLOT_STATE_BYTES - def _validate_index(self, index: int) -> None: if not 0 <= index < self.slots: raise IndexError(f"Index {index} outside valid range 0..{self.slots - 1}.") def state(self, index: int) -> SlotState: """Return the current slot state.""" - offset = self._slot_state_offset(index) + offset = self._slot_offset(index) return SlotState(self._shm.buf[offset]) - def set_state(self, index: int, state: SlotState) -> None: - """Set the current slot state.""" - offset = self._slot_state_offset(index) - self._shm.buf[offset] = int(state) - @contextmanager - def read_slot(self, index: int, force: bool = False) -> Iterator[memoryview]: + def read_slot(self, index: int, timeout_lock: float = 0) -> Iterator[memoryview]: """Read from a slot and mark it writable afterwards.""" - if force: - valid_read_states = [SlotState.READY_TO_READ.value, SlotState.READY_TO_WRITE.value] - else: - valid_read_states = [SlotState.READY_TO_WRITE.value] - while not self.state(index) in valid_read_states: - ... - self.set_state(index, SlotState.READING) try: + self._semaphore_lock.acquire(timeout=timeout_lock) yield self._data(index) finally: - self.set_state(index, SlotState.READY_TO_WRITE) + self._semaphore_lock.release() @contextmanager - def write_slot(self, index: int, force: bool = False) -> Iterator[memoryview]: + def write_slot(self, index: int, timeout_lock: float = 0) -> Iterator[memoryview]: """Write to a slot and mark it readable afterwards.""" - if force: - valid_write_states = [SlotState.READY_TO_READ.value, SlotState.READY_TO_WRITE.value] - else: - valid_write_states = [SlotState.READY_TO_WRITE.value] - - while not self.state(index) in valid_write_states: - ... - self.set_state(index, SlotState.WRITING) try: + self._semaphore_lock.acquire(timeout=timeout_lock) yield self._data(index) - except Exception as exc: - self.set_state(index, SlotState.READY_TO_WRITE) - raise exc - else: - self.set_state(index, SlotState.READY_TO_READ) + finally: + self._semaphore_lock.release() def close(self): """Close the shared memory object.""" @@ -216,7 +204,16 @@ def close(self): def unlink(self): if not self._owns_memory: - raise RuntimeError(f"Can't unlike memory {self.name} that is not owned by this process") + raise RuntimeError(f"Can't unlink memory {self.name} that is not owned by this process") + self.close() self._shm.unlink() + posix_ipc.unlink_semaphore(self._semaphore_lock.name) + + def shutdown(self): + """Close and unlink the shared memory object if owned.""" + self.close() + if self._owns_memory: + self.unlink() + - # TODO shutdown procedure for proper clean up of resources.. +# TODO to be tested, check if semaphore locking works From fbd3518fb68c84989d6d33a67af4483565f6fedc Mon Sep 17 00:00:00 2001 From: appel_c Date: Wed, 13 May 2026 22:53:26 +0200 Subject: [PATCH 04/16] w --- .../bec_server/shared_memory/manager.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/bec_server/bec_server/shared_memory/manager.py b/bec_server/bec_server/shared_memory/manager.py index 7c63970f6..49ef6a799 100644 --- a/bec_server/bec_server/shared_memory/manager.py +++ b/bec_server/bec_server/shared_memory/manager.py @@ -1,10 +1,13 @@ -from typing import Literal, TypeVar +from typing import Literal from bec_lib.messages import BECMessage SUPPORTED_DATATYPES = Literal["str", "float", "byte", "np.array", "list", "dict"] +################# +## Messages +################# class SharedMemRequestAllocation(BECMessage): """Message to send to the shared memory manager to create a new shared memory object.""" @@ -16,6 +19,7 @@ class SharedMemDescriptor(BECMessage): """Message with metadata about the shared memory created in the shared memory manager.""" id: str + lock_id: str max_index: int owner: Literal["device", "client"] device: str | None = None @@ -29,18 +33,30 @@ class AvailableDataAnalysisMethods(BECMessage): methods: list[str] +# TODO maybe not needed to warm up, could automatically start a DAP worker once a shared memory object is created, +# Then DataAnalysisRegisterRequest is designed to register analysis methods for the shared memory object, and +# DataAnalysisTrigger is designed to trigger the analysis of the shared memory object. +# DataAnalysisResponse is designed to send the results back to the client. class DataAnalysisRequestWarmup(BECMessage): """Message to request a data analysis""" shared_mem: SharedMemDescriptor -class DataAnalysisRequest(BECMessage): +class DataAnalysisRegisterRequest(BECMessage): """Message to request processing of a shared memory object.""" shared_mem: SharedMemDescriptor - index: int methods: list[str] + client_id: str + device: str | None = None + + +class DataAnalysisTrigger(BECMessage): + """Message to request processing of a shared memory object.""" + + shared_mem: SharedMemDescriptor + index: int class DataAnalysisResponse(BECMessage): @@ -48,8 +64,9 @@ class DataAnalysisResponse(BECMessage): shared_mem: SharedMemDescriptor index: int - methods: list[str] results: dict + client_id: str + device: str | None = None class SharedMemoryManager: From 399c9c7e6b12472f2b9baf0478aa1a7b9c2ffe65 Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 15 May 2026 13:44:57 +0200 Subject: [PATCH 05/16] wip add endpoints and messages --- bec_lib/bec_lib/endpoints.py | 48 ++++++++++++++++++++++++++++++++++++ bec_lib/bec_lib/messages.py | 34 +++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index 930e3bbb0..58b46c307 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -83,6 +83,54 @@ class MessageEndpoints: Class for message endpoints. """ + @staticmethod + def shared_memory_info(client_id: str): + """ + Endpoint for shared memory information. This endpoint is used to publish the shared memory information using + a messages.SharedMemAllocationInfo message. + + Returns: + EndpointInfo: Endpoint for shared memory information. + """ + endpoint = f"{EndpointType.INFO.value}/shared_memory/info/{client_id}" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.SharedMemAllocationInfo, + message_op=MessageOp.SET_PUBLISH, + ) + + @staticmethod + def shared_memory_allocate(): + """ + Endpoint for shared memory allocation. This endpoint is used to request the allocation of a shared memory object using + a messages.SharedMemAllocationRequest message. + + Returns: + EndpointInfo: Endpoint for shared memory allocation. + """ + endpoint = f"{EndpointType.INFO.value}/shared_memory/allocate" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.SharedMemAllocationRequest, + message_op=MessageOp.STREAM, + ) + + @staticmethod + def shared_memory_deallocate(): + """ + Endpoint for shared memory deallocation. This endpoint is used to request the deallocation of a shared memory object using + a messages.SharedMemDeallocationRequest message. + + Returns: + EndpointInfo: Endpoint for shared memory deallocation. + """ + endpoint = f"{EndpointType.INFO.value}/shared_memory/deallocate" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.SharedMemDeallocationRequest, + message_op=MessageOp.STREAM, + ) + # devices feedback @staticmethod def device_status(device: str): diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index b94fc6938..36e1a2f19 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -17,6 +17,9 @@ from bec_lib.metadata_schema import get_metadata_schema_for_scan +# TODO remove bec_server depencency.. +from bec_server.shared_memory.models import PayloadDescriptor, SharedMemInfo + class ProcedureWorkerStatus(Enum): RUNNING = auto() @@ -94,6 +97,37 @@ def __hash__(self) -> int: return self.model_dump_json().__hash__() +class SharedMemAllocationInfo(BECMessage): + """ + This message is published by the shared memory manager and contains a list of all currently allocated shared memory objects. + Once shared memory objects are created or destroyed, this message will publish the updated list of shared memory objects. + """ + + msg_type: ClassVar[str] = "shared_mem_allocation_info" + + info: list[SharedMemInfo] + + +class SharedMemAllocationRequest(BECMessage): + """Message to request information about a shared memory object.""" + + msg_type: ClassVar[str] = "shared_mem_allocation_request" + + client_id: str + slots: int + payload_desc: PayloadDescriptor + signal: str | None = None + + +class SharedMemDeallocationRequest(BECMessage): + """Message to request deallocation of a shared memory object.""" + + msg_type: ClassVar[str] = "shared_mem_deallocation_request" + + client_id: str + shared_mem_info: SharedMemInfo + + class BundleMessage(BECMessage): """Message type to send a bundle of BECMessages. From 38c2a2e89784e863386cca9f4e45fc9831c9ceb7 Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 15 May 2026 13:45:14 +0200 Subject: [PATCH 06/16] refactor manager, buffer and client --- .../bec_server/shared_memory/cli/launch.py | 36 ++ bec_server/bec_server/shared_memory/client.py | 120 ++++++ .../bec_server/shared_memory/manager.py | 171 ++++---- bec_server/bec_server/shared_memory/models.py | 124 ++++++ .../bec_server/shared_memory/ring_buffer.py | 384 +++++++++--------- 5 files changed, 573 insertions(+), 262 deletions(-) create mode 100644 bec_server/bec_server/shared_memory/cli/launch.py create mode 100644 bec_server/bec_server/shared_memory/client.py create mode 100644 bec_server/bec_server/shared_memory/models.py diff --git a/bec_server/bec_server/shared_memory/cli/launch.py b/bec_server/bec_server/shared_memory/cli/launch.py new file mode 100644 index 000000000..28bfd5f5a --- /dev/null +++ b/bec_server/bec_server/shared_memory/cli/launch.py @@ -0,0 +1,36 @@ +# Description: Launch the shared memory manager server. +# This script is the entry point for the Shared Memory Manager Server. It is called either +# by the bec-shared-mem-manager entry point or directly from the command line. +import threading + +from bec_lib.bec_service import parse_cmdline_args +from bec_lib.logger import bec_logger +from bec_lib.redis_connector import RedisConnector +from bec_server.shared_memory.manager import SharedMemoryManager + +logger = bec_logger.logger +bec_logger.level = bec_logger.LOGLEVEL.INFO + + +def main(): + """ + Launch the shared memory manager server. + """ + _, _, config = parse_cmdline_args() + + bec_server = SharedMemoryManager(config=config, connector_cls=RedisConnector) + bec_server.start() + + try: + event = threading.Event() + logger.success( + f"Started Shared Memory Manager server (id: {bec_server._service_id}). Press Ctrl+C to stop." + ) + event.wait() + except KeyboardInterrupt: + bec_server.shutdown() + event.set() + + +if __name__ == "__main__": + main() diff --git a/bec_server/bec_server/shared_memory/client.py b/bec_server/bec_server/shared_memory/client.py new file mode 100644 index 000000000..df8b5c285 --- /dev/null +++ b/bec_server/bec_server/shared_memory/client.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from bec_lib.endpoints import MessageEndpoints +from bec_lib.messages import SharedMemAllocationInfo +from bec_server.shared_memory.models import PayloadDescriptor +from bec_server.shared_memory.ring_buffer import RingBufferView + +if TYPE_CHECKING: + import numpy as np + + from bec_lib.redis_connector import RedisConnector + + +# TODO one per service, or N per service. +class SharedMemoryClient: + """Client for interacting with shared memory objects managed by the SharedMemoryManager.""" + + def __init__(self, name: str, connector: RedisConnector): + self.name = name + self.connector = connector + self._ring_buffer_views: dict[str, RingBufferView] = {} + self._signal_to_buffer_mapping: dict[str, str] = ( + {} + ) # Mapping from signal names to buffer names + + def start(self): + """Start the client by subscribing to the shared memory object.""" + self.connector.register( + MessageEndpoints.shared_memory_info(self.name), cb=self._handle_info_update + ) + + def _handle_info_update(self, info: SharedMemAllocationInfo) -> None: + """Handle updates to the shared memory information.""" + if isinstance(info, dict): + info = SharedMemAllocationInfo.model_validate(info) + # Any info update can potentially contain relevant information for creating or deleting ring buffer views. + info_updates = [] + for buff_info in info.info: + info_updates.append(buff_info.buffer_desc.name) + if buff_info.buffer_desc.name not in self._ring_buffer_views: + self._ring_buffer_views[buff_info.buffer_desc.name] = RingBufferView( + descriptor=buff_info.buffer_desc + ) + self._signal_to_buffer_mapping[buff_info.buffer_desc.signal_name] = ( + buff_info.buffer_desc.name + ) + if len(info.info) < len(self._ring_buffer_views): + # Some shared memory objects have been deallocated. Remove them from the local dictionary. + to_be_removed = set(self._ring_buffer_views.keys()) - set(info_updates) + for name in to_be_removed: + view = self._ring_buffer_views.pop(name) + view.destroy() + self._signal_to_buffer_mapping.pop(view.descriptor.signal_name, None) + + def request_allocation( + self, signal_name: str, slots: int, payload_desc: PayloadDescriptor | dict + ) -> None: + """Request the allocation of a shared memory object.""" + if isinstance(payload_desc, dict): + payload_desc = PayloadDescriptor.model_validate(payload_desc) + + self.connector.xadd( + MessageEndpoints.shared_memory_allocate(), + { + "client_id": self.name, + "slots": slots, + "payload_desc": payload_desc, + "signal": signal_name, + }, + max_size=1000, # Keep history of 1000 allocation requests + ) + + def request_deallocation(self, signal_name: str) -> None: + """Request the deallocation of a shared memory object.""" + self.connector.xadd( + MessageEndpoints.shared_memory_deallocate(), + {"client_id": self.name, "signal": signal_name}, + max_size=1000, # Keep history of 1000 deallocation requests + ) + + def read_from_buffer( + self, signal_name: str, index: int, timeout: float | None = None + ) -> np.ndarray: + """ + Read data from the shared memory buffer associated with the given signal name. + If timeout is provided, the method will wait for the specified time and raise a TimeoutError if it cannot + read the data within that time frame. Please be aware, this is meant to block during write/read operations. + """ + if signal_name not in self._signal_to_buffer_mapping: + raise ValueError(f"No buffer found for signal name: {signal_name}") + buffer_name = self._signal_to_buffer_mapping[signal_name] + if buffer_name not in self._ring_buffer_views: + raise ValueError(f"No ring buffer view found for buffer name: {buffer_name}") + return self._ring_buffer_views[buffer_name].copy_data(index, timeout) + + def write_to_buffer( + self, signal_name: str, index: int, data: np.ndarray, timeout: float | None = None + ) -> None: + """ + Write data to the shared memory buffer associated with the given signal name. + If timeout is provided, the method will wait for the specified time and raise a TimeoutError if it cannot + write the data within that time frame. Please be aware, this is meant to block during write/read operations. + """ + if signal_name not in self._signal_to_buffer_mapping: + raise ValueError(f"No buffer found for signal name: {signal_name}") + buffer_name = self._signal_to_buffer_mapping[signal_name] + if buffer_name not in self._ring_buffer_views: + raise ValueError(f"No ring buffer view found for buffer name: {buffer_name}") + self._ring_buffer_views[buffer_name].write_data( + index=index, data=data, acquire_timeout=timeout + ) + + def shutdown(self) -> None: + """Clean up resources and all shared memory views.""" + for view in self._ring_buffer_views.values(): + view.destroy() + self._ring_buffer_views.clear() + self._signal_to_buffer_mapping.clear() diff --git a/bec_server/bec_server/shared_memory/manager.py b/bec_server/bec_server/shared_memory/manager.py index 49ef6a799..7994a913b 100644 --- a/bec_server/bec_server/shared_memory/manager.py +++ b/bec_server/bec_server/shared_memory/manager.py @@ -1,81 +1,98 @@ -from typing import Literal +from __future__ import annotations -from bec_lib.messages import BECMessage +import threading +from typing import TYPE_CHECKING, Literal -SUPPORTED_DATATYPES = Literal["str", "float", "byte", "np.array", "list", "dict"] - - -################# -## Messages -################# -class SharedMemRequestAllocation(BECMessage): - """Message to send to the shared memory manager to create a new shared memory object.""" - - sender: Literal["device", "client"] - device: str | None = None - - -class SharedMemDescriptor(BECMessage): - """Message with metadata about the shared memory created in the shared memory manager.""" - - id: str - lock_id: str - max_index: int - owner: Literal["device", "client"] - device: str | None = None - shape: tuple[int, ...] - dtype: SUPPORTED_DATATYPES - - -class AvailableDataAnalysisMethods(BECMessage): - """Message published by the DAP server on which analysis methods are available.""" - - methods: list[str] - - -# TODO maybe not needed to warm up, could automatically start a DAP worker once a shared memory object is created, -# Then DataAnalysisRegisterRequest is designed to register analysis methods for the shared memory object, and -# DataAnalysisTrigger is designed to trigger the analysis of the shared memory object. -# DataAnalysisResponse is designed to send the results back to the client. -class DataAnalysisRequestWarmup(BECMessage): - """Message to request a data analysis""" - - shared_mem: SharedMemDescriptor +from bec_lib import messages +from bec_lib.bec_service import BECService +from bec_lib.endpoints import MessageEndpoints +from bec_lib.logger import bec_logger +from bec_server.shared_memory.models import SharedMemInfo +from bec_server.shared_memory.ring_buffer import RingBuffer +SUPPORTED_DATATYPES = Literal["str", "float", "byte", "np.array", "list", "dict"] -class DataAnalysisRegisterRequest(BECMessage): - """Message to request processing of a shared memory object.""" - - shared_mem: SharedMemDescriptor - methods: list[str] - client_id: str - device: str | None = None - - -class DataAnalysisTrigger(BECMessage): - """Message to request processing of a shared memory object.""" - - shared_mem: SharedMemDescriptor - index: int - - -class DataAnalysisResponse(BECMessage): - """Message to request processing of a shared memory object.""" - - shared_mem: SharedMemDescriptor - index: int - results: dict - client_id: str - device: str | None = None - - -class SharedMemoryManager: - - def shutdown(self): - """Shutdown method, should clean up all shared memory objects.""" - - def create_shared_mem(self, msg: SharedMemRequestAllocation) -> str: - """Creates a shared memory object under a unique name.""" - - def _publish_shared_mem_info(self, msg: SharedMemDescriptor): - """Publish information about a shared memory object.""" +if TYPE_CHECKING: + from bec_lib.redis_connector import MessageObject, RedisConnector + +logger = bec_logger.logger + + +class SharedMemoryManager(BECService): + """ + Service to manage shared memory objects. It keeps track of all allocated shared memory objects and their descriptors. + It also handles the creation and destruction of shared memory objects, and publishes the updated list of shared memory objects + whenever a new shared memory object is created or destroyed. + """ + + def __init__(self, config, connector_cls: type[RedisConnector]) -> None: + super().__init__(config, connector_cls, unique_service=True) + self._shared_memory_objects: dict[str, RingBuffer] = {} + self.lock = threading.RLock() + + def _allocate_memory(self, request: messages.SharedMemAllocationRequest) -> None: + """Callback function to handle shared memory allocation requests.""" + if isinstance(request, dict): + request = messages.SharedMemAllocationRequest.model_validate(request) + if request.client_id in self._shared_memory_objects: + logger.error( + f"Shared memory object for client {request.client_id} already exists. Overwriting." + ) + return + + buff = RingBuffer( + slots=request.slots, payload=request.payload_desc, name_suffix=request.signal + ) + with self.lock: + self._shared_memory_objects[request.client_id] = buff + self._publish_allocation_info(client_id=request.client_id) + + def _deallocate_memory(self, request: messages.SharedMemDeallocationRequest) -> None: + """Callback function to handle shared memory deallocation requests.""" + if isinstance(request, dict): + request = messages.SharedMemDeallocationRequest.model_validate(request) + if request.client_id not in self._shared_memory_objects: + logger.error( + f"Shared memory object for client {request.client_id} does not exist. Cannot deallocate." + ) + return + + with self.lock: + buff = self._shared_memory_objects.pop(request.client_id) + buff.destroy() + self._publish_allocation_info(client_id=request.client_id) + + def _publish_allocation_info(self, client_id: str = "*") -> None: + """Publish the updated list of allocated shared memory objects.""" + with self.lock: + info = [ + SharedMemInfo(client_id=client_id, buffer_desc=buff.descriptor) + for client_id, buff in self._shared_memory_objects.items() + ] + # Maybe use regex here.. + if client_id != "*": + info = [buff_info for buff_info in info if buff_info.client_id == client_id] + self.connector.set_and_publish( + MessageEndpoints.shared_memory_info(client_id), + messages.SharedMemAllocationInfo(info=info), + ) + + def start(self) -> None: + """start the shared memory manager server""" + self.connector.register(MessageEndpoints.shared_memory_allocate(), cb=self._allocate_memory) + self.connector.register( + MessageEndpoints.shared_memory_deallocate(), cb=self._deallocate_memory + ) + + def stop(self) -> None: + with self.lock: + for buff in self._shared_memory_objects.values(): + buff.destroy() + self._shared_memory_objects.clear() + self._publish_allocation_info() + # Cleanup bec service related resources + + def shutdown(self) -> None: + """Shutdown the shared memory manager server and destroy all shared memory objects.""" + self.stop() + super().shutdown() diff --git a/bec_server/bec_server/shared_memory/models.py b/bec_server/bec_server/shared_memory/models.py new file mode 100644 index 000000000..fe403a0c6 --- /dev/null +++ b/bec_server/bec_server/shared_memory/models.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import sys +from typing import Literal, Tuple + +import numpy as np +from pydantic import BaseModel, ConfigDict + + +class SharedMemInfo(BaseModel): + """ + Store information about the shared memory object. This message has the client_id, the buffer descriptor and + the potentially a list of devices for which this shared memory object is relevant. + """ + + model_config = ConfigDict(validate_assignment=True) + client_id: str + buffer_desc: RingBufferDescriptor + signal: str | None = None # dotted signal name, e.g. "eiger.preview" + + +class DTypeDescriptor(BaseModel): + kind: Literal["uint", "int", "float", "bool"] + itemsize: int + byte_order: Literal["little", "big"] = "little" + + @classmethod + def from_numpy(cls, dtype: np.dtype) -> DTypeDescriptor: + """Class method to create DTypeDescriptor from numpy dtype.""" + dtype = np.dtype(dtype) + kind_map = {"u": "uint", "i": "int", "f": "float", "b": "bool"} + if dtype.kind not in kind_map: + raise ValueError(f"Unsupported dtype kind: {dtype.kind!r}") + + byte_order = dtype.byteorder + if byte_order in ("=", "|"): + byte_order = sys.byteorder + elif byte_order == "<": + byte_order = "little" + elif byte_order == ">": + byte_order = "big" + else: + raise ValueError(f"Unsupported byte order: {dtype.byteorder!r}") + + return cls(kind=kind_map[dtype.kind], itemsize=dtype.itemsize, byte_order=byte_order) + + @property + def numpy_dtype(self) -> np.dtype: + """Return the corresponding numpy dtype for this DTypeDescriptor.""" + byte_order_char = {"little": "<", "big": ">"}[self.byte_order] + kind_char = {"uint": "u", "int": "i", "float": "f", "bool": "b"}[self.kind] + dtype_str = f"{byte_order_char}{kind_char}{self.itemsize}" + return np.dtype(dtype_str) + + +class PayloadDescriptor(BaseModel): + """Descriptor for the data payload stored in each slot of the ring buffer.""" + + nbytes: int + shape: Tuple[int, ...] + dtype: DTypeDescriptor + layout: Literal["C"] = "C" + + @classmethod + def from_numpy(cls, array: np.ndarray) -> PayloadDescriptor: + """Class method to create PayloadDescriptor from a numpy array.""" + return cls( + nbytes=array.nbytes, + shape=array.shape, + dtype=DTypeDescriptor.from_numpy(array.dtype), + layout="C" if array.flags.c_contiguous else "C", + ) + + +class RingBufferDescriptor(BaseModel): + """Descriptor for SharedRingBuffer object.""" + + name: str + lock_id: str + slots: int + bytes_per_slot: int + payload: PayloadDescriptor + + +# class AvailableDataAnalysisMethods(messages.BECMessage): +# """Message published by the DAP server on which analysis methods are available.""" + +# methods: list[str] + + +# TODO maybe not needed to warm up, could automatically start a DAP worker once a shared memory object is created, +# Then DataAnalysisRegisterRequest is designed to register analysis methods for the shared memory object, and +# DataAnalysisTrigger is designed to trigger the analysis of the shared memory object. +# DataAnalysisResponse is designed to send the results back to the client. +# class DataAnalysisRequestWarmup(BECMessage): +# """Message to request a data analysis""" + +# shared_mem: SharedMemDescriptor + + +# class DataAnalysisRegisterRequest(BECMessage): +# """Message to request processing of a shared memory object.""" + +# shared_mem: SharedMemDescriptor +# methods: list[str] +# client_id: str +# device: str | None = None + + +# class DataAnalysisTrigger(BECMessage): +# """Message to request processing of a shared memory object.""" + +# shared_mem: SharedMemDescriptor +# index: int + + +# class DataAnalysisResponse(BECMessage): +# """Message to request processing of a shared memory object.""" + +# shared_mem: SharedMemDescriptor +# index: int +# results: dict +# client_id: str +# device: str | None = None diff --git a/bec_server/bec_server/shared_memory/ring_buffer.py b/bec_server/bec_server/shared_memory/ring_buffer.py index 507dd295f..81911a888 100644 --- a/bec_server/bec_server/shared_memory/ring_buffer.py +++ b/bec_server/bec_server/shared_memory/ring_buffer.py @@ -1,219 +1,233 @@ from __future__ import annotations -import sys -from contextlib import contextmanager -from enum import IntEnum -from multiprocessing import shared_memory -from typing import Iterator, Literal, Tuple +from functools import wraps +from multiprocessing import resource_tracker, shared_memory +from threading import RLock +from typing import Any, Callable +from uuid import uuid4 import numpy as np import posix_ipc -from pydantic import BaseModel +from bec_server.shared_memory.models import PayloadDescriptor, RingBufferDescriptor -class SlotState(IntEnum): - """State of the data at memory slot.""" +# pylint: disable=c-extension-no-member - READY_TO_WRITE = 0 - WRITING = 1 - READY_TO_READ = 2 - READING = 3 +def not_destroyed(method: Callable[..., Any]) -> Callable[..., Any]: + """Decorator to check if the RingBufferView has been destroyed before allowing method execution.""" -class DTypeDescriptor(BaseModel): - kind: Literal["uint", "int", "float", "bool"] - itemsize: int - byte_order: Literal["little", "big"] = "little" + @wraps(method) + def wrapper(self: RingBufferView, *args: Any, **kwargs: Any) -> Any: + if self.destroyed: + raise RuntimeError( + f"Cannot perform operation on a destroyed {self.__class__.__name__} object with name {self.name!r}." + ) + return method(self, *args, **kwargs) - @classmethod - def from_numpy(cls, dtype: np.dtype) -> "DTypeDescriptor": - """Class method to create DTypeDescriptor from numpy dtype.""" - dtype = np.dtype(dtype) - kind_map = {"u": "uint", "i": "int", "f": "float", "b": "bool"} - if dtype.kind not in kind_map: - raise ValueError(f"Unsupported dtype kind: {dtype.kind!r}") - - byte_order = dtype.byteorder - if byte_order in ("=", "|"): - byte_order = sys.byteorder - elif byte_order == "<": - byte_order = "little" - elif byte_order == ">": - byte_order = "big" - else: - raise ValueError(f"Unsupported byte order: {dtype.byteorder!r}") - - return cls(kind=kind_map[dtype.kind], itemsize=dtype.itemsize, byte_order=byte_order) - - @property - def numpy_dtype(self) -> np.dtype: - """Return the corresponding numpy dtype for this DTypeDescriptor.""" - byte_order_char = {"little": "<", "big": ">"}[self.byte_order] - kind_char = {"uint": "u", "int": "i", "float": "f", "bool": "b"}[self.kind] - dtype_str = f"{byte_order_char}{kind_char}{self.itemsize}" - return np.dtype(dtype_str) - - -class PayloadDescriptor(BaseModel): - """Descriptor for the data payload stored in each slot of the ring buffer.""" - - nbytes: int - shape: Tuple[int, ...] - dtype: DTypeDescriptor - layout: Literal["C"] = "C" - - @classmethod - def from_numpy(cls, array: np.ndarray) -> "PayloadDescriptor": - """Class method to create PayloadDescriptor from a numpy array.""" - return cls( - nbytes=array.nbytes, - shape=array.shape, - dtype=DTypeDescriptor.from_numpy(array.dtype), - layout="C" if array.flags.c_contiguous else "C", - ) - - -class SharedRingBufferDescriptor(BaseModel): - """Descriptor for SharedRingBuffer object.""" - - name: str - lock_id: str - slots: int - bytes_per_slot: int - payload: PayloadDescriptor + return wrapper -class SharedRingBuffer: - """Descriptor for RingBuffer Object to share memory across processes.""" +class RingBufferView: + """ + Class for handling shared RingBuffer objects from clients, which attach to an existing shared memory object + defined by a RingBufferDescriptor. The view can be used to read/write to the buffer without owning the memory. + It can not be used to create new shared memory objects, which is reserved for the RingBuffer class. + """ def __init__( - self, - shm: shared_memory.SharedMemory, - payload: PayloadDescriptor, - slots: int, - bytes_per_slot: int, - owns_memory: bool = False, - lock_id: str | None = None, + self, descriptor: RingBufferDescriptor, shm: shared_memory.SharedMemory | None = None ): - self._shm = shm - self._slots = slots - self._bytes_per_slot = bytes_per_slot - self._owns_memory = owns_memory - self._payload = payload - self._semaphore_lock = ( - posix_ipc.Semaphore(shm.name + "_lock", flags=posix_ipc.O_CREAT, initial_value=1) - if lock_id is None - else posix_ipc.Semaphore(lock_id, flags=0) - ) + self._descriptor = descriptor + self._shm = shm if shm is not None else shared_memory.SharedMemory(name=descriptor.name) + self._owns_memory = shm is None + self._semaphore_lock = posix_ipc.Semaphore(descriptor.lock_id, flags=0) + self.__destroyed = False + self._lock = RLock() + # # TODO: Check why this might be needed, but to be sure to lock is accidently kept. + # self._semaphore_lock.release() + + ############ + # API + ############ + + @not_destroyed + def copy_data(self, index: int, acquire_timeout: float = 0) -> np.ndarray: + """ + Returns a copy of the data at the given slot index as a numpy array. While the data is being copied, + the shared memory is locked to prevent concurrent modifications. Once copied, the shared memory is released. + NOTE: The additional argument acquire_timeout can be used to specify a timeout for acquiring the lock. The + default value of 0 means that it will wait indefinitely until the lock is acquired. If the lock cannot + be acquired within the specified timeout, a TimeoutError will be raised. Please NOTE that this feature + requires the underlying OS to support timeouts for posix semaphores, which is for example not the case for MAC OS. + + Args: + index (int): The slot index to copy data from. + acquire_timeout (float): The timeout in seconds to acquire the lock. If 0, it will wait indefinitely. + + Returns: + np.ndarray: A copy of the data at the specified slot index. + Raises: + TimeoutError: If the lock cannot be acquired within the specified timeout. + """ + with self._lock: + try: + self._semaphore_lock.acquire(timeout=acquire_timeout) + array = np.ndarray( + shape=self.payload_descriptor.shape, + dtype=self.payload_descriptor.dtype.numpy_dtype, + buffer=self._shm.buf, + offset=index * self.bytes_per_slot, + ) + local_copy = array.copy() # Make a local copy of the data + except posix_ipc.BusyError: + # pylint: disable=raise-missing-from + raise TimeoutError( + f"Could not acquire lock for reading from buffer {self.name!r} within {acquire_timeout} seconds." + ) + finally: + self._semaphore_lock.release() + return local_copy + + @not_destroyed + def write_data(self, index: int, data: np.ndarray, acquire_timeout: float = 0) -> None: + """ + Writes the given numpy array data to the specified slot index in the shared memory. While the data is being + written, the shared memory is locked to prevent concurrent modifications. Once the data is written, the shared + memory is released. + NOTE: The additional argument acquire_timeout can be used to specify a timeout for acquiring the lock. The + default value of 0 means that it will wait indefinitely until the lock is acquired. If the lock cannot + be acquired within the specified timeout, a TimeoutError will be raised. Please NOTE that this feature + requires the underlying OS to support timeouts for posix semaphores, which is for example not the case for MAC OS. + + Args: + index (int): The slot index to write data to. + data (np.ndarray): The numpy array data to write to the shared memory. + acquire_timeout (float): The timeout in seconds to acquire the lock. If 0, it will wait indefinitely. + + Raises: + ValueError: If the size of the data does not match the expected size defined by the + payload descriptor. + TimeoutError: If the lock cannot be acquired within the specified timeout. + """ + descriptor = PayloadDescriptor.from_numpy(data) + if descriptor != self.payload_descriptor: + raise ValueError( + f"Data shape/dtype {descriptor.shape}/{descriptor.dtype} does not match expected " + f"shape/dtype {self.payload_descriptor.shape}/{self.payload_descriptor.dtype}" + ) + with self._lock: + try: + self._semaphore_lock.acquire(timeout=acquire_timeout) + array = np.ndarray( + shape=self.payload_descriptor.shape, + dtype=self.payload_descriptor.dtype.numpy_dtype, + buffer=self._shm.buf, + offset=index * self.bytes_per_slot, + ) + np.copyto(array, data) # Copy data into shared memory + except posix_ipc.BusyError: + # pylint: disable=raise-missing-from + raise TimeoutError( + f"Could not acquire lock for reading from buffer {self.name!r} within {acquire_timeout} seconds." + ) + finally: + self._semaphore_lock.release() + + ############ + # Properties + ############ + + @property + def descriptor(self) -> RingBufferDescriptor: + """Return the descriptor for this RingBuffer.""" + return self._descriptor + + @property + def destroyed(self) -> bool: + """Indicates whether the view has been destroyed.""" + return self.__destroyed @property def name(self): """Name of shared ring buffer""" - return self._shm.name + return self._descriptor.name @property def slots(self): """Max Index of shared ring buffer""" - return self._slots + return self._descriptor.slots @property def bytes_per_slot(self): """Bytes per index in shared ring buffer""" - return self._bytes_per_slot + return self._descriptor.bytes_per_slot @property - def payload(self): - return self._payload - - @classmethod - def create(cls, slots: int, payload: PayloadDescriptor | dict) -> SharedRingBuffer: - """Create a new shared memory location and SharedRingBuffer object.""" - if isinstance(payload, dict): - payload = PayloadDescriptor.model_validate(payload) - bytes_per_slot = payload.nbytes - total_size = slots * (bytes_per_slot) - shm = shared_memory.SharedMemory(create=True, size=total_size) - ring_buffer = cls( - shm, slots=slots, bytes_per_slot=bytes_per_slot, payload=payload, owns_memory=True + def payload_descriptor(self): + """Payload descriptor for the data stored in the ring buffer.""" + return self._descriptor.payload + + def destroy(self): + """ + Destroy the shared memory object. The method can be called multiple times but only the first call will have an effect. + """ + if self.destroyed: + return + with self._lock: + # Semaphore lock + self._semaphore_lock.release() # Make sure to release upon closing to avoid deadlocks if the lock is still held by this process + # Shared memory + self._shm.close() + # Cleanup depends on whether the memory is owned by this view or not. + if self._owns_memory: + self._semaphore_lock.unlink() + self._shm.unlink() + else: + # NOTE: From Python 3.13 onwards, we can use the track=False option when creating the reference + # For views not owning the memory, we have to manually unregister it. + # pylint: disable=protected-access + resource_tracker.unregister(self._shm._name, "shared_memory") + self._semaphore_lock.close() + + # to avoid registering the shared memory with the resource tracker. + self.__destroyed = True + + +class RingBuffer(RingBufferView): + """ + RingBuffer class that owns the shared memory. If created, it will create a new sharedMemory object together with a semaphore lock. + + Args: + slots (int): The number of slots in the ring buffer. + payload (PayloadDescriptor): The descriptor for the data payload stored in each slot of the ring buffer. + name_suffix (str): An optional suffix to append to the shared memory and semaphore names for identification. + """ + + def __init__(self, slots: int, payload: PayloadDescriptor, name_suffix: str = ""): + name = f"bec_psm_{uuid4().hex[:6]}" + shm = shared_memory.SharedMemory( + create=True, + size=slots * payload.nbytes, + name=RingBuffer._name_suffix(name, name_suffix), ) - return ring_buffer - - @classmethod - def attach(cls, descriptor: SharedRingBufferDescriptor) -> SharedRingBuffer: - """Create SharedRingBuffer by attaching to an existing shared memory object by descriptor name.""" - shm = shared_memory.SharedMemory(name=descriptor.name) - return cls( - shm, - slots=descriptor.slots, - bytes_per_slot=descriptor.bytes_per_slot, - payload=descriptor.payload, - lock_id=descriptor.lock_id, + lock_name = f"{name}_lock" + semaphore_lock = posix_ipc.Semaphore( + RingBuffer._name_suffix(lock_name, name_suffix), + flags=posix_ipc.O_CREAT, + initial_value=1, ) - - def descriptor(self) -> SharedRingBufferDescriptor: - """Create a serializable descriptor for this ring buffer.""" - - return SharedRingBufferDescriptor( - name=self.name, - slots=self.slots, - bytes_per_slot=self.bytes_per_slot, - payload=self.payload, - lock_id=self._semaphore_lock.name, + self._descriptor = RingBufferDescriptor( + name=shm.name, + lock_id=semaphore_lock.name, + slots=slots, + bytes_per_slot=payload.nbytes, + payload=payload, ) + super().__init__(descriptor=self._descriptor, shm=shm) - def _data(self, index: int) -> memoryview: - """Return payload memory for one slot.""" - start = self._slot_offset(index) - stop = start + self.payload.nbytes - return self._shm.buf[start:stop] - - def _slot_offset(self, index: int) -> int: - self._validate_index(index) - return index * self.bytes_per_slot - - def _validate_index(self, index: int) -> None: - if not 0 <= index < self.slots: - raise IndexError(f"Index {index} outside valid range 0..{self.slots - 1}.") - - def state(self, index: int) -> SlotState: - """Return the current slot state.""" - offset = self._slot_offset(index) - return SlotState(self._shm.buf[offset]) - - @contextmanager - def read_slot(self, index: int, timeout_lock: float = 0) -> Iterator[memoryview]: - """Read from a slot and mark it writable afterwards.""" - try: - self._semaphore_lock.acquire(timeout=timeout_lock) - yield self._data(index) - finally: - self._semaphore_lock.release() - - @contextmanager - def write_slot(self, index: int, timeout_lock: float = 0) -> Iterator[memoryview]: - """Write to a slot and mark it readable afterwards.""" - try: - self._semaphore_lock.acquire(timeout=timeout_lock) - yield self._data(index) - finally: - self._semaphore_lock.release() - - def close(self): - """Close the shared memory object.""" - self._shm.close() - - def unlink(self): - if not self._owns_memory: - raise RuntimeError(f"Can't unlink memory {self.name} that is not owned by this process") - self.close() - self._shm.unlink() - posix_ipc.unlink_semaphore(self._semaphore_lock.name) - - def shutdown(self): - """Close and unlink the shared memory object if owned.""" - self.close() - if self._owns_memory: - self.unlink() - - -# TODO to be tested, check if semaphore locking works + @classmethod + def _name_suffix(cls, name: str, suffix: str, max_length: int = 63) -> str: + if suffix: + name = f"{name}_{suffix}" + return name[:max_length] From 5afdac054c9f8ecb79c56551dd921a7deecf61ff Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 15 May 2026 13:45:20 +0200 Subject: [PATCH 07/16] w --- bec_server/bec_server/shared_memory/cli/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 bec_server/bec_server/shared_memory/cli/__init__.py diff --git a/bec_server/bec_server/shared_memory/cli/__init__.py b/bec_server/bec_server/shared_memory/cli/__init__.py new file mode 100644 index 000000000..e69de29bb From 5723deba1ffc213e55a5a2e81d98a03a540d1090 Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 15 May 2026 13:45:27 +0200 Subject: [PATCH 08/16] test --- .../tests_shared_memory/test_ring_buffer.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 bec_server/tests/tests_shared_memory/test_ring_buffer.py diff --git a/bec_server/tests/tests_shared_memory/test_ring_buffer.py b/bec_server/tests/tests_shared_memory/test_ring_buffer.py new file mode 100644 index 000000000..6e17e3fb0 --- /dev/null +++ b/bec_server/tests/tests_shared_memory/test_ring_buffer.py @@ -0,0 +1,27 @@ +import numpy as np + +from bec_server.shared_memory.ring_buffer import PayloadDescriptor, SharedRingBuffer + + +def test_shutdown_after_slot_context_releases_exported_memoryview(): + payload = PayloadDescriptor.from_numpy(np.zeros((4,), dtype=np.float64)) + ring_buffer = SharedRingBuffer.create(slots=2, payload=payload) + + with ring_buffer.write_slot(0) as view: + array = np.ndarray(payload.shape, dtype=payload.dtype.numpy_dtype, buffer=view) + array[:] = 1 + + ring_buffer.shutdown() + + +def test_create_uses_fresh_shared_memory_and_lock_names(): + payload = PayloadDescriptor.from_numpy(np.zeros((4,), dtype=np.float64)) + first = SharedRingBuffer.create(slots=2, payload=payload) + second = SharedRingBuffer.create(slots=2, payload=payload) + + try: + assert first.name != second.name + assert first.descriptor().lock_id != second.descriptor().lock_id + finally: + first.shutdown() + second.shutdown() From 9ae28e848390b99d2b2f1b77cf2b62ef349adcd1 Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 15 May 2026 14:02:50 +0200 Subject: [PATCH 09/16] wip --- bec_server/bec_server/shared_memory/client.py | 21 ++++++++++++++++++ .../bec_server/shared_memory/manager.py | 22 ++++++++++--------- .../bec_server/shared_memory/ring_buffer.py | 12 ++-------- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/bec_server/bec_server/shared_memory/client.py b/bec_server/bec_server/shared_memory/client.py index df8b5c285..e70bd00fc 100644 --- a/bec_server/bec_server/shared_memory/client.py +++ b/bec_server/bec_server/shared_memory/client.py @@ -24,6 +24,7 @@ def __init__(self, name: str, connector: RedisConnector): self._signal_to_buffer_mapping: dict[str, str] = ( {} ) # Mapping from signal names to buffer names + self.start() def start(self): """Start the client by subscribing to the shared memory object.""" @@ -118,3 +119,23 @@ def shutdown(self) -> None: view.destroy() self._ring_buffer_views.clear() self._signal_to_buffer_mapping.clear() + self.connector.unregister( + MessageEndpoints.shared_memory_info(self.name), cb=self._handle_info_update + ) + + +if __name__ == "__main__": + import time + + import numpy as np + + from bec_lib.redis_connector import RedisConnector + + array = np.random.rand(5, 5) + connector = RedisConnector(bootstrap="localhost:6379") + client = SharedMemoryClient(name="test_client", connector=connector) + client.request_allocation( + signal_name="test_signal", slots=10, payload_desc=PayloadDescriptor.from_numpy(array) + ) + time.sleep(1) # Wait for the allocation to be processed + print(client._ring_buffer_views) diff --git a/bec_server/bec_server/shared_memory/manager.py b/bec_server/bec_server/shared_memory/manager.py index 7994a913b..0d2d19613 100644 --- a/bec_server/bec_server/shared_memory/manager.py +++ b/bec_server/bec_server/shared_memory/manager.py @@ -1,7 +1,7 @@ from __future__ import annotations import threading -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Tuple from bec_lib import messages from bec_lib.bec_service import BECService @@ -27,16 +27,18 @@ class SharedMemoryManager(BECService): def __init__(self, config, connector_cls: type[RedisConnector]) -> None: super().__init__(config, connector_cls, unique_service=True) - self._shared_memory_objects: dict[str, RingBuffer] = {} + # Shared memory objects are stored in a dictionary with the client_id and signal name tuple as key + # and the RingBuffer instance as value + self._shared_memory_objects: dict[Tuple[str, str], RingBuffer] = {} self.lock = threading.RLock() def _allocate_memory(self, request: messages.SharedMemAllocationRequest) -> None: """Callback function to handle shared memory allocation requests.""" if isinstance(request, dict): request = messages.SharedMemAllocationRequest.model_validate(request) - if request.client_id in self._shared_memory_objects: + if (request.client_id, request.signal) in self._shared_memory_objects: logger.error( - f"Shared memory object for client {request.client_id} already exists. Overwriting." + f"Shared memory object for client {request.client_id} and signal {request.signal} already exists. Overwriting." ) return @@ -44,21 +46,21 @@ def _allocate_memory(self, request: messages.SharedMemAllocationRequest) -> None slots=request.slots, payload=request.payload_desc, name_suffix=request.signal ) with self.lock: - self._shared_memory_objects[request.client_id] = buff + self._shared_memory_objects[(request.client_id, request.signal)] = buff self._publish_allocation_info(client_id=request.client_id) def _deallocate_memory(self, request: messages.SharedMemDeallocationRequest) -> None: """Callback function to handle shared memory deallocation requests.""" if isinstance(request, dict): request = messages.SharedMemDeallocationRequest.model_validate(request) - if request.client_id not in self._shared_memory_objects: + if (request.client_id, request.signal) not in self._shared_memory_objects: logger.error( - f"Shared memory object for client {request.client_id} does not exist. Cannot deallocate." + f"Shared memory object for client {request.client_id} and signal {request.signal} does not exist. Cannot deallocate." ) return with self.lock: - buff = self._shared_memory_objects.pop(request.client_id) + buff = self._shared_memory_objects.pop((request.client_id, request.signal)) buff.destroy() self._publish_allocation_info(client_id=request.client_id) @@ -66,8 +68,8 @@ def _publish_allocation_info(self, client_id: str = "*") -> None: """Publish the updated list of allocated shared memory objects.""" with self.lock: info = [ - SharedMemInfo(client_id=client_id, buffer_desc=buff.descriptor) - for client_id, buff in self._shared_memory_objects.items() + SharedMemInfo(client_id=client_id, buffer_desc=buff.descriptor, signal=signal_name) + for (client_id, signal_name), buff in self._shared_memory_objects.items() ] # Maybe use regex here.. if client_id != "*": diff --git a/bec_server/bec_server/shared_memory/ring_buffer.py b/bec_server/bec_server/shared_memory/ring_buffer.py index 81911a888..4526a6251 100644 --- a/bec_server/bec_server/shared_memory/ring_buffer.py +++ b/bec_server/bec_server/shared_memory/ring_buffer.py @@ -206,17 +206,9 @@ class RingBuffer(RingBufferView): def __init__(self, slots: int, payload: PayloadDescriptor, name_suffix: str = ""): name = f"bec_psm_{uuid4().hex[:6]}" - shm = shared_memory.SharedMemory( - create=True, - size=slots * payload.nbytes, - name=RingBuffer._name_suffix(name, name_suffix), - ) + shm = shared_memory.SharedMemory(create=True, size=slots * payload.nbytes, name=name) lock_name = f"{name}_lock" - semaphore_lock = posix_ipc.Semaphore( - RingBuffer._name_suffix(lock_name, name_suffix), - flags=posix_ipc.O_CREAT, - initial_value=1, - ) + semaphore_lock = posix_ipc.Semaphore(lock_name, flags=posix_ipc.O_CREAT, initial_value=1) self._descriptor = RingBufferDescriptor( name=shm.name, lock_id=semaphore_lock.name, From 91b32ee0ed05e6d2216415896af3076a0bd8dbb7 Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 15 May 2026 14:35:18 +0200 Subject: [PATCH 10/16] wip --- bec_server/bec_server/shared_memory/client.py | 12 +++++------- bec_server/bec_server/shared_memory/manager.py | 4 ++++ bec_server/bec_server/shared_memory/ring_buffer.py | 8 ++++++++ 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/bec_server/bec_server/shared_memory/client.py b/bec_server/bec_server/shared_memory/client.py index e70bd00fc..caac023b0 100644 --- a/bec_server/bec_server/shared_memory/client.py +++ b/bec_server/bec_server/shared_memory/client.py @@ -3,13 +3,14 @@ from typing import TYPE_CHECKING from bec_lib.endpoints import MessageEndpoints -from bec_lib.messages import SharedMemAllocationInfo from bec_server.shared_memory.models import PayloadDescriptor from bec_server.shared_memory.ring_buffer import RingBufferView if TYPE_CHECKING: import numpy as np + from bec_lib.connector import MessageObject + from bec_lib.messages import SharedMemAllocationInfo from bec_lib.redis_connector import RedisConnector @@ -32,10 +33,9 @@ def start(self): MessageEndpoints.shared_memory_info(self.name), cb=self._handle_info_update ) - def _handle_info_update(self, info: SharedMemAllocationInfo) -> None: + def _handle_info_update(self, info: MessageObject) -> None: """Handle updates to the shared memory information.""" - if isinstance(info, dict): - info = SharedMemAllocationInfo.model_validate(info) + info: SharedMemAllocationInfo = info.value # Any info update can potentially contain relevant information for creating or deleting ring buffer views. info_updates = [] for buff_info in info.info: @@ -44,9 +44,7 @@ def _handle_info_update(self, info: SharedMemAllocationInfo) -> None: self._ring_buffer_views[buff_info.buffer_desc.name] = RingBufferView( descriptor=buff_info.buffer_desc ) - self._signal_to_buffer_mapping[buff_info.buffer_desc.signal_name] = ( - buff_info.buffer_desc.name - ) + self._signal_to_buffer_mapping[buff_info.signal] = buff_info.buffer_desc.name if len(info.info) < len(self._ring_buffer_views): # Some shared memory objects have been deallocated. Remove them from the local dictionary. to_be_removed = set(self._ring_buffer_views.keys()) - set(info_updates) diff --git a/bec_server/bec_server/shared_memory/manager.py b/bec_server/bec_server/shared_memory/manager.py index 0d2d19613..335b9f83d 100644 --- a/bec_server/bec_server/shared_memory/manager.py +++ b/bec_server/bec_server/shared_memory/manager.py @@ -40,6 +40,8 @@ def _allocate_memory(self, request: messages.SharedMemAllocationRequest) -> None logger.error( f"Shared memory object for client {request.client_id} and signal {request.signal} already exists. Overwriting." ) + # TODO should this republish the info? + self._publish_allocation_info(client_id=request.client_id) return buff = RingBuffer( @@ -57,6 +59,8 @@ def _deallocate_memory(self, request: messages.SharedMemDeallocationRequest) -> logger.error( f"Shared memory object for client {request.client_id} and signal {request.signal} does not exist. Cannot deallocate." ) + # TODO should this republish the info? + self._publish_allocation_info(client_id=request.client_id) return with self.lock: diff --git a/bec_server/bec_server/shared_memory/ring_buffer.py b/bec_server/bec_server/shared_memory/ring_buffer.py index 4526a6251..ec7f4c995 100644 --- a/bec_server/bec_server/shared_memory/ring_buffer.py +++ b/bec_server/bec_server/shared_memory/ring_buffer.py @@ -70,6 +70,10 @@ def copy_data(self, index: int, acquire_timeout: float = 0) -> np.ndarray: Raises: TimeoutError: If the lock cannot be acquired within the specified timeout. """ + if index < 0 or index >= self.slots: + raise IndexError( + f"Index {index} is out of bounds for ring buffer with {self.slots} slots." + ) with self._lock: try: self._semaphore_lock.acquire(timeout=acquire_timeout) @@ -110,6 +114,10 @@ def write_data(self, index: int, data: np.ndarray, acquire_timeout: float = 0) - payload descriptor. TimeoutError: If the lock cannot be acquired within the specified timeout. """ + if index < 0 or index >= self.slots: + raise IndexError( + f"Index {index} is out of bounds for ring buffer with {self.slots} slots." + ) descriptor = PayloadDescriptor.from_numpy(data) if descriptor != self.payload_descriptor: raise ValueError( From fccac9f8942c426693bc4b594ceff7230b3cf6b5 Mon Sep 17 00:00:00 2001 From: appel_c Date: Tue, 19 May 2026 10:03:58 +0200 Subject: [PATCH 11/16] wip --- bec_lib/bec_lib/endpoints.py | 4 +- bec_lib/bec_lib/messages.py | 3 +- bec_server/bec_server/shared_memory/client.py | 53 +++++++++---------- .../bec_server/shared_memory/manager.py | 32 +++++------ 4 files changed, 45 insertions(+), 47 deletions(-) diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index 58b46c307..3a5033713 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -84,7 +84,7 @@ class MessageEndpoints: """ @staticmethod - def shared_memory_info(client_id: str): + def shared_memory_info(): """ Endpoint for shared memory information. This endpoint is used to publish the shared memory information using a messages.SharedMemAllocationInfo message. @@ -92,7 +92,7 @@ def shared_memory_info(client_id: str): Returns: EndpointInfo: Endpoint for shared memory information. """ - endpoint = f"{EndpointType.INFO.value}/shared_memory/info/{client_id}" + endpoint = f"{EndpointType.INFO.value}/shared_memory/info/" return EndpointInfo( endpoint=endpoint, message_type=messages.SharedMemAllocationInfo, diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index 36e1a2f19..c181027a0 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -105,7 +105,8 @@ class SharedMemAllocationInfo(BECMessage): msg_type: ClassVar[str] = "shared_mem_allocation_info" - info: list[SharedMemInfo] + # Consider structure, nested dict with client_id as key, and dict with signal name and ShareMemInfo as value + info: dict[str, dict[str, SharedMemInfo]] class SharedMemAllocationRequest(BECMessage): diff --git a/bec_server/bec_server/shared_memory/client.py b/bec_server/bec_server/shared_memory/client.py index caac023b0..413b84233 100644 --- a/bec_server/bec_server/shared_memory/client.py +++ b/bec_server/bec_server/shared_memory/client.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from bec_lib.endpoints import MessageEndpoints +from bec_lib.logger import bec_logger from bec_server.shared_memory.models import PayloadDescriptor from bec_server.shared_memory.ring_buffer import RingBufferView @@ -13,6 +14,8 @@ from bec_lib.messages import SharedMemAllocationInfo from bec_lib.redis_connector import RedisConnector +logger = bec_logger.logger + # TODO one per service, or N per service. class SharedMemoryClient: @@ -21,37 +24,35 @@ class SharedMemoryClient: def __init__(self, name: str, connector: RedisConnector): self.name = name self.connector = connector + # signal name to ring buffer view mapping self._ring_buffer_views: dict[str, RingBufferView] = {} - self._signal_to_buffer_mapping: dict[str, str] = ( - {} - ) # Mapping from signal names to buffer names self.start() def start(self): """Start the client by subscribing to the shared memory object.""" - self.connector.register( - MessageEndpoints.shared_memory_info(self.name), cb=self._handle_info_update - ) + self.connector.register(MessageEndpoints.shared_memory_info(), cb=self._handle_info_update) def _handle_info_update(self, info: MessageObject) -> None: """Handle updates to the shared memory information.""" info: SharedMemAllocationInfo = info.value # Any info update can potentially contain relevant information for creating or deleting ring buffer views. info_updates = [] - for buff_info in info.info: - info_updates.append(buff_info.buffer_desc.name) - if buff_info.buffer_desc.name not in self._ring_buffer_views: - self._ring_buffer_views[buff_info.buffer_desc.name] = RingBufferView( - descriptor=buff_info.buffer_desc + client_info = info.info.get(self.name, {}) + + for signal, buff_info in client_info.items(): + info_updates.append(signal) + if signal not in self._ring_buffer_views: # + self._ring_buffer_views[signal] = RingBufferView(descriptor=buff_info.buffer_desc) + else: + logger.error( + f"Ring buffer view for signal {signal} already exists, should not happend. Received info update: {buff_info}" ) - self._signal_to_buffer_mapping[buff_info.signal] = buff_info.buffer_desc.name - if len(info.info) < len(self._ring_buffer_views): + if len(client_info) < len(self._ring_buffer_views): # Some shared memory objects have been deallocated. Remove them from the local dictionary. to_be_removed = set(self._ring_buffer_views.keys()) - set(info_updates) for name in to_be_removed: view = self._ring_buffer_views.pop(name) view.destroy() - self._signal_to_buffer_mapping.pop(view.descriptor.signal_name, None) def request_allocation( self, signal_name: str, slots: int, payload_desc: PayloadDescriptor | dict @@ -87,12 +88,13 @@ def read_from_buffer( If timeout is provided, the method will wait for the specified time and raise a TimeoutError if it cannot read the data within that time frame. Please be aware, this is meant to block during write/read operations. """ - if signal_name not in self._signal_to_buffer_mapping: + # TODO add option to wait receiving an update on a specific signal in the buffer + # Also block until there is an update on the specific index in the buffer. + # Should there be a consume logic??? + buff = self._ring_buffer_views.get(signal_name) + if buff is None: raise ValueError(f"No buffer found for signal name: {signal_name}") - buffer_name = self._signal_to_buffer_mapping[signal_name] - if buffer_name not in self._ring_buffer_views: - raise ValueError(f"No ring buffer view found for buffer name: {buffer_name}") - return self._ring_buffer_views[buffer_name].copy_data(index, timeout) + return buff.copy_data(index, timeout) def write_to_buffer( self, signal_name: str, index: int, data: np.ndarray, timeout: float | None = None @@ -102,23 +104,18 @@ def write_to_buffer( If timeout is provided, the method will wait for the specified time and raise a TimeoutError if it cannot write the data within that time frame. Please be aware, this is meant to block during write/read operations. """ - if signal_name not in self._signal_to_buffer_mapping: + buff = self._ring_buffer_views.get(signal_name) + if buff is None: raise ValueError(f"No buffer found for signal name: {signal_name}") - buffer_name = self._signal_to_buffer_mapping[signal_name] - if buffer_name not in self._ring_buffer_views: - raise ValueError(f"No ring buffer view found for buffer name: {buffer_name}") - self._ring_buffer_views[buffer_name].write_data( - index=index, data=data, acquire_timeout=timeout - ) + buff.write_data(index=index, data=data, acquire_timeout=timeout) def shutdown(self) -> None: """Clean up resources and all shared memory views.""" for view in self._ring_buffer_views.values(): view.destroy() self._ring_buffer_views.clear() - self._signal_to_buffer_mapping.clear() self.connector.unregister( - MessageEndpoints.shared_memory_info(self.name), cb=self._handle_info_update + MessageEndpoints.shared_memory_info(), cb=self._handle_info_update ) diff --git a/bec_server/bec_server/shared_memory/manager.py b/bec_server/bec_server/shared_memory/manager.py index 335b9f83d..fad2412dc 100644 --- a/bec_server/bec_server/shared_memory/manager.py +++ b/bec_server/bec_server/shared_memory/manager.py @@ -1,6 +1,7 @@ from __future__ import annotations import threading +from collections import defaultdict from typing import TYPE_CHECKING, Literal, Tuple from bec_lib import messages @@ -30,6 +31,9 @@ def __init__(self, config, connector_cls: type[RedisConnector]) -> None: # Shared memory objects are stored in a dictionary with the client_id and signal name tuple as key # and the RingBuffer instance as value self._shared_memory_objects: dict[Tuple[str, str], RingBuffer] = {} + self._shared_memory_info: dict[str, dict[str, SharedMemInfo]] = defaultdict( + dict + ) # Nested dict with client_id as key, and dict with signal name and ShareMemInfo as value self.lock = threading.RLock() def _allocate_memory(self, request: messages.SharedMemAllocationRequest) -> None: @@ -41,7 +45,7 @@ def _allocate_memory(self, request: messages.SharedMemAllocationRequest) -> None f"Shared memory object for client {request.client_id} and signal {request.signal} already exists. Overwriting." ) # TODO should this republish the info? - self._publish_allocation_info(client_id=request.client_id) + self._publish_allocation_info(self._shared_memory_info) return buff = RingBuffer( @@ -49,7 +53,10 @@ def _allocate_memory(self, request: messages.SharedMemAllocationRequest) -> None ) with self.lock: self._shared_memory_objects[(request.client_id, request.signal)] = buff - self._publish_allocation_info(client_id=request.client_id) + self._shared_memory_info[request.client_id][request.signal] = SharedMemInfo( + client_id=request.client_id, buffer_desc=buff.descriptor, signal=request.signal + ) + self._publish_allocation_info(self._shared_memory_info) def _deallocate_memory(self, request: messages.SharedMemDeallocationRequest) -> None: """Callback function to handle shared memory deallocation requests.""" @@ -60,27 +67,19 @@ def _deallocate_memory(self, request: messages.SharedMemDeallocationRequest) -> f"Shared memory object for client {request.client_id} and signal {request.signal} does not exist. Cannot deallocate." ) # TODO should this republish the info? - self._publish_allocation_info(client_id=request.client_id) + self._publish_allocation_info(self._shared_memory_info) return with self.lock: buff = self._shared_memory_objects.pop((request.client_id, request.signal)) buff.destroy() - self._publish_allocation_info(client_id=request.client_id) + self._shared_memory_info[request.client_id].pop(request.signal, None) + self._publish_allocation_info(self._shared_memory_info) - def _publish_allocation_info(self, client_id: str = "*") -> None: + def _publish_allocation_info(self, info: dict[str, dict[str, SharedMemInfo]]) -> None: """Publish the updated list of allocated shared memory objects.""" - with self.lock: - info = [ - SharedMemInfo(client_id=client_id, buffer_desc=buff.descriptor, signal=signal_name) - for (client_id, signal_name), buff in self._shared_memory_objects.items() - ] - # Maybe use regex here.. - if client_id != "*": - info = [buff_info for buff_info in info if buff_info.client_id == client_id] self.connector.set_and_publish( - MessageEndpoints.shared_memory_info(client_id), - messages.SharedMemAllocationInfo(info=info), + MessageEndpoints.shared_memory_info(), messages.SharedMemAllocationInfo(info=info) ) def start(self) -> None: @@ -95,7 +94,8 @@ def stop(self) -> None: for buff in self._shared_memory_objects.values(): buff.destroy() self._shared_memory_objects.clear() - self._publish_allocation_info() + self._shared_memory_info.clear() + self._publish_allocation_info({}) # Cleanup bec service related resources def shutdown(self) -> None: From 685386bc7e544b4d72e4b70bf7b070b7f7ff5a1b Mon Sep 17 00:00:00 2001 From: appel_c Date: Tue, 19 May 2026 10:05:40 +0200 Subject: [PATCH 12/16] wip --- bec_lib/bec_lib/messages.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index c181027a0..d242ff470 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -105,7 +105,8 @@ class SharedMemAllocationInfo(BECMessage): msg_type: ClassVar[str] = "shared_mem_allocation_info" - # Consider structure, nested dict with client_id as key, and dict with signal name and ShareMemInfo as value + # Consider structure with dict[str, SharedMemInfo]. signal dotted name as key, which allows to identify this directly + # Alternatively, dict[str, dict[str, SharedMemInfo]] with device name as key, and then signal name as nested key info: dict[str, dict[str, SharedMemInfo]] From 4fed647870e153d01e0c7b77677866b4d1583b05 Mon Sep 17 00:00:00 2001 From: appel_c Date: Wed, 27 May 2026 14:37:40 +0200 Subject: [PATCH 13/16] WIP agent code #1: characterize current ring buffer behavior --- .../tests_shared_memory/test_ring_buffer.py | 148 ++++++++++++++++-- 1 file changed, 133 insertions(+), 15 deletions(-) diff --git a/bec_server/tests/tests_shared_memory/test_ring_buffer.py b/bec_server/tests/tests_shared_memory/test_ring_buffer.py index 6e17e3fb0..a7a6ce043 100644 --- a/bec_server/tests/tests_shared_memory/test_ring_buffer.py +++ b/bec_server/tests/tests_shared_memory/test_ring_buffer.py @@ -1,27 +1,145 @@ +from multiprocessing import shared_memory + import numpy as np +import posix_ipc +import pytest + +from bec_server.shared_memory.models import PayloadDescriptor +from bec_server.shared_memory.ring_buffer import RingBuffer, RingBufferView + + +@pytest.fixture +def payload() -> PayloadDescriptor: + return PayloadDescriptor.from_numpy(np.zeros((4,), dtype=np.float64)) + -from bec_server.shared_memory.ring_buffer import PayloadDescriptor, SharedRingBuffer +@pytest.fixture +def ring_buffer(payload: PayloadDescriptor): + buffer = RingBuffer(slots=2, payload=payload) + yield buffer + _cleanup_owned_buffer(buffer) -def test_shutdown_after_slot_context_releases_exported_memoryview(): - payload = PayloadDescriptor.from_numpy(np.zeros((4,), dtype=np.float64)) - ring_buffer = SharedRingBuffer.create(slots=2, payload=payload) +def _close_view(view: RingBufferView) -> None: + view._shm.close() + view._semaphore_lock.close() + + +def _cleanup_owned_buffer(buffer: RingBuffer) -> None: + buffer._shm.close() + try: + buffer._shm.unlink() + except FileNotFoundError: + pass + try: + buffer._semaphore_lock.close() + except OSError: + pass + try: + posix_ipc.unlink_semaphore(buffer.descriptor.lock_id) + except posix_ipc.ExistentialError: + pass - with ring_buffer.write_slot(0) as view: - array = np.ndarray(payload.shape, dtype=payload.dtype.numpy_dtype, buffer=view) - array[:] = 1 - ring_buffer.shutdown() +def test_descriptor_describes_allocated_payload_slots( + ring_buffer: RingBuffer, payload: PayloadDescriptor +): + assert ring_buffer.descriptor.name == ring_buffer.name + assert ring_buffer.descriptor.slots == 2 + assert ring_buffer.descriptor.bytes_per_slot == payload.nbytes + assert ring_buffer.descriptor.payload == payload + assert ring_buffer.slots == 2 + assert ring_buffer.bytes_per_slot == payload.nbytes -def test_create_uses_fresh_shared_memory_and_lock_names(): - payload = PayloadDescriptor.from_numpy(np.zeros((4,), dtype=np.float64)) - first = SharedRingBuffer.create(slots=2, payload=payload) - second = SharedRingBuffer.create(slots=2, payload=payload) +def test_write_and_copy_data_round_trip(ring_buffer: RingBuffer): + data = np.arange(4, dtype=np.float64) + + ring_buffer.write_data(1, data) + + np.testing.assert_array_equal(ring_buffer.copy_data(1), data) + data[:] = -1 + np.testing.assert_array_equal(ring_buffer.copy_data(1), np.arange(4, dtype=np.float64)) + + +def test_attached_view_shares_payload_storage(ring_buffer: RingBuffer): + view = RingBufferView(ring_buffer.descriptor) + try: + written_from_view = np.array([1, 2, 3, 4], dtype=np.float64) + written_from_owner = np.array([5, 6, 7, 8], dtype=np.float64) + + view.write_data(0, written_from_view) + np.testing.assert_array_equal(ring_buffer.copy_data(0), written_from_view) + + ring_buffer.write_data(1, written_from_owner) + np.testing.assert_array_equal(view.copy_data(1), written_from_owner) + finally: + # Calling destroy() here currently unlinks the owner's resources; see the xfail below. + _close_view(view) + +def test_each_buffer_has_distinct_shared_memory_and_lock_names(payload: PayloadDescriptor): + first = RingBuffer(slots=2, payload=payload) + second = RingBuffer(slots=2, payload=payload) try: assert first.name != second.name - assert first.descriptor().lock_id != second.descriptor().lock_id + assert first.descriptor.lock_id != second.descriptor.lock_id + finally: + _cleanup_owned_buffer(first) + _cleanup_owned_buffer(second) + + +@pytest.mark.parametrize("index", [-1, 2]) +def test_copy_data_rejects_indices_outside_slots(ring_buffer: RingBuffer, index: int): + with pytest.raises(IndexError, match="out of bounds"): + ring_buffer.copy_data(index) + + +@pytest.mark.parametrize("index", [-1, 2]) +def test_write_data_rejects_indices_outside_slots(ring_buffer: RingBuffer, index: int): + with pytest.raises(IndexError, match="out of bounds"): + ring_buffer.write_data(index, np.zeros((4,), dtype=np.float64)) + + +@pytest.mark.parametrize( + "data", [np.zeros((2,), dtype=np.float64), np.zeros((4,), dtype=np.float32)] +) +def test_write_data_rejects_payload_shape_or_dtype_mismatch( + ring_buffer: RingBuffer, data: np.ndarray +): + with pytest.raises(ValueError, match="does not match expected"): + ring_buffer.write_data(0, data) + + +def test_destroy_is_idempotent_and_rejects_further_operations( + ring_buffer: RingBuffer, payload: PayloadDescriptor +): + ring_buffer.destroy() + ring_buffer.destroy() + + with pytest.raises(RuntimeError, match="destroyed"): + ring_buffer.write_data(0, np.zeros(payload.shape, dtype=payload.dtype.numpy_dtype)) + + +@pytest.mark.xfail(reason="RingBuffer and RingBufferView currently invert shared-memory ownership") +def test_only_creator_owns_shared_memory_resources(ring_buffer: RingBuffer): + view = RingBufferView(ring_buffer.descriptor) + try: + assert ring_buffer._owns_memory is True + assert view._owns_memory is False + finally: + _close_view(view) + + +@pytest.mark.xfail(reason="RingBufferView.destroy currently unlinks resources owned by RingBuffer") +def test_destroying_view_keeps_owner_resources_attachable(ring_buffer: RingBuffer): + view = RingBufferView(ring_buffer.descriptor) + view.destroy() + + attached = None + try: + attached = shared_memory.SharedMemory(name=ring_buffer.name) + assert attached.name == ring_buffer.name finally: - first.shutdown() - second.shutdown() + if attached is not None: + attached.close() From 7f14996f0f50e9a81bb9dfc2a72bd84707f52fa3 Mon Sep 17 00:00:00 2001 From: appel_c Date: Wed, 27 May 2026 15:05:41 +0200 Subject: [PATCH 14/16] WIP agent code #2: refactor ring buffer metadata and slot locking --- bec_server/bec_server/shared_memory/client.py | 15 +- bec_server/bec_server/shared_memory/models.py | 6 +- .../bec_server/shared_memory/ring_buffer.py | 404 +++++++++++------- .../tests_shared_memory/test_ring_buffer.py | 154 ++++--- 4 files changed, 358 insertions(+), 221 deletions(-) diff --git a/bec_server/bec_server/shared_memory/client.py b/bec_server/bec_server/shared_memory/client.py index 413b84233..674d86ba6 100644 --- a/bec_server/bec_server/shared_memory/client.py +++ b/bec_server/bec_server/shared_memory/client.py @@ -52,7 +52,7 @@ def _handle_info_update(self, info: MessageObject) -> None: to_be_removed = set(self._ring_buffer_views.keys()) - set(info_updates) for name in to_be_removed: view = self._ring_buffer_views.pop(name) - view.destroy() + view.close() def request_allocation( self, signal_name: str, slots: int, payload_desc: PayloadDescriptor | dict @@ -97,22 +97,25 @@ def read_from_buffer( return buff.copy_data(index, timeout) def write_to_buffer( - self, signal_name: str, index: int, data: np.ndarray, timeout: float | None = None - ) -> None: + self, signal_name: str, data: np.ndarray, timeout: float | None = None + ) -> int: """ - Write data to the shared memory buffer associated with the given signal name. + Write data to the next ring position associated with the given signal name. If timeout is provided, the method will wait for the specified time and raise a TimeoutError if it cannot write the data within that time frame. Please be aware, this is meant to block during write/read operations. + + Returns: + int: The slot index containing the newly written payload. """ buff = self._ring_buffer_views.get(signal_name) if buff is None: raise ValueError(f"No buffer found for signal name: {signal_name}") - buff.write_data(index=index, data=data, acquire_timeout=timeout) + return buff.write_data(data=data, acquire_timeout=timeout) def shutdown(self) -> None: """Clean up resources and all shared memory views.""" for view in self._ring_buffer_views.values(): - view.destroy() + view.close() self._ring_buffer_views.clear() self.connector.unregister( MessageEndpoints.shared_memory_info(), cb=self._handle_info_update diff --git a/bec_server/bec_server/shared_memory/models.py b/bec_server/bec_server/shared_memory/models.py index fe403a0c6..038afe78e 100644 --- a/bec_server/bec_server/shared_memory/models.py +++ b/bec_server/bec_server/shared_memory/models.py @@ -73,12 +73,12 @@ def from_numpy(cls, array: np.ndarray) -> PayloadDescriptor: class RingBufferDescriptor(BaseModel): - """Descriptor for SharedRingBuffer object.""" + """Information required to attach to a shared ring buffer.""" name: str - lock_id: str + metadata_lock_id: str + slot_lock_ids: Tuple[str, ...] slots: int - bytes_per_slot: int payload: PayloadDescriptor diff --git a/bec_server/bec_server/shared_memory/ring_buffer.py b/bec_server/bec_server/shared_memory/ring_buffer.py index ec7f4c995..b4669d94a 100644 --- a/bec_server/bec_server/shared_memory/ring_buffer.py +++ b/bec_server/bec_server/shared_memory/ring_buffer.py @@ -1,21 +1,33 @@ from __future__ import annotations +import struct +from contextlib import contextmanager from functools import wraps -from multiprocessing import resource_tracker, shared_memory +from multiprocessing import shared_memory from threading import RLock -from typing import Any, Callable +from typing import Any, Callable, Iterator from uuid import uuid4 import numpy as np import posix_ipc -from bec_server.shared_memory.models import PayloadDescriptor, RingBufferDescriptor +from bec_server.shared_memory.models import DTypeDescriptor, PayloadDescriptor, RingBufferDescriptor # pylint: disable=c-extension-no-member +MAGIC = b"BEC_RING" +MAX_NDIM = 8 +_HEADER = struct.Struct("<8sIQIBBBB8Q") +HEADER_SIZE = _HEADER.size + +_KIND_TO_CODE = {"uint": 1, "int": 2, "float": 3, "bool": 4} +_CODE_TO_KIND = {code: kind for kind, code in _KIND_TO_CODE.items()} +_BYTE_ORDER_TO_CODE = {"little": 1, "big": 2} +_CODE_TO_BYTE_ORDER = {code: order for order, code in _BYTE_ORDER_TO_CODE.items()} + def not_destroyed(method: Callable[..., Any]) -> Callable[..., Any]: - """Decorator to check if the RingBufferView has been destroyed before allowing method execution.""" + """Check that a shared-memory handle is still open before accessing it.""" @wraps(method) def wrapper(self: RingBufferView, *args: Any, **kwargs: Any) -> Any: @@ -29,202 +41,284 @@ def wrapper(self: RingBufferView, *args: Any, **kwargs: Any) -> Any: class RingBufferView: - """ - Class for handling shared RingBuffer objects from clients, which attach to an existing shared memory object - defined by a RingBufferDescriptor. The view can be used to read/write to the buffer without owning the memory. - It can not be used to create new shared memory objects, which is reserved for the RingBuffer class. - """ + """Attached handle for accessing a ring buffer without owning its resources.""" def __init__( - self, descriptor: RingBufferDescriptor, shm: shared_memory.SharedMemory | None = None + self, + descriptor: RingBufferDescriptor, + shm: shared_memory.SharedMemory | None = None, + *, + owns_memory: bool = False, ): + if len(descriptor.slot_lock_ids) != descriptor.slots: + raise ValueError("Ring buffer descriptor must provide exactly one lock per slot.") self._descriptor = descriptor self._shm = shm if shm is not None else shared_memory.SharedMemory(name=descriptor.name) - self._owns_memory = shm is None - self._semaphore_lock = posix_ipc.Semaphore(descriptor.lock_id, flags=0) + self._owns_memory = owns_memory + self._metadata_lock = posix_ipc.Semaphore(descriptor.metadata_lock_id, flags=0) + self._slot_locks = [ + posix_ipc.Semaphore(lock_id, flags=0) for lock_id in descriptor.slot_lock_ids + ] self.__destroyed = False - self._lock = RLock() - # # TODO: Check why this might be needed, but to be sure to lock is accidently kept. - # self._semaphore_lock.release() - - ############ - # API - ############ + self._lifecycle_lock = RLock() + try: + self._slots, self._payload_descriptor = self._read_header() + if self._slots != descriptor.slots or self._payload_descriptor != descriptor.payload: + raise ValueError("Ring buffer descriptor does not match shared-memory metadata.") + except Exception: + self._close_handles() + raise + + @staticmethod + def _encode_header( + slots: int, payload: PayloadDescriptor, next_write_position: int = 0 + ) -> bytes: + if not 0 < slots: + raise ValueError("Ring buffer must contain at least one slot.") + if len(payload.shape) > MAX_NDIM: + raise ValueError(f"Ring buffer payload supports at most {MAX_NDIM} dimensions.") + dimensions = (*payload.shape, *((0,) * (MAX_NDIM - len(payload.shape)))) + return _HEADER.pack( + MAGIC, + slots, + payload.nbytes, + next_write_position, + _KIND_TO_CODE[payload.dtype.kind], + payload.dtype.itemsize, + _BYTE_ORDER_TO_CODE[payload.dtype.byte_order], + len(payload.shape), + *dimensions, + ) - @not_destroyed - def copy_data(self, index: int, acquire_timeout: float = 0) -> np.ndarray: - """ - Returns a copy of the data at the given slot index as a numpy array. While the data is being copied, - the shared memory is locked to prevent concurrent modifications. Once copied, the shared memory is released. - NOTE: The additional argument acquire_timeout can be used to specify a timeout for acquiring the lock. The - default value of 0 means that it will wait indefinitely until the lock is acquired. If the lock cannot - be acquired within the specified timeout, a TimeoutError will be raised. Please NOTE that this feature - requires the underlying OS to support timeouts for posix semaphores, which is for example not the case for MAC OS. - - Args: - index (int): The slot index to copy data from. - acquire_timeout (float): The timeout in seconds to acquire the lock. If 0, it will wait indefinitely. - - Returns: - np.ndarray: A copy of the data at the specified slot index. - Raises: - TimeoutError: If the lock cannot be acquired within the specified timeout. - """ - if index < 0 or index >= self.slots: - raise IndexError( - f"Index {index} is out of bounds for ring buffer with {self.slots} slots." + def _read_header(self) -> tuple[int, PayloadDescriptor]: + ( + magic, + slots, + bytes_per_slot, + _next_write_position, + kind_code, + itemsize, + byte_order_code, + ndim, + *dimensions, + ) = _HEADER.unpack_from(self._shm.buf) + if magic != MAGIC: + raise ValueError("Shared memory does not contain a BEC ring buffer header.") + if ndim > MAX_NDIM: + raise ValueError("Shared memory contains an invalid ring buffer payload shape.") + try: + dtype = DTypeDescriptor( + kind=_CODE_TO_KIND[kind_code], + itemsize=itemsize, + byte_order=_CODE_TO_BYTE_ORDER[byte_order_code], ) - with self._lock: - try: - self._semaphore_lock.acquire(timeout=acquire_timeout) - array = np.ndarray( - shape=self.payload_descriptor.shape, - dtype=self.payload_descriptor.dtype.numpy_dtype, - buffer=self._shm.buf, - offset=index * self.bytes_per_slot, - ) - local_copy = array.copy() # Make a local copy of the data - except posix_ipc.BusyError: - # pylint: disable=raise-missing-from - raise TimeoutError( - f"Could not acquire lock for reading from buffer {self.name!r} within {acquire_timeout} seconds." - ) - finally: - self._semaphore_lock.release() - return local_copy - - @not_destroyed - def write_data(self, index: int, data: np.ndarray, acquire_timeout: float = 0) -> None: - """ - Writes the given numpy array data to the specified slot index in the shared memory. While the data is being - written, the shared memory is locked to prevent concurrent modifications. Once the data is written, the shared - memory is released. - NOTE: The additional argument acquire_timeout can be used to specify a timeout for acquiring the lock. The - default value of 0 means that it will wait indefinitely until the lock is acquired. If the lock cannot - be acquired within the specified timeout, a TimeoutError will be raised. Please NOTE that this feature - requires the underlying OS to support timeouts for posix semaphores, which is for example not the case for MAC OS. - - Args: - index (int): The slot index to write data to. - data (np.ndarray): The numpy array data to write to the shared memory. - acquire_timeout (float): The timeout in seconds to acquire the lock. If 0, it will wait indefinitely. - - Raises: - ValueError: If the size of the data does not match the expected size defined by the - payload descriptor. - TimeoutError: If the lock cannot be acquired within the specified timeout. - """ + except KeyError as exc: + raise ValueError( + "Shared memory contains an invalid ring buffer payload dtype." + ) from exc + payload = PayloadDescriptor( + nbytes=bytes_per_slot, shape=tuple(dimensions[:ndim]), dtype=dtype, layout="C" + ) + if ( + payload.nbytes + != int(np.prod(payload.shape, dtype=np.int64)) * dtype.numpy_dtype.itemsize + ): + raise ValueError("Shared memory contains inconsistent ring buffer payload metadata.") + return slots, payload + + def _write_next_position(self, position: int) -> None: + struct.pack_into(" int: + return struct.unpack_from(" Iterator[None]: + acquired = False + try: + semaphore.acquire(timeout=None if timeout in (None, 0) else timeout) + acquired = True + yield + except posix_ipc.BusyError: + raise TimeoutError( + f"Could not acquire lock for {operation} buffer {self.name!r} within {timeout} seconds." + ) from None + finally: + if acquired: + semaphore.release() + + def _validate_index(self, index: int) -> None: if index < 0 or index >= self.slots: raise IndexError( f"Index {index} is out of bounds for ring buffer with {self.slots} slots." ) + + def _validate_payload(self, data: np.ndarray) -> None: descriptor = PayloadDescriptor.from_numpy(data) if descriptor != self.payload_descriptor: raise ValueError( f"Data shape/dtype {descriptor.shape}/{descriptor.dtype} does not match expected " f"shape/dtype {self.payload_descriptor.shape}/{self.payload_descriptor.dtype}" ) - with self._lock: - try: - self._semaphore_lock.acquire(timeout=acquire_timeout) - array = np.ndarray( - shape=self.payload_descriptor.shape, - dtype=self.payload_descriptor.dtype.numpy_dtype, - buffer=self._shm.buf, - offset=index * self.bytes_per_slot, - ) - np.copyto(array, data) # Copy data into shared memory - except posix_ipc.BusyError: - # pylint: disable=raise-missing-from - raise TimeoutError( - f"Could not acquire lock for reading from buffer {self.name!r} within {acquire_timeout} seconds." - ) - finally: - self._semaphore_lock.release() - ############ - # Properties - ############ + def _array_for_slot(self, index: int) -> np.ndarray: + return np.ndarray( + shape=self.payload_descriptor.shape, + dtype=self.payload_descriptor.dtype.numpy_dtype, + buffer=self._shm.buf, + offset=HEADER_SIZE + index * self.bytes_per_slot, + ) + + def _claim_next_write_position(self, acquire_timeout: float | None = 0) -> int: + with self._acquire(self._metadata_lock, acquire_timeout, "selecting a write slot in"): + index = self._read_next_position() + self._write_next_position((index + 1) % self.slots) + return index + + @not_destroyed + def copy_data(self, index: int, acquire_timeout: float | None = 0) -> np.ndarray: + """Copy one identified payload slot under that slot's semaphore.""" + self._validate_index(index) + with self._acquire(self._slot_locks[index], acquire_timeout, "reading from"): + return self._array_for_slot(index).copy() + + @not_destroyed + def write_data(self, data: np.ndarray, acquire_timeout: float | None = 0) -> int: + """Write to the next circular slot and return the written slot index.""" + self._validate_payload(data) + index = self._claim_next_write_position(acquire_timeout) + self._write_data_at(index, data, acquire_timeout) + return index + + @not_destroyed + def write_data_at( + self, index: int, data: np.ndarray, acquire_timeout: float | None = 0 + ) -> None: + """Write directly to an identified slot without advancing the write cursor.""" + self._validate_index(index) + self._validate_payload(data) + self._write_data_at(index, data, acquire_timeout) + + def _write_data_at( + self, index: int, data: np.ndarray, acquire_timeout: float | None = 0 + ) -> None: + with self._acquire(self._slot_locks[index], acquire_timeout, "writing to"): + np.copyto(self._array_for_slot(index), data) @property def descriptor(self) -> RingBufferDescriptor: - """Return the descriptor for this RingBuffer.""" return self._descriptor @property def destroyed(self) -> bool: - """Indicates whether the view has been destroyed.""" return self.__destroyed @property - def name(self): - """Name of shared ring buffer""" + def name(self) -> str: return self._descriptor.name @property - def slots(self): - """Max Index of shared ring buffer""" - return self._descriptor.slots + def slots(self) -> int: + return self._slots @property - def bytes_per_slot(self): - """Bytes per index in shared ring buffer""" - return self._descriptor.bytes_per_slot + def bytes_per_slot(self) -> int: + return self._payload_descriptor.nbytes @property - def payload_descriptor(self): - """Payload descriptor for the data stored in the ring buffer.""" - return self._descriptor.payload - - def destroy(self): - """ - Destroy the shared memory object. The method can be called multiple times but only the first call will have an effect. - """ + def payload_descriptor(self) -> PayloadDescriptor: + return self._payload_descriptor + + @property + @not_destroyed + def next_write_position(self) -> int: + with self._acquire(self._metadata_lock, None, "reading metadata from"): + return self._read_next_position() + + def _close_handles(self) -> None: + for lock in self._slot_locks: + lock.close() + self._metadata_lock.close() + self._shm.close() + + def close(self) -> None: + """Close local handles without unlinking owner-managed resources.""" if self.destroyed: return - with self._lock: - # Semaphore lock - self._semaphore_lock.release() # Make sure to release upon closing to avoid deadlocks if the lock is still held by this process - # Shared memory - self._shm.close() - # Cleanup depends on whether the memory is owned by this view or not. - if self._owns_memory: - self._semaphore_lock.unlink() - self._shm.unlink() - else: - # NOTE: From Python 3.13 onwards, we can use the track=False option when creating the reference - # For views not owning the memory, we have to manually unregister it. - # pylint: disable=protected-access - resource_tracker.unregister(self._shm._name, "shared_memory") - self._semaphore_lock.close() - - # to avoid registering the shared memory with the resource tracker. - self.__destroyed = True + with self._lifecycle_lock: + if self.destroyed: + return + self._close_handles() + self.__destroyed = True + def destroy(self) -> None: + """Compatibility alias for attached clients; attached handles only close resources.""" + self.close() -class RingBuffer(RingBufferView): - """ - RingBuffer class that owns the shared memory. If created, it will create a new sharedMemory object together with a semaphore lock. - Args: - slots (int): The number of slots in the ring buffer. - payload (PayloadDescriptor): The descriptor for the data payload stored in each slot of the ring buffer. - name_suffix (str): An optional suffix to append to the shared memory and semaphore names for identification. - """ +class RingBuffer(RingBufferView): + """Owner of a shared ring buffer and its semaphore resources.""" def __init__(self, slots: int, payload: PayloadDescriptor, name_suffix: str = ""): name = f"bec_psm_{uuid4().hex[:6]}" - shm = shared_memory.SharedMemory(create=True, size=slots * payload.nbytes, name=name) - lock_name = f"{name}_lock" - semaphore_lock = posix_ipc.Semaphore(lock_name, flags=posix_ipc.O_CREAT, initial_value=1) - self._descriptor = RingBufferDescriptor( - name=shm.name, - lock_id=semaphore_lock.name, - slots=slots, - bytes_per_slot=payload.nbytes, - payload=payload, + header = self._encode_header(slots, payload) + shm = shared_memory.SharedMemory( + create=True, size=HEADER_SIZE + slots * payload.nbytes, name=name ) - super().__init__(descriptor=self._descriptor, shm=shm) + shm.buf[:HEADER_SIZE] = header + metadata_lock_name = f"{name}_metadata_lock" + slot_lock_names = tuple(f"{name}_slot_{index}_lock" for index in range(slots)) + created_locks: list[posix_ipc.Semaphore] = [] + try: + created_locks.append( + posix_ipc.Semaphore( + metadata_lock_name, flags=posix_ipc.O_CREAT | posix_ipc.O_EXCL, initial_value=1 + ) + ) + created_locks.extend( + posix_ipc.Semaphore( + lock_name, flags=posix_ipc.O_CREAT | posix_ipc.O_EXCL, initial_value=1 + ) + for lock_name in slot_lock_names + ) + for lock in created_locks: + lock.close() + descriptor = RingBufferDescriptor( + name=shm.name, + metadata_lock_id=metadata_lock_name, + slot_lock_ids=slot_lock_names, + slots=slots, + payload=payload, + ) + super().__init__(descriptor=descriptor, shm=shm, owns_memory=True) + except Exception: + for lock in created_locks: + try: + lock.close() + except OSError: + pass + try: + posix_ipc.unlink_semaphore(lock.name) + except posix_ipc.ExistentialError: + pass + shm.close() + shm.unlink() + raise + + def destroy(self) -> None: + """Close and unlink all resources created for this owned ring buffer.""" + if self.destroyed: + return + descriptor = self.descriptor + self.close() + self._shm.unlink() + for lock_id in (descriptor.metadata_lock_id, *descriptor.slot_lock_ids): + try: + posix_ipc.unlink_semaphore(lock_id) + except posix_ipc.ExistentialError: + pass @classmethod def _name_suffix(cls, name: str, suffix: str, max_length: int = 63) -> str: diff --git a/bec_server/tests/tests_shared_memory/test_ring_buffer.py b/bec_server/tests/tests_shared_memory/test_ring_buffer.py index a7a6ce043..29edc778e 100644 --- a/bec_server/tests/tests_shared_memory/test_ring_buffer.py +++ b/bec_server/tests/tests_shared_memory/test_ring_buffer.py @@ -5,7 +5,7 @@ import pytest from bec_server.shared_memory.models import PayloadDescriptor -from bec_server.shared_memory.ring_buffer import RingBuffer, RingBufferView +from bec_server.shared_memory.ring_buffer import HEADER_SIZE, RingBuffer, RingBufferView @pytest.fixture @@ -17,76 +17,109 @@ def payload() -> PayloadDescriptor: def ring_buffer(payload: PayloadDescriptor): buffer = RingBuffer(slots=2, payload=payload) yield buffer - _cleanup_owned_buffer(buffer) + buffer.destroy() -def _close_view(view: RingBufferView) -> None: - view._shm.close() - view._semaphore_lock.close() - - -def _cleanup_owned_buffer(buffer: RingBuffer) -> None: - buffer._shm.close() - try: - buffer._shm.unlink() - except FileNotFoundError: - pass - try: - buffer._semaphore_lock.close() - except OSError: - pass - try: - posix_ipc.unlink_semaphore(buffer.descriptor.lock_id) - except posix_ipc.ExistentialError: - pass - - -def test_descriptor_describes_allocated_payload_slots( +def test_descriptor_exposes_attachment_resources_and_payload( ring_buffer: RingBuffer, payload: PayloadDescriptor ): assert ring_buffer.descriptor.name == ring_buffer.name assert ring_buffer.descriptor.slots == 2 - assert ring_buffer.descriptor.bytes_per_slot == payload.nbytes assert ring_buffer.descriptor.payload == payload - assert ring_buffer.slots == 2 - assert ring_buffer.bytes_per_slot == payload.nbytes + assert ring_buffer.descriptor.metadata_lock_id + assert len(ring_buffer.descriptor.slot_lock_ids) == 2 + assert len(set(ring_buffer.descriptor.slot_lock_ids)) == 2 + + +def test_header_reconstructs_payload_description_for_attached_view( + ring_buffer: RingBuffer, payload: PayloadDescriptor +): + view = RingBufferView(ring_buffer.descriptor) + try: + assert view.slots == ring_buffer.descriptor.slots + assert view.bytes_per_slot == payload.nbytes + assert view.payload_descriptor == payload + finally: + view.close() + + +def test_attached_view_rejects_descriptor_payload_mismatch(ring_buffer: RingBuffer): + invalid_payload = PayloadDescriptor.from_numpy(np.zeros((2,), dtype=np.float64)) + descriptor = ring_buffer.descriptor.model_copy(update={"payload": invalid_payload}) + + with pytest.raises(ValueError, match="does not match shared-memory metadata"): + RingBufferView(descriptor) -def test_write_and_copy_data_round_trip(ring_buffer: RingBuffer): +def test_write_data_uses_circular_position_and_returns_written_slot(ring_buffer: RingBuffer): + first = np.array([1, 2, 3, 4], dtype=np.float64) + second = np.array([5, 6, 7, 8], dtype=np.float64) + third = np.array([9, 10, 11, 12], dtype=np.float64) + + assert ring_buffer.next_write_position == 0 + assert ring_buffer.write_data(first) == 0 + assert ring_buffer.next_write_position == 1 + assert ring_buffer.write_data(second) == 1 + assert ring_buffer.next_write_position == 0 + assert ring_buffer.write_data(third) == 0 + assert ring_buffer.next_write_position == 1 + np.testing.assert_array_equal(ring_buffer.copy_data(0), third) + np.testing.assert_array_equal(ring_buffer.copy_data(1), second) + + +def test_explicit_write_uses_header_payload_offset_without_advancing_cursor( + ring_buffer: RingBuffer, payload: PayloadDescriptor +): data = np.arange(4, dtype=np.float64) - ring_buffer.write_data(1, data) + ring_buffer.write_data_at(1, data) - np.testing.assert_array_equal(ring_buffer.copy_data(1), data) - data[:] = -1 - np.testing.assert_array_equal(ring_buffer.copy_data(1), np.arange(4, dtype=np.float64)) + raw_payload = np.ndarray( + payload.shape, + dtype=payload.dtype.numpy_dtype, + buffer=ring_buffer._shm.buf, + offset=HEADER_SIZE + payload.nbytes, + ) + np.testing.assert_array_equal(raw_payload, data) + assert ring_buffer.next_write_position == 0 -def test_attached_view_shares_payload_storage(ring_buffer: RingBuffer): +def test_attached_view_shares_cursor_and_payload_storage(ring_buffer: RingBuffer): view = RingBufferView(ring_buffer.descriptor) try: written_from_view = np.array([1, 2, 3, 4], dtype=np.float64) written_from_owner = np.array([5, 6, 7, 8], dtype=np.float64) - view.write_data(0, written_from_view) + assert view.write_data(written_from_view) == 0 np.testing.assert_array_equal(ring_buffer.copy_data(0), written_from_view) - ring_buffer.write_data(1, written_from_owner) + assert ring_buffer.write_data(written_from_owner) == 1 np.testing.assert_array_equal(view.copy_data(1), written_from_owner) + assert view.next_write_position == 0 finally: - # Calling destroy() here currently unlinks the owner's resources; see the xfail below. - _close_view(view) + view.close() -def test_each_buffer_has_distinct_shared_memory_and_lock_names(payload: PayloadDescriptor): +def test_each_buffer_has_distinct_shared_memory_and_semaphore_names(payload: PayloadDescriptor): first = RingBuffer(slots=2, payload=payload) second = RingBuffer(slots=2, payload=payload) try: assert first.name != second.name - assert first.descriptor.lock_id != second.descriptor.lock_id + assert first.descriptor.metadata_lock_id != second.descriptor.metadata_lock_id + assert first.descriptor.slot_lock_ids != second.descriptor.slot_lock_ids + finally: + first.destroy() + second.destroy() + + +def test_slot_locks_do_not_block_access_to_other_slots(ring_buffer: RingBuffer): + first_slot_lock = posix_ipc.Semaphore(ring_buffer.descriptor.slot_lock_ids[0]) + try: + first_slot_lock.acquire() + ring_buffer.write_data_at(1, np.arange(4, dtype=np.float64), acquire_timeout=0.01) finally: - _cleanup_owned_buffer(first) - _cleanup_owned_buffer(second) + first_slot_lock.release() + first_slot_lock.close() @pytest.mark.parametrize("index", [-1, 2]) @@ -96,9 +129,9 @@ def test_copy_data_rejects_indices_outside_slots(ring_buffer: RingBuffer, index: @pytest.mark.parametrize("index", [-1, 2]) -def test_write_data_rejects_indices_outside_slots(ring_buffer: RingBuffer, index: int): +def test_write_data_at_rejects_indices_outside_slots(ring_buffer: RingBuffer, index: int): with pytest.raises(IndexError, match="out of bounds"): - ring_buffer.write_data(index, np.zeros((4,), dtype=np.float64)) + ring_buffer.write_data_at(index, np.zeros((4,), dtype=np.float64)) @pytest.mark.parametrize( @@ -108,7 +141,7 @@ def test_write_data_rejects_payload_shape_or_dtype_mismatch( ring_buffer: RingBuffer, data: np.ndarray ): with pytest.raises(ValueError, match="does not match expected"): - ring_buffer.write_data(0, data) + ring_buffer.write_data(data) def test_destroy_is_idempotent_and_rejects_further_operations( @@ -118,28 +151,35 @@ def test_destroy_is_idempotent_and_rejects_further_operations( ring_buffer.destroy() with pytest.raises(RuntimeError, match="destroyed"): - ring_buffer.write_data(0, np.zeros(payload.shape, dtype=payload.dtype.numpy_dtype)) + ring_buffer.write_data(np.zeros(payload.shape, dtype=payload.dtype.numpy_dtype)) -@pytest.mark.xfail(reason="RingBuffer and RingBufferView currently invert shared-memory ownership") def test_only_creator_owns_shared_memory_resources(ring_buffer: RingBuffer): view = RingBufferView(ring_buffer.descriptor) try: assert ring_buffer._owns_memory is True assert view._owns_memory is False finally: - _close_view(view) + view.close() -@pytest.mark.xfail(reason="RingBufferView.destroy currently unlinks resources owned by RingBuffer") -def test_destroying_view_keeps_owner_resources_attachable(ring_buffer: RingBuffer): +def test_closing_view_keeps_owner_resources_attachable(ring_buffer: RingBuffer): view = RingBufferView(ring_buffer.descriptor) - view.destroy() + view.close() - attached = None - try: - attached = shared_memory.SharedMemory(name=ring_buffer.name) - assert attached.name == ring_buffer.name - finally: - if attached is not None: - attached.close() + attached = RingBufferView(ring_buffer.descriptor) + attached.close() + assert ring_buffer.next_write_position == 0 + + +def test_destroying_owner_unlinks_shared_memory_and_semaphores(ring_buffer: RingBuffer): + descriptor = ring_buffer.descriptor + + ring_buffer.destroy() + + with pytest.raises(FileNotFoundError): + shared_memory.SharedMemory(name=descriptor.name) + with pytest.raises(posix_ipc.ExistentialError): + posix_ipc.Semaphore(descriptor.metadata_lock_id) + with pytest.raises(posix_ipc.ExistentialError): + posix_ipc.Semaphore(descriptor.slot_lock_ids[0]) From db070a43a088b61bcd923164dd13f022551badd3 Mon Sep 17 00:00:00 2001 From: appel_c Date: Wed, 27 May 2026 15:22:09 +0200 Subject: [PATCH 15/16] chore: add posix_ipc to bec_server dependencies --- bec_server/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/bec_server/pyproject.toml b/bec_server/pyproject.toml index 2598be6ef..f0cfa79b8 100644 --- a/bec_server/pyproject.toml +++ b/bec_server/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "pyyaml~=6.0", "python-dotenv~=1.0", "rich>=13.7,<16.0", + "posix_ipc~=1.0", ] [project.optional-dependencies] From d783c0c4e2b05e476b0ce7730223e82872e1f8a6 Mon Sep 17 00:00:00 2001 From: appel_c Date: Wed, 27 May 2026 15:23:48 +0200 Subject: [PATCH 16/16] wip: fix semaphore lock naming convention --- bec_server/bec_server/shared_memory/ring_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bec_server/bec_server/shared_memory/ring_buffer.py b/bec_server/bec_server/shared_memory/ring_buffer.py index b4669d94a..6d0072d3d 100644 --- a/bec_server/bec_server/shared_memory/ring_buffer.py +++ b/bec_server/bec_server/shared_memory/ring_buffer.py @@ -268,8 +268,8 @@ def __init__(self, slots: int, payload: PayloadDescriptor, name_suffix: str = "" create=True, size=HEADER_SIZE + slots * payload.nbytes, name=name ) shm.buf[:HEADER_SIZE] = header - metadata_lock_name = f"{name}_metadata_lock" - slot_lock_names = tuple(f"{name}_slot_{index}_lock" for index in range(slots)) + metadata_lock_name = f"{name}_md_lock" + slot_lock_names = tuple(f"{name}_slot__lock_{index}" for index in range(slots)) created_locks: list[posix_ipc.Semaphore] = [] try: created_locks.append(