diff --git a/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_retry.py b/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_retry.py new file mode 100644 index 0000000..aff4d92 --- /dev/null +++ b/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_retry.py @@ -0,0 +1,195 @@ +"""Tests for retry configuration and error handling.""" + +from unittest.mock import patch + +from aieng.agent_evals.knowledge_qa.retry import ( + API_RETRY_INITIAL_WAIT, + API_RETRY_JITTER, + API_RETRY_MAX_ATTEMPTS, + API_RETRY_MAX_WAIT, + MAX_EMPTY_RESPONSE_RETRIES, + is_context_overflow_error, + is_retryable_api_error, +) + + +class FakeClientError(Exception): + """Fake ClientError for testing isinstance checks without API credentials.""" + + +class TestRetryConstants: + """Tests for retry configuration constants.""" + + def test_max_empty_response_retries(self): + """Test MAX_EMPTY_RESPONSE_RETRIES constant value.""" + assert MAX_EMPTY_RESPONSE_RETRIES == 2 + + def test_api_retry_max_attempts(self): + """Test API_RETRY_MAX_ATTEMPTS constant value.""" + assert API_RETRY_MAX_ATTEMPTS == 5 + + def test_api_retry_initial_wait(self): + """Test API_RETRY_INITIAL_WAIT constant value in seconds.""" + assert API_RETRY_INITIAL_WAIT == 1 + + def test_api_retry_max_wait(self): + """Test API_RETRY_MAX_WAIT constant value in seconds.""" + assert API_RETRY_MAX_WAIT == 60 + + def test_api_retry_jitter(self): + """Test API_RETRY_JITTER constant value in seconds.""" + assert API_RETRY_JITTER == 5 + + +class TestIsRetryableApiError: + """Tests for the is_retryable_api_error function.""" + + def test_returns_true_for_429_error(self): + """Test returns True when error message contains '429'.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_retryable_api_error(FakeClientError("Error 429: Too Many Requests")) is True + + def test_returns_true_for_resource_exhausted(self): + """Test returns True when error message contains 'resource_exhausted'.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_retryable_api_error(FakeClientError("RESOURCE_EXHAUSTED: API limit hit")) is True + + def test_returns_true_for_quota_error(self): + """Test returns True when error message contains 'quota'.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_retryable_api_error(FakeClientError("quota limit reached for this project")) is True + + def test_returns_true_for_mixed_case_429(self): + """Test case-insensitive match for rate limit errors containing '429'.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_retryable_api_error(FakeClientError("Rate limit error: status=429")) is True + + def test_returns_true_for_mixed_case_resource_exhausted(self): + """Test case-insensitive match for RESOURCE_EXHAUSTED errors.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_retryable_api_error(FakeClientError("resource_exhausted quota for gemini")) is True + + def test_returns_true_for_mixed_case_quota(self): + """Test case-insensitive match for QUOTA errors.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_retryable_api_error(FakeClientError("QUOTA_EXCEEDED for this project")) is True + + def test_returns_false_for_token_count_exceeds(self): + """Test returns False for context overflow 'token count exceeds'.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_retryable_api_error(FakeClientError("Token count exceeds the context window")) is False + + def test_returns_false_for_invalid_argument_with_token(self): + """Test returns False for invalid_argument errors involving tokens.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_retryable_api_error(FakeClientError("INVALID_ARGUMENT: token limit exceeded")) is False + + def test_returns_false_for_cache_expired(self): + """Test returns False for cache expiration errors.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_retryable_api_error(FakeClientError("Cache has expired for this request")) is False + + def test_returns_false_for_non_client_error_with_rate_limit_text(self): + """Test returns False for non-ClientError even with rate limit keywords.""" + assert is_retryable_api_error(ValueError("rate limit 429")) is False + + def test_returns_false_for_base_exception(self): + """Test returns False for plain BaseException.""" + assert is_retryable_api_error(Exception("quota resource_exhausted")) is False + + def test_returns_false_for_runtime_error(self): + """Test returns False for RuntimeError with rate limit text.""" + assert is_retryable_api_error(RuntimeError("resource_exhausted quota exceeded")) is False + + def test_returns_false_for_other_client_error(self): + """Test returns False for ClientError without any retryable keywords.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_retryable_api_error(FakeClientError("Bad request: unknown field")) is False + + def test_returns_false_for_client_error_with_token_no_rate_limit(self): + """Test returns False for ClientError with 'token' but no rate limit.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_retryable_api_error(FakeClientError("token refresh failed")) is False + + def test_context_overflow_takes_precedence_over_rate_limit(self): + """Test context overflow early-exit occurs before rate limit check.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + # Message matches both context overflow and rate limit patterns + error = FakeClientError("token count exceeds limit, status 429 quota") + assert is_retryable_api_error(error) is False + + def test_cache_expired_takes_precedence_over_rate_limit(self): + """Test cache expiration early-exit occurs before rate limit check.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + # Message matches both cache expiration and rate limit patterns + error = FakeClientError("cache expired and quota resource_exhausted") + assert is_retryable_api_error(error) is False + + def test_invalid_argument_without_token_does_not_block_rate_limit(self): + """Test invalid_argument without 'token' does not suppress rate limit retry.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + # no "token" → not context overflow → falls through to rate limit + error = FakeClientError("INVALID_ARGUMENT: bad request quota 429") + assert is_retryable_api_error(error) is True + + +class TestIsContextOverflowError: + """Tests for the is_context_overflow_error function.""" + + def test_returns_true_for_token_count_exceeds(self): + """Test returns True when error message contains 'token count exceeds'.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_context_overflow_error(FakeClientError("Token count exceeds the context window")) is True + + def test_returns_true_for_invalid_argument_with_token(self): + """Test returns True for invalid_argument errors with 'token' in message.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_context_overflow_error(FakeClientError("INVALID_ARGUMENT: token limit exceeded")) is True + + def test_returns_true_for_mixed_case_token_count_exceeds(self): + """Test case-insensitive match for 'TOKEN COUNT EXCEEDS'.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_context_overflow_error(FakeClientError("TOKEN COUNT EXCEEDS maximum limit")) is True + + def test_returns_true_for_mixed_case_invalid_argument_token(self): + """Test case-insensitive match for INVALID_ARGUMENT + TOKEN.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_context_overflow_error(FakeClientError("INVALID_ARGUMENT: TOKEN limit exceeded")) is True + + def test_returns_false_for_rate_limit_429(self): + """Test returns False for 429 rate limit errors.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_context_overflow_error(FakeClientError("429 Too Many Requests")) is False + + def test_returns_false_for_resource_exhausted(self): + """Test returns False for resource exhausted errors.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_context_overflow_error(FakeClientError("RESOURCE_EXHAUSTED: quota exceeded")) is False + + def test_returns_false_for_cache_expired(self): + """Test returns False for cache expiration errors.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_context_overflow_error(FakeClientError("cache has expired")) is False + + def test_returns_false_for_invalid_argument_without_token(self): + """Test returns False when 'invalid_argument' present but 'token' is absent.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_context_overflow_error(FakeClientError("INVALID_ARGUMENT: bad field value")) is False + + def test_returns_false_for_token_without_invalid_argument_or_token_count_exceeds(self): + """Test returns False when 'token' appears alone without matching patterns.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_context_overflow_error(FakeClientError("token refresh failed")) is False + + def test_returns_false_for_non_client_error_with_overflow_text(self): + """Test returns False for non-ClientError even with overflow keywords.""" + assert is_context_overflow_error(ValueError("token count exceeds limit")) is False + + def test_returns_false_for_base_exception(self): + """Test returns False for plain Exception with context overflow text.""" + assert is_context_overflow_error(Exception("token count exceeds")) is False + + def test_returns_false_for_other_client_error(self): + """Test returns False for ClientError without context overflow indicators.""" + with patch("aieng.agent_evals.knowledge_qa.retry.ClientError", FakeClientError): + assert is_context_overflow_error(FakeClientError("Internal server error occurred")) is False diff --git a/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_token_tracker.py b/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_token_tracker.py new file mode 100644 index 0000000..8dea363 --- /dev/null +++ b/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_token_tracker.py @@ -0,0 +1,378 @@ +"""Tests for token usage tracking.""" + +from unittest.mock import MagicMock, patch + +import pytest +from aieng.agent_evals.knowledge_qa.agent import KnowledgeGroundedAgent +from aieng.agent_evals.knowledge_qa.token_tracker import ( + DEFAULT_MODEL, + KNOWN_MODEL_LIMITS, + TokenTracker, + TokenUsage, +) +from google.adk.events import Event +from google.genai import types + + +def _make_event( + prompt: int = 0, + completion: int = 0, + total: int = 0, + cached: int = 0, +) -> Event: + """Build an ADK Event carrying usage_metadata.""" + return Event( + author="model", + usageMetadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=prompt, + candidates_token_count=completion, + total_token_count=total, + cached_content_token_count=cached, + ), + ) + + +def _make_tracker(model: str = "gemini-2.5-flash", context_limit: int = 1_000_000) -> TokenTracker: + """Build a TokenTracker with a mocked API call so no network I/O occurs.""" + mock_model_info = MagicMock() + mock_model_info.input_token_limit = context_limit + mock_client = MagicMock() + mock_client.models.get.return_value = mock_model_info + + with patch("aieng.agent_evals.knowledge_qa.token_tracker.Client", return_value=mock_client): + return TokenTracker(model=model) + + +# ============================================================================= +# TokenUsage model +# ============================================================================= + + +class TestTokenUsage: + """Tests for the TokenUsage Pydantic model.""" + + def test_defaults(self): + """Test all fields start at zero with a sensible context_limit default.""" + usage = TokenUsage() + assert usage.latest_prompt_tokens == 0 + assert usage.latest_cached_tokens == 0 + assert usage.total_prompt_tokens == 0 + assert usage.total_completion_tokens == 0 + assert usage.total_tokens == 0 + assert usage.context_limit == 1_000_000 + + def test_context_used_percent_proportional(self): + """Test context_used_percent equals prompt / limit * 100.""" + usage = TokenUsage(latest_prompt_tokens=250_000, context_limit=1_000_000) + assert usage.context_used_percent == pytest.approx(25.0) + + def test_context_used_percent_full(self): + """Test context_used_percent is 100 when prompt equals limit.""" + usage = TokenUsage(latest_prompt_tokens=500_000, context_limit=500_000) + assert usage.context_used_percent == pytest.approx(100.0) + + def test_context_used_percent_zero_limit(self): + """Test context_used_percent is 0.0 when context_limit is zero.""" + usage = TokenUsage(latest_prompt_tokens=1000, context_limit=0) + assert usage.context_used_percent == 0.0 + + def test_context_remaining_percent_complements_used(self): + """Test context_remaining_percent sums to 100 with context_used_percent.""" + usage = TokenUsage(latest_prompt_tokens=300_000, context_limit=1_000_000) + assert usage.context_used_percent + usage.context_remaining_percent == pytest.approx(100.0) + + def test_context_remaining_percent_clamps_at_zero(self): + """Test context_remaining_percent never goes negative when over limit.""" + usage = TokenUsage(latest_prompt_tokens=2_000_000, context_limit=1_000_000) + assert usage.context_remaining_percent == 0.0 + + def test_context_remaining_percent_zero_limit(self): + """Test context_remaining_percent is 100 when limit is zero (used = 0%).""" + usage = TokenUsage(context_limit=0) + assert usage.context_remaining_percent == pytest.approx(100.0) + + +# ============================================================================= +# TokenTracker initialisation & model limit fetching +# ============================================================================= + + +class TestTokenTrackerInit: + """Tests for TokenTracker initialisation and model limit resolution.""" + + def test_uses_api_limit_when_available(self): + """Test the context limit is taken from the API when it succeeds.""" + tracker = _make_tracker(model="gemini-2.5-pro", context_limit=2_000_000) + assert tracker.usage.context_limit == 2_000_000 + + def test_falls_back_to_known_limit_on_api_error(self): + """Test falls back to KNOWN_MODEL_LIMITS when the API call raises.""" + with patch("aieng.agent_evals.knowledge_qa.token_tracker.Client", side_effect=Exception("network error")): + tracker = TokenTracker(model="gemini-2.5-flash") + assert tracker.usage.context_limit == KNOWN_MODEL_LIMITS["gemini-2.5-flash"] + + def test_falls_back_to_default_for_unknown_model(self): + """Test uses TokenUsage default limit for a model not in KNOWN_MODEL_LIMITS.""" + with patch("aieng.agent_evals.knowledge_qa.token_tracker.Client", side_effect=Exception("network error")): + tracker = TokenTracker(model="gemini-unknown-model") + assert tracker.usage.context_limit == 1_000_000 # TokenUsage default + + def test_api_client_is_closed_after_successful_fetch(self): + """Test the Google API client is always closed after a successful fetch.""" + mock_model_info = MagicMock() + mock_model_info.input_token_limit = 500_000 + mock_client = MagicMock() + mock_client.models.get.return_value = mock_model_info + + with patch("aieng.agent_evals.knowledge_qa.token_tracker.Client", return_value=mock_client): + TokenTracker(model="gemini-2.5-flash") + + mock_client.close.assert_called_once() + + def test_api_client_is_closed_after_failed_fetch(self): + """Test the Google API client is closed even when models.get raises.""" + mock_client = MagicMock() + mock_client.models.get.side_effect = Exception("timeout") + + with patch("aieng.agent_evals.knowledge_qa.token_tracker.Client", return_value=mock_client): + TokenTracker(model="gemini-2.5-flash") + + mock_client.close.assert_called_once() + + def test_uses_default_model_when_none_given(self): + """Test the model defaults to DEFAULT_MODEL when none is provided.""" + mock_model_info = MagicMock() + mock_model_info.input_token_limit = 1_000_000 + mock_client = MagicMock() + mock_client.models.get.return_value = mock_model_info + + with patch("aieng.agent_evals.knowledge_qa.token_tracker.Client", return_value=mock_client): + tracker = TokenTracker() + + assert tracker._model == DEFAULT_MODEL + + def test_api_none_input_token_limit_falls_back_to_known(self): + """Test falls back to KNOWN_MODEL_LIMITS if API returns None for token limit.""" + mock_model_info = MagicMock() + mock_model_info.input_token_limit = None + mock_client = MagicMock() + mock_client.models.get.return_value = mock_model_info + + with patch("aieng.agent_evals.knowledge_qa.token_tracker.Client", return_value=mock_client): + tracker = TokenTracker(model="gemini-2.5-flash") + + assert tracker.usage.context_limit == KNOWN_MODEL_LIMITS["gemini-2.5-flash"] + + def test_initial_usage_all_zero(self): + """Test all token counts start at zero after initialisation.""" + tracker = _make_tracker() + u = tracker.usage + assert u.latest_prompt_tokens == 0 + assert u.latest_cached_tokens == 0 + assert u.total_prompt_tokens == 0 + assert u.total_completion_tokens == 0 + assert u.total_tokens == 0 + + +# ============================================================================= +# add_from_event +# ============================================================================= + + +class TestAddFromEvent: + """Tests for TokenTracker.add_from_event.""" + + def test_ignores_event_without_usage_metadata(self): + """Test that events without usage_metadata leave all counts at zero.""" + tracker = _make_tracker() + event = Event(author="model") + tracker.add_from_event(event) + assert tracker.usage.total_tokens == 0 + + def test_ignores_event_with_none_usage_metadata(self): + """Test that an event with usage_metadata=None leaves all counts at zero.""" + tracker = _make_tracker() + event = MagicMock(spec=[]) # no usage_metadata attribute at all + tracker.add_from_event(event) + assert tracker.usage.total_tokens == 0 + + def test_single_event_sets_latest_and_totals(self): + """Test a single event correctly populates both latest and cumulative fields.""" + tracker = _make_tracker() + tracker.add_from_event(_make_event(prompt=100, completion=50, total=150, cached=10)) + u = tracker.usage + assert u.latest_prompt_tokens == 100 + assert u.latest_cached_tokens == 10 + assert u.total_prompt_tokens == 100 + assert u.total_completion_tokens == 50 + assert u.total_tokens == 150 + + def test_latest_tokens_reflect_most_recent_event(self): + """Test latest_* fields are overwritten (not accumulated) on each event.""" + tracker = _make_tracker() + tracker.add_from_event(_make_event(prompt=100, completion=40, total=140, cached=5)) + tracker.add_from_event(_make_event(prompt=200, completion=60, total=260, cached=20)) + u = tracker.usage + # latest should reflect only the second event + assert u.latest_prompt_tokens == 200 + assert u.latest_cached_tokens == 20 + + def test_totals_accumulate_across_events(self): + """Test total_* fields accumulate across multiple events.""" + tracker = _make_tracker() + tracker.add_from_event(_make_event(prompt=100, completion=40, total=140)) + tracker.add_from_event(_make_event(prompt=200, completion=60, total=260)) + u = tracker.usage + assert u.total_prompt_tokens == 300 + assert u.total_completion_tokens == 100 + assert u.total_tokens == 400 + + def test_none_field_values_treated_as_zero(self): + """Test that None values in usage_metadata fields are treated as zero.""" + tracker = _make_tracker() + # Build an event whose metadata returns None for every attribute + mock_meta = MagicMock() + mock_meta.prompt_token_count = None + mock_meta.cached_content_token_count = None + mock_meta.candidates_token_count = None + mock_meta.total_token_count = None + + event = MagicMock() + event.usage_metadata = mock_meta + tracker.add_from_event(event) + + u = tracker.usage + assert u.latest_prompt_tokens == 0 + assert u.latest_cached_tokens == 0 + assert u.total_completion_tokens == 0 + assert u.total_tokens == 0 + + def test_context_used_percent_updates_after_event(self): + """Test context_used_percent reflects latest prompt tokens after an event.""" + tracker = _make_tracker(context_limit=1_000_000) + tracker.add_from_event(_make_event(prompt=500_000, total=500_000)) + assert tracker.usage.context_used_percent == pytest.approx(50.0) + + +# ============================================================================= +# reset +# ============================================================================= + + +class TestTokenTrackerReset: + """Tests for TokenTracker.reset.""" + + def test_reset_clears_all_counts(self): + """Test reset brings all token counts back to zero.""" + tracker = _make_tracker(context_limit=1_048_576) + tracker.add_from_event(_make_event(prompt=100, completion=50, total=150, cached=5)) + tracker.reset() + u = tracker.usage + assert u.latest_prompt_tokens == 0 + assert u.latest_cached_tokens == 0 + assert u.total_prompt_tokens == 0 + assert u.total_completion_tokens == 0 + assert u.total_tokens == 0 + + def test_reset_preserves_context_limit(self): + """Test reset keeps the context_limit that was fetched at initialisation.""" + tracker = _make_tracker(context_limit=1_048_576) + tracker.add_from_event(_make_event(prompt=100, total=100)) + tracker.reset() + assert tracker.usage.context_limit == 1_048_576 + + def test_accumulation_continues_correctly_after_reset(self): + """Test token counts accumulate normally after a reset.""" + tracker = _make_tracker() + tracker.add_from_event(_make_event(prompt=100, completion=40, total=140)) + tracker.reset() + tracker.add_from_event(_make_event(prompt=200, completion=60, total=260)) + u = tracker.usage + assert u.total_prompt_tokens == 200 + assert u.total_completion_tokens == 60 + assert u.total_tokens == 260 + + +# ============================================================================= +# Integration test — real Gemini model +# ============================================================================= + + +@pytest.mark.integration_test +class TestTokenTrackerIntegration: + """Integration tests that validate token tracking against the live Gemini API. + + Requires GOOGLE_API_KEY to be set in the environment / .env file. + Run with: cd aieng-eval-agents && + uv run --env-file ../.env pytest -m integration_test tests -v + """ + + def test_fetch_model_limits_from_real_api(self): + """Test that _fetch_model_limits contacts the real API and returns a limit.""" + tracker = TokenTracker(model="gemini-2.5-flash") + assert tracker.usage.context_limit > 0 + + @pytest.mark.asyncio + async def test_agent_populates_token_tracker_after_answer(self): + """Test that running a real agent query results in non-zero token counts. + + This end-to-end test exercises the full path: + Agent.answer_async() + -> Runner emits Event(usageMetadata=...) + -> _process_event() + -> TokenTracker.add_from_event() + -> usage fields updated + """ + agent = KnowledgeGroundedAgent(enable_planning=False, enable_caching=False, enable_compaction=False) + await agent.answer_async("What is the capital of France?") + + usage = agent.token_tracker.usage + # The API must have returned prompt tokens for at least one call + assert usage.total_prompt_tokens > 0, "expected prompt tokens to be tracked" + assert usage.total_completion_tokens > 0, "expected completion tokens to be tracked" + assert usage.total_tokens > 0, "expected total tokens to be tracked" + + # latest_prompt_tokens should equal the last event's prompt count, which is + # non-zero for any model call; it should also be <= total + assert usage.latest_prompt_tokens > 0 + assert usage.latest_prompt_tokens <= usage.total_prompt_tokens + + @pytest.mark.asyncio + async def test_context_used_percent_is_sensible_after_answer(self): + """Test that context_used_percent is a small positive fraction after a query.""" + agent = KnowledgeGroundedAgent(enable_planning=False, enable_caching=False, enable_compaction=False) + await agent.answer_async("What is 2 + 2?") + + usage = agent.token_tracker.usage + # A short query must use some context but nowhere near the full window + assert 0.0 < usage.context_used_percent < 50.0 + assert usage.context_remaining_percent == pytest.approx(100.0 - usage.context_used_percent) + + @pytest.mark.asyncio + async def test_reset_clears_tracking_between_agent_calls(self): + """Test that agent.reset() zeroes the token tracker before a conversation.""" + agent = KnowledgeGroundedAgent(enable_planning=False, enable_caching=False, enable_compaction=False) + await agent.answer_async("What is the capital of France?") + + assert agent.token_tracker.usage.total_tokens > 0 + + agent.reset() + + assert agent.token_tracker.usage.total_prompt_tokens == 0 + assert agent.token_tracker.usage.total_completion_tokens == 0 + assert agent.token_tracker.usage.total_tokens == 0 + # Context limit must survive the reset + assert agent.token_tracker.usage.context_limit > 0 + + @pytest.mark.asyncio + async def test_second_call_accumulates_on_top_of_first(self): + """Test that totals accumulate across two successive answer_async calls.""" + agent = KnowledgeGroundedAgent(enable_planning=False, enable_caching=False, enable_compaction=False) + + await agent.answer_async("What is the capital of France?") + tokens_after_first = agent.token_tracker.usage.total_tokens + + await agent.answer_async("What is the capital of Germany?") + tokens_after_second = agent.token_tracker.usage.total_tokens + + assert tokens_after_second > tokens_after_first, "total_tokens should grow after a second answer_async call" diff --git a/pyproject.toml b/pyproject.toml index c4eab2a..938b101 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "virtualenv>=20.36.1", "tenacity>=9.1.2", "certifi>=2026.1.4", + "pypdf>=6.7.3", ] [dependency-groups] diff --git a/uv.lock b/uv.lock index db2527a..083009f 100644 --- a/uv.lock +++ b/uv.lock @@ -40,6 +40,7 @@ dependencies = [ { name = "pydantic" }, { name = "pydantic-ai-slim", extra = ["logfire"] }, { name = "pydantic-settings" }, + { name = "pypdf" }, { name = "scikit-learn" }, { name = "tenacity" }, { name = "urllib3" }, @@ -104,6 +105,7 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.12.4" }, { name = "pydantic-ai-slim", extras = ["logfire"], specifier = ">=1.26.0" }, { name = "pydantic-settings", specifier = ">=2.7.0" }, + { name = "pypdf", specifier = ">=6.7.3" }, { name = "scikit-learn", specifier = ">=1.7.0" }, { name = "tenacity", specifier = ">=9.1.2" }, { name = "urllib3", specifier = ">=2.6.3" }, @@ -5021,11 +5023,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.7.2" +version = "6.7.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fe/b2/335465d6cff28a772ace8a58beb168f125c2e1d8f7a31527da180f4d89a1/pypdf-6.7.2.tar.gz", hash = "sha256:82a1a48de500ceea59a52a7d979f5095927ef802e4e4fac25ab862a73468acbb", size = 5302986 } +sdist = { url = "https://files.pythonhosted.org/packages/09/dc/f52deef12797ad58b88e4663f097a343f53b9361338aef6573f135ac302f/pypdf-6.7.4.tar.gz", hash = "sha256:9edd1cd47938bb35ec87795f61225fd58a07cfaf0c5699018ae1a47d6f8ab0e3", size = 5304821 } wheels = [ - { url = "https://files.pythonhosted.org/packages/df/df/38b06d6e74646a4281856920a11efb431559bdeb643bf1e192bff5e29082/pypdf-6.7.2-py3-none-any.whl", hash = "sha256:331b63cd66f63138f152a700565b3e0cebdf4ec8bec3b7594b2522418782f1f3", size = 331245 }, + { url = "https://files.pythonhosted.org/packages/c1/be/cded021305f5c81b47265b8c5292b99388615a4391c21ff00fd538d34a56/pypdf-6.7.4-py3-none-any.whl", hash = "sha256:527d6da23274a6c70a9cb59d1986d93946ba8e36a6bc17f3f7cce86331492dda", size = 331496 }, ] [[package]]