11import os
22import uuid
3+ from asyncio import Future , Queue , Task , create_task
4+ from collections .abc import AsyncIterator
5+ from copy import deepcopy
36
47import cv2 as cv
58import numpy as np
69import pytest
710import pytest_asyncio
811from dotenv import load_dotenv
12+ from grpc .aio import Channel
913
14+ from resolver_athena_client .client .athena_client import AthenaClient
1015from resolver_athena_client .client .athena_options import AthenaOptions
11- from resolver_athena_client .client .channel import CredentialHelper
16+ from resolver_athena_client .client .channel import (
17+ CredentialHelper ,
18+ create_channel_with_credentials ,
19+ )
1220from resolver_athena_client .client .consts import (
1321 EXPECTED_HEIGHT ,
1422 EXPECTED_WIDTH ,
1523 MAX_DEPLOYMENT_ID_LENGTH ,
1624)
25+ from resolver_athena_client .client .models .input_model import ImageData
26+ from resolver_athena_client .generated .athena .models_pb2 import (
27+ ClassificationOutput ,
28+ )
1729
1830
1931def _create_base_test_image_opencv (width : int , height : int ) -> np .ndarray :
@@ -79,7 +91,7 @@ async def credential_helper() -> CredentialHelper:
7991 )
8092
8193
82- @pytest .fixture
94+ @pytest .fixture ( scope = "session" )
8395def athena_options () -> AthenaOptions :
8496 _ = load_dotenv ()
8597 host = os .getenv ("ATHENA_HOST" , "localhost" )
@@ -99,6 +111,7 @@ def athena_options() -> AthenaOptions:
99111 timeout = 120.0 , # Maximum duration, not forced timeout
100112 keepalive_interval = 30.0 , # Longer intervals for persistent streams
101113 affiliate = affiliate ,
114+ compression_quality = 2 ,
102115 )
103116
104117
@@ -144,3 +157,76 @@ def valid_formatted_image(
144157 _ = f .write (image_bytes )
145158
146159 return image_bytes
160+
161+
162+ class StreamingSender :
163+ """Helper class to provide a single-send-like interface with speed
164+
165+ The class provides a 'send' method that can be passed an imagedata and will
166+ send it along a stream, and collect all results into an internal buffer.
167+
168+ The 'send' method will asynchronously wait for the result and return it,
169+ providing an interface that mimics a single request-response call, while
170+ under the hood it is using a streaming connection for speed.
171+ """
172+
173+ def __init__ (self , grpc_channel : Channel , options : AthenaOptions ) -> None :
174+ self ._results : list [ClassificationOutput ] = []
175+ self ._request_queue : Queue [ImageData ] = Queue ()
176+ self ._pending_results : dict [str , Future [ClassificationOutput ]] = {}
177+
178+ # tests are run in series, so we gain nothing here from waiting for a
179+ # batch that will never fill, so just send it immediately for better
180+ # latency
181+ streaming_options = deepcopy (options )
182+ streaming_options .max_batch_size = 1
183+
184+ self ._run_task : Task [None ] = create_task (
185+ self ._run (grpc_channel , streaming_options )
186+ )
187+
188+ async def _run (self , grpc_channel : Channel , options : AthenaOptions ) -> None :
189+ async with AthenaClient (grpc_channel , options ) as client :
190+ generator = self ._send_from_queue ()
191+ responses = client .classify_images (generator )
192+ async for response in responses :
193+ for output in response .outputs :
194+ if output .correlation_id in self ._pending_results :
195+ future = self ._pending_results .pop (
196+ output .correlation_id
197+ )
198+ future .set_result (output )
199+ self ._results .append (output )
200+
201+ async def _send_from_queue (self ) -> AsyncIterator [ImageData ]:
202+ """Async generator to yield requests from the queue."""
203+ while True :
204+ if image_data := await self ._request_queue .get ():
205+ yield image_data
206+ self ._request_queue .task_done ()
207+
208+ async def send (self , image_data : ImageData ) -> ClassificationOutput :
209+ """Send an image and wait for the corresponding result."""
210+ if self ._run_task .done ():
211+ self ._run_task .result ()
212+
213+ if image_data .correlation_id is None :
214+ image_data .correlation_id = str (uuid .uuid4 ())
215+ future : Future [ClassificationOutput ] = Future ()
216+ self ._pending_results [image_data .correlation_id ] = future
217+
218+ await self ._request_queue .put (image_data )
219+
220+ return await future
221+
222+
223+ @pytest_asyncio .fixture (scope = "session" , loop_scope = "session" )
224+ async def streaming_sender (
225+ athena_options : AthenaOptions , credential_helper : CredentialHelper
226+ ) -> StreamingSender :
227+ """Fixture to provide a helper for sending over a streaming connection."""
228+ # Create gRPC channel with credentials
229+ channel = await create_channel_with_credentials (
230+ athena_options .host , credential_helper
231+ )
232+ return StreamingSender (channel , athena_options )
0 commit comments