Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/altertable_flightsql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,18 @@ def execute(

# Execute via DoPut
writer, reader = self._client.do_put(descriptor, pa.schema([]))
writer.close()
# Signal end of upload while keeping the read side open to receive the
# server's DoPutUpdateResult metadata. writer.close() would close both
# sides prematurely, causing reader.read() to return None.
writer.done_writing()

# Read result from metadata
result = sql_pb2.DoPutUpdateResult()
metadata = reader.read()
if metadata:
result.ParseFromString(metadata)
result.ParseFromString(bytes(metadata))

writer.close()
return result.record_count

def ingest(
Expand Down
44 changes: 44 additions & 0 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Integration tests for client.execute() DML row count reporting.

Verifies that execute() correctly returns the number of rows
affected by INSERT, UPDATE, and DELETE statements.
"""

from altertable_flightsql import Client
from tests.conftest import TableInfo


class TestExecuteRowCount:
"""Test that execute() returns the correct number of affected rows."""

def test_insert_returns_row_count(self, altertable_client: Client, test_table: TableInfo):
"""Test that INSERT returns the number of inserted rows."""
rows = altertable_client.execute(
f"INSERT INTO {test_table.full_name} (id, name, value) VALUES (4, 'Dave', 400), (5, 'Eve', 500)"
)
assert rows == 2

def test_update_returns_row_count(self, altertable_client: Client, test_table: TableInfo):
"""Test that UPDATE returns the number of updated rows."""
rows = altertable_client.execute(
f"UPDATE {test_table.full_name} SET value = 999 WHERE value >= 200"
)
assert rows == 2

def test_delete_returns_row_count(self, altertable_client: Client, test_table: TableInfo):
"""Test that DELETE returns the number of deleted rows."""
rows = altertable_client.execute(f"DELETE FROM {test_table.full_name} WHERE id IN (1, 2)")
assert rows == 2

def test_delete_no_match_returns_zero(self, altertable_client: Client, test_table: TableInfo):
"""Test that DELETE with no matching rows returns 0."""
rows = altertable_client.execute(f"DELETE FROM {test_table.full_name} WHERE id = 9999")
assert rows == 0

def test_update_no_match_returns_zero(self, altertable_client: Client, test_table: TableInfo):
"""Test that UPDATE with no matching rows returns 0."""
rows = altertable_client.execute(
f"UPDATE {test_table.full_name} SET value = 0 WHERE id = 9999"
)
assert rows == 0