diff --git a/tests/coordination/test_coordination_client.py b/tests/coordination/test_coordination_client.py index 460c1139..91cdc881 100644 --- a/tests/coordination/test_coordination_client.py +++ b/tests/coordination/test_coordination_client.py @@ -1,7 +1,11 @@ +import asyncio +import threading + import pytest import ydb from ydb.aio.coordination import CoordinationClient as AioCoordinationClient +from ydb import StatusCode from ydb.coordination import ( NodeConfig, @@ -11,27 +15,63 @@ ) -class TestCoordination: - def test_coordination_node_lifecycle(self, driver_sync: ydb.Driver): - client = CoordinationClient(driver_sync) - node_path = "/local/test_node_lifecycle" +@pytest.fixture +def sync_coordination_node(driver_sync): + client = CoordinationClient(driver_sync) + node_path = "/local/test_node" - try: - client.delete_node(node_path) - except ydb.SchemeError: - pass + try: + client.delete_node(node_path) + except ydb.SchemeError: + pass - with pytest.raises(ydb.SchemeError): - client.describe_node(node_path) + config = NodeConfig( + session_grace_period_millis=1000, + attach_consistency_mode=ConsistencyMode.STRICT, + read_consistency_mode=ConsistencyMode.STRICT, + rate_limiter_counters_mode=RateLimiterCountersMode.UNSET, + self_check_period_millis=0, + ) + client.create_node(node_path, config) - initial_config = NodeConfig( - session_grace_period_millis=1000, - attach_consistency_mode=ConsistencyMode.STRICT, - read_consistency_mode=ConsistencyMode.STRICT, - rate_limiter_counters_mode=RateLimiterCountersMode.UNSET, - self_check_period_millis=0, - ) - client.create_node(node_path, initial_config) + yield client, node_path, config + + try: + client.delete_node(node_path) + except ydb.SchemeError: + pass + + +@pytest.fixture +async def async_coordination_node(aio_connection): + client = AioCoordinationClient(aio_connection) + node_path = "/local/test_node" + + try: + await client.delete_node(node_path) + except ydb.SchemeError: + pass + + config = NodeConfig( + session_grace_period_millis=1000, + attach_consistency_mode=ConsistencyMode.STRICT, + read_consistency_mode=ConsistencyMode.STRICT, + rate_limiter_counters_mode=RateLimiterCountersMode.UNSET, + self_check_period_millis=0, + ) + await client.create_node(node_path, config) + + yield client, node_path, config + + try: + await client.delete_node(node_path) + except ydb.SchemeError: + pass + + +class TestCoordination: + def test_coordination_node_lifecycle(self, sync_coordination_node): + client, node_path, initial_config = sync_coordination_node node_conf = client.describe_node(node_path) assert node_conf == initial_config @@ -53,26 +93,8 @@ def test_coordination_node_lifecycle(self, driver_sync: ydb.Driver): with pytest.raises(ydb.SchemeError): client.describe_node(node_path) - async def test_coordination_node_lifecycle_async(self, aio_connection): - client = AioCoordinationClient(aio_connection) - node_path = "/local/test_node_lifecycle" - - try: - await client.delete_node(node_path) - except ydb.SchemeError: - pass - - with pytest.raises(ydb.SchemeError): - await client.describe_node(node_path) - - initial_config = NodeConfig( - session_grace_period_millis=1000, - attach_consistency_mode=ConsistencyMode.STRICT, - read_consistency_mode=ConsistencyMode.STRICT, - rate_limiter_counters_mode=RateLimiterCountersMode.UNSET, - self_check_period_millis=0, - ) - await client.create_node(node_path, initial_config) + async def test_coordination_node_lifecycle_async(self, async_coordination_node): + client, node_path, initial_config = async_coordination_node node_conf = await client.describe_node(node_path) assert node_conf == initial_config @@ -93,3 +115,105 @@ async def test_coordination_node_lifecycle_async(self, aio_connection): with pytest.raises(ydb.SchemeError): await client.describe_node(node_path) + + async def test_coordination_lock_describe_full_async(self, async_coordination_node): + client, node_path, _ = async_coordination_node + + async with client.node(node_path) as node: + lock = node.lock("test_lock") + + desc = await lock.describe() + assert desc.status == StatusCode.NOT_FOUND + + async with lock: + pass + + desc = await lock.describe() + assert desc.data == b"" + + await lock.update(new_data=b"world") + + desc2 = await lock.describe() + assert desc2.data == b"world" + + def test_coordination_lock_describe_full(self, sync_coordination_node): + client, node_path, _ = sync_coordination_node + + with client.node(node_path) as node: + lock = node.lock("test_lock") + + desc = lock.describe() + assert desc.status == StatusCode.NOT_FOUND + + with lock: + pass + + desc = lock.describe() + assert desc.data == b"" + + lock.update(new_data=b"world") + + desc2 = lock.describe() + assert desc2.data == b"world" + + async def test_coordination_lock_racing_async(self, async_coordination_node): + client, node_path, _ = async_coordination_node + timeout = 5 + + async with client.node(node_path) as node: + lock2_started = asyncio.Event() + lock2_acquired = asyncio.Event() + lock2_release = asyncio.Event() + + async def second_lock_task(): + lock2_started.set() + async with node.lock("test_lock"): + lock2_acquired.set() + await lock2_release.wait() + + async with node.lock("test_lock"): + t2 = asyncio.create_task(second_lock_task()) + await asyncio.wait_for(lock2_started.wait(), timeout=timeout) + + await asyncio.wait_for(lock2_acquired.wait(), timeout=timeout) + lock2_release.set() + await asyncio.wait_for(t2, timeout=timeout) + + def test_coordination_lock_racing(self, sync_coordination_node): + client, node_path, _ = sync_coordination_node + timeout = 5 + + with client.node(node_path) as node: + lock2_started = threading.Event() + lock2_acquired = threading.Event() + lock2_release = threading.Event() + + def second_lock_task(): + lock2_started.set() + with node.lock("test_lock"): + lock2_acquired.set() + lock2_release.wait(timeout) + + with node.lock("test_lock"): + t2 = threading.Thread(target=second_lock_task) + t2.start() + + assert lock2_started.wait(timeout) + + assert lock2_acquired.wait(timeout) + lock2_release.set() + t2.join(timeout) + + async def test_coordination_reconnect_async(self, async_coordination_node): + client, node_path, _ = async_coordination_node + + async with client.node(node_path) as node: + lock = node.lock("test_lock") + + async with lock: + pass + + await node._reconnector.stop() + + async with lock: + pass diff --git a/ydb/_apis.py b/ydb/_apis.py index 97f64b90..595550b2 100644 --- a/ydb/_apis.py +++ b/ydb/_apis.py @@ -143,9 +143,9 @@ class QueryService(object): class CoordinationService(object): Stub = ydb_coordination_v1_pb2_grpc.CoordinationServiceStub - - Session = "Session" CreateNode = "CreateNode" AlterNode = "AlterNode" DropNode = "DropNode" DescribeNode = "DescribeNode" + SessionRequest = "SessionRequest" + Session = "Session" diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 0fb960d6..cf91b9c9 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -220,7 +220,7 @@ async def _start_sync_driver(self, driver: Driver, stub, method): self._stream_call = stream_call self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor) - async def receive(self, timeout: Optional[int] = None) -> Any: + async def receive(self, timeout: Optional[int] = None, is_coordination_calls=False) -> Any: # todo handle grpc exceptions and convert it to internal exceptions try: if timeout is None: @@ -235,7 +235,8 @@ async def get_response(): except (grpc.RpcError, grpc.aio.AioRpcError) as e: raise connection._rpc_error_handler(self._connection_state, e) - issues._process_response(grpc_message) + if not is_coordination_calls: + issues._process_response(grpc_message) if self._connection_state != "has_received_messages": self._connection_state = "has_received_messages" diff --git a/ydb/_grpc/grpcwrapper/ydb_coordination.py b/ydb/_grpc/grpcwrapper/ydb_coordination.py index 176e4e02..8794b570 100644 --- a/ydb/_grpc/grpcwrapper/ydb_coordination.py +++ b/ydb/_grpc/grpcwrapper/ydb_coordination.py @@ -16,7 +16,7 @@ class CreateNodeRequest(IToProto): path: str config: typing.Optional[NodeConfig] - def to_proto(self) -> ydb_coordination_pb2.CreateNodeRequest: + def to_proto(self) -> "ydb_coordination_pb2.CreateNodeRequest": cfg_proto = self.config.to_proto() if self.config else None return ydb_coordination_pb2.CreateNodeRequest( path=self.path, @@ -29,7 +29,7 @@ class AlterNodeRequest(IToProto): path: str config: NodeConfig - def to_proto(self) -> ydb_coordination_pb2.AlterNodeRequest: + def to_proto(self) -> "ydb_coordination_pb2.AlterNodeRequest": cfg_proto = self.config.to_proto() if self.config else None return ydb_coordination_pb2.AlterNodeRequest( path=self.path, @@ -41,7 +41,7 @@ def to_proto(self) -> ydb_coordination_pb2.AlterNodeRequest: class DescribeNodeRequest(IToProto): path: str - def to_proto(self) -> ydb_coordination_pb2.DescribeNodeRequest: + def to_proto(self) -> "ydb_coordination_pb2.DescribeNodeRequest": return ydb_coordination_pb2.DescribeNodeRequest( path=self.path, ) @@ -51,7 +51,174 @@ def to_proto(self) -> ydb_coordination_pb2.DescribeNodeRequest: class DropNodeRequest(IToProto): path: str - def to_proto(self) -> ydb_coordination_pb2.DropNodeRequest: + def to_proto(self) -> "ydb_coordination_pb2.DropNodeRequest": return ydb_coordination_pb2.DropNodeRequest( path=self.path, ) + + +@dataclass +class SessionStart(IToProto): + path: str + timeout_millis: int + description: str = "" + session_id: int = 0 + seq_no: int = 0 + protection_key: bytes = b"" + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + session_start=ydb_coordination_pb2.SessionRequest.SessionStart( + path=self.path, + session_id=self.session_id, + timeout_millis=self.timeout_millis, + description=self.description, + seq_no=self.seq_no, + protection_key=self.protection_key, + ) + ) + + +@dataclass +class SessionStop(IToProto): + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest(session_stop=ydb_coordination_pb2.SessionRequest.SessionStop()) + + +@dataclass +class Ping(IToProto): + opaque: int = 0 + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + ping=ydb_coordination_pb2.SessionRequest.PingPong(opaque=self.opaque) + ) + + +@dataclass +class CreateSemaphore(IToProto): + name: str + req_id: int + limit: int + data: bytes = b"" + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + create_semaphore=ydb_coordination_pb2.SessionRequest.CreateSemaphore( + req_id=self.req_id, name=self.name, limit=self.limit, data=self.data + ) + ) + + +@dataclass +class AcquireSemaphore(IToProto): + name: str + req_id: int + count: int = 1 + timeout_millis: int = 0 + data: bytes = b"" + ephemeral: bool = False + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + acquire_semaphore=ydb_coordination_pb2.SessionRequest.AcquireSemaphore( + req_id=self.req_id, + name=self.name, + timeout_millis=self.timeout_millis, + count=self.count, + data=self.data, + ephemeral=self.ephemeral, + ) + ) + + +@dataclass +class ReleaseSemaphore(IToProto): + name: str + req_id: int + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + release_semaphore=ydb_coordination_pb2.SessionRequest.ReleaseSemaphore(req_id=self.req_id, name=self.name) + ) + + +@dataclass +class DescribeSemaphore(IToProto): + include_owners: bool + include_waiters: bool + name: str + req_id: int + watch_data: bool + watch_owners: bool + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + describe_semaphore=ydb_coordination_pb2.SessionRequest.DescribeSemaphore( + include_owners=self.include_owners, + include_waiters=self.include_waiters, + name=self.name, + req_id=self.req_id, + watch_data=self.watch_data, + watch_owners=self.watch_owners, + ) + ) + + +@dataclass +class UpdateSemaphore(IToProto): + name: str + req_id: int + data: bytes + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + update_semaphore=ydb_coordination_pb2.SessionRequest.UpdateSemaphore( + req_id=self.req_id, name=self.name, data=self.data + ) + ) + + +@dataclass +class DeleteSemaphore(IToProto): + name: str + req_id: int + force: bool = False + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + delete_semaphore=ydb_coordination_pb2.SessionRequest.DeleteSemaphore( + req_id=self.req_id, name=self.name, force=self.force + ) + ) + + +@dataclass +class FromServer: + raw: "ydb_coordination_pb2.SessionResponse" + + @staticmethod + def from_proto(resp: "ydb_coordination_pb2.SessionResponse") -> "FromServer": + return FromServer(raw=resp) + + def __getattr__(self, name: str): + return getattr(self.raw, name) + + @property + def session_started(self) -> typing.Optional["ydb_coordination_pb2.SessionResponse.SessionStarted"]: + s = self.raw.session_started + return s if s.session_id else None + + @property + def opaque(self) -> typing.Optional[int]: + if self.raw.HasField("ping"): + return self.raw.ping.opaque + return None + + @property + def acquire_semaphore_result(self): + return self.raw.acquire_semaphore_result if self.raw.HasField("acquire_semaphore_result") else None + + @property + def create_semaphore_result(self): + return self.raw.create_semaphore_result if self.raw.HasField("create_semaphore_result") else None diff --git a/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py b/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py index a3580974..1112cd4b 100644 --- a/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py +++ b/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py @@ -2,7 +2,6 @@ from enum import IntEnum import typing - if typing.TYPE_CHECKING: from ..v4.protos import ydb_coordination_pb2 else: @@ -55,3 +54,60 @@ def from_proto(msg: ydb_coordination_pb2.DescribeNodeResponse) -> "NodeConfig": result = ydb_coordination_pb2.DescribeNodeResult() msg.operation.result.Unpack(result) return NodeConfig.from_proto(result.config) + + +@dataclass +class AcquireSemaphoreResult: + req_id: int + acquired: bool + status: int + + @staticmethod + def from_proto(msg: ydb_coordination_pb2.SessionResponse.AcquireSemaphoreResult) -> "AcquireSemaphoreResult": + return AcquireSemaphoreResult( + req_id=msg.req_id, + acquired=msg.acquired, + status=msg.status, + ) + + +@dataclass +class CreateSemaphoreResult: + req_id: int + status: int + + @staticmethod + def from_proto(msg: ydb_coordination_pb2.SessionResponse.CreateSemaphoreResult) -> "CreateSemaphoreResult": + return CreateSemaphoreResult( + req_id=msg.req_id, + status=msg.status, + ) + + +@dataclass +class DescribeLockResult: + req_id: int + status: int + watch_added: bool + count: int + data: bytes + ephemeral: bool + limit: int + name: str + owners: list + waiters: list + + @staticmethod + def from_proto(msg: ydb_coordination_pb2.SessionResponse.DescribeSemaphoreResult) -> "DescribeLockResult": + return DescribeLockResult( + req_id=msg.req_id, + status=msg.status, + watch_added=msg.watch_added, + count=msg.semaphore_description.count, + data=msg.semaphore_description.data, + ephemeral=msg.semaphore_description.ephemeral, + limit=msg.semaphore_description.limit, + name=msg.semaphore_description.name, + owners=msg.semaphore_description.owners, + waiters=msg.semaphore_description.waiters, + ) diff --git a/ydb/aio/__init__.py b/ydb/aio/__init__.py index 4e4192a8..9747666f 100644 --- a/ydb/aio/__init__.py +++ b/ydb/aio/__init__.py @@ -1,5 +1,4 @@ from .driver import Driver # noqa from .table import SessionPool, retry_operation # noqa from .query import QuerySessionPool, QuerySession, QueryTxContext # noqa - -# from .coordination import CoordinationClient # noqa +from .coordination import CoordinationClient # noqa diff --git a/ydb/aio/coordination/client.py b/ydb/aio/coordination/client.py index b36b8950..83735368 100644 --- a/ydb/aio/coordination/client.py +++ b/ydb/aio/coordination/client.py @@ -8,9 +8,14 @@ ) from ..._grpc.grpcwrapper.ydb_coordination_public_types import NodeConfig from ...coordination.base import BaseCoordinationClient +from .node import CoordinationNode class CoordinationClient(BaseCoordinationClient): + def __init__(self, driver): + super().__init__(driver) + self._driver = driver + async def create_node(self, path: str, config: Optional[NodeConfig] = None, settings=None): return await self._call_create( CreateNodeRequest(path=path, config=config).to_proto(), @@ -35,5 +40,5 @@ async def delete_node(self, path: str, settings=None): settings=settings, ) - async def lock(self): - raise NotImplementedError("Will be implemented in future release") + def node(self, path: str) -> CoordinationNode: + return CoordinationNode(self._driver, path) diff --git a/ydb/aio/coordination/lock.py b/ydb/aio/coordination/lock.py new file mode 100644 index 00000000..e9268630 --- /dev/null +++ b/ydb/aio/coordination/lock.py @@ -0,0 +1,101 @@ +from ... import StatusCode, issues + +from ..._grpc.grpcwrapper.ydb_coordination import ( + AcquireSemaphore, + ReleaseSemaphore, + UpdateSemaphore, + DescribeSemaphore, + CreateSemaphore, +) +from ..._grpc.grpcwrapper.ydb_coordination_public_types import ( + DescribeLockResult, +) + + +class CoordinationLock: + def __init__(self, node, name: str): + self._node = node + self._name = name + + self._count: int = 1 + self._timeout_millis: int = node._timeout_millis + + async def acquire(self): + resp = await self._try_acquire() + + if resp.status == StatusCode.NOT_FOUND: + await self._create_if_not_exists() + resp = await self._try_acquire() + + if resp.status != StatusCode.SUCCESS: + raise issues.Error(f"Failed to acquire lock {self._name}: {resp.status}") + + return self + + async def release(self): + req = ReleaseSemaphore( + req_id=await self._node.next_req_id(), + name=self._name, + ) + try: + await self._node._reconnector.send_and_wait(req) + except Exception: + pass + + async def describe(self) -> DescribeLockResult: + req = DescribeSemaphore( + req_id=await self._node.next_req_id(), + name=self._name, + include_owners=True, + include_waiters=True, + watch_data=False, + watch_owners=False, + ) + resp = await self._node._reconnector.send_and_wait(req) + return DescribeLockResult.from_proto(resp) + + async def update(self, new_data: bytes) -> None: + req = UpdateSemaphore( + req_id=await self._node.next_req_id(), + name=self._name, + data=new_data, + ) + resp = await self._node._reconnector.send_and_wait(req) + + if resp.status != StatusCode.SUCCESS: + raise issues.Error(f"Failed to update lock {self._name}: {resp.status}") + + async def close(self): + await self.release() + + async def __aenter__(self): + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.release() + + async def _try_acquire(self): + req = AcquireSemaphore( + req_id=await self._node.next_req_id(), + name=self._name, + count=self._count, + ephemeral=False, + timeout_millis=self._timeout_millis, + ) + return await self._node._reconnector.send_and_wait(req) + + async def _create_if_not_exists(self): + req = CreateSemaphore( + req_id=await self._node.next_req_id(), + name=self._name, + limit=self._count, + data=b"", + ) + resp = await self._node._reconnector.send_and_wait(req) + + if resp.status not in ( + StatusCode.SUCCESS, + StatusCode.ALREADY_EXISTS, + ): + raise issues.Error(f"Failed to create lock {self._name}: {resp.status}") diff --git a/ydb/aio/coordination/node.py b/ydb/aio/coordination/node.py new file mode 100644 index 00000000..2c498955 --- /dev/null +++ b/ydb/aio/coordination/node.py @@ -0,0 +1,43 @@ +import asyncio + +from .reconnector import CoordinationReconnector +from .lock import CoordinationLock + + +class CoordinationNode: + def __init__(self, driver, path: str, timeout_millis: int = 30000): + self._driver = driver + self._path = path + self._timeout_millis = timeout_millis + + self._reconnector = CoordinationReconnector( + driver=driver, + node_path=path, + timeout_millis=timeout_millis, + ) + + self._req_id = 0 + self._req_id_lock = asyncio.Lock() + self._closed = False + + async def next_req_id(self) -> int: + async with self._req_id_lock: + self._req_id += 1 + return self._req_id + + def lock(self, name: str) -> CoordinationLock: + if self._closed: + raise RuntimeError("CoordinationNode is closed") + return CoordinationLock(self, name) + + async def close(self): + if self._closed: + return + self._closed = True + await self._reconnector.stop() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() diff --git a/ydb/aio/coordination/reconnector.py b/ydb/aio/coordination/reconnector.py new file mode 100644 index 00000000..eefc9c01 --- /dev/null +++ b/ydb/aio/coordination/reconnector.py @@ -0,0 +1,153 @@ +import asyncio +import contextlib +from typing import Optional, Dict + +from .stream import CoordinationStream +from ..._grpc.grpcwrapper.ydb_coordination import FromServer +from ... import issues + + +class CoordinationReconnector: + def __init__(self, driver, node_path: str, timeout_millis: int = 30000): + self._driver = driver + self._node_path = node_path + self._timeout_millis = timeout_millis + + self._stream: Optional[CoordinationStream] = None + self._dispatch_task: Optional[asyncio.Task] = None + + self._wait_timeout = timeout_millis / 1000.0 + self._pending_futures: Dict[int, asyncio.Future] = {} + self._ensure_lock = asyncio.Lock() + + async def stop(self): + if self._dispatch_task: + self._dispatch_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._dispatch_task + self._dispatch_task = None + + if self._stream: + with contextlib.suppress(Exception): + await self._stream.close() + self._stream = None + + if self._pending_futures: + err = asyncio.CancelledError() + for fut in self._pending_futures.values(): + if not fut.done(): + fut.set_exception(err) + self._pending_futures.clear() + + async def send_and_wait(self, req): + await self._ensure_stream() + + loop = asyncio.get_running_loop() + fut = loop.create_future() + self._pending_futures[req.req_id] = fut + + try: + await self._stream.send(req) + except Exception as exc: + await self._pending_futures.pop(req.req_id, None) + if not fut.done(): + fut.set_exception(exc) + if self._stream: + with contextlib.suppress(Exception): + await self._stream.close() + self._stream = None + raise + + try: + return await asyncio.wait_for( + asyncio.shield(fut), + timeout=self._wait_timeout, + ) + except Exception: + await self._pending_futures.pop(req.req_id, None) + raise + + async def _ensure_stream(self): + async with self._ensure_lock: + if self._stream is not None and not self._stream._closed: + return + + if self._dispatch_task: + self._dispatch_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._dispatch_task + self._dispatch_task = None + + if self._stream: + with contextlib.suppress(Exception): + await self._stream.close() + self._stream = None + + if self._pending_futures: + err = issues.Error("Connection lost") + for fut in self._pending_futures.values(): + if not fut.done(): + fut.set_exception(err) + self._pending_futures.clear() + + stream = CoordinationStream(self._driver) + await stream.start_session(self._node_path, self._timeout_millis) + + self._stream = stream + + loop = asyncio.get_running_loop() + self._dispatch_task = loop.create_task(self._dispatch_loop(stream)) + + async def _dispatch_loop(self, stream: CoordinationStream): + try: + while True: + try: + resp = await stream.receive(self._wait_timeout) + except asyncio.TimeoutError: + continue + except Exception as exc: + await self._on_stream_error(stream, exc) + break + + if resp is None: + continue + + fs = FromServer.from_proto(resp) + raw = fs.raw + + payload = None + for field_name in ( + "acquire_semaphore_result", + "release_semaphore_result", + "describe_semaphore_result", + "create_semaphore_result", + "update_semaphore_result", + "delete_semaphore_result", + ): + if raw.HasField(field_name): + payload = getattr(fs, field_name) + break + + if payload is None: + continue + + fut = self._pending_futures.pop(payload.req_id, None) + if fut and not fut.done(): + fut.set_result(payload) + finally: + if self._stream is stream: + with contextlib.suppress(Exception): + await stream.close() + self._stream = None + + async def _on_stream_error(self, stream: CoordinationStream, exc: Exception): + if self._pending_futures: + for fut in self._pending_futures.values(): + if not fut.done(): + fut.set_exception(exc) + self._pending_futures.clear() + + if self._stream is stream: + with contextlib.suppress(Exception): + await stream.close() + self._stream = None diff --git a/ydb/aio/coordination/stream.py b/ydb/aio/coordination/stream.py new file mode 100644 index 00000000..a29897a6 --- /dev/null +++ b/ydb/aio/coordination/stream.py @@ -0,0 +1,153 @@ +import asyncio +import contextlib + +from ... import issues, _apis +from ..._grpc.grpcwrapper.common_utils import IToProto, GrpcWrapperAsyncIO +from ..._grpc.grpcwrapper.ydb_coordination import FromServer, Ping, SessionStart + + +class CoordinationStream: + def __init__(self, driver): + self._driver = driver + self._stream = GrpcWrapperAsyncIO(FromServer.from_proto) + self._background_tasks = set() + self._incoming_queue: asyncio.Queue = asyncio.Queue() + self._closed = False + self._started = False + self.session_id = None + + async def start_session(self, path: str, timeout_millis: int): + if self._started: + raise issues.Error("CoordinationStream already started") + + self._started = True + + await self._stream.start( + self._driver, + _apis.CoordinationService.Stub, + _apis.CoordinationService.Session, + ) + + self._stream.write(SessionStart(path=path, timeout_millis=timeout_millis)) + + try: + while True: + try: + resp = await self._stream.receive(timeout=3, is_coordination_calls=True) + except asyncio.TimeoutError: + raise issues.Error("Timeout waiting for SessionStart response") + except StopAsyncIteration: + raise issues.Error("Stream closed while waiting for SessionStart response") + except asyncio.CancelledError: + raise + except Exception as exc: + raise issues.Error(f"Failed to start session: {exc}") from exc + + if resp is None: + continue + + if getattr(resp, "session_started", None): + self.session_id = resp.session_started + break + + except Exception: + with contextlib.suppress(Exception): + await self._stream.close() + self._stream = None + self._started = False + raise + + loop = asyncio.get_running_loop() + task = loop.create_task(self._reader_loop()) + self._background_tasks.add(task) + + def _on_done(t: asyncio.Task) -> None: + self._background_tasks.discard(t) + with contextlib.suppress(asyncio.CancelledError, Exception): + _ = t.exception() + + task.add_done_callback(_on_done) + + async def _reader_loop(self): + try: + while True: + try: + resp = await self._stream.receive(timeout=3, is_coordination_calls=True) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + except StopAsyncIteration: + break + except Exception: + break + + if self._closed: + break + + if resp is None: + continue + + fs = FromServer.from_proto(resp) + if fs.opaque: + try: + self._stream.write(Ping(fs.opaque)) + except Exception: + break + else: + await self._incoming_queue.put(resp) + finally: + if not self._closed: + self._closed = True + with contextlib.suppress(asyncio.QueueFull): + self._incoming_queue.put_nowait(None) + + if self._stream is not None: + with contextlib.suppress(Exception): + await self._stream.close() + self._stream = None + + async def send(self, req: IToProto): + if self._closed: + raise issues.Error("Stream closed") + self._stream.write(req) + + async def receive(self, timeout=None): + if self._closed: + raise issues.Error("Stream closed") + + try: + if timeout is not None: + return await asyncio.wait_for(self._incoming_queue.get(), timeout) + return await self._incoming_queue.get() + except asyncio.TimeoutError: + return None + + async def close(self): + if self._closed: + return + + self._closed = True + + if self._stream is not None: + with contextlib.suppress(Exception): + await self._stream.close() + self._stream = None + + with contextlib.suppress(asyncio.QueueFull): + self._incoming_queue.put_nowait(None) + + if self._background_tasks: + for task in list(self._background_tasks): + task.cancel() + + with contextlib.suppress(asyncio.CancelledError): + await asyncio.wait(self._background_tasks) + + self._background_tasks.clear() + + while not self._incoming_queue.empty(): + with contextlib.suppress(asyncio.QueueEmpty): + self._incoming_queue.get_nowait() + + self.session_id = None diff --git a/ydb/coordination/__init__.py b/ydb/coordination/__init__.py index fd994c56..b50bfa61 100644 --- a/ydb/coordination/__init__.py +++ b/ydb/coordination/__init__.py @@ -4,13 +4,18 @@ "ConsistencyMode", "RateLimiterCountersMode", "DescribeResult", + "CreateSemaphoreResult", + "DescribeLockResult", ] from .client import CoordinationClient + from .._grpc.grpcwrapper.ydb_coordination_public_types import ( NodeConfig, ConsistencyMode, RateLimiterCountersMode, DescribeResult, + CreateSemaphoreResult, + DescribeLockResult, ) diff --git a/ydb/coordination/client.py b/ydb/coordination/client.py index 549528d9..beeb5f66 100644 --- a/ydb/coordination/client.py +++ b/ydb/coordination/client.py @@ -8,6 +8,7 @@ ) from .._grpc.grpcwrapper.ydb_coordination_public_types import NodeConfig from .base import BaseCoordinationClient +from .node_sync import CoordinationNodeSync class CoordinationClient(BaseCoordinationClient): @@ -35,5 +36,5 @@ def delete_node(self, path: str, settings=None): settings=settings, ) - def lock(self): - raise NotImplementedError("Will be implemented in future release") + def node(self, path: str) -> CoordinationNodeSync: + return CoordinationNodeSync(self, path) diff --git a/ydb/coordination/lock_sync.py b/ydb/coordination/lock_sync.py new file mode 100644 index 00000000..6b9ee2e8 --- /dev/null +++ b/ydb/coordination/lock_sync.py @@ -0,0 +1,74 @@ +from typing import Optional + +from .. import issues +from .._topic_common.common import _get_shared_event_loop, CallFromSyncToAsync +from ..aio.coordination.lock import CoordinationLock + + +class CoordinationLockSync: + def __init__(self, node_sync, name: str, timeout_sec: float = 5): + self._node_sync = node_sync + self._name = name + self._timeout_sec = timeout_sec + self._closed = False + self._caller = CallFromSyncToAsync(_get_shared_event_loop()) + self._async_lock: CoordinationLock = self._node_sync._async_node.lock(name) + + def _check_closed(self): + if self._closed: + raise issues.Error(f"CoordinationLockSync {self._name} already closed") + + def __enter__(self): + self.acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + self.release() + except Exception: + pass + + def acquire(self, timeout: Optional[float] = None): + self._check_closed() + t = timeout or self._timeout_sec + return self._caller.safe_call_with_result( + self._async_lock.acquire(), + t, + ) + + def release(self, timeout: Optional[float] = None): + if self._closed: + return + t = timeout or self._timeout_sec + return self._caller.safe_call_with_result( + self._async_lock.release(), + t, + ) + + def describe(self, timeout: Optional[float] = None): + self._check_closed() + t = timeout or self._timeout_sec + return self._caller.safe_call_with_result( + self._async_lock.describe(), + t, + ) + + def update(self, new_data: bytes, timeout: Optional[float] = None): + self._check_closed() + t = timeout or self._timeout_sec + return self._caller.safe_call_with_result( + self._async_lock.update(new_data), + t, + ) + + def close(self, timeout: Optional[float] = None): + if self._closed: + return + t = timeout or self._timeout_sec + try: + self._caller.safe_call_with_result( + self._async_lock.release(), + t, + ) + finally: + self._closed = True diff --git a/ydb/coordination/node_sync.py b/ydb/coordination/node_sync.py new file mode 100644 index 00000000..1f6f5d19 --- /dev/null +++ b/ydb/coordination/node_sync.py @@ -0,0 +1,37 @@ +from .._topic_common.common import _get_shared_event_loop, CallFromSyncToAsync +from ..aio.coordination.node import CoordinationNode +from .lock_sync import CoordinationLockSync + + +class CoordinationNodeSync: + def __init__(self, client, path: str, timeout_sec: float = 5): + self._client = client + self._path = path + self._timeout_sec = timeout_sec + + self._caller = CallFromSyncToAsync(_get_shared_event_loop()) + + self._async_node: CoordinationNode = CoordinationNode( + client._driver, + path, + ) + + self._closed = False + + def lock(self, name: str): + return CoordinationLockSync(self, name) + + def close(self): + if self._closed: + return + self._caller.safe_call_with_result( + self._async_node.close(), + self._timeout_sec, + ) + self._closed = True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close()