From ecd407c8f75c44b8426d8e9e25622f3d651b6429 Mon Sep 17 00:00:00 2001 From: !link Date: Sun, 8 Mar 2026 17:27:10 -0700 Subject: [PATCH 01/11] Add cache token tracking fields to LLMResponse --- evoforge/llm/client.py | 2 ++ tests/test_core/test_llm_client.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/evoforge/llm/client.py b/evoforge/llm/client.py index 9c0f868..2a7e563 100644 --- a/evoforge/llm/client.py +++ b/evoforge/llm/client.py @@ -40,6 +40,8 @@ class LLMResponse: input_tokens: int output_tokens: int model: str + cache_read_tokens: int = 0 + cache_creation_tokens: int = 0 class LLMClient: diff --git a/tests/test_core/test_llm_client.py b/tests/test_core/test_llm_client.py index ef13179..4b73c47 100644 --- a/tests/test_core/test_llm_client.py +++ b/tests/test_core/test_llm_client.py @@ -80,3 +80,26 @@ def test_haiku_pricing(self) -> None: def test_unknown_model_uses_sonnet_default(self) -> None: cost = LLMClient.estimate_cost(1_000_000, 1_000_000, "unknown-model") assert cost == pytest.approx(3.0 + 15.0) + + +class TestLLMResponseCacheFields: + def test_default_cache_fields_are_zero(self) -> None: + from evoforge.llm.client import LLMResponse + + r = LLMResponse(text="hi", input_tokens=10, output_tokens=5, model="test") + assert r.cache_read_tokens == 0 + assert r.cache_creation_tokens == 0 + + def test_cache_fields_can_be_set(self) -> None: + from evoforge.llm.client import LLMResponse + + r = LLMResponse( + text="hi", + input_tokens=10, + output_tokens=5, + model="test", + cache_read_tokens=100, + cache_creation_tokens=50, + ) + assert r.cache_read_tokens == 100 + assert r.cache_creation_tokens == 50 From 7dde4be52b451c573f19ee259a9b181af3f0b181 Mon Sep 17 00:00:00 2001 From: !link Date: Sun, 8 Mar 2026 17:28:22 -0700 Subject: [PATCH 02/11] Update cost estimation to account for cache read/write pricing --- evoforge/llm/client.py | 16 +++++++++++-- tests/test_core/test_llm_client.py | 36 ++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/evoforge/llm/client.py b/evoforge/llm/client.py index 2a7e563..1cd8e8c 100644 --- a/evoforge/llm/client.py +++ b/evoforge/llm/client.py @@ -164,7 +164,19 @@ async def async_generate( raise RuntimeError(msg) from last_exc @staticmethod - def estimate_cost(input_tokens: int, output_tokens: int, model: str) -> float: + def estimate_cost( + input_tokens: int, + output_tokens: int, + model: str, + *, + cache_read_tokens: int = 0, + cache_creation_tokens: int = 0, + ) -> float: """Estimate USD cost for the given token counts and model.""" input_rate, output_rate = _pricing_for_model(model) - return (input_tokens * input_rate + output_tokens * output_rate) / 1_000_000 + return ( + input_tokens * input_rate + + output_tokens * output_rate + + cache_read_tokens * input_rate * 0.1 + + cache_creation_tokens * input_rate * 1.25 + ) / 1_000_000 diff --git a/tests/test_core/test_llm_client.py b/tests/test_core/test_llm_client.py index 4b73c47..9167cfe 100644 --- a/tests/test_core/test_llm_client.py +++ b/tests/test_core/test_llm_client.py @@ -82,6 +82,42 @@ def test_unknown_model_uses_sonnet_default(self) -> None: assert cost == pytest.approx(3.0 + 15.0) +class TestCostEstimationWithCache: + def test_cache_read_tokens_at_10_percent(self) -> None: + cost = LLMClient.estimate_cost( + input_tokens=0, + output_tokens=0, + model="claude-sonnet-4-5-20250929", + cache_read_tokens=1_000_000, + cache_creation_tokens=0, + ) + assert cost == pytest.approx(0.30) + + def test_cache_creation_tokens_at_125_percent(self) -> None: + cost = LLMClient.estimate_cost( + input_tokens=0, + output_tokens=0, + model="claude-sonnet-4-5-20250929", + cache_read_tokens=0, + cache_creation_tokens=1_000_000, + ) + assert cost == pytest.approx(3.75) + + def test_mixed_cache_and_regular_tokens(self) -> None: + cost = LLMClient.estimate_cost( + input_tokens=100_000, + output_tokens=50_000, + model="claude-sonnet-4-5-20250929", + cache_read_tokens=500_000, + cache_creation_tokens=200_000, + ) + assert cost == pytest.approx(0.30 + 0.75 + 0.15 + 0.75) + + def test_existing_cost_estimation_unchanged(self) -> None: + cost = LLMClient.estimate_cost(1_000_000, 1_000_000, "claude-haiku-4-5-20251001") + assert cost == pytest.approx(0.25 + 1.25) + + class TestLLMResponseCacheFields: def test_default_cache_fields_are_zero(self) -> None: from evoforge.llm.client import LLMResponse From de185033bb02757ac6d21f8595b80442920237e7 Mon Sep 17 00:00:00 2001 From: !link Date: Sun, 8 Mar 2026 18:19:47 -0700 Subject: [PATCH 03/11] Add prompt caching support to LLMClient --- evoforge/llm/client.py | 33 +++++++++++- tests/test_core/test_llm_client.py | 84 ++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 2 deletions(-) diff --git a/evoforge/llm/client.py b/evoforge/llm/client.py index 1cd8e8c..0780e66 100644 --- a/evoforge/llm/client.py +++ b/evoforge/llm/client.py @@ -8,8 +8,10 @@ import random import time from dataclasses import dataclass +from typing import Any import anthropic +from anthropic.types import TextBlockParam logger = logging.getLogger(__name__) @@ -53,14 +55,35 @@ def __init__( max_retries: int = 6, base_delay: float = 2.0, max_delay: float = 120.0, + prompt_caching: bool = True, ) -> None: self._api_key = api_key self._max_retries = max_retries self._base_delay = base_delay self._max_delay = max_delay + self._prompt_caching = prompt_caching self._sync_client: anthropic.Anthropic | None = None self._async_client: anthropic.AsyncAnthropic | None = None + def _format_system(self, system: str) -> str | list[TextBlockParam]: + """Format system prompt, optionally adding cache_control for prompt caching.""" + if not self._prompt_caching: + return system + return [ + TextBlockParam( + type="text", + text=system, + cache_control={"type": "ephemeral"}, + ) + ] + + @staticmethod + def _extract_cache_tokens(usage: Any) -> tuple[int, int]: + """Extract cache token counts from API usage, defaulting to 0.""" + cache_read = getattr(usage, "cache_read_input_tokens", None) or 0 + cache_creation = getattr(usage, "cache_creation_input_tokens", None) or 0 + return cache_read, cache_creation + def _compute_delay(self, attempt: int) -> float: """Compute retry delay with exponential backoff, jitter, and cap.""" delay = self._base_delay * (2**attempt) @@ -95,15 +118,18 @@ def generate( model=model, max_tokens=max_tokens, temperature=temperature, - system=system, + system=self._format_system(system), messages=[{"role": "user", "content": prompt}], ) text = response.content[0].text # type: ignore[union-attr] + cache_read, cache_creation = self._extract_cache_tokens(response.usage) return LLMResponse( text=text, input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens, model=model, + cache_read_tokens=cache_read, + cache_creation_tokens=cache_creation, ) except (anthropic.RateLimitError, anthropic.APIError) as exc: last_exc = exc @@ -138,15 +164,18 @@ async def async_generate( model=model, max_tokens=max_tokens, temperature=temperature, - system=system, + system=self._format_system(system), messages=[{"role": "user", "content": prompt}], ) text = response.content[0].text # type: ignore[union-attr] + cache_read, cache_creation = self._extract_cache_tokens(response.usage) return LLMResponse( text=text, input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens, model=model, + cache_read_tokens=cache_read, + cache_creation_tokens=cache_creation, ) except (anthropic.RateLimitError, anthropic.APIError) as exc: last_exc = exc diff --git a/tests/test_core/test_llm_client.py b/tests/test_core/test_llm_client.py index 9167cfe..1a9bdee 100644 --- a/tests/test_core/test_llm_client.py +++ b/tests/test_core/test_llm_client.py @@ -139,3 +139,87 @@ def test_cache_fields_can_be_set(self) -> None: ) assert r.cache_read_tokens == 100 assert r.cache_creation_tokens == 50 + + +class TestPromptCaching: + async def test_async_generate_sends_cache_control_when_enabled(self) -> None: + client = LLMClient(api_key="test", prompt_caching=True) + + mock_response = MagicMock() + mock_response.content = [MagicMock(text="result")] + mock_response.usage.input_tokens = 10 + mock_response.usage.output_tokens = 5 + mock_response.usage.cache_read_input_tokens = 100 + mock_response.usage.cache_creation_input_tokens = 0 + + with patch("anthropic.AsyncAnthropic") as mock_cls: + mock_instance = AsyncMock() + mock_instance.messages.create = AsyncMock(return_value=mock_response) + mock_cls.return_value = mock_instance + + result = await client.async_generate("prompt", "system text", "haiku", 0.7) + + call_kwargs = mock_instance.messages.create.call_args[1] + assert call_kwargs["system"] == [ + { + "type": "text", + "text": "system text", + "cache_control": {"type": "ephemeral"}, + } + ] + assert result.cache_read_tokens == 100 + assert result.cache_creation_tokens == 0 + + async def test_async_generate_no_cache_control_when_disabled(self) -> None: + client = LLMClient(api_key="test", prompt_caching=False) + + mock_response = MagicMock() + mock_response.content = [MagicMock(text="result")] + mock_response.usage.input_tokens = 10 + mock_response.usage.output_tokens = 5 + mock_response.usage.cache_read_input_tokens = None + mock_response.usage.cache_creation_input_tokens = None + + with patch("anthropic.AsyncAnthropic") as mock_cls: + mock_instance = AsyncMock() + mock_instance.messages.create = AsyncMock(return_value=mock_response) + mock_cls.return_value = mock_instance + + result = await client.async_generate("prompt", "system text", "haiku", 0.7) + + call_kwargs = mock_instance.messages.create.call_args[1] + assert call_kwargs["system"] == "system text" + assert result.cache_read_tokens == 0 + assert result.cache_creation_tokens == 0 + + def test_sync_generate_sends_cache_control_when_enabled(self) -> None: + client = LLMClient(api_key="test", prompt_caching=True) + + mock_response = MagicMock() + mock_response.content = [MagicMock(text="result")] + mock_response.usage.input_tokens = 10 + mock_response.usage.output_tokens = 5 + mock_response.usage.cache_read_input_tokens = 50 + mock_response.usage.cache_creation_input_tokens = 200 + + with patch("anthropic.Anthropic") as mock_cls: + mock_instance = MagicMock() + mock_instance.messages.create = MagicMock(return_value=mock_response) + mock_cls.return_value = mock_instance + + result = client.generate("prompt", "system text", "haiku", 0.7) + + call_kwargs = mock_instance.messages.create.call_args[1] + assert call_kwargs["system"] == [ + { + "type": "text", + "text": "system text", + "cache_control": {"type": "ephemeral"}, + } + ] + assert result.cache_read_tokens == 50 + assert result.cache_creation_tokens == 200 + + async def test_default_prompt_caching_is_true(self) -> None: + client = LLMClient(api_key="test") + assert client._prompt_caching is True From 11b8fca0fdfdd8595c50380ee53f8de86a4c8e58 Mon Sep 17 00:00:00 2001 From: !link Date: Sun, 8 Mar 2026 18:21:26 -0700 Subject: [PATCH 04/11] Add prompt_caching toggle to LLMConfig --- evoforge/core/config.py | 1 + tests/test_core/test_llm_client.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/evoforge/core/config.py b/evoforge/core/config.py index b240591..e9837a3 100644 --- a/evoforge/core/config.py +++ b/evoforge/core/config.py @@ -55,6 +55,7 @@ class LLMConfig(BaseModel): max_calls: int = 1000 max_cost_usd: float = 50.0 max_attempts: int = 3 + prompt_caching: bool = True class EvalConfig(BaseModel): diff --git a/tests/test_core/test_llm_client.py b/tests/test_core/test_llm_client.py index 1a9bdee..6b791c5 100644 --- a/tests/test_core/test_llm_client.py +++ b/tests/test_core/test_llm_client.py @@ -223,3 +223,17 @@ def test_sync_generate_sends_cache_control_when_enabled(self) -> None: async def test_default_prompt_caching_is_true(self) -> None: client = LLMClient(api_key="test") assert client._prompt_caching is True + + +class TestLLMConfigCacheFields: + def test_prompt_caching_defaults_true(self) -> None: + from evoforge.core.config import LLMConfig + + cfg = LLMConfig() + assert cfg.prompt_caching is True + + def test_prompt_caching_can_be_disabled(self) -> None: + from evoforge.core.config import LLMConfig + + cfg = LLMConfig(prompt_caching=False) + assert cfg.prompt_caching is False From 014caf146f8518de8207c20a831a9dc267c0f93e Mon Sep 17 00:00:00 2001 From: !link Date: Sun, 8 Mar 2026 18:28:57 -0700 Subject: [PATCH 05/11] Add BatchCollector for batching LLM requests via Message Batch API Introduces BatchCollector async context manager that collects LLM requests during a generation and submits them as a single Anthropic Message Batch. Falls back to individual async_generate calls on batch submission failure. Uses ContextVar for transparent access. --- evoforge/llm/batch.py | 127 +++++++++++++++++++++++ tests/test_core/test_batch.py | 184 ++++++++++++++++++++++++++++++++++ 2 files changed, 311 insertions(+) create mode 100644 evoforge/llm/batch.py create mode 100644 tests/test_core/test_batch.py diff --git a/evoforge/llm/batch.py b/evoforge/llm/batch.py new file mode 100644 index 0000000..b8b1f40 --- /dev/null +++ b/evoforge/llm/batch.py @@ -0,0 +1,127 @@ +# Copyright (c) 2026 evocode contributors. MIT License. See LICENSE. +"""Batch API support for collecting and submitting LLM requests as a Message Batch.""" + +from __future__ import annotations + +import asyncio +import logging +from contextvars import ContextVar +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from evoforge.llm.client import LLMClient, LLMResponse + +logger = logging.getLogger(__name__) + +_active_collector: ContextVar[BatchCollector | None] = ContextVar( + "_active_collector", default=None +) + + +def get_batch_collector() -> BatchCollector | None: + """Return the active BatchCollector, or None if not inside a batch context.""" + return _active_collector.get() + + +class BatchCollector: + """Async context manager that collects LLM requests and submits them as a batch.""" + + def __init__(self, client: LLMClient, poll_interval: float = 2.0) -> None: + self._client = client + self._poll_interval = poll_interval + self._requests: list[tuple[str, str, str, float, int]] = [] + self._futures: list[asyncio.Future[LLMResponse | None]] = [] + self._token: Any = None + + def register( + self, prompt: str, system: str, model: str, temperature: float, max_tokens: int + ) -> asyncio.Future[LLMResponse | None]: + """Register a request and return a future resolved with the result.""" + loop = asyncio.get_running_loop() + future: asyncio.Future[LLMResponse | None] = loop.create_future() + self._requests.append((prompt, system, model, temperature, max_tokens)) + self._futures.append(future) + return future + + async def __aenter__(self) -> BatchCollector: + self._token = _active_collector.set(self) + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + _active_collector.reset(self._token) + if not self._requests: + return + try: + await self._submit_and_resolve() + except Exception: + logger.warning( + "Batch submission failed, falling back to individual calls", + exc_info=True, + ) + await self._fallback_individual() + + async def _submit_and_resolve(self) -> None: + from anthropic.types.message_create_params import MessageCreateParamsNonStreaming + from anthropic.types.messages.batch_create_params import Request + + from evoforge.llm.client import LLMResponse + + async_client = self._client._get_async_client() + + batch_requests = [] + for i, (prompt, system, model, temperature, max_tokens) in enumerate(self._requests): + system_param = self._client._format_system(system) + batch_requests.append( + Request( + custom_id=f"req-{i}", + params=MessageCreateParamsNonStreaming( + model=model, + max_tokens=max_tokens, + temperature=temperature, + system=system_param, + messages=[{"role": "user", "content": prompt}], + ), + ) + ) + + batch = await async_client.messages.batches.create(requests=batch_requests) + batch_id = batch.id + + while batch.processing_status != "ended": + await asyncio.sleep(self._poll_interval) + batch = await async_client.messages.batches.retrieve(batch_id) + + results_by_id: dict[str, LLMResponse | None] = {} + async for result in await async_client.messages.batches.results(batch_id): + if result.result.type == "succeeded": + msg = result.result.message + text = msg.content[0].text # type: ignore[union-attr] + cache_read = getattr(msg.usage, "cache_read_input_tokens", None) or 0 + cache_creation = getattr(msg.usage, "cache_creation_input_tokens", None) or 0 + results_by_id[result.custom_id] = LLMResponse( + text=text, + input_tokens=msg.usage.input_tokens, + output_tokens=msg.usage.output_tokens, + model=msg.model, + cache_read_tokens=cache_read, + cache_creation_tokens=cache_creation, + ) + else: + logger.warning("Batch request %s: %s", result.custom_id, result.result.type) + results_by_id[result.custom_id] = None + + for i, future in enumerate(self._futures): + req_id = f"req-{i}" + future.set_result(results_by_id.get(req_id)) + + async def _fallback_individual(self) -> None: + for i, (prompt, system, model, temperature, max_tokens) in enumerate(self._requests): + try: + result = await self._client.async_generate( + prompt, system, model, temperature, max_tokens + ) + self._futures[i].set_result(result) + except Exception: + logger.warning("Fallback call %d failed", i, exc_info=True) + if not self._futures[i].done(): + self._futures[i].set_result(None) diff --git a/tests/test_core/test_batch.py b/tests/test_core/test_batch.py new file mode 100644 index 0000000..198e716 --- /dev/null +++ b/tests/test_core/test_batch.py @@ -0,0 +1,184 @@ +# Copyright (c) 2026 evocode contributors. MIT License. See LICENSE. +"""Tests for BatchCollector async context manager.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from evoforge.llm.batch import BatchCollector, get_batch_collector +from evoforge.llm.client import LLMClient, LLMResponse + + +def _make_succeeded_result(custom_id: str, text: str, model: str = "test-model") -> MagicMock: + """Build a mock batch result with result.type == 'succeeded'.""" + result = MagicMock() + result.custom_id = custom_id + result.result.type = "succeeded" + msg = result.result.message + msg.content = [MagicMock(text=text)] + msg.usage.input_tokens = 10 + msg.usage.output_tokens = 5 + msg.usage.cache_read_input_tokens = 2 + msg.usage.cache_creation_input_tokens = 1 + msg.model = model + return result + + +def _make_errored_result(custom_id: str) -> MagicMock: + """Build a mock batch result with result.type == 'errored'.""" + result = MagicMock() + result.custom_id = custom_id + result.result.type = "errored" + return result + + +class _AsyncIterList: + """Wrap a list as an async iterable for mocking batch results.""" + + def __init__(self, items: list[Any]) -> None: + self._items = items + + def __aiter__(self) -> _AsyncIterList: + self._index = 0 + return self + + async def __anext__(self) -> Any: + if self._index >= len(self._items): + raise StopAsyncIteration + item = self._items[self._index] + self._index += 1 + return item + + +def _make_batch_mock(processing_status: str = "ended") -> MagicMock: + """Build a mock batch object.""" + batch = MagicMock() + batch.id = "batch_test_123" + batch.processing_status = processing_status + return batch + + +def _make_client_with_mock() -> tuple[LLMClient, MagicMock]: + """Create an LLMClient with a mocked async client.""" + client = LLMClient(api_key="test-key", prompt_caching=False) + mock_async = AsyncMock() + client._async_client = mock_async + return client, mock_async + + +class TestBatchCollectorCollectsRequests: + """Register 2 requests, mock batch create/retrieve/results, verify futures resolve.""" + + @pytest.mark.asyncio + async def test_two_requests_resolve(self) -> None: + client, mock_async = _make_client_with_mock() + + # Batch create returns an already-ended batch + batch = _make_batch_mock("ended") + mock_async.messages.batches.create = AsyncMock(return_value=batch) + + # Results: two succeeded + r0 = _make_succeeded_result("req-0", "response zero") + r1 = _make_succeeded_result("req-1", "response one") + mock_async.messages.batches.results = AsyncMock(return_value=_AsyncIterList([r0, r1])) + + async with BatchCollector(client) as collector: + f0 = collector.register("prompt0", "system", "test-model", 0.7, 1024) + f1 = collector.register("prompt1", "system", "test-model", 0.5, 2048) + + # Futures should be resolved + result0 = f0.result() + result1 = f1.result() + + assert isinstance(result0, LLMResponse) + assert result0.text == "response zero" + assert result0.input_tokens == 10 + assert result0.output_tokens == 5 + assert result0.cache_read_tokens == 2 + assert result0.cache_creation_tokens == 1 + + assert isinstance(result1, LLMResponse) + assert result1.text == "response one" + + # Verify batch create was called with 2 requests + mock_async.messages.batches.create.assert_awaited_once() + call_kwargs = mock_async.messages.batches.create.call_args + assert len(call_kwargs.kwargs["requests"]) == 2 + + +class TestBatchCollectorEmpty: + """Enter/exit context with no registrations -> create never called.""" + + @pytest.mark.asyncio + async def test_no_requests_no_create(self) -> None: + client, mock_async = _make_client_with_mock() + + async with BatchCollector(client): + pass + + mock_async.messages.batches.create.assert_not_awaited() + + +class TestBatchCollectorFallback: + """Mock create to raise -> verify async_generate called as fallback.""" + + @pytest.mark.asyncio + async def test_fallback_on_create_failure(self) -> None: + client, mock_async = _make_client_with_mock() + + # Batch create raises + mock_async.messages.batches.create = AsyncMock(side_effect=RuntimeError("API down")) + + # Mock async_generate for fallback + fallback_response = LLMResponse( + text="fallback", input_tokens=5, output_tokens=3, model="test-model" + ) + with patch.object(client, "async_generate", new=AsyncMock(return_value=fallback_response)): + async with BatchCollector(client) as collector: + f0 = collector.register("prompt0", "system", "test-model", 0.7, 1024) + f1 = collector.register("prompt1", "system", "test-model", 0.5, 2048) + + assert f0.result() == fallback_response + assert f1.result() == fallback_response + assert client.async_generate.call_count == 2 # type: ignore[union-attr] + + +class TestBatchCollectorPerRequestError: + """One request errored -> that future resolves to None, other succeeds.""" + + @pytest.mark.asyncio + async def test_errored_request_resolves_none(self) -> None: + client, mock_async = _make_client_with_mock() + + batch = _make_batch_mock("ended") + mock_async.messages.batches.create = AsyncMock(return_value=batch) + + r0 = _make_succeeded_result("req-0", "good response") + r1 = _make_errored_result("req-1") + mock_async.messages.batches.results = AsyncMock(return_value=_AsyncIterList([r0, r1])) + + async with BatchCollector(client) as collector: + f0 = collector.register("prompt0", "system", "test-model", 0.7, 1024) + f1 = collector.register("prompt1", "system", "test-model", 0.5, 2048) + + assert isinstance(f0.result(), LLMResponse) + assert f0.result().text == "good response" + assert f1.result() is None + + +class TestGetBatchCollector: + """Returns None outside context, collector inside, None after exit.""" + + @pytest.mark.asyncio + async def test_context_var_lifecycle(self) -> None: + client, mock_async = _make_client_with_mock() + + assert get_batch_collector() is None + + async with BatchCollector(client) as collector: + assert get_batch_collector() is collector + + assert get_batch_collector() is None From a5739ecbf23a27ef79cbe52bfbc76754fe91f440 Mon Sep 17 00:00:00 2001 From: !link Date: Sun, 8 Mar 2026 18:32:48 -0700 Subject: [PATCH 06/11] Wire LLM operators to use BatchCollector when active --- evoforge/llm/operators.py | 51 ++++++++--- tests/test_core/test_llm_operators.py | 126 +++++++++++++++++++++++++- 2 files changed, 162 insertions(+), 15 deletions(-) diff --git a/evoforge/llm/operators.py b/evoforge/llm/operators.py index 0b02958..0cef63d 100644 --- a/evoforge/llm/operators.py +++ b/evoforge/llm/operators.py @@ -13,6 +13,7 @@ from evoforge.core.mutation import MutationContext, MutationOperator from evoforge.core.types import Individual +from evoforge.llm.batch import get_batch_collector logger = logging.getLogger(__name__) @@ -43,13 +44,24 @@ async def apply(self, parent: Individual, context: MutationContext) -> str: prompt = context.backend.format_mutation_prompt(parent, context) system = context.backend.system_prompt() - response = await self._client.async_generate( - prompt, - system, - self._model, - context.temperature, - self._max_tokens, - ) + collector = get_batch_collector() + if collector is not None: + future = collector.register( + prompt, system, self._model, context.temperature, self._max_tokens + ) + response = await future + else: + response = await self._client.async_generate( + prompt, + system, + self._model, + context.temperature, + self._max_tokens, + ) + + if response is None: + logger.warning("LLMMutate: batch request failed, falling back to parent") + return parent.genome genome: str | None = context.backend.extract_genome(response.text) if genome is not None: @@ -91,13 +103,24 @@ async def apply(self, parent: Individual, context: MutationContext) -> str: system = context.backend.system_prompt() - response = await self._client.async_generate( - prompt, - system, - self._model, - context.temperature, - self._max_tokens, - ) + collector = get_batch_collector() + if collector is not None: + future = collector.register( + prompt, system, self._model, context.temperature, self._max_tokens + ) + response = await future + else: + response = await self._client.async_generate( + prompt, + system, + self._model, + context.temperature, + self._max_tokens, + ) + + if response is None: + logger.warning("LLMCrossover: batch request failed, falling back to parent") + return parent.genome genome: str | None = context.backend.extract_genome(response.text) if genome is not None: diff --git a/tests/test_core/test_llm_operators.py b/tests/test_core/test_llm_operators.py index f5c0286..6bc81d4 100644 --- a/tests/test_core/test_llm_operators.py +++ b/tests/test_core/test_llm_operators.py @@ -4,12 +4,13 @@ from __future__ import annotations from typing import Any -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest from evoforge.core.mutation import MutationContext from evoforge.core.types import Credit +from evoforge.llm.client import LLMResponse from evoforge.llm.operators import LLMCrossover, LLMMutate from tests.conftest import FakeLLMResponse, make_individual @@ -145,3 +146,126 @@ async def test_falls_back_to_parent_genome(self) -> None: result = await op.apply(parent, ctx) assert result == "original_genome" # falls back to parent + + +# ------------------------------------------------------------------ # +# LLMMutate batch context tests +# ------------------------------------------------------------------ # + + +class TestLLMMutateWithBatchContext: + @pytest.mark.asyncio + async def test_registers_with_batch_when_active(self) -> None: + import asyncio + + client = MagicMock() + client.async_generate = AsyncMock() + + mock_collector = MagicMock() + future: asyncio.Future[LLMResponse | None] = asyncio.get_event_loop().create_future() + future.set_result(FakeLLMResponse(text="```lean\nbatched_genome\n```")) + mock_collector.register = MagicMock(return_value=future) + + backend = MagicMock() + backend.format_mutation_prompt.return_value = "mutation prompt" + backend.system_prompt.return_value = "system" + backend.extract_genome.return_value = "batched_genome" + ctx = _make_context(backend=backend) + parent = make_individual("original") + + op = LLMMutate(client=client, model="test-model") + + with patch("evoforge.llm.operators.get_batch_collector", return_value=mock_collector): + result = await op.apply(parent, ctx) + + mock_collector.register.assert_called_once_with( + "mutation prompt", + "system", + "test-model", + 0.7, + 4096, + ) + client.async_generate.assert_not_called() + assert result == "batched_genome" + + @pytest.mark.asyncio + async def test_direct_call_when_no_batch(self) -> None: + client = MagicMock() + client.async_generate = AsyncMock( + return_value=FakeLLMResponse(text="```lean\ndirect_genome\n```") + ) + backend = MagicMock() + backend.format_mutation_prompt.return_value = "mutation prompt" + backend.system_prompt.return_value = "system" + backend.extract_genome.return_value = "direct_genome" + ctx = _make_context(backend=backend) + parent = make_individual("original") + + op = LLMMutate(client=client, model="test-model") + + with patch("evoforge.llm.operators.get_batch_collector", return_value=None): + result = await op.apply(parent, ctx) + + client.async_generate.assert_called_once() + assert result == "direct_genome" + + @pytest.mark.asyncio + async def test_batch_returns_none_falls_back_to_parent(self) -> None: + """When batch request fails (None result), fall back to parent genome.""" + import asyncio + + client = MagicMock() + mock_collector = MagicMock() + future: asyncio.Future[LLMResponse | None] = asyncio.get_event_loop().create_future() + future.set_result(None) + mock_collector.register = MagicMock(return_value=future) + + backend = MagicMock() + backend.format_mutation_prompt.return_value = "mutation prompt" + backend.system_prompt.return_value = "system" + ctx = _make_context(backend=backend) + parent = make_individual("original_genome") + + op = LLMMutate(client=client, model="test-model") + + with patch("evoforge.llm.operators.get_batch_collector", return_value=mock_collector): + result = await op.apply(parent, ctx) + + assert result == "original_genome" + + +# ------------------------------------------------------------------ # +# LLMCrossover batch context tests +# ------------------------------------------------------------------ # + + +class TestLLMCrossoverWithBatchContext: + @pytest.mark.asyncio + async def test_registers_with_batch_when_active(self) -> None: + import asyncio + + client = MagicMock() + client.async_generate = AsyncMock() + + mock_collector = MagicMock() + future: asyncio.Future[LLMResponse | None] = asyncio.get_event_loop().create_future() + future.set_result(FakeLLMResponse(text="batched")) + mock_collector.register = MagicMock(return_value=future) + + backend = MagicMock() + backend.format_crossover_prompt.return_value = "crossover prompt" + backend.system_prompt.return_value = "system" + backend.extract_genome.return_value = "batched_genome" + + parent_a = make_individual("genome_a") + parent_b = make_individual("genome_b") + ctx = _make_context(backend=backend, guidance_individual=parent_b) + + op = LLMCrossover(client=client, model="test-model") + + with patch("evoforge.llm.operators.get_batch_collector", return_value=mock_collector): + result = await op.apply(parent_a, ctx) + + mock_collector.register.assert_called_once() + client.async_generate.assert_not_called() + assert result == "batched_genome" From 98143ac6076962e4cdd8b0ba8149fc59b3a04f4a Mon Sep 17 00:00:00 2001 From: !link Date: Sun, 8 Mar 2026 18:36:15 -0700 Subject: [PATCH 07/11] Wire batch context into engine mutation phase with config toggle --- evoforge/core/config.py | 2 ++ evoforge/core/engine.py | 10 +++++++++- tests/test_core/test_batch.py | 18 ++++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/evoforge/core/config.py b/evoforge/core/config.py index e9837a3..46503bb 100644 --- a/evoforge/core/config.py +++ b/evoforge/core/config.py @@ -56,6 +56,8 @@ class LLMConfig(BaseModel): max_cost_usd: float = 50.0 max_attempts: int = 3 prompt_caching: bool = True + batch_enabled: bool = False + batch_poll_interval: float = 2.0 class EvalConfig(BaseModel): diff --git a/evoforge/core/engine.py b/evoforge/core/engine.py index 5fa9a40..3094980 100644 --- a/evoforge/core/engine.py +++ b/evoforge/core/engine.py @@ -41,6 +41,7 @@ SelectionStrategy, ) from evoforge.core.types import Fitness, Individual +from evoforge.llm.batch import BatchCollector logger = logging.getLogger(__name__) @@ -312,7 +313,14 @@ async def run(self) -> ExperimentResult: tasks.append( asyncio.create_task(self._mutate_one(parent, operator, context)) ) - results = await asyncio.gather(*tasks) + if self.config.llm.batch_enabled: + async with BatchCollector( + self.llm_client, + poll_interval=self.config.llm.batch_poll_interval, + ): + results = await asyncio.gather(*tasks) + else: + results = await asyncio.gather(*tasks) for r in results: if r is not None: offspring_genomes.append(r) diff --git a/tests/test_core/test_batch.py b/tests/test_core/test_batch.py index 198e716..e1b63c1 100644 --- a/tests/test_core/test_batch.py +++ b/tests/test_core/test_batch.py @@ -182,3 +182,21 @@ async def test_context_var_lifecycle(self) -> None: assert get_batch_collector() is collector assert get_batch_collector() is None + + +class TestBatchConfigFields: + """Verify LLMConfig batch fields have correct defaults and are settable.""" + + def test_batch_defaults(self) -> None: + from evoforge.core.config import LLMConfig + + cfg = LLMConfig() + assert cfg.batch_enabled is False + assert cfg.batch_poll_interval == 2.0 + + def test_batch_can_be_enabled(self) -> None: + from evoforge.core.config import LLMConfig + + cfg = LLMConfig(batch_enabled=True, batch_poll_interval=5.0) + assert cfg.batch_enabled is True + assert cfg.batch_poll_interval == 5.0 From 3868c58cf430aef96891c16cf6d71ce4012beb1f Mon Sep 17 00:00:00 2001 From: !link Date: Sun, 8 Mar 2026 18:37:58 -0700 Subject: [PATCH 08/11] Make tactic generator batch-aware via BatchCollector --- evoforge/backends/lean/tactic_generator.py | 30 +++++++--- tests/test_core/test_batch.py | 70 ++++++++++++++++++++++ 2 files changed, 93 insertions(+), 7 deletions(-) diff --git a/evoforge/backends/lean/tactic_generator.py b/evoforge/backends/lean/tactic_generator.py index 3a495ae..a282aaf 100644 --- a/evoforge/backends/lean/tactic_generator.py +++ b/evoforge/backends/lean/tactic_generator.py @@ -10,6 +10,8 @@ import jinja2 +from evoforge.llm.batch import get_batch_collector + logger = logging.getLogger(__name__) _TEMPLATES_DIR = Path(__file__).parent / "templates" @@ -51,13 +53,27 @@ async def suggest_tactics(self, goal_state: str, proof_so_far: list[str], n: int ) logger.debug("Requesting %d tactics for goal: %s", n, goal_state[:80]) try: - response = await self._client.async_generate( - prompt, - self._system_prompt, - self._model, - self._temperature, - self._max_tokens, - ) + collector = get_batch_collector() + if collector is not None: + future = collector.register( + prompt, + self._system_prompt, + self._model, + self._temperature, + self._max_tokens, + ) + response = await future + if response is None: + logger.warning("Tactic generation batch request failed") + return [] + else: + response = await self._client.async_generate( + prompt, + self._system_prompt, + self._model, + self._temperature, + self._max_tokens, + ) except (RuntimeError, TimeoutError): logger.warning("Tactic generation LLM call failed", exc_info=True) return [] diff --git a/tests/test_core/test_batch.py b/tests/test_core/test_batch.py index e1b63c1..40c030f 100644 --- a/tests/test_core/test_batch.py +++ b/tests/test_core/test_batch.py @@ -200,3 +200,73 @@ def test_batch_can_be_enabled(self) -> None: cfg = LLMConfig(batch_enabled=True, batch_poll_interval=5.0) assert cfg.batch_enabled is True assert cfg.batch_poll_interval == 5.0 + + +class TestTacticGeneratorBatchAware: + @pytest.mark.asyncio + async def test_registers_with_batch_when_active(self) -> None: + import asyncio + + from evoforge.backends.lean.tactic_generator import LLMTacticGenerator + from tests.conftest import FakeLLMResponse + + client = MagicMock() + client.async_generate = AsyncMock() + + mock_collector = MagicMock() + future: asyncio.Future[Any] = asyncio.get_event_loop().create_future() + future.set_result(FakeLLMResponse(text="1. simp\n2. ring\n3. linarith")) + mock_collector.register = MagicMock(return_value=future) + + gen = LLMTacticGenerator(client, "test-model", "system prompt") + + with patch( + "evoforge.backends.lean.tactic_generator.get_batch_collector", + return_value=mock_collector, + ): + tactics = await gen.suggest_tactics("⊢ x = x", [], 3) + + mock_collector.register.assert_called_once() + client.async_generate.assert_not_called() + assert len(tactics) > 0 + + @pytest.mark.asyncio + async def test_direct_call_when_no_batch(self) -> None: + from evoforge.backends.lean.tactic_generator import LLMTacticGenerator + from tests.conftest import FakeLLMResponse + + client = MagicMock() + client.async_generate = AsyncMock(return_value=FakeLLMResponse(text="1. simp\n2. ring")) + + gen = LLMTacticGenerator(client, "test-model", "system prompt") + + with patch( + "evoforge.backends.lean.tactic_generator.get_batch_collector", + return_value=None, + ): + tactics = await gen.suggest_tactics("⊢ x = x", [], 3) + + client.async_generate.assert_called_once() + assert len(tactics) > 0 + + @pytest.mark.asyncio + async def test_batch_returns_none_returns_empty(self) -> None: + import asyncio + + from evoforge.backends.lean.tactic_generator import LLMTacticGenerator + + client = MagicMock() + mock_collector = MagicMock() + future: asyncio.Future[Any] = asyncio.get_event_loop().create_future() + future.set_result(None) + mock_collector.register = MagicMock(return_value=future) + + gen = LLMTacticGenerator(client, "test-model", "system prompt") + + with patch( + "evoforge.backends.lean.tactic_generator.get_batch_collector", + return_value=mock_collector, + ): + tactics = await gen.suggest_tactics("⊢ x = x", [], 3) + + assert tactics == [] From 23e7f8c4b6522ade42ed63ad152764923c94b2bd Mon Sep 17 00:00:00 2001 From: !link Date: Sun, 8 Mar 2026 18:52:40 -0700 Subject: [PATCH 09/11] Update README with prompt caching and batch API documentation --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 1a9d19d..2d5e3c4 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,8 @@ graph TD - **Selection** (`evoforge/core/selection.py`) — four strategies. Lexicase is the default and tends to maintain more diversity than tournament. - **Search memory** (`evoforge/core/memory.py`) — tracks patterns that led to fitness gains and dead ends to avoid. Fed into LLM prompts so the model learns from the population's history. - **Tree search** (`evoforge/backends/lean/tree_search.py`) — best-first search over REPL proof states. Used as a refinement step on promising partial proofs found by evolution. -- **LLM client** (`evoforge/llm/client.py`) — Anthropic API wrapper with exponential backoff, budget tracking, and graceful degradation (if calls fail, cheap operators fill in). +- **LLM client** (`evoforge/llm/client.py`) — Anthropic API wrapper with exponential backoff, prompt caching (90% input cost reduction on repeated system prompts), budget tracking, and graceful degradation (if calls fail, cheap operators fill in). +- **Batch collector** (`evoforge/llm/batch.py`) — optional Message Batch API integration that collects per-generation LLM requests into a single batch for 50% cost savings (stacks with prompt caching for up to 95% savings on cached input tokens). Falls back to individual calls on failure. ### Proof verification @@ -89,8 +90,8 @@ evoforge/ cfd/ — CFD turbulence closure optimization: SymPy IR, solver adapter, ablation credit, expression mutation operators llm/ — Anthropic client, LLM mutation/crossover operators, - Jinja2 prompt templates -tests/ — 652 tests, strict mypy, ruff + Jinja2 prompt templates, batch API collector +tests/ — 681 tests, strict mypy, ruff configs/ — TOML experiment configs scripts/ — CLI entry point (run.py) ``` @@ -144,7 +145,7 @@ Experiments are configured via TOML files. See `configs/lean_default.toml` for a | `[population]` | Size, elite count | | `[selection]` | Strategy (lexicase, tournament, pareto, map_elites), parameters | | `[mutation]` | LLM vs cheap operator weights, crossover weight | -| `[llm]` | Model, temperature schedule, token/cost budgets | +| `[llm]` | Model, temperature schedule, token/cost budgets, prompt caching, batch API | | `[evolution]` | Max generations, stagnation window, tree search settings, checkpointing | | `[backend]` | Theorem statement, project dir, imports, seed proofs | | `[ablation]` | Flags to disable individual components for experiments | @@ -167,6 +168,5 @@ Research software. The core evolutionary engine, LLM integration, and Lean 4 bac Known limitations: - Two backends (Lean 4 on hold, CFD active) -- LLM mutations are expensive and the search space is vast - Tree search helps but is limited by the quality of tactic suggestions - `greenlet` pinned to 3.1.0 due to a macOS compiler crash on newer versions From a8cc90d96085b2d764afc49b1a33a367ccfe7b97 Mon Sep 17 00:00:00 2001 From: !link Date: Sun, 8 Mar 2026 19:13:07 -0700 Subject: [PATCH 10/11] Fix batch race condition and simplify LLM call dispatch - Move task creation inside batch context to fix ContextVar propagation - Extract batch_aware_generate() to eliminate 3-site dispatch duplication - Add polling timeout (max_wait) to prevent infinite batch hang - Promote client internals to public API for batch module access - Use concurrent fallback via asyncio.gather instead of sequential - Update all tests to patch batch_aware_generate consistently --- evoforge/backends/lean/tactic_generator.py | 34 +++------ evoforge/core/engine.py | 71 ++++++++--------- evoforge/llm/batch.py | 82 ++++++++++++++------ evoforge/llm/client.py | 20 ++--- evoforge/llm/operators.py | 36 ++------- tests/test_core/test_batch.py | 52 ++----------- tests/test_core/test_llm_operators.py | 88 ++++++---------------- 7 files changed, 156 insertions(+), 227 deletions(-) diff --git a/evoforge/backends/lean/tactic_generator.py b/evoforge/backends/lean/tactic_generator.py index a282aaf..446910a 100644 --- a/evoforge/backends/lean/tactic_generator.py +++ b/evoforge/backends/lean/tactic_generator.py @@ -10,7 +10,7 @@ import jinja2 -from evoforge.llm.batch import get_batch_collector +from evoforge.llm.batch import batch_aware_generate logger = logging.getLogger(__name__) @@ -53,27 +53,17 @@ async def suggest_tactics(self, goal_state: str, proof_so_far: list[str], n: int ) logger.debug("Requesting %d tactics for goal: %s", n, goal_state[:80]) try: - collector = get_batch_collector() - if collector is not None: - future = collector.register( - prompt, - self._system_prompt, - self._model, - self._temperature, - self._max_tokens, - ) - response = await future - if response is None: - logger.warning("Tactic generation batch request failed") - return [] - else: - response = await self._client.async_generate( - prompt, - self._system_prompt, - self._model, - self._temperature, - self._max_tokens, - ) + response = await batch_aware_generate( + self._client, + prompt, + self._system_prompt, + self._model, + self._temperature, + self._max_tokens, + ) + if response is None: + logger.warning("Tactic generation batch request failed") + return [] except (RuntimeError, TimeoutError): logger.warning("Tactic generation LLM call failed", exc_info=True) return [] diff --git a/evoforge/core/engine.py b/evoforge/core/engine.py index 3094980..cdd2008 100644 --- a/evoforge/core/engine.py +++ b/evoforge/core/engine.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio +import contextlib import logging import random from dataclasses import dataclass @@ -284,42 +285,44 @@ async def run(self) -> ExperimentResult: # Mutate (concurrently) offspring_genomes: list[tuple[str, str, str]] = [] guidance = self._memory.prompt_section(max_tokens=200) - tasks: list[asyncio.Task[tuple[str, str, str] | None]] = [] - for parent in parents: - operator = self._ensemble.select_operator() - - # Skip LLM operators when per-gen budget exhausted - if operator.cost == "llm" and not self._scheduler.can_use_llm(): - operator = self._ensemble.cheapest_operator() - - # Pick a second parent for crossover operators - guidance_ind = None - if "crossover" in operator.name: - other_parents = [p for p in parents if p.ir_hash != parent.ir_hash] - if other_parents: - guidance_ind = random.choice(other_parents) - else: - guidance_ind = random.choice(parents) - - context = MutationContext( - generation=gen, - memory=self._memory, - guidance=guidance, - temperature=self._temperature, - backend=self.backend, - credits=parent.credits, - guidance_individual=guidance_ind, - ) - tasks.append( - asyncio.create_task(self._mutate_one(parent, operator, context)) - ) - if self.config.llm.batch_enabled: - async with BatchCollector( + batch_cm: Any = ( + BatchCollector( self.llm_client, poll_interval=self.config.llm.batch_poll_interval, - ): - results = await asyncio.gather(*tasks) - else: + ) + if self.config.llm.batch_enabled + else contextlib.nullcontext() + ) + async with batch_cm: + tasks: list[asyncio.Task[tuple[str, str, str] | None]] = [] + for parent in parents: + operator = self._ensemble.select_operator() + + # Skip LLM operators when per-gen budget exhausted + if operator.cost == "llm" and not self._scheduler.can_use_llm(): + operator = self._ensemble.cheapest_operator() + + # Pick a second parent for crossover operators + guidance_ind = None + if "crossover" in operator.name: + other_parents = [p for p in parents if p.ir_hash != parent.ir_hash] + if other_parents: + guidance_ind = random.choice(other_parents) + else: + guidance_ind = random.choice(parents) + + context = MutationContext( + generation=gen, + memory=self._memory, + guidance=guidance, + temperature=self._temperature, + backend=self.backend, + credits=parent.credits, + guidance_individual=guidance_ind, + ) + tasks.append( + asyncio.create_task(self._mutate_one(parent, operator, context)) + ) results = await asyncio.gather(*tasks) for r in results: if r is not None: diff --git a/evoforge/llm/batch.py b/evoforge/llm/batch.py index b8b1f40..a474739 100644 --- a/evoforge/llm/batch.py +++ b/evoforge/llm/batch.py @@ -6,7 +6,7 @@ import asyncio import logging from contextvars import ContextVar -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NamedTuple if TYPE_CHECKING: from evoforge.llm.client import LLMClient, LLMResponse @@ -17,19 +17,53 @@ "_active_collector", default=None ) +_DEFAULT_MAX_WAIT: float = 1800.0 # 30 minutes + def get_batch_collector() -> BatchCollector | None: """Return the active BatchCollector, or None if not inside a batch context.""" return _active_collector.get() +async def batch_aware_generate( + client: LLMClient, + prompt: str, + system: str, + model: str, + temperature: float, + max_tokens: int, +) -> LLMResponse | None: + """Generate via batch collector if active, otherwise via direct async call. + + Returns None if the batch request failed; raises on direct-call failure. + """ + collector = get_batch_collector() + if collector is not None: + return await collector.register(prompt, system, model, temperature, max_tokens) + return await client.async_generate(prompt, system, model, temperature, max_tokens) + + +class _BatchRequest(NamedTuple): + prompt: str + system: str + model: str + temperature: float + max_tokens: int + + class BatchCollector: """Async context manager that collects LLM requests and submits them as a batch.""" - def __init__(self, client: LLMClient, poll_interval: float = 2.0) -> None: + def __init__( + self, + client: LLMClient, + poll_interval: float = 2.0, + max_wait: float = _DEFAULT_MAX_WAIT, + ) -> None: self._client = client self._poll_interval = poll_interval - self._requests: list[tuple[str, str, str, float, int]] = [] + self._max_wait = max_wait + self._requests: list[_BatchRequest] = [] self._futures: list[asyncio.Future[LLMResponse | None]] = [] self._token: Any = None @@ -39,7 +73,7 @@ def register( """Register a request and return a future resolved with the result.""" loop = asyncio.get_running_loop() future: asyncio.Future[LLMResponse | None] = loop.create_future() - self._requests.append((prompt, system, model, temperature, max_tokens)) + self._requests.append(_BatchRequest(prompt, system, model, temperature, max_tokens)) self._futures.append(future) return future @@ -64,22 +98,22 @@ async def _submit_and_resolve(self) -> None: from anthropic.types.message_create_params import MessageCreateParamsNonStreaming from anthropic.types.messages.batch_create_params import Request - from evoforge.llm.client import LLMResponse + from evoforge.llm.client import LLMClient, LLMResponse - async_client = self._client._get_async_client() + async_client = self._client.get_async_client() batch_requests = [] - for i, (prompt, system, model, temperature, max_tokens) in enumerate(self._requests): - system_param = self._client._format_system(system) + for i, req in enumerate(self._requests): + system_param = self._client.format_system(req.system) batch_requests.append( Request( custom_id=f"req-{i}", params=MessageCreateParamsNonStreaming( - model=model, - max_tokens=max_tokens, - temperature=temperature, + model=req.model, + max_tokens=req.max_tokens, + temperature=req.temperature, system=system_param, - messages=[{"role": "user", "content": prompt}], + messages=[{"role": "user", "content": req.prompt}], ), ) ) @@ -87,22 +121,26 @@ async def _submit_and_resolve(self) -> None: batch = await async_client.messages.batches.create(requests=batch_requests) batch_id = batch.id + elapsed = 0.0 while batch.processing_status != "ended": + if elapsed >= self._max_wait: + msg = f"Batch {batch_id} not ended after {elapsed:.0f}s" + raise TimeoutError(msg) await asyncio.sleep(self._poll_interval) + elapsed += self._poll_interval batch = await async_client.messages.batches.retrieve(batch_id) results_by_id: dict[str, LLMResponse | None] = {} async for result in await async_client.messages.batches.results(batch_id): if result.result.type == "succeeded": - msg = result.result.message - text = msg.content[0].text # type: ignore[union-attr] - cache_read = getattr(msg.usage, "cache_read_input_tokens", None) or 0 - cache_creation = getattr(msg.usage, "cache_creation_input_tokens", None) or 0 + msg_obj = result.result.message + text = msg_obj.content[0].text # type: ignore[union-attr] + cache_read, cache_creation = LLMClient.extract_cache_tokens(msg_obj.usage) results_by_id[result.custom_id] = LLMResponse( text=text, - input_tokens=msg.usage.input_tokens, - output_tokens=msg.usage.output_tokens, - model=msg.model, + input_tokens=msg_obj.usage.input_tokens, + output_tokens=msg_obj.usage.output_tokens, + model=msg_obj.model, cache_read_tokens=cache_read, cache_creation_tokens=cache_creation, ) @@ -115,13 +153,15 @@ async def _submit_and_resolve(self) -> None: future.set_result(results_by_id.get(req_id)) async def _fallback_individual(self) -> None: - for i, (prompt, system, model, temperature, max_tokens) in enumerate(self._requests): + async def _do_one(i: int, req: _BatchRequest) -> None: try: result = await self._client.async_generate( - prompt, system, model, temperature, max_tokens + req.prompt, req.system, req.model, req.temperature, req.max_tokens ) self._futures[i].set_result(result) except Exception: logger.warning("Fallback call %d failed", i, exc_info=True) if not self._futures[i].done(): self._futures[i].set_result(None) + + await asyncio.gather(*(_do_one(i, req) for i, req in enumerate(self._requests))) diff --git a/evoforge/llm/client.py b/evoforge/llm/client.py index 0780e66..c582a0c 100644 --- a/evoforge/llm/client.py +++ b/evoforge/llm/client.py @@ -65,7 +65,7 @@ def __init__( self._sync_client: anthropic.Anthropic | None = None self._async_client: anthropic.AsyncAnthropic | None = None - def _format_system(self, system: str) -> str | list[TextBlockParam]: + def format_system(self, system: str) -> str | list[TextBlockParam]: """Format system prompt, optionally adding cache_control for prompt caching.""" if not self._prompt_caching: return system @@ -78,7 +78,7 @@ def _format_system(self, system: str) -> str | list[TextBlockParam]: ] @staticmethod - def _extract_cache_tokens(usage: Any) -> tuple[int, int]: + def extract_cache_tokens(usage: Any) -> tuple[int, int]: """Extract cache token counts from API usage, defaulting to 0.""" cache_read = getattr(usage, "cache_read_input_tokens", None) or 0 cache_creation = getattr(usage, "cache_creation_input_tokens", None) or 0 @@ -90,12 +90,12 @@ def _compute_delay(self, attempt: int) -> float: jitter = random.uniform(0, self._base_delay) return float(min(delay + jitter, self._max_delay)) - def _get_sync_client(self) -> anthropic.Anthropic: + def get_sync_client(self) -> anthropic.Anthropic: if self._sync_client is None: self._sync_client = anthropic.Anthropic(api_key=self._api_key) return self._sync_client - def _get_async_client(self) -> anthropic.AsyncAnthropic: + def get_async_client(self) -> anthropic.AsyncAnthropic: if self._async_client is None: self._async_client = anthropic.AsyncAnthropic(api_key=self._api_key) return self._async_client @@ -109,7 +109,7 @@ def generate( max_tokens: int = 4096, ) -> LLMResponse: """Call the Anthropic API synchronously with exponential-backoff retry.""" - client = self._get_sync_client() + client = self.get_sync_client() last_exc: Exception | None = None for attempt in range(self._max_retries): @@ -118,11 +118,11 @@ def generate( model=model, max_tokens=max_tokens, temperature=temperature, - system=self._format_system(system), + system=self.format_system(system), messages=[{"role": "user", "content": prompt}], ) text = response.content[0].text # type: ignore[union-attr] - cache_read, cache_creation = self._extract_cache_tokens(response.usage) + cache_read, cache_creation = self.extract_cache_tokens(response.usage) return LLMResponse( text=text, input_tokens=response.usage.input_tokens, @@ -155,7 +155,7 @@ async def async_generate( max_tokens: int = 4096, ) -> LLMResponse: """Call the Anthropic API asynchronously with exponential-backoff retry.""" - client = self._get_async_client() + client = self.get_async_client() last_exc: Exception | None = None for attempt in range(self._max_retries): @@ -164,11 +164,11 @@ async def async_generate( model=model, max_tokens=max_tokens, temperature=temperature, - system=self._format_system(system), + system=self.format_system(system), messages=[{"role": "user", "content": prompt}], ) text = response.content[0].text # type: ignore[union-attr] - cache_read, cache_creation = self._extract_cache_tokens(response.usage) + cache_read, cache_creation = self.extract_cache_tokens(response.usage) return LLMResponse( text=text, input_tokens=response.usage.input_tokens, diff --git a/evoforge/llm/operators.py b/evoforge/llm/operators.py index 0cef63d..9c2c8b4 100644 --- a/evoforge/llm/operators.py +++ b/evoforge/llm/operators.py @@ -13,7 +13,7 @@ from evoforge.core.mutation import MutationContext, MutationOperator from evoforge.core.types import Individual -from evoforge.llm.batch import get_batch_collector +from evoforge.llm.batch import batch_aware_generate logger = logging.getLogger(__name__) @@ -44,20 +44,9 @@ async def apply(self, parent: Individual, context: MutationContext) -> str: prompt = context.backend.format_mutation_prompt(parent, context) system = context.backend.system_prompt() - collector = get_batch_collector() - if collector is not None: - future = collector.register( - prompt, system, self._model, context.temperature, self._max_tokens - ) - response = await future - else: - response = await self._client.async_generate( - prompt, - system, - self._model, - context.temperature, - self._max_tokens, - ) + response = await batch_aware_generate( + self._client, prompt, system, self._model, context.temperature, self._max_tokens + ) if response is None: logger.warning("LLMMutate: batch request failed, falling back to parent") @@ -103,20 +92,9 @@ async def apply(self, parent: Individual, context: MutationContext) -> str: system = context.backend.system_prompt() - collector = get_batch_collector() - if collector is not None: - future = collector.register( - prompt, system, self._model, context.temperature, self._max_tokens - ) - response = await future - else: - response = await self._client.async_generate( - prompt, - system, - self._model, - context.temperature, - self._max_tokens, - ) + response = await batch_aware_generate( + self._client, prompt, system, self._model, context.temperature, self._max_tokens + ) if response is None: logger.warning("LLMCrossover: batch request failed, falling back to parent") diff --git a/tests/test_core/test_batch.py b/tests/test_core/test_batch.py index 40c030f..10ef567 100644 --- a/tests/test_core/test_batch.py +++ b/tests/test_core/test_batch.py @@ -204,68 +204,32 @@ def test_batch_can_be_enabled(self) -> None: class TestTacticGeneratorBatchAware: @pytest.mark.asyncio - async def test_registers_with_batch_when_active(self) -> None: - import asyncio - + async def test_uses_batch_aware_generate(self) -> None: from evoforge.backends.lean.tactic_generator import LLMTacticGenerator from tests.conftest import FakeLLMResponse client = MagicMock() - client.async_generate = AsyncMock() - - mock_collector = MagicMock() - future: asyncio.Future[Any] = asyncio.get_event_loop().create_future() - future.set_result(FakeLLMResponse(text="1. simp\n2. ring\n3. linarith")) - mock_collector.register = MagicMock(return_value=future) - gen = LLMTacticGenerator(client, "test-model", "system prompt") with patch( - "evoforge.backends.lean.tactic_generator.get_batch_collector", - return_value=mock_collector, - ): + "evoforge.backends.lean.tactic_generator.batch_aware_generate", + new=AsyncMock(return_value=FakeLLMResponse(text="1. simp\n2. ring\n3. linarith")), + ) as mock_gen: tactics = await gen.suggest_tactics("⊢ x = x", [], 3) - mock_collector.register.assert_called_once() - client.async_generate.assert_not_called() - assert len(tactics) > 0 - - @pytest.mark.asyncio - async def test_direct_call_when_no_batch(self) -> None: - from evoforge.backends.lean.tactic_generator import LLMTacticGenerator - from tests.conftest import FakeLLMResponse - - client = MagicMock() - client.async_generate = AsyncMock(return_value=FakeLLMResponse(text="1. simp\n2. ring")) - - gen = LLMTacticGenerator(client, "test-model", "system prompt") - - with patch( - "evoforge.backends.lean.tactic_generator.get_batch_collector", - return_value=None, - ): - tactics = await gen.suggest_tactics("⊢ x = x", [], 3) - - client.async_generate.assert_called_once() - assert len(tactics) > 0 + mock_gen.assert_called_once() + assert tactics == ["simp", "ring", "linarith"] @pytest.mark.asyncio async def test_batch_returns_none_returns_empty(self) -> None: - import asyncio - from evoforge.backends.lean.tactic_generator import LLMTacticGenerator client = MagicMock() - mock_collector = MagicMock() - future: asyncio.Future[Any] = asyncio.get_event_loop().create_future() - future.set_result(None) - mock_collector.register = MagicMock(return_value=future) - gen = LLMTacticGenerator(client, "test-model", "system prompt") with patch( - "evoforge.backends.lean.tactic_generator.get_batch_collector", - return_value=mock_collector, + "evoforge.backends.lean.tactic_generator.batch_aware_generate", + new=AsyncMock(return_value=None), ): tactics = await gen.suggest_tactics("⊢ x = x", [], 3) diff --git a/tests/test_core/test_llm_operators.py b/tests/test_core/test_llm_operators.py index 6bc81d4..2b0befa 100644 --- a/tests/test_core/test_llm_operators.py +++ b/tests/test_core/test_llm_operators.py @@ -10,7 +10,6 @@ from evoforge.core.mutation import MutationContext from evoforge.core.types import Credit -from evoforge.llm.client import LLMResponse from evoforge.llm.operators import LLMCrossover, LLMMutate from tests.conftest import FakeLLMResponse, make_individual @@ -149,23 +148,15 @@ async def test_falls_back_to_parent_genome(self) -> None: # ------------------------------------------------------------------ # -# LLMMutate batch context tests +# LLMMutate batch-aware generate tests # ------------------------------------------------------------------ # -class TestLLMMutateWithBatchContext: +class TestLLMMutateWithBatchAwareGenerate: @pytest.mark.asyncio - async def test_registers_with_batch_when_active(self) -> None: - import asyncio - + async def test_uses_batch_aware_generate(self) -> None: + """LLMMutate delegates to batch_aware_generate.""" client = MagicMock() - client.async_generate = AsyncMock() - - mock_collector = MagicMock() - future: asyncio.Future[LLMResponse | None] = asyncio.get_event_loop().create_future() - future.set_result(FakeLLMResponse(text="```lean\nbatched_genome\n```")) - mock_collector.register = MagicMock(return_value=future) - backend = MagicMock() backend.format_mutation_prompt.return_value = "mutation prompt" backend.system_prompt.return_value = "system" @@ -175,51 +166,18 @@ async def test_registers_with_batch_when_active(self) -> None: op = LLMMutate(client=client, model="test-model") - with patch("evoforge.llm.operators.get_batch_collector", return_value=mock_collector): + with patch( + "evoforge.llm.operators.batch_aware_generate", + new=AsyncMock(return_value=FakeLLMResponse(text="batched")), + ): result = await op.apply(parent, ctx) - mock_collector.register.assert_called_once_with( - "mutation prompt", - "system", - "test-model", - 0.7, - 4096, - ) - client.async_generate.assert_not_called() assert result == "batched_genome" - @pytest.mark.asyncio - async def test_direct_call_when_no_batch(self) -> None: - client = MagicMock() - client.async_generate = AsyncMock( - return_value=FakeLLMResponse(text="```lean\ndirect_genome\n```") - ) - backend = MagicMock() - backend.format_mutation_prompt.return_value = "mutation prompt" - backend.system_prompt.return_value = "system" - backend.extract_genome.return_value = "direct_genome" - ctx = _make_context(backend=backend) - parent = make_individual("original") - - op = LLMMutate(client=client, model="test-model") - - with patch("evoforge.llm.operators.get_batch_collector", return_value=None): - result = await op.apply(parent, ctx) - - client.async_generate.assert_called_once() - assert result == "direct_genome" - @pytest.mark.asyncio async def test_batch_returns_none_falls_back_to_parent(self) -> None: - """When batch request fails (None result), fall back to parent genome.""" - import asyncio - + """When batch_aware_generate returns None, fall back to parent genome.""" client = MagicMock() - mock_collector = MagicMock() - future: asyncio.Future[LLMResponse | None] = asyncio.get_event_loop().create_future() - future.set_result(None) - mock_collector.register = MagicMock(return_value=future) - backend = MagicMock() backend.format_mutation_prompt.return_value = "mutation prompt" backend.system_prompt.return_value = "system" @@ -228,30 +186,25 @@ async def test_batch_returns_none_falls_back_to_parent(self) -> None: op = LLMMutate(client=client, model="test-model") - with patch("evoforge.llm.operators.get_batch_collector", return_value=mock_collector): + with patch( + "evoforge.llm.operators.batch_aware_generate", + new=AsyncMock(return_value=None), + ): result = await op.apply(parent, ctx) assert result == "original_genome" # ------------------------------------------------------------------ # -# LLMCrossover batch context tests +# LLMCrossover batch-aware generate tests # ------------------------------------------------------------------ # -class TestLLMCrossoverWithBatchContext: +class TestLLMCrossoverWithBatchAwareGenerate: @pytest.mark.asyncio - async def test_registers_with_batch_when_active(self) -> None: - import asyncio - + async def test_uses_batch_aware_generate(self) -> None: + """LLMCrossover delegates to batch_aware_generate.""" client = MagicMock() - client.async_generate = AsyncMock() - - mock_collector = MagicMock() - future: asyncio.Future[LLMResponse | None] = asyncio.get_event_loop().create_future() - future.set_result(FakeLLMResponse(text="batched")) - mock_collector.register = MagicMock(return_value=future) - backend = MagicMock() backend.format_crossover_prompt.return_value = "crossover prompt" backend.system_prompt.return_value = "system" @@ -263,9 +216,10 @@ async def test_registers_with_batch_when_active(self) -> None: op = LLMCrossover(client=client, model="test-model") - with patch("evoforge.llm.operators.get_batch_collector", return_value=mock_collector): + with patch( + "evoforge.llm.operators.batch_aware_generate", + new=AsyncMock(return_value=FakeLLMResponse(text="batched")), + ): result = await op.apply(parent_a, ctx) - mock_collector.register.assert_called_once() - client.async_generate.assert_not_called() assert result == "batched_genome" From 7dc854bf52b46ff41a19861d0e6ed71412316902 Mon Sep 17 00:00:00 2001 From: !link Date: Sun, 8 Mar 2026 19:56:25 -0700 Subject: [PATCH 11/11] Fix README test count to 679 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2d5e3c4..24df9d3 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ evoforge/ ablation credit, expression mutation operators llm/ — Anthropic client, LLM mutation/crossover operators, Jinja2 prompt templates, batch API collector -tests/ — 681 tests, strict mypy, ruff +tests/ — 679 tests, strict mypy, ruff configs/ — TOML experiment configs scripts/ — CLI entry point (run.py) ```