Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions google/cloud/spanner_v1/_async/database_sessions_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def __init__(self, database, pool):
self._pool = pool
self._multiplexed_session: Optional[Session] = None
self._multiplexed_session_thread: Optional[CrossSync.Task] = None
# Use threading.Lock because this is accessed in a synchronous maintenance thread
self._multiplexed_session_lock: threading.Lock = threading.Lock()
self._multiplexed_session_terminate_event: CrossSync.Event = CrossSync.Event()
self._init_lock = threading.Lock()
self._multiplexed_session_lock: Optional[CrossSync.Lock] = None
self._multiplexed_session_terminate_event: Optional[CrossSync.Event] = None

@CrossSync.convert
async def get_session(self, transaction_type: TransactionType) -> Session:
Expand Down Expand Up @@ -119,7 +119,13 @@ async def _get_multiplexed_session(self) -> Session:

:rtype: :class:`~google.cloud.spanner_v1.session.Session`
:returns: a multiplexed session."""
with CrossSync.rm_aio(self._multiplexed_session_lock):
with self._init_lock:
if self._multiplexed_session_lock is None:
self._multiplexed_session_lock = CrossSync.Lock()
if self._multiplexed_session_terminate_event is None:
self._multiplexed_session_terminate_event = CrossSync.Event()

async with self._multiplexed_session_lock:
if self._multiplexed_session is None:
self._multiplexed_session = await self._build_multiplexed_session()
self._multiplexed_session_thread = self._build_maintenance_thread()
Expand Down Expand Up @@ -193,7 +199,7 @@ async def _maintain_multiplexed_session(session_manager_ref) -> None:
if time() - session_created_time < refresh_interval_seconds:
await CrossSync.sleep(polling_interval_seconds)
continue
with manager._multiplexed_session_lock:
async with manager._multiplexed_session_lock:
await CrossSync.run_if_async(manager._multiplexed_session.delete)
manager._multiplexed_session = (
await manager._build_multiplexed_session()
Expand All @@ -220,7 +226,8 @@ def _getenv(cls, env_var_name: str) -> bool:
@CrossSync.convert
async def close(self) -> None:
"""Closes the database session manager and stops all background tasks."""
self._multiplexed_session_terminate_event.set()
if self._multiplexed_session_terminate_event is not None:
self._multiplexed_session_terminate_event.set()
if self._multiplexed_session_thread is not None:
if CrossSync.is_async:
self._multiplexed_session_thread.cancel()
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def wrapped_method():
max_commit_delay=max_commit_delay,
request_options=request_options,
)
(call_metadata, error_augmenter) = database.with_error_augmentation(
call_metadata, error_augmenter = database.with_error_augmentation(
getattr(database, "_next_nth_request", 0), 1, metadata, span
)
commit_method = functools.partial(
Expand Down
13 changes: 5 additions & 8 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
trace_call,
)
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture

from google.cloud.spanner_v1.table import Table

SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data"
Expand Down Expand Up @@ -211,11 +210,9 @@ def __init__(
def _resource_info(self):
"""Resource information for metrics labels."""
return {
"project": (
self._instance._client.project
if self._instance and self._instance._client
else None
),
"project": self._instance._client.project
if self._instance and self._instance._client
else None,
"instance": self._instance.instance_id if self._instance else None,
"database": self.database_id,
}
Expand Down Expand Up @@ -533,7 +530,7 @@ def with_error_augmentation(
tuple: (metadata_list, context_manager)"""
if span is None:
span = get_current_span()
(metadata, request_id) = _metadata_with_request_id_and_req_id(
metadata, request_id = _metadata_with_request_id_and_req_id(
self._nth_client_id,
self._channel_id,
nth_request,
Expand Down Expand Up @@ -810,7 +807,7 @@ def execute_pdml():
session = self._sessions_manager.get_session(transaction_type)
try:
add_span_event(span, "Starting BeginTransaction")
(call_metadata, error_augmenter) = self.with_error_augmentation(
call_metadata, error_augmenter = self.with_error_augmentation(
self._next_nth_request, 1, metadata, span
)
with error_augmenter:
Expand Down
17 changes: 12 additions & 5 deletions google/cloud/spanner_v1/database_sessions_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ def __init__(self, database, pool):
self._pool = pool
self._multiplexed_session: Optional[Session] = None
self._multiplexed_session_thread: Optional[CrossSync._Sync_Impl.Task] = None
self._multiplexed_session_lock: threading.Lock = threading.Lock()
self._multiplexed_session_terminate_event: CrossSync._Sync_Impl.Event = (
CrossSync._Sync_Impl.Event()
)
self._init_lock = threading.Lock()
self._multiplexed_session_lock: Optional[CrossSync._Sync_Impl.Lock] = None
self._multiplexed_session_terminate_event: Optional[
CrossSync._Sync_Impl.Event
] = None

def get_session(self, transaction_type: TransactionType) -> Session:
"""Returns a session for the given transaction type from the database session manager.
Expand Down Expand Up @@ -115,6 +116,11 @@ def _get_multiplexed_session(self) -> Session:

:rtype: :class:`~google.cloud.spanner_v1.session.Session`
:returns: a multiplexed session."""
with self._init_lock:
if self._multiplexed_session_lock is None:
self._multiplexed_session_lock = CrossSync._Sync_Impl.Lock()
if self._multiplexed_session_terminate_event is None:
self._multiplexed_session_terminate_event = CrossSync._Sync_Impl.Event()
with self._multiplexed_session_lock:
if self._multiplexed_session is None:
self._multiplexed_session = self._build_multiplexed_session()
Expand Down Expand Up @@ -205,7 +211,8 @@ def _getenv(cls, env_var_name: str) -> bool:

def close(self) -> None:
"""Closes the database session manager and stops all background tasks."""
self._multiplexed_session_terminate_event.set()
if self._multiplexed_session_terminate_event is not None:
self._multiplexed_session_terminate_event.set()
if self._multiplexed_session_thread is not None:
self._multiplexed_session_thread.join()
if self._multiplexed_session is not None:
Expand Down
4 changes: 3 additions & 1 deletion google/cloud/spanner_v1/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,9 @@ def database(
database_role=database_role,
enable_drop_protection=enable_drop_protection,
)
db._pool.bind(db)
res = db._pool.bind(db)
if res is not None:
res
return db

def list_databases(self, page_size=None):
Expand Down
12 changes: 5 additions & 7 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _fill_pool(self):
f"Creating {request.session_count} sessions",
span_event_attributes,
)
(call_metadata, error_augmenter) = database.with_error_augmentation(
call_metadata, error_augmenter = database.with_error_augmentation(
database._next_nth_request, 1, metadata, span
)
with error_augmenter:
Expand Down Expand Up @@ -612,7 +612,7 @@ def bind(self, database):
) as span, MetricsCapture(self._resource_info):
returned_session_count = 0
while returned_session_count < self.size:
(call_metadata, error_augmenter) = database.with_error_augmentation(
call_metadata, error_augmenter = database.with_error_augmentation(
database._next_nth_request, 1, metadata, span
)
with error_augmenter:
Expand Down Expand Up @@ -654,7 +654,7 @@ def get(self, timeout=None):
ping_after = None
session = None
try:
(ping_after, session) = CrossSync._Sync_Impl.queue_get(
ping_after, session = CrossSync._Sync_Impl.queue_get(
self._sessions, block=True, timeout=timeout
)
except CrossSync._Sync_Impl.QueueEmpty as e:
Expand Down Expand Up @@ -698,9 +698,7 @@ def clear(self):
"""Delete all sessions in the pool."""
while True:
try:
(_, session) = CrossSync._Sync_Impl.queue_get(
self._sessions, block=False
)
_, session = CrossSync._Sync_Impl.queue_get(self._sessions, block=False)
except CrossSync._Sync_Impl.QueueEmpty:
break
else:
Expand All @@ -713,7 +711,7 @@ def ping(self):
or during the "idle" phase of an event loop."""
while True:
try:
(ping_after, session) = CrossSync._Sync_Impl.queue_get(
ping_after, session = CrossSync._Sync_Impl.queue_get(
self._sessions, block=False
)
except CrossSync._Sync_Impl.QueueEmpty:
Expand Down
8 changes: 4 additions & 4 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def create(self):
observability_options=observability_options,
metadata=metadata,
) as span, MetricsCapture(self._resource_info):
(call_metadata, error_augmenter) = database.with_error_augmentation(
call_metadata, error_augmenter = database.with_error_augmentation(
nth_request, 1, metadata, span
)
with error_augmenter:
Expand Down Expand Up @@ -232,7 +232,7 @@ def exists(self):
observability_options=observability_options,
metadata=metadata,
) as span, MetricsCapture(self._resource_info):
(call_metadata, error_augmenter) = database.with_error_augmentation(
call_metadata, error_augmenter = database.with_error_augmentation(
nth_request, 1, metadata, span
)
with error_augmenter:
Expand Down Expand Up @@ -283,7 +283,7 @@ def delete(self):
observability_options=observability_options,
metadata=metadata,
) as span, MetricsCapture(self._resource_info):
(call_metadata, error_augmenter) = database.with_error_augmentation(
call_metadata, error_augmenter = database.with_error_augmentation(
nth_request, 1, metadata, span
)
with error_augmenter:
Expand All @@ -300,7 +300,7 @@ def ping(self):
metadata = _metadata_with_prefix(database.name)
nth_request = database._next_nth_request
with trace_call("CloudSpanner.Session.ping", self) as span:
(call_metadata, error_augmenter) = database.with_error_augmentation(
call_metadata, error_augmenter = database.with_error_augmentation(
nth_request, 1, metadata, span
)
with error_augmenter:
Expand Down
6 changes: 3 additions & 3 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def execute_sql(
raise ValueError("Transaction has not begun.")
if params is not None:
params_pb = Struct(
fields={key: _make_value_pb(value) for (key, value) in params.items()}
fields={key: _make_value_pb(value) for key, value in params.items()}
)
else:
params_pb = {}
Expand Down Expand Up @@ -513,7 +513,7 @@ def partition_query(
raise ValueError("Cannot partition a single-use transaction.")
if params is not None:
params_pb = Struct(
fields={key: _make_value_pb(value) for (key, value) in params.items()}
fields={key: _make_value_pb(value) for key, value in params.items()}
)
else:
params_pb = Struct()
Expand Down Expand Up @@ -614,7 +614,7 @@ def wrapped_method():
begin_transaction_request = BeginTransactionRequest(
**begin_request_kwargs
)
(call_metadata, error_augmenter) = database.with_error_augmentation(
call_metadata, error_augmenter = database.with_error_augmentation(
nth_request, attempt.increment(), metadata, span
)
begin_transaction_method = functools.partial(
Expand Down
8 changes: 4 additions & 4 deletions google/cloud/spanner_v1/streamed.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _consume_next(self):

def __iter__(self):
while True:
(iter_rows, self._rows[:]) = (self._rows[:], ())
iter_rows, self._rows[:] = (self._rows[:], ())
while iter_rows:
yield iter_rows.pop(0)
if self._done:
Expand Down Expand Up @@ -230,7 +230,7 @@ def to_dict_list(self):
rows.append(
{
column: value
for (column, value) in zip(
for column, value in zip(
[column.name for column in self._metadata.row_type.fields], row
)
}
Expand Down Expand Up @@ -291,7 +291,7 @@ def _merge_array(lhs, rhs, type_):
if element_type.code in _UNMERGEABLE_TYPES:
lhs.list_value.values.extend(rhs.list_value.values)
return lhs
(lhs, rhs) = (list(lhs.list_value.values), list(rhs.list_value.values))
lhs, rhs = (list(lhs.list_value.values), list(rhs.list_value.values))
if not len(lhs) or not len(rhs):
return Value(list_value=ListValue(values=lhs + rhs))
first = rhs.pop(0)
Expand All @@ -316,7 +316,7 @@ def _merge_array(lhs, rhs, type_):
def _merge_struct(lhs, rhs, type_):
"""Helper for '_merge_by_type'."""
fields = type_.struct_type.fields
(lhs, rhs) = (list(lhs.list_value.values), list(rhs.list_value.values))
lhs, rhs = (list(lhs.list_value.values), list(rhs.list_value.values))
if not len(lhs) or not len(rhs):
return Value(list_value=ListValue(values=lhs + rhs))
candidate_type = fields[len(lhs) - 1].type_
Expand Down
18 changes: 9 additions & 9 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def rollback(self) -> None:

def wrapped_method(*args, **kwargs):
attempt.increment()
(call_metadata, error_augmenter) = database.with_error_augmentation(
call_metadata, error_augmenter = database.with_error_augmentation(
nth_request, attempt.value, metadata, span
)
rollback_method = functools.partial(
Expand Down Expand Up @@ -269,7 +269,7 @@ def wrapped_method(*args, **kwargs):
is_multiplexed = getattr(self._session, "is_multiplexed", False)
if is_multiplexed and self._precommit_token is not None:
commit_request_args["precommit_token"] = self._precommit_token
(call_metadata, error_augmenter) = database.with_error_augmentation(
call_metadata, error_augmenter = database.with_error_augmentation(
nth_request, attempt.value, metadata, span
)
commit_method = functools.partial(
Expand Down Expand Up @@ -300,7 +300,7 @@ def before_next_retry(nth_retry, delay_in_seconds):
if commit_response_pb._pb.HasField("precommit_token"):
add_span_event(span, commit_retry_event_name)
nth_request = database._next_nth_request
(call_metadata, error_augmenter) = database.with_error_augmentation(
call_metadata, error_augmenter = database.with_error_augmentation(
nth_request, 1, metadata, span
)
with error_augmenter:
Expand Down Expand Up @@ -338,7 +338,7 @@ def _make_params_pb(params, param_types):
If ``params`` is None but ``param_types`` is not None."""
if params:
return Struct(
fields={key: _make_value_pb(value) for (key, value) in params.items()}
fields={key: _make_value_pb(value) for key, value in params.items()}
)
return {}

Expand Down Expand Up @@ -417,7 +417,7 @@ def execute_update(
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
(seqno, self._execute_sql_request_count) = (
seqno, self._execute_sql_request_count = (
self._execute_sql_request_count,
self._execute_sql_request_count + 1,
)
Expand Down Expand Up @@ -454,7 +454,7 @@ def execute_update(

def wrapped_method(*args, **kwargs):
attempt.increment()
(call_metadata, error_augmenter) = database.with_error_augmentation(
call_metadata, error_augmenter = database.with_error_augmentation(
nth_request, attempt.value, metadata
)
execute_sql_method = functools.partial(
Expand Down Expand Up @@ -544,7 +544,7 @@ def batch_update(
if isinstance(statement, str):
parsed.append(ExecuteBatchDmlRequest.Statement(sql=statement))
else:
(dml, params, param_types) = statement
dml, params, param_types = statement
params_pb = self._make_params_pb(params, param_types)
parsed.append(
ExecuteBatchDmlRequest.Statement(
Expand All @@ -556,7 +556,7 @@ def batch_update(
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
(seqno, self._execute_sql_request_count) = (
seqno, self._execute_sql_request_count = (
self._execute_sql_request_count,
self._execute_sql_request_count + 1,
)
Expand Down Expand Up @@ -590,7 +590,7 @@ def batch_update(

def wrapped_method(*args, **kwargs):
attempt.increment()
(call_metadata, error_augmenter) = database.with_error_augmentation(
call_metadata, error_augmenter = database.with_error_augmentation(
nth_request, attempt.value, metadata
)
execute_batch_dml_method = functools.partial(
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/_async/test_sessions_manager_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ async def test_maintenance_thread_sync_branch(self):
async def test_maintain_multiplexed_session_terminate(self):
# coverage for line 191-193
manager = DatabaseSessionsManager(self.database, self.pool)
manager._multiplexed_session_terminate_event = asyncio.Event()
manager._multiplexed_session_terminate_event.set()

from weakref import ref
Expand Down Expand Up @@ -127,6 +128,8 @@ async def fake_coro():
async def test_maintain_multiplexed_session_refresh(self):
# coverage for line 196-202
manager = DatabaseSessionsManager(self.database, self.pool)
manager._multiplexed_session_lock = asyncio.Lock()
manager._multiplexed_session_terminate_event = asyncio.Event()
manager._multiplexed_session = mock.AsyncMock()

# We need to simulate time passing and then terminating
Expand Down Expand Up @@ -190,6 +193,8 @@ def mock_ref():
async def test_maintain_multiplexed_session_loop_sleep(self):
# coverage for line 196
manager = DatabaseSessionsManager(self.database, self.pool)
manager._multiplexed_session_lock = asyncio.Lock()
manager._multiplexed_session_terminate_event = asyncio.Event()
call_count = 0

def mock_time():
Expand Down
Loading
Loading