diff --git a/google/cloud/spanner_v1/_async/database_sessions_manager.py b/google/cloud/spanner_v1/_async/database_sessions_manager.py index 344cbd7d7c..e3f54dbebd 100644 --- a/google/cloud/spanner_v1/_async/database_sessions_manager.py +++ b/google/cloud/spanner_v1/_async/database_sessions_manager.py @@ -71,9 +71,9 @@ def __init__(self, database, pool): self._pool = pool self._multiplexed_session: Optional[Session] = None self._multiplexed_session_thread: Optional[CrossSync.Task] = None - # Use threading.Lock because this is accessed in a synchronous maintenance thread - self._multiplexed_session_lock: threading.Lock = threading.Lock() - self._multiplexed_session_terminate_event: CrossSync.Event = CrossSync.Event() + self._init_lock = threading.Lock() + self._multiplexed_session_lock: Optional[CrossSync.Lock] = None + self._multiplexed_session_terminate_event: Optional[CrossSync.Event] = None @CrossSync.convert async def get_session(self, transaction_type: TransactionType) -> Session: @@ -119,7 +119,13 @@ async def _get_multiplexed_session(self) -> Session: :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: a multiplexed session.""" - with CrossSync.rm_aio(self._multiplexed_session_lock): + with self._init_lock: + if self._multiplexed_session_lock is None: + self._multiplexed_session_lock = CrossSync.Lock() + if self._multiplexed_session_terminate_event is None: + self._multiplexed_session_terminate_event = CrossSync.Event() + + async with self._multiplexed_session_lock: if self._multiplexed_session is None: self._multiplexed_session = await self._build_multiplexed_session() self._multiplexed_session_thread = self._build_maintenance_thread() @@ -193,7 +199,7 @@ async def _maintain_multiplexed_session(session_manager_ref) -> None: if time() - session_created_time < refresh_interval_seconds: await CrossSync.sleep(polling_interval_seconds) continue - with manager._multiplexed_session_lock: + async with manager._multiplexed_session_lock: await CrossSync.run_if_async(manager._multiplexed_session.delete) manager._multiplexed_session = ( await manager._build_multiplexed_session() @@ -220,7 +226,8 @@ def _getenv(cls, env_var_name: str) -> bool: @CrossSync.convert async def close(self) -> None: """Closes the database session manager and stops all background tasks.""" - self._multiplexed_session_terminate_event.set() + if self._multiplexed_session_terminate_event is not None: + self._multiplexed_session_terminate_event.set() if self._multiplexed_session_thread is not None: if CrossSync.is_async: self._multiplexed_session_thread.cancel() diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index b783495766..ebec36bc20 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -243,7 +243,7 @@ def wrapped_method(): max_commit_delay=max_commit_delay, request_options=request_options, ) - (call_metadata, error_augmenter) = database.with_error_augmentation( + call_metadata, error_augmenter = database.with_error_augmentation( getattr(database, "_next_nth_request", 0), 1, metadata, span ) commit_method = functools.partial( diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 4098ae2078..c4be870609 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -82,7 +82,6 @@ trace_call, ) from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture - from google.cloud.spanner_v1.table import Table SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data" @@ -211,11 +210,9 @@ def __init__( def _resource_info(self): """Resource information for metrics labels.""" return { - "project": ( - self._instance._client.project - if self._instance and self._instance._client - else None - ), + "project": self._instance._client.project + if self._instance and self._instance._client + else None, "instance": self._instance.instance_id if self._instance else None, "database": self.database_id, } @@ -533,7 +530,7 @@ def with_error_augmentation( tuple: (metadata_list, context_manager)""" if span is None: span = get_current_span() - (metadata, request_id) = _metadata_with_request_id_and_req_id( + metadata, request_id = _metadata_with_request_id_and_req_id( self._nth_client_id, self._channel_id, nth_request, @@ -810,7 +807,7 @@ def execute_pdml(): session = self._sessions_manager.get_session(transaction_type) try: add_span_event(span, "Starting BeginTransaction") - (call_metadata, error_augmenter) = self.with_error_augmentation( + call_metadata, error_augmenter = self.with_error_augmentation( self._next_nth_request, 1, metadata, span ) with error_augmenter: diff --git a/google/cloud/spanner_v1/database_sessions_manager.py b/google/cloud/spanner_v1/database_sessions_manager.py index 27b0c116de..a492c497a2 100644 --- a/google/cloud/spanner_v1/database_sessions_manager.py +++ b/google/cloud/spanner_v1/database_sessions_manager.py @@ -69,10 +69,11 @@ def __init__(self, database, pool): self._pool = pool self._multiplexed_session: Optional[Session] = None self._multiplexed_session_thread: Optional[CrossSync._Sync_Impl.Task] = None - self._multiplexed_session_lock: threading.Lock = threading.Lock() - self._multiplexed_session_terminate_event: CrossSync._Sync_Impl.Event = ( - CrossSync._Sync_Impl.Event() - ) + self._init_lock = threading.Lock() + self._multiplexed_session_lock: Optional[CrossSync._Sync_Impl.Lock] = None + self._multiplexed_session_terminate_event: Optional[ + CrossSync._Sync_Impl.Event + ] = None def get_session(self, transaction_type: TransactionType) -> Session: """Returns a session for the given transaction type from the database session manager. @@ -115,6 +116,11 @@ def _get_multiplexed_session(self) -> Session: :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: a multiplexed session.""" + with self._init_lock: + if self._multiplexed_session_lock is None: + self._multiplexed_session_lock = CrossSync._Sync_Impl.Lock() + if self._multiplexed_session_terminate_event is None: + self._multiplexed_session_terminate_event = CrossSync._Sync_Impl.Event() with self._multiplexed_session_lock: if self._multiplexed_session is None: self._multiplexed_session = self._build_multiplexed_session() @@ -205,7 +211,8 @@ def _getenv(cls, env_var_name: str) -> bool: def close(self) -> None: """Closes the database session manager and stops all background tasks.""" - self._multiplexed_session_terminate_event.set() + if self._multiplexed_session_terminate_event is not None: + self._multiplexed_session_terminate_event.set() if self._multiplexed_session_thread is not None: self._multiplexed_session_thread.join() if self._multiplexed_session is not None: diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 8a86fd3503..fca1303d75 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -479,7 +479,9 @@ def database( database_role=database_role, enable_drop_protection=enable_drop_protection, ) - db._pool.bind(db) + res = db._pool.bind(db) + if res is not None: + res return db def list_databases(self, page_size=None): diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 543db56e76..4d17911da3 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -304,7 +304,7 @@ def _fill_pool(self): f"Creating {request.session_count} sessions", span_event_attributes, ) - (call_metadata, error_augmenter) = database.with_error_augmentation( + call_metadata, error_augmenter = database.with_error_augmentation( database._next_nth_request, 1, metadata, span ) with error_augmenter: @@ -612,7 +612,7 @@ def bind(self, database): ) as span, MetricsCapture(self._resource_info): returned_session_count = 0 while returned_session_count < self.size: - (call_metadata, error_augmenter) = database.with_error_augmentation( + call_metadata, error_augmenter = database.with_error_augmentation( database._next_nth_request, 1, metadata, span ) with error_augmenter: @@ -654,7 +654,7 @@ def get(self, timeout=None): ping_after = None session = None try: - (ping_after, session) = CrossSync._Sync_Impl.queue_get( + ping_after, session = CrossSync._Sync_Impl.queue_get( self._sessions, block=True, timeout=timeout ) except CrossSync._Sync_Impl.QueueEmpty as e: @@ -698,9 +698,7 @@ def clear(self): """Delete all sessions in the pool.""" while True: try: - (_, session) = CrossSync._Sync_Impl.queue_get( - self._sessions, block=False - ) + _, session = CrossSync._Sync_Impl.queue_get(self._sessions, block=False) except CrossSync._Sync_Impl.QueueEmpty: break else: @@ -713,7 +711,7 @@ def ping(self): or during the "idle" phase of an event loop.""" while True: try: - (ping_after, session) = CrossSync._Sync_Impl.queue_get( + ping_after, session = CrossSync._Sync_Impl.queue_get( self._sessions, block=False ) except CrossSync._Sync_Impl.QueueEmpty: diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 336c659669..82cb71fb9e 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -188,7 +188,7 @@ def create(self): observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(self._resource_info): - (call_metadata, error_augmenter) = database.with_error_augmentation( + call_metadata, error_augmenter = database.with_error_augmentation( nth_request, 1, metadata, span ) with error_augmenter: @@ -232,7 +232,7 @@ def exists(self): observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(self._resource_info): - (call_metadata, error_augmenter) = database.with_error_augmentation( + call_metadata, error_augmenter = database.with_error_augmentation( nth_request, 1, metadata, span ) with error_augmenter: @@ -283,7 +283,7 @@ def delete(self): observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(self._resource_info): - (call_metadata, error_augmenter) = database.with_error_augmentation( + call_metadata, error_augmenter = database.with_error_augmentation( nth_request, 1, metadata, span ) with error_augmenter: @@ -300,7 +300,7 @@ def ping(self): metadata = _metadata_with_prefix(database.name) nth_request = database._next_nth_request with trace_call("CloudSpanner.Session.ping", self) as span: - (call_metadata, error_augmenter) = database.with_error_augmentation( + call_metadata, error_augmenter = database.with_error_augmentation( nth_request, 1, metadata, span ) with error_augmenter: diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 533fb0aaa3..d7d7451775 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -322,7 +322,7 @@ def execute_sql( raise ValueError("Transaction has not begun.") if params is not None: params_pb = Struct( - fields={key: _make_value_pb(value) for (key, value) in params.items()} + fields={key: _make_value_pb(value) for key, value in params.items()} ) else: params_pb = {} @@ -513,7 +513,7 @@ def partition_query( raise ValueError("Cannot partition a single-use transaction.") if params is not None: params_pb = Struct( - fields={key: _make_value_pb(value) for (key, value) in params.items()} + fields={key: _make_value_pb(value) for key, value in params.items()} ) else: params_pb = Struct() @@ -614,7 +614,7 @@ def wrapped_method(): begin_transaction_request = BeginTransactionRequest( **begin_request_kwargs ) - (call_metadata, error_augmenter) = database.with_error_augmentation( + call_metadata, error_augmenter = database.with_error_augmentation( nth_request, attempt.increment(), metadata, span ) begin_transaction_method = functools.partial( diff --git a/google/cloud/spanner_v1/streamed.py b/google/cloud/spanner_v1/streamed.py index 21bdc889ac..48630ef574 100644 --- a/google/cloud/spanner_v1/streamed.py +++ b/google/cloud/spanner_v1/streamed.py @@ -147,7 +147,7 @@ def _consume_next(self): def __iter__(self): while True: - (iter_rows, self._rows[:]) = (self._rows[:], ()) + iter_rows, self._rows[:] = (self._rows[:], ()) while iter_rows: yield iter_rows.pop(0) if self._done: @@ -230,7 +230,7 @@ def to_dict_list(self): rows.append( { column: value - for (column, value) in zip( + for column, value in zip( [column.name for column in self._metadata.row_type.fields], row ) } @@ -291,7 +291,7 @@ def _merge_array(lhs, rhs, type_): if element_type.code in _UNMERGEABLE_TYPES: lhs.list_value.values.extend(rhs.list_value.values) return lhs - (lhs, rhs) = (list(lhs.list_value.values), list(rhs.list_value.values)) + lhs, rhs = (list(lhs.list_value.values), list(rhs.list_value.values)) if not len(lhs) or not len(rhs): return Value(list_value=ListValue(values=lhs + rhs)) first = rhs.pop(0) @@ -316,7 +316,7 @@ def _merge_array(lhs, rhs, type_): def _merge_struct(lhs, rhs, type_): """Helper for '_merge_by_type'.""" fields = type_.struct_type.fields - (lhs, rhs) = (list(lhs.list_value.values), list(rhs.list_value.values)) + lhs, rhs = (list(lhs.list_value.values), list(rhs.list_value.values)) if not len(lhs) or not len(rhs): return Value(list_value=ListValue(values=lhs + rhs)) candidate_type = fields[len(lhs) - 1].type_ diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 9e14387ea8..254604dfa6 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -162,7 +162,7 @@ def rollback(self) -> None: def wrapped_method(*args, **kwargs): attempt.increment() - (call_metadata, error_augmenter) = database.with_error_augmentation( + call_metadata, error_augmenter = database.with_error_augmentation( nth_request, attempt.value, metadata, span ) rollback_method = functools.partial( @@ -269,7 +269,7 @@ def wrapped_method(*args, **kwargs): is_multiplexed = getattr(self._session, "is_multiplexed", False) if is_multiplexed and self._precommit_token is not None: commit_request_args["precommit_token"] = self._precommit_token - (call_metadata, error_augmenter) = database.with_error_augmentation( + call_metadata, error_augmenter = database.with_error_augmentation( nth_request, attempt.value, metadata, span ) commit_method = functools.partial( @@ -300,7 +300,7 @@ def before_next_retry(nth_retry, delay_in_seconds): if commit_response_pb._pb.HasField("precommit_token"): add_span_event(span, commit_retry_event_name) nth_request = database._next_nth_request - (call_metadata, error_augmenter) = database.with_error_augmentation( + call_metadata, error_augmenter = database.with_error_augmentation( nth_request, 1, metadata, span ) with error_augmenter: @@ -338,7 +338,7 @@ def _make_params_pb(params, param_types): If ``params`` is None but ``param_types`` is not None.""" if params: return Struct( - fields={key: _make_value_pb(value) for (key, value) in params.items()} + fields={key: _make_value_pb(value) for key, value in params.items()} ) return {} @@ -417,7 +417,7 @@ def execute_update( metadata.append( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) - (seqno, self._execute_sql_request_count) = ( + seqno, self._execute_sql_request_count = ( self._execute_sql_request_count, self._execute_sql_request_count + 1, ) @@ -454,7 +454,7 @@ def execute_update( def wrapped_method(*args, **kwargs): attempt.increment() - (call_metadata, error_augmenter) = database.with_error_augmentation( + call_metadata, error_augmenter = database.with_error_augmentation( nth_request, attempt.value, metadata ) execute_sql_method = functools.partial( @@ -544,7 +544,7 @@ def batch_update( if isinstance(statement, str): parsed.append(ExecuteBatchDmlRequest.Statement(sql=statement)) else: - (dml, params, param_types) = statement + dml, params, param_types = statement params_pb = self._make_params_pb(params, param_types) parsed.append( ExecuteBatchDmlRequest.Statement( @@ -556,7 +556,7 @@ def batch_update( metadata.append( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) - (seqno, self._execute_sql_request_count) = ( + seqno, self._execute_sql_request_count = ( self._execute_sql_request_count, self._execute_sql_request_count + 1, ) @@ -590,7 +590,7 @@ def batch_update( def wrapped_method(*args, **kwargs): attempt.increment() - (call_metadata, error_augmenter) = database.with_error_augmentation( + call_metadata, error_augmenter = database.with_error_augmentation( nth_request, attempt.value, metadata ) execute_batch_dml_method = functools.partial( diff --git a/tests/unit/_async/test_sessions_manager_extra.py b/tests/unit/_async/test_sessions_manager_extra.py index c860965f7c..56e07f970b 100644 --- a/tests/unit/_async/test_sessions_manager_extra.py +++ b/tests/unit/_async/test_sessions_manager_extra.py @@ -69,6 +69,7 @@ async def test_maintenance_thread_sync_branch(self): async def test_maintain_multiplexed_session_terminate(self): # coverage for line 191-193 manager = DatabaseSessionsManager(self.database, self.pool) + manager._multiplexed_session_terminate_event = asyncio.Event() manager._multiplexed_session_terminate_event.set() from weakref import ref @@ -127,6 +128,8 @@ async def fake_coro(): async def test_maintain_multiplexed_session_refresh(self): # coverage for line 196-202 manager = DatabaseSessionsManager(self.database, self.pool) + manager._multiplexed_session_lock = asyncio.Lock() + manager._multiplexed_session_terminate_event = asyncio.Event() manager._multiplexed_session = mock.AsyncMock() # We need to simulate time passing and then terminating @@ -190,6 +193,8 @@ def mock_ref(): async def test_maintain_multiplexed_session_loop_sleep(self): # coverage for line 196 manager = DatabaseSessionsManager(self.database, self.pool) + manager._multiplexed_session_lock = asyncio.Lock() + manager._multiplexed_session_terminate_event = asyncio.Event() call_count = 0 def mock_time(): diff --git a/tests/unit/test_database_session_manager.py b/tests/unit/test_database_session_manager.py index 87b7fdd972..3aa1534f69 100644 --- a/tests/unit/test_database_session_manager.py +++ b/tests/unit/test_database_session_manager.py @@ -199,6 +199,55 @@ def test_multiplexed_maintenance(self): self.assertTrue(session_2.is_multiplexed) self.assertNotEqual(session_1, session_2) + def test_concurrent_get_multiplexed_session_no_deadlock(self): + """Verify that concurrent _get_multiplexed_session calls do not deadlock. + This tests that holding the lock across suspension points (like asyncio.sleep) + doesn't freeze the event loop for subsequent lock seekers using CrossSync.Lock. + """ + import asyncio + from google.cloud.spanner_v1._async.database_sessions_manager import ( + DatabaseSessionsManager, + ) + from os import environ + + # Build fresh async manager decoupling from test suite setup + db = Mock() + db._experimental_host = None + db.database_role = "reader" + pool = Mock() + + manager = DatabaseSessionsManager(db, pool) + + # Mock maintenance thread creation to avoid spawning background tasks + manager._build_maintenance_thread = Mock(return_value=Mock()) + + # Mock _build_multiplexed_session to include a suspension point + async def slow_build(): + await asyncio.sleep(0.5) + return Mock() + + manager._build_multiplexed_session = slow_build + + # Enable multiplexed sessions in environment for verification + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "true" + + async def run_concurrent(): + # Trigger Coroutine 1 + task1 = asyncio.create_task(manager._get_multiplexed_session()) + await asyncio.sleep(0.1) # Allow Coroutine 1 to acquire lock and suspend + + # Trigger Coroutine 2 - this would previously block the main thread + task2 = asyncio.create_task(manager._get_multiplexed_session()) + + await asyncio.gather(task1, task2) + + try: + asyncio.run(asyncio.wait_for(run_concurrent(), timeout=5.0)) + except asyncio.TimeoutError: + self.fail( + "test_concurrent_get_multiplexed_session_no_deadlock timed out (DEADLOCK)!" + ) + def test_exception_bad_request(self): manager = self._manager api = manager._database.spanner_api