From f93bed64edb94cc67f6fb9363e8e610529a698df Mon Sep 17 00:00:00 2001 From: waiho-gumloop Date: Thu, 26 Mar 2026 18:47:37 -0700 Subject: [PATCH] feat(dbapi): wire timeout parameter through Connection to execute_sql The DBAPI layer calls _SnapshotBase.execute_sql() in three code paths (snapshot reads, transaction reads/writes, autocommit DML) but never passes the timeout= argument. This causes all queries to use the gRPC default timeout of 3600 seconds. Add a timeout property to Connection and pass it through to execute_sql() in cursor._handle_DQL_with_snapshot(), cursor._do_execute_update_in_autocommit(), and connection.run_statement(). Fixes #1534 --- google/cloud/spanner_dbapi/connection.py | 34 +++++++++++- google/cloud/spanner_dbapi/cursor.py | 18 +++++-- tests/unit/spanner_dbapi/test_connection.py | 50 +++++++++++++++++ tests/unit/spanner_dbapi/test_cursor.py | 60 ++++++++++++++++++++- 4 files changed, 155 insertions(+), 7 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index fdfce994fd..8a47394bc2 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -111,6 +111,7 @@ def __init__(self, instance, database=None, read_only=False, **kwargs): self._read_only = read_only self._staleness = None self.request_priority = None + self._timeout = None self._transaction_begin_marked = False self._transaction_isolation_level = None # whether transaction started at Spanner. This means that we had @@ -347,6 +348,30 @@ def staleness(self, value): self._staleness = value + @property + def timeout(self): + """Timeout in seconds for the next SQL operation on this connection. + + When set, this value is passed as the ``timeout`` argument to + ``execute_sql`` calls on the underlying Spanner client, controlling + the gRPC deadline for those calls. + + Returns: + Optional[float]: The timeout in seconds, or None to use the + default gRPC timeout (3600s). + """ + return self._timeout + + @timeout.setter + def timeout(self, value): + """Set the timeout for subsequent SQL operations. + + Args: + value (Optional[float]): Timeout in seconds. Set to None to + revert to the default gRPC timeout. + """ + self._timeout = value + def _session_checkout(self): """Get a Cloud Spanner session from the pool. @@ -559,11 +584,16 @@ def run_statement( checksum of this statement results. """ transaction = self.transaction_checkout() + kwargs = dict( + param_types=statement.param_types, + request_options=request_options or self.request_options, + ) + if self._timeout is not None: + kwargs["timeout"] = self._timeout return transaction.execute_sql( statement.sql, statement.params, - param_types=statement.param_types, - request_options=request_options or self.request_options, + **kwargs, ) @check_not_closed diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index d40ad7ed07..503e51027a 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -227,12 +227,17 @@ def _do_execute_update_in_autocommit(self, transaction, sql, params): """This function should only be used in autocommit mode.""" self.connection._transaction = transaction self.connection._snapshot = None - self._result_set = transaction.execute_sql( - sql, + kwargs = dict( params=params, param_types=get_param_types(params), last_statement=True, ) + if self.connection._timeout is not None: + kwargs["timeout"] = self.connection._timeout + self._result_set = transaction.execute_sql( + sql, + **kwargs, + ) self._itr = PeekIterator(self._result_set) self._row_count = None @@ -541,11 +546,16 @@ def _fetch(self, cursor_statement_type, size=None): return rows def _handle_DQL_with_snapshot(self, snapshot, sql, params): + kwargs = dict( + param_types=get_param_types(params), + request_options=self.request_options, + ) + if self.connection._timeout is not None: + kwargs["timeout"] = self.connection._timeout self._result_set = snapshot.execute_sql( sql, params, - get_param_types(params), - request_options=self.request_options, + **kwargs, ) # Read the first element so that the StreamedResultSet can # return the metadata after a DQL statement. diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 5fc7164ced..66a7057243 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -838,6 +838,56 @@ def test_request_priority(self): sql, params, param_types=param_types, request_options=None ) + def test_timeout_default_none(self): + connection = self._make_connection() + self.assertIsNone(connection.timeout) + + def test_timeout_property(self): + connection = self._make_connection() + connection.timeout = 60 + self.assertEqual(connection.timeout, 60) + + connection.timeout = None + self.assertIsNone(connection.timeout) + + def test_timeout_passed_to_run_statement(self): + from google.cloud.spanner_dbapi.parsed_statement import Statement + + sql = "SELECT 1" + params = [] + param_types = {} + + connection = self._make_connection() + connection._spanner_transaction_started = True + connection._transaction = mock.Mock() + connection._transaction.execute_sql = mock.Mock() + + connection.timeout = 60 + + connection.run_statement(Statement(sql, params, param_types)) + + connection._transaction.execute_sql.assert_called_with( + sql, params, param_types=param_types, request_options=None, timeout=60 + ) + + def test_timeout_not_passed_when_none(self): + from google.cloud.spanner_dbapi.parsed_statement import Statement + + sql = "SELECT 1" + params = [] + param_types = {} + + connection = self._make_connection() + connection._spanner_transaction_started = True + connection._transaction = mock.Mock() + connection._transaction.execute_sql = mock.Mock() + + connection.run_statement(Statement(sql, params, param_types)) + + connection._transaction.execute_sql.assert_called_with( + sql, params, param_types=param_types, request_options=None + ) + def test_custom_client_connection(self): from google.cloud.spanner_dbapi import connect diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 4366d2c519..ce0f5f59ef 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -122,6 +122,64 @@ def test_do_execute_update(self): self.assertEqual(cursor._result_set, result_set) self.assertEqual(cursor.rowcount, 1234) + def test_do_execute_update_with_timeout(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + connection._timeout = 30 + cursor = self._make_one(connection) + transaction = mock.MagicMock() + + cursor._do_execute_update_in_autocommit( + transaction=transaction, + sql="UPDATE t SET x=1 WHERE true", + params={}, + ) + + transaction.execute_sql.assert_called_once_with( + "UPDATE t SET x=1 WHERE true", + params={}, + param_types={}, + last_statement=True, + timeout=30, + ) + + def test_handle_DQL_with_snapshot_timeout(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + connection._timeout = 45 + cursor = self._make_one(connection) + + snapshot = mock.MagicMock() + result_set = mock.MagicMock() + result_set.metadata.transaction.read_timestamp = None + snapshot.execute_sql.return_value = result_set + + cursor._handle_DQL_with_snapshot(snapshot, "SELECT 1", None) + + snapshot.execute_sql.assert_called_once_with( + "SELECT 1", + None, + param_types=None, + request_options=None, + timeout=45, + ) + + def test_handle_DQL_with_snapshot_no_timeout(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + snapshot = mock.MagicMock() + result_set = mock.MagicMock() + result_set.metadata.transaction.read_timestamp = None + snapshot.execute_sql.return_value = result_set + + cursor._handle_DQL_with_snapshot(snapshot, "SELECT 1", None) + + snapshot.execute_sql.assert_called_once_with( + "SELECT 1", + None, + param_types=None, + request_options=None, + ) + def test_do_batch_update(self): from google.cloud.spanner_dbapi import connect from google.cloud.spanner_v1.param_types import INT64 @@ -952,7 +1010,7 @@ def test_handle_dql_priority(self, MockedPeekIterator): self.assertEqual(cursor._itr, MockedPeekIterator()) self.assertEqual(cursor._row_count, None) mock_snapshot.execute_sql.assert_called_with( - sql, None, None, request_options=RequestOptions(priority=1) + sql, None, param_types=None, request_options=RequestOptions(priority=1) ) def test_handle_dql_database_error(self):