Skip to content

Commit fa14ec9

Browse files
author
anna-singleton-resolver
committed
perf: streaming connection for e2e test cases
1 parent 541957a commit fa14ec9

2 files changed

Lines changed: 115 additions & 40 deletions

File tree

tests/functional/conftest.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,31 @@
11
import os
22
import uuid
3+
from asyncio import Future, Queue, Task, create_task
4+
from collections.abc import AsyncIterator
5+
from copy import deepcopy
36

47
import cv2 as cv
58
import numpy as np
69
import pytest
710
import pytest_asyncio
811
from dotenv import load_dotenv
12+
from grpc.aio import Channel
913

14+
from resolver_athena_client.client.athena_client import AthenaClient
1015
from resolver_athena_client.client.athena_options import AthenaOptions
11-
from resolver_athena_client.client.channel import CredentialHelper
16+
from resolver_athena_client.client.channel import (
17+
CredentialHelper,
18+
create_channel_with_credentials,
19+
)
1220
from resolver_athena_client.client.consts import (
1321
EXPECTED_HEIGHT,
1422
EXPECTED_WIDTH,
1523
MAX_DEPLOYMENT_ID_LENGTH,
1624
)
25+
from resolver_athena_client.client.models.input_model import ImageData
26+
from resolver_athena_client.generated.athena.models_pb2 import (
27+
ClassificationOutput,
28+
)
1729

1830

1931
def _create_base_test_image_opencv(width: int, height: int) -> np.ndarray:
@@ -79,7 +91,7 @@ async def credential_helper() -> CredentialHelper:
7991
)
8092

8193

82-
@pytest.fixture
94+
@pytest.fixture(scope="session")
8395
def athena_options() -> AthenaOptions:
8496
_ = load_dotenv()
8597
host = os.getenv("ATHENA_HOST", "localhost")
@@ -99,6 +111,7 @@ def athena_options() -> AthenaOptions:
99111
timeout=120.0, # Maximum duration, not forced timeout
100112
keepalive_interval=30.0, # Longer intervals for persistent streams
101113
affiliate=affiliate,
114+
compression_quality=2,
102115
)
103116

104117

@@ -144,3 +157,76 @@ def valid_formatted_image(
144157
_ = f.write(image_bytes)
145158

146159
return image_bytes
160+
161+
162+
class StreamingSender:
163+
"""Helper class to provide a single-send-like interface with speed
164+
165+
The class provides a 'send' method that can be passed an imagedata and will
166+
send it along a stream, and collect all results into an internal buffer.
167+
168+
The 'send' method will asynchronously wait for the result and return it,
169+
providing an interface that mimics a single request-response call, while
170+
under the hood it is using a streaming connection for speed.
171+
"""
172+
173+
def __init__(self, grpc_channel: Channel, options: AthenaOptions) -> None:
174+
self._results: list[ClassificationOutput] = []
175+
self._request_queue: Queue[ImageData] = Queue()
176+
self._pending_results: dict[str, Future[ClassificationOutput]] = {}
177+
178+
# tests are run in series, so we gain nothing here from waiting for a
179+
# batch that will never fill, so just send it immediately for better
180+
# latency
181+
streaming_options = deepcopy(options)
182+
streaming_options.max_batch_size = 1
183+
184+
self._run_task: Task[None] = create_task(
185+
self._run(grpc_channel, streaming_options)
186+
)
187+
188+
async def _run(self, grpc_channel: Channel, options: AthenaOptions) -> None:
189+
async with AthenaClient(grpc_channel, options) as client:
190+
generator = self._send_from_queue()
191+
responses = client.classify_images(generator)
192+
async for response in responses:
193+
for output in response.outputs:
194+
if output.correlation_id in self._pending_results:
195+
future = self._pending_results.pop(
196+
output.correlation_id
197+
)
198+
future.set_result(output)
199+
self._results.append(output)
200+
201+
async def _send_from_queue(self) -> AsyncIterator[ImageData]:
202+
"""Async generator to yield requests from the queue."""
203+
while True:
204+
if image_data := await self._request_queue.get():
205+
yield image_data
206+
self._request_queue.task_done()
207+
208+
async def send(self, image_data: ImageData) -> ClassificationOutput:
209+
"""Send an image and wait for the corresponding result."""
210+
if self._run_task.done():
211+
self._run_task.result()
212+
213+
if image_data.correlation_id is None:
214+
image_data.correlation_id = str(uuid.uuid4())
215+
future: Future[ClassificationOutput] = Future()
216+
self._pending_results[image_data.correlation_id] = future
217+
218+
await self._request_queue.put(image_data)
219+
220+
return await future
221+
222+
223+
@pytest_asyncio.fixture(scope="session", loop_scope="session")
224+
async def streaming_sender(
225+
athena_options: AthenaOptions, credential_helper: CredentialHelper
226+
) -> StreamingSender:
227+
"""Fixture to provide a helper for sending over a streaming connection."""
228+
# Create gRPC channel with credentials
229+
channel = await create_channel_with_credentials(
230+
athena_options.host, credential_helper
231+
)
232+
return StreamingSender(channel, athena_options)

tests/functional/e2e/test_classify_single.py

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,8 @@
22

33
import pytest
44

5-
from resolver_athena_client.client.athena_client import AthenaClient
6-
from resolver_athena_client.client.athena_options import AthenaOptions
7-
from resolver_athena_client.client.channel import (
8-
CredentialHelper,
9-
create_channel_with_credentials,
10-
)
115
from resolver_athena_client.client.models import ImageData
6+
from tests.functional.conftest import StreamingSender
127
from tests.functional.e2e.testcases.parser import (
138
AthenaTestCase,
149
load_test_cases_by_env,
@@ -19,13 +14,12 @@
1914
FP_ERROR_TOLERANCE = 1e-4
2015

2116

22-
@pytest.mark.asyncio
17+
@pytest.mark.asyncio(loop_scope="session")
2318
@pytest.mark.functional
2419
@pytest.mark.e2e
2520
@pytest.mark.parametrize("test_case", TEST_CASES, ids=lambda tc: tc.id)
26-
async def test_classify_single(
27-
athena_options: AthenaOptions,
28-
credential_helper: CredentialHelper,
21+
async def test_e2e_case(
22+
streaming_sender: StreamingSender,
2923
test_case: AthenaTestCase,
3024
) -> None:
3125
"""Functional test for ClassifySingle endpoint and API methods.
@@ -34,38 +28,33 @@ async def test_classify_single(
3428
3529
"""
3630

37-
# Create gRPC channel with credentials
38-
channel = await create_channel_with_credentials(
39-
athena_options.host, credential_helper
40-
)
4131
with Path.open(Path(test_case.filepath), "rb") as f:
4232
image_bytes = f.read()
4333

44-
async with AthenaClient(channel, athena_options) as client:
45-
image_data = ImageData(image_bytes)
34+
image_data = ImageData(image_bytes)
4635

47-
# Classify with auto-generated correlation ID
48-
result = await client.classify_single(image_data)
36+
# Classify with auto-generated correlation ID
37+
result = await streaming_sender.send(image_data)
4938

50-
if result.error.code:
51-
msg = f"Image Result Error: {result.error.message}"
52-
pytest.fail(msg)
39+
if result.error.code:
40+
msg = f"Image Result Error: {result.error.message}"
41+
pytest.fail(msg)
5342

54-
actual_output = {c.label: c.weight for c in result.classifications}
55-
assert set(test_case.expected_output.keys()).issubset(
56-
set(actual_output.keys())
57-
), (
58-
"Expected output to contain labels: ",
59-
f"{test_case.expected_output.keys() - actual_output.keys()}",
43+
actual_output = {c.label: c.weight for c in result.classifications}
44+
assert set(test_case.expected_output.keys()).issubset(
45+
set(actual_output.keys())
46+
), (
47+
"Expected output to contain labels: ",
48+
f"{test_case.expected_output.keys() - actual_output.keys()}",
49+
)
50+
actual_output = {k: actual_output[k] for k in test_case.expected_output}
51+
52+
for label in test_case.expected_output:
53+
expected = test_case.expected_output[label]
54+
actual = actual_output[label]
55+
diff = abs(expected - actual)
56+
assert diff < FP_ERROR_TOLERANCE, (
57+
f"Weight for label '{label}' differs by more than "
58+
f"{FP_ERROR_TOLERANCE}: expected={expected}, actual={actual}, "
59+
f"diff={diff}"
6060
)
61-
actual_output = {k: actual_output[k] for k in test_case.expected_output}
62-
63-
for label in test_case.expected_output:
64-
expected = test_case.expected_output[label]
65-
actual = actual_output[label]
66-
diff = abs(expected - actual)
67-
assert diff < FP_ERROR_TOLERANCE, (
68-
f"Weight for label '{label}' differs by more than "
69-
f"{FP_ERROR_TOLERANCE}: expected={expected}, actual={actual}, "
70-
f"diff={diff}"
71-
)

0 commit comments

Comments
 (0)