Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion mock_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import weaviate
from mock_tests.mock_data import mock_class
from weaviate.config import AdditionalConfig, RetryConfig
from weaviate.connect.base import ConnectionParams, ProtocolParams
from weaviate.proto.v1 import (
batch_delete_pb2,
Expand Down Expand Up @@ -139,6 +140,20 @@ def weaviate_client(
client.close()


@pytest.fixture(scope="function")
def weaviate_client_retry_timeout(
weaviate_mock: HTTPServer, start_grpc_server: grpc.Server
) -> Generator[weaviate.WeaviateClient, None, None]:
client = weaviate.connect_to_local(
port=MOCK_PORT,
host=MOCK_IP,
grpc_port=MOCK_PORT_GRPC,
additional_config=AdditionalConfig(retry=RetryConfig(timeout_ms=500)),
)
yield client
client.close()


@pytest.fixture(scope="function")
def weaviate_timeouts_client(
weaviate_timeouts_mock: HTTPServer, start_grpc_server: grpc.Server
Expand All @@ -148,7 +163,8 @@ def weaviate_timeouts_client(
port=MOCK_PORT,
grpc_port=MOCK_PORT_GRPC,
additional_config=weaviate.classes.init.AdditionalConfig(
timeout=weaviate.classes.init.Timeout(query=0.5, insert=1.5)
timeout=weaviate.classes.init.Timeout(query=0.5, insert=1.5),
retry=weaviate.config.RetryConfig(request_retry_count=5, request_retry_backoff_ms=0),
),
)
yield client
Expand Down Expand Up @@ -253,6 +269,40 @@ def BatchObjects(
class MockRetriesWeaviateService(weaviate_pb2_grpc.WeaviateServicer):
search_count = 0
tenants_count = 0
delete_count = 0
batch_count = 0

def BatchObjects(
self, request: batch_pb2.BatchObjectsRequest, context: grpc.ServicerContext
) -> batch_pb2.BatchObjectsReply:
if self.batch_count == 0:
self.batch_count += 1
context.set_code(grpc.StatusCode.ABORTED)
context.set_details("Aborted")
return batch_pb2.BatchObjectsReply()
if self.batch_count == 1:
self.batch_count += 1
context.set_code(grpc.StatusCode.CANCELLED)
context.set_details("Cancelled")
return batch_pb2.BatchObjectsReply()
return batch_pb2.BatchObjectsReply(
errors=[],
)

def BatchDelete(
self, request: batch_delete_pb2.BatchDeleteRequest, context: grpc.ServicerContext
) -> batch_delete_pb2.BatchDeleteReply:
if self.delete_count == 0:
self.delete_count += 1
context.set_code(grpc.StatusCode.DEADLINE_EXCEEDED)
context.set_details("Deadline Exceeded")
return batch_delete_pb2.BatchDeleteReply()
if self.delete_count == 1:
self.delete_count += 1
context.set_code(grpc.StatusCode.UNAVAILABLE)
context.set_details("Service is unavailable")
return batch_delete_pb2.BatchDeleteReply()
return batch_delete_pb2.BatchDeleteReply(matches=1, failed=0, successful=1, objects=[])

def Search(
self, request: search_get_pb2.SearchRequest, context: grpc.ServicerContext
Expand Down Expand Up @@ -310,6 +360,15 @@ def retries(
return weaviate_client.collections.use("RetriesCollection"), service


@pytest.fixture(scope="function")
def no_retries(
weaviate_client_retry_timeout: weaviate.WeaviateClient, start_grpc_server: grpc.Server
) -> tuple[weaviate.collections.Collection, MockRetriesWeaviateService]:
service = MockRetriesWeaviateService()
weaviate_pb2_grpc.add_WeaviateServicer_to_server(service, start_grpc_server)
return weaviate_client_retry_timeout.collections.use("RetriesCollection"), service


class MockForbiddenWeaviateService(weaviate_pb2_grpc.WeaviateServicer):
def Search(
self, request: search_get_pb2.SearchRequest, context: grpc.ServicerContext
Expand Down
37 changes: 29 additions & 8 deletions mock_tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@
VectorIndexType,
Vectorizers,
)
from weaviate.collections.classes.filters import Filter
from weaviate.connect.base import ConnectionParams, ProtocolParams
from weaviate.connect.integrations import _IntegrationConfig
from weaviate.exceptions import (
BackupCanceledError,
InsufficientPermissionsError,
UnexpectedStatusCodeError,
WeaviateStartUpError,
WeaviateQueryError,
WeaviateBatchError,
WeaviateDeleteManyError,
)

ACCESS_TOKEN = "HELLO!IamAnAccessToken"
Expand Down Expand Up @@ -372,26 +376,43 @@ def test_grpc_retry_logic(
collection = retries[0]
service = retries[1]

with pytest.raises(weaviate.exceptions.WeaviateQueryError):
# checks first call correctly handles INTERNAL error
collection.query.fetch_objects()

# should perform one retry and then succeed subsequently
objs = collection.query.fetch_objects().objects
assert len(objs) == 1
assert objs[0].properties["name"] == "test"
assert service.search_count == 2

with pytest.raises(weaviate.exceptions.WeaviateTenantGetError):
# checks first call correctly handles error that isn't UNAVAILABLE
collection.tenants.get()

# should perform one retry and then succeed subsequently
tenants = list(collection.tenants.get().values())
assert len(tenants) == 1
assert tenants[0].name == "tenant1"
assert service.tenants_count == 2

# Should perform two retry and then succeed subsequently
collection.data.insert_many(objects=[{"Hello": "World"}])

# should perform two retries and then succeed subsequently
deleted = collection.data.delete_many(where=Filter.by_id().equal(objs[0].uuid))
assert deleted.matches == 1


def test_grpc_retry_timeout_logic(
no_retries: tuple[weaviate.collections.Collection, MockRetriesWeaviateService],
) -> None:
collection, _ = no_retries[0], no_retries[1]

# timeout after 1 retry
with pytest.raises(WeaviateQueryError):
collection.query.fetch_objects().objects

# timeout after 1 retry
with pytest.raises(WeaviateBatchError):
collection.data.insert_many(objects=[{"Hello": "World"}])

# timeout after 1 retry
with pytest.raises(WeaviateDeleteManyError):
collection.data.delete_many(where=Filter.by_property("Hello").equal("World"))


def test_grpc_forbidden_exception(forbidden: weaviate.collections.Collection) -> None:
with pytest.raises(weaviate.exceptions.InsufficientPermissionsError):
Expand Down
4 changes: 2 additions & 2 deletions mock_tests/test_timeouts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

import weaviate
from weaviate.exceptions import WeaviateQueryError, WeaviateTimeoutError
from weaviate.exceptions import WeaviateQueryError, WeaviateTimeoutError, WeaviateBatchError


def test_timeout_rest_query(timeouts_collection: weaviate.collections.Collection):
Expand All @@ -21,6 +21,6 @@ def test_timeout_grpc_query(timeouts_collection: weaviate.collections.Collection


def test_timeout_grpc_insert(timeouts_collection: weaviate.collections.Collection):
with pytest.raises(WeaviateQueryError) as recwarn:
with pytest.raises(WeaviateBatchError) as recwarn:
timeouts_collection.data.insert_many([{"what": "ever"}])
assert "DEADLINE_EXCEEDED" in str(recwarn)
1 change: 1 addition & 0 deletions weaviate/client_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
additional_headers=additional_headers,
embedded_db=embedded_db,
connection_config=config.connection,
retry_config=config.retry,
proxies=config.proxies,
trust_env=config.trust_env,
skip_init_checks=skip_init_checks,
Expand Down
21 changes: 21 additions & 0 deletions weaviate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@ class Proxies(BaseModel):
grpc: Optional[str] = Field(default=None)


@dataclass
class RetryConfig:
request_retry_count: int = 20
request_retry_backoff_ms: int = 100
timeout_ms: int = 30000

def __post_init__(self) -> None:
if not isinstance(self.request_retry_count, int):
raise TypeError(
f"request_retry_count must be {int}, received {type(self.request_retry_count)}"
)
if not isinstance(self.request_retry_backoff_ms, int):
raise TypeError(
f"request_retry_backoff_ms must be {int}, received {type(self.request_retry_backoff_ms)}"
)

if not isinstance(self.timeout_ms, int):
raise TypeError(f"timeout_ms must be {int}, received {type(self.timeout_ms)}")


class AdditionalConfig(BaseModel):
"""Use this class to specify the connection and proxy settings for your client when connecting to Weaviate.

Expand All @@ -80,6 +100,7 @@ class AdditionalConfig(BaseModel):
connection: ConnectionConfig = Field(default_factory=ConnectionConfig)
proxies: Union[str, Proxies, None] = Field(default=None)
timeout_: Union[Tuple[int, int], Timeout] = Field(default_factory=Timeout, alias="timeout")
retry: RetryConfig = Field(default_factory=RetryConfig)
trust_env: bool = Field(default=False)

@property
Expand Down
Loading