diff --git a/src/trivia_agent/worker.py b/src/trivia_agent/worker.py index 9cb43b2..08c409d 100644 --- a/src/trivia_agent/worker.py +++ b/src/trivia_agent/worker.py @@ -14,13 +14,13 @@ import os import sys -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from dataclasses import field from datetime import UTC, datetime, timedelta from pathlib import Path from typing import TYPE_CHECKING, TextIO -from weakincentives import FrozenDataclass, Prompt +from weakincentives import Budget, FrozenDataclass, Prompt, PromptResponse from weakincentives.adapters import ProviderAdapter from weakincentives.adapters.claude_agent_sdk import ClaudeAgentWorkspaceSection, HostMount from weakincentives.deadlines import Deadline @@ -28,6 +28,7 @@ from weakincentives.prompt import PromptTemplate from weakincentives.prompt.overrides import LocalPromptOverridesStore, PromptOverridesStore from weakincentives.runtime import ( + Heartbeat, LoopGroup, MainLoop, MainLoopConfig, @@ -204,6 +205,10 @@ class TriviaAgentLoop(MainLoop[TriviaRequest, TriviaResponse]): customization of prompt content without code changes - **Evaluation integration**: Compatible with EvalLoop for running evaluations with session-aware scoring + - **Request preprocessing**: Override preprocess_request() to transform + requests before agent processing + - **Response postprocessing**: Override postprocess_response() to transform + responses before returning to callers To use this loop, instantiate it with an adapter and mailbox, then either: 1. Call run() directly for single-threaded processing @@ -264,6 +269,50 @@ def __init__( self._workspace_dir = workspace_dir or DEFAULT_WORKSPACE_DIR self._base_template = build_prompt_template() self._overrides_store = overrides_store + self._last_prompt: Prompt[TriviaResponse] | None = None + self._last_session: Session | None = None + + def preprocess_request( + self, + request: TriviaRequest, + session: Session, + ) -> TriviaRequest: + """Transform request before agent processing. + + Override this method in subclasses to implement custom preprocessing + logic such as validation, normalization, or enrichment. Has access to + the session for context-aware preprocessing. + + Args: + request: The incoming TriviaRequest. + session: The session for this execution. + + Returns: + TriviaRequest: The preprocessed request. + """ + return request + + def postprocess_response( + self, + response: TriviaResponse, + prompt: Prompt[TriviaResponse], + session: Session, + ) -> TriviaResponse: + """Transform response before returning to caller. + + Override this method in subclasses to implement custom postprocessing + logic such as formatting, cleanup, or validation. Has access to the + prompt and session for context-aware postprocessing. + + Args: + response: The TriviaResponse from the agent. + prompt: The prompt used for this execution. + session: The session used for this execution. + + Returns: + TriviaResponse: The postprocessed response. + """ + return response def prepare( self, @@ -277,34 +326,20 @@ def prepare( for isolation, builds the complete PromptTemplate with workspace section, binds request parameters, and optionally applies experiment overrides. - This method demonstrates key WINK patterns: - - - **Session per request**: Each request gets its own Session for proper - isolation between concurrent requests - - **Dynamic workspace section**: ClaudeAgentWorkspaceSection is created - here (not in build_prompt_template) because it needs the Session - - **Parameter binding**: Binds QuestionParams with the user's question - and EmptyParams for parameterless sections - - **Experiment support**: Uses experiment.overrides_tag to select prompt - variants for A/B testing; defaults to "latest" - - **Override seeding**: Automatically seeds the overrides store with - current prompt state, creating editable files for customization + Calls preprocess_request() to allow request transformation before binding. Args: - request: TriviaRequest containing the question field. The question - is bound to QuestionSection via QuestionParams. - experiment: Optional Experiment instance for evaluation runs. When - provided, uses experiment.overrides_tag to select prompt variant. - Pass None for production requests. + request: TriviaRequest containing the question field. + experiment: Optional Experiment instance for evaluation runs. Returns: - tuple[Prompt[TriviaResponse], Session]: A 2-tuple containing: - - Prompt: Fully configured prompt with all sections and bound - parameters, ready for adapter.run() - - Session: Fresh session instance for this request's execution + tuple[Prompt[TriviaResponse], Session]: Prompt and session for execution. """ session = Session() + # Allow subclasses to preprocess the request + request = self.preprocess_request(request, session) + # Create workspace section with seeded files # This needs to be per-request because it references the session workspace_section = create_workspace_section( @@ -346,8 +381,58 @@ def prepare( prompt.bind(QuestionParams(question=request.question)) prompt.bind(EmptyParams()) # For sections without params (GameRules, Hints) + # Store for postprocess_response access + self._last_prompt = prompt + self._last_session = session + return prompt, session + def execute( + self, + request: TriviaRequest, + *, + budget: Budget | None = None, + deadline: Deadline | None = None, + resources: Mapping[type[object], object] | None = None, + heartbeat: Heartbeat | None = None, + experiment: Experiment | None = None, + ) -> tuple[PromptResponse[TriviaResponse], Session]: + """Execute a trivia request with preprocessing and postprocessing. + + Preprocessing happens in prepare() via preprocess_request(). + Postprocessing happens after execution via postprocess_response(). + + Args: + request: TriviaRequest containing the question to process. + budget: Optional Budget for token/cost limits. + deadline: Optional Deadline for time limits. + resources: Optional mapping of resource types to instances. + heartbeat: Optional Heartbeat for progress reporting. + experiment: Optional Experiment for evaluation runs. + + Returns: + tuple[PromptResponse[TriviaResponse], Session]: Response and session. + """ + # Execute (prepare() is called internally by parent, which calls preprocess_request) + prompt_response, session = super().execute( + request, + budget=budget, + deadline=deadline, + resources=resources, + heartbeat=heartbeat, + experiment=experiment, + ) + + # Apply postprocessing to the response output (if present) + output = prompt_response.output + if output is not None and self._last_prompt is not None: + postprocessed_output = self.postprocess_response( + output, self._last_prompt, self._last_session or session + ) + return prompt_response.update(output=postprocessed_output), session # type: ignore[return-value] + + return prompt_response, session + @FrozenDataclass() class TriviaRuntime: diff --git a/tests/trivia_agent/test_worker.py b/tests/trivia_agent/test_worker.py index df4e981..71c70fa 100644 --- a/tests/trivia_agent/test_worker.py +++ b/tests/trivia_agent/test_worker.py @@ -238,6 +238,216 @@ def test_prepare_seeds_overrides_store( assert call_args.kwargs.get("tag") == "latest" +class TestTriviaAgentLoopPreprocessing: + """Tests for TriviaAgentLoop.preprocess_request() method.""" + + def test_preprocess_request_returns_unchanged_by_default( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that preprocess_request returns request unchanged by default.""" + from weakincentives.runtime import Session + + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = TriviaAgentLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + request = TriviaRequest(question="test question") + session = Session() + result = loop.preprocess_request(request, session) + + assert result is request + + def test_preprocess_request_can_be_overridden( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that preprocess_request can be overridden in subclass.""" + from weakincentives.runtime import Session + + class CustomLoop(TriviaAgentLoop): + def preprocess_request(self, request: TriviaRequest, session: Session) -> TriviaRequest: + return TriviaRequest(question=request.question.upper()) + + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = CustomLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + request = TriviaRequest(question="hello world") + session = Session() + result = loop.preprocess_request(request, session) + + assert result.question == "HELLO WORLD" + + +class TestTriviaAgentLoopPostprocessing: + """Tests for TriviaAgentLoop.postprocess_response() method.""" + + def test_postprocess_response_returns_unchanged_by_default( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that postprocess_response returns response unchanged by default.""" + from weakincentives.runtime import Session + + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = TriviaAgentLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + response = TriviaResponse(answer="42") + mock_prompt = MagicMock() + session = Session() + result = loop.postprocess_response(response, mock_prompt, session) + + assert result is response + + def test_postprocess_response_can_be_overridden( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that postprocess_response can be overridden in subclass.""" + from weakincentives import Prompt + from weakincentives.runtime import Session + + class CustomLoop(TriviaAgentLoop): + def postprocess_response( + self, + response: TriviaResponse, + prompt: Prompt[TriviaResponse], + session: Session, + ) -> TriviaResponse: + return TriviaResponse(answer=f"Answer: {response.answer}") + + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = CustomLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + response = TriviaResponse(answer="42") + mock_prompt = MagicMock() + session = Session() + result = loop.postprocess_response(response, mock_prompt, session) + + assert result.answer == "Answer: 42" + + +class TestTriviaAgentLoopExecute: + """Tests for TriviaAgentLoop.execute() with preprocessing/postprocessing.""" + + def test_execute_calls_preprocess_request_via_prepare( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that prepare() calls preprocess_request during execution.""" + from weakincentives.runtime import Session + + preprocess_calls: list[TriviaRequest] = [] + + class CustomLoop(TriviaAgentLoop): + def preprocess_request(self, request: TriviaRequest, session: Session) -> TriviaRequest: + preprocess_calls.append(request) + return TriviaRequest(question=request.question.strip()) + + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = CustomLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + request = TriviaRequest(question=" What is the answer? ") + prompt, session = loop.prepare(request) + + assert len(preprocess_calls) == 1 + assert preprocess_calls[0].question == " What is the answer? " + # The prompt should have the preprocessed question bound + rendered = str(prompt.render()) + assert "What is the answer?" in rendered + + def test_execute_calls_postprocess_response( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that execute() calls postprocess_response.""" + from weakincentives import Prompt + from weakincentives.runtime import MainLoop, Session + + postprocess_calls: list[tuple[TriviaResponse, bool, bool]] = [] + + class CustomLoop(TriviaAgentLoop): + def postprocess_response( + self, + response: TriviaResponse, + prompt: Prompt[TriviaResponse], + session: Session, + ) -> TriviaResponse: + postprocess_calls.append((response, prompt is not None, session is not None)) + return TriviaResponse(answer=response.answer.strip()) + + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = CustomLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + mock_response = MagicMock() + mock_response.output = TriviaResponse(answer=" 42 ") + mock_response.update = MagicMock(return_value=mock_response) + mock_session = MagicMock() + + # Call prepare() first to set up _last_prompt/_last_session + request = TriviaRequest(question="What is the answer?") + loop.prepare(request) + + with patch.object(MainLoop, "execute", return_value=(mock_response, mock_session)): + loop.execute(request) + + assert len(postprocess_calls) == 1 + assert postprocess_calls[0][0].answer == " 42 " + assert postprocess_calls[0][1] is True # prompt was passed + assert postprocess_calls[0][2] is True # session was passed + mock_response.update.assert_called_once() + + def test_execute_skips_postprocessing_when_output_is_none( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that execute() skips postprocessing when output is None.""" + from weakincentives.runtime import MainLoop + + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = TriviaAgentLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + mock_response = MagicMock() + mock_response.output = None + mock_response.update = MagicMock() + mock_session = MagicMock() + + with patch.object(MainLoop, "execute", return_value=(mock_response, mock_session)): + request = TriviaRequest(question="What is the answer?") + result_response, result_session = loop.execute(request) + + mock_response.update.assert_not_called() + assert result_response is mock_response + assert result_session is mock_session + + class TestTriviaRuntime: """Tests for TriviaRuntime dataclass."""