From 928b6ee5dbc2e828feaa9fefbef333af71bcf6e8 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 26 Jan 2026 06:13:55 +0000 Subject: [PATCH 1/2] Add request preprocessing and response postprocessing to MainLoop Add two overridable methods for transforming requests before agent processing and responses before returning to callers: - preprocess_request(request) - override to transform incoming requests - postprocess_response(response) - override to transform outgoing responses Both methods return their input unchanged by default. Subclass TriviaAgentLoop and override these methods to implement custom preprocessing/postprocessing logic. --- src/trivia_agent/worker.py | 88 +++++++++++++- tests/trivia_agent/test_worker.py | 183 ++++++++++++++++++++++++++++++ 2 files changed, 269 insertions(+), 2 deletions(-) diff --git a/src/trivia_agent/worker.py b/src/trivia_agent/worker.py index 9cb43b2..7192a1d 100644 --- a/src/trivia_agent/worker.py +++ b/src/trivia_agent/worker.py @@ -14,13 +14,13 @@ import os import sys -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from dataclasses import field from datetime import UTC, datetime, timedelta from pathlib import Path from typing import TYPE_CHECKING, TextIO -from weakincentives import FrozenDataclass, Prompt +from weakincentives import Budget, FrozenDataclass, Prompt, PromptResponse from weakincentives.adapters import ProviderAdapter from weakincentives.adapters.claude_agent_sdk import ClaudeAgentWorkspaceSection, HostMount from weakincentives.deadlines import Deadline @@ -28,6 +28,7 @@ from weakincentives.prompt import PromptTemplate from weakincentives.prompt.overrides import LocalPromptOverridesStore, PromptOverridesStore from weakincentives.runtime import ( + Heartbeat, LoopGroup, MainLoop, MainLoopConfig, @@ -204,6 +205,10 @@ class TriviaAgentLoop(MainLoop[TriviaRequest, TriviaResponse]): customization of prompt content without code changes - **Evaluation integration**: Compatible with EvalLoop for running evaluations with session-aware scoring + - **Request preprocessing**: Override preprocess_request() to transform + requests before agent processing + - **Response postprocessing**: Override postprocess_response() to transform + responses before returning to callers To use this loop, instantiate it with an adapter and mailbox, then either: 1. Call run() directly for single-threaded processing @@ -265,6 +270,34 @@ def __init__( self._base_template = build_prompt_template() self._overrides_store = overrides_store + def preprocess_request(self, request: TriviaRequest) -> TriviaRequest: + """Transform request before agent processing. + + Override this method in subclasses to implement custom preprocessing + logic such as validation, normalization, or enrichment. + + Args: + request: The incoming TriviaRequest. + + Returns: + TriviaRequest: The preprocessed request. + """ + return request + + def postprocess_response(self, response: TriviaResponse) -> TriviaResponse: + """Transform response before returning to caller. + + Override this method in subclasses to implement custom postprocessing + logic such as formatting, cleanup, or validation. + + Args: + response: The TriviaResponse from the agent. + + Returns: + TriviaResponse: The postprocessed response. + """ + return response + def prepare( self, request: TriviaRequest, @@ -277,6 +310,10 @@ def prepare( for isolation, builds the complete PromptTemplate with workspace section, binds request parameters, and optionally applies experiment overrides. + Note: Request preprocessing is applied in the execute() method before + prepare() is called. This ensures the preprocessed request is used + consistently throughout the execution flow. + This method demonstrates key WINK patterns: - **Session per request**: Each request gets its own Session for proper @@ -348,6 +385,53 @@ def prepare( return prompt, session + def execute( + self, + request: TriviaRequest, + *, + budget: Budget | None = None, + deadline: Deadline | None = None, + resources: Mapping[type[object], object] | None = None, + heartbeat: Heartbeat | None = None, + experiment: Experiment | None = None, + ) -> tuple[PromptResponse[TriviaResponse], Session]: + """Execute a trivia request with preprocessing and postprocessing. + + Overrides the parent MainLoop.execute() to apply preprocess_request() + before execution and postprocess_response() after execution. + + Args: + request: TriviaRequest containing the question to process. + budget: Optional Budget for token/cost limits. + deadline: Optional Deadline for time limits. + resources: Optional mapping of resource types to instances. + heartbeat: Optional Heartbeat for progress reporting. + experiment: Optional Experiment for evaluation runs. + + Returns: + tuple[PromptResponse[TriviaResponse], Session]: Response and session. + """ + # Apply preprocessing to the request + preprocessed_request = self.preprocess_request(request) + + # Execute with the preprocessed request + prompt_response, session = super().execute( + preprocessed_request, + budget=budget, + deadline=deadline, + resources=resources, + heartbeat=heartbeat, + experiment=experiment, + ) + + # Apply postprocessing to the response output (if present) + output = prompt_response.output + if output is not None: + postprocessed_output = self.postprocess_response(output) + return prompt_response.update(output=postprocessed_output), session # type: ignore[return-value] + + return prompt_response, session + @FrozenDataclass() class TriviaRuntime: diff --git a/tests/trivia_agent/test_worker.py b/tests/trivia_agent/test_worker.py index df4e981..3f2bb8c 100644 --- a/tests/trivia_agent/test_worker.py +++ b/tests/trivia_agent/test_worker.py @@ -238,6 +238,189 @@ def test_prepare_seeds_overrides_store( assert call_args.kwargs.get("tag") == "latest" +class TestTriviaAgentLoopPreprocessing: + """Tests for TriviaAgentLoop.preprocess_request() method.""" + + def test_preprocess_request_returns_unchanged_by_default( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that preprocess_request returns request unchanged by default.""" + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = TriviaAgentLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + request = TriviaRequest(question="test question") + result = loop.preprocess_request(request) + + assert result is request + + def test_preprocess_request_can_be_overridden( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that preprocess_request can be overridden in subclass.""" + + class CustomLoop(TriviaAgentLoop): + def preprocess_request(self, request: TriviaRequest) -> TriviaRequest: + return TriviaRequest(question=request.question.upper()) + + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = CustomLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + request = TriviaRequest(question="hello world") + result = loop.preprocess_request(request) + + assert result.question == "HELLO WORLD" + + +class TestTriviaAgentLoopPostprocessing: + """Tests for TriviaAgentLoop.postprocess_response() method.""" + + def test_postprocess_response_returns_unchanged_by_default( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that postprocess_response returns response unchanged by default.""" + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = TriviaAgentLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + response = TriviaResponse(answer="42") + result = loop.postprocess_response(response) + + assert result is response + + def test_postprocess_response_can_be_overridden( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that postprocess_response can be overridden in subclass.""" + + class CustomLoop(TriviaAgentLoop): + def postprocess_response(self, response: TriviaResponse) -> TriviaResponse: + return TriviaResponse(answer=f"Answer: {response.answer}") + + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = CustomLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + response = TriviaResponse(answer="42") + result = loop.postprocess_response(response) + + assert result.answer == "Answer: 42" + + +class TestTriviaAgentLoopExecute: + """Tests for TriviaAgentLoop.execute() with preprocessing/postprocessing.""" + + def test_execute_calls_preprocess_request( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that execute() calls preprocess_request.""" + from weakincentives.runtime import MainLoop + + class CustomLoop(TriviaAgentLoop): + def preprocess_request(self, request: TriviaRequest) -> TriviaRequest: + return TriviaRequest(question=request.question.strip()) + + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = CustomLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + mock_response = MagicMock() + mock_response.output = TriviaResponse(answer="42") + mock_session = MagicMock() + + captured_requests: list[TriviaRequest] = [] + + def capture_execute(self_arg, request, **kwargs): + captured_requests.append(request) + return (mock_response, mock_session) + + with patch.object(MainLoop, "execute", capture_execute): + request = TriviaRequest(question=" What is the answer? ") + loop.execute(request) + + assert len(captured_requests) == 1 + assert captured_requests[0].question == "What is the answer?" + + def test_execute_calls_postprocess_response( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that execute() calls postprocess_response.""" + from weakincentives.runtime import MainLoop + + class CustomLoop(TriviaAgentLoop): + def postprocess_response(self, response: TriviaResponse) -> TriviaResponse: + return TriviaResponse(answer=response.answer.strip()) + + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = CustomLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + mock_response = MagicMock() + mock_response.output = TriviaResponse(answer=" 42 ") + mock_response.update = MagicMock(return_value=mock_response) + mock_session = MagicMock() + + with patch.object(MainLoop, "execute", return_value=(mock_response, mock_session)): + request = TriviaRequest(question="What is the answer?") + loop.execute(request) + + mock_response.update.assert_called_once() + call_kwargs = mock_response.update.call_args.kwargs + assert call_kwargs["output"].answer == "42" + + def test_execute_skips_postprocessing_when_output_is_none( + self, + fake_mailboxes: TriviaMailboxes, + ) -> None: + """Test that execute() skips postprocessing when output is None.""" + from weakincentives.runtime import MainLoop + + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() + + loop = TriviaAgentLoop( + adapter=mock_adapter, + requests=fake_mailboxes.requests, + ) + + mock_response = MagicMock() + mock_response.output = None + mock_response.update = MagicMock() + mock_session = MagicMock() + + with patch.object(MainLoop, "execute", return_value=(mock_response, mock_session)): + request = TriviaRequest(question="What is the answer?") + result_response, result_session = loop.execute(request) + + mock_response.update.assert_not_called() + assert result_response is mock_response + assert result_session is mock_session + + class TestTriviaRuntime: """Tests for TriviaRuntime dataclass.""" From 54da0f1d84e5f5315c0a3f3405a5de70531b9100 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 26 Jan 2026 08:26:54 +0000 Subject: [PATCH 2/2] Extend preprocessing/postprocessing hooks with Session and Prompt access - preprocess_request now receives Session for context-aware preprocessing - postprocess_response now receives both Prompt and Session - Preprocessing moved inside prepare() for proper request transformation - Use instance variables to pass context from prepare() to postprocess_response() --- src/trivia_agent/worker.py | 77 +++++++++++++-------------- tests/trivia_agent/test_worker.py | 87 ++++++++++++++++++++----------- 2 files changed, 96 insertions(+), 68 deletions(-) diff --git a/src/trivia_agent/worker.py b/src/trivia_agent/worker.py index 7192a1d..08c409d 100644 --- a/src/trivia_agent/worker.py +++ b/src/trivia_agent/worker.py @@ -269,29 +269,45 @@ def __init__( self._workspace_dir = workspace_dir or DEFAULT_WORKSPACE_DIR self._base_template = build_prompt_template() self._overrides_store = overrides_store + self._last_prompt: Prompt[TriviaResponse] | None = None + self._last_session: Session | None = None - def preprocess_request(self, request: TriviaRequest) -> TriviaRequest: + def preprocess_request( + self, + request: TriviaRequest, + session: Session, + ) -> TriviaRequest: """Transform request before agent processing. Override this method in subclasses to implement custom preprocessing - logic such as validation, normalization, or enrichment. + logic such as validation, normalization, or enrichment. Has access to + the session for context-aware preprocessing. Args: request: The incoming TriviaRequest. + session: The session for this execution. Returns: TriviaRequest: The preprocessed request. """ return request - def postprocess_response(self, response: TriviaResponse) -> TriviaResponse: + def postprocess_response( + self, + response: TriviaResponse, + prompt: Prompt[TriviaResponse], + session: Session, + ) -> TriviaResponse: """Transform response before returning to caller. Override this method in subclasses to implement custom postprocessing - logic such as formatting, cleanup, or validation. + logic such as formatting, cleanup, or validation. Has access to the + prompt and session for context-aware postprocessing. Args: response: The TriviaResponse from the agent. + prompt: The prompt used for this execution. + session: The session used for this execution. Returns: TriviaResponse: The postprocessed response. @@ -310,38 +326,20 @@ def prepare( for isolation, builds the complete PromptTemplate with workspace section, binds request parameters, and optionally applies experiment overrides. - Note: Request preprocessing is applied in the execute() method before - prepare() is called. This ensures the preprocessed request is used - consistently throughout the execution flow. - - This method demonstrates key WINK patterns: - - - **Session per request**: Each request gets its own Session for proper - isolation between concurrent requests - - **Dynamic workspace section**: ClaudeAgentWorkspaceSection is created - here (not in build_prompt_template) because it needs the Session - - **Parameter binding**: Binds QuestionParams with the user's question - and EmptyParams for parameterless sections - - **Experiment support**: Uses experiment.overrides_tag to select prompt - variants for A/B testing; defaults to "latest" - - **Override seeding**: Automatically seeds the overrides store with - current prompt state, creating editable files for customization + Calls preprocess_request() to allow request transformation before binding. Args: - request: TriviaRequest containing the question field. The question - is bound to QuestionSection via QuestionParams. - experiment: Optional Experiment instance for evaluation runs. When - provided, uses experiment.overrides_tag to select prompt variant. - Pass None for production requests. + request: TriviaRequest containing the question field. + experiment: Optional Experiment instance for evaluation runs. Returns: - tuple[Prompt[TriviaResponse], Session]: A 2-tuple containing: - - Prompt: Fully configured prompt with all sections and bound - parameters, ready for adapter.run() - - Session: Fresh session instance for this request's execution + tuple[Prompt[TriviaResponse], Session]: Prompt and session for execution. """ session = Session() + # Allow subclasses to preprocess the request + request = self.preprocess_request(request, session) + # Create workspace section with seeded files # This needs to be per-request because it references the session workspace_section = create_workspace_section( @@ -383,6 +381,10 @@ def prepare( prompt.bind(QuestionParams(question=request.question)) prompt.bind(EmptyParams()) # For sections without params (GameRules, Hints) + # Store for postprocess_response access + self._last_prompt = prompt + self._last_session = session + return prompt, session def execute( @@ -397,8 +399,8 @@ def execute( ) -> tuple[PromptResponse[TriviaResponse], Session]: """Execute a trivia request with preprocessing and postprocessing. - Overrides the parent MainLoop.execute() to apply preprocess_request() - before execution and postprocess_response() after execution. + Preprocessing happens in prepare() via preprocess_request(). + Postprocessing happens after execution via postprocess_response(). Args: request: TriviaRequest containing the question to process. @@ -411,12 +413,9 @@ def execute( Returns: tuple[PromptResponse[TriviaResponse], Session]: Response and session. """ - # Apply preprocessing to the request - preprocessed_request = self.preprocess_request(request) - - # Execute with the preprocessed request + # Execute (prepare() is called internally by parent, which calls preprocess_request) prompt_response, session = super().execute( - preprocessed_request, + request, budget=budget, deadline=deadline, resources=resources, @@ -426,8 +425,10 @@ def execute( # Apply postprocessing to the response output (if present) output = prompt_response.output - if output is not None: - postprocessed_output = self.postprocess_response(output) + if output is not None and self._last_prompt is not None: + postprocessed_output = self.postprocess_response( + output, self._last_prompt, self._last_session or session + ) return prompt_response.update(output=postprocessed_output), session # type: ignore[return-value] return prompt_response, session diff --git a/tests/trivia_agent/test_worker.py b/tests/trivia_agent/test_worker.py index 3f2bb8c..71c70fa 100644 --- a/tests/trivia_agent/test_worker.py +++ b/tests/trivia_agent/test_worker.py @@ -246,6 +246,8 @@ def test_preprocess_request_returns_unchanged_by_default( fake_mailboxes: TriviaMailboxes, ) -> None: """Test that preprocess_request returns request unchanged by default.""" + from weakincentives.runtime import Session + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() loop = TriviaAgentLoop( @@ -254,7 +256,8 @@ def test_preprocess_request_returns_unchanged_by_default( ) request = TriviaRequest(question="test question") - result = loop.preprocess_request(request) + session = Session() + result = loop.preprocess_request(request, session) assert result is request @@ -263,9 +266,10 @@ def test_preprocess_request_can_be_overridden( fake_mailboxes: TriviaMailboxes, ) -> None: """Test that preprocess_request can be overridden in subclass.""" + from weakincentives.runtime import Session class CustomLoop(TriviaAgentLoop): - def preprocess_request(self, request: TriviaRequest) -> TriviaRequest: + def preprocess_request(self, request: TriviaRequest, session: Session) -> TriviaRequest: return TriviaRequest(question=request.question.upper()) mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() @@ -276,7 +280,8 @@ def preprocess_request(self, request: TriviaRequest) -> TriviaRequest: ) request = TriviaRequest(question="hello world") - result = loop.preprocess_request(request) + session = Session() + result = loop.preprocess_request(request, session) assert result.question == "HELLO WORLD" @@ -289,6 +294,8 @@ def test_postprocess_response_returns_unchanged_by_default( fake_mailboxes: TriviaMailboxes, ) -> None: """Test that postprocess_response returns response unchanged by default.""" + from weakincentives.runtime import Session + mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() loop = TriviaAgentLoop( @@ -297,7 +304,9 @@ def test_postprocess_response_returns_unchanged_by_default( ) response = TriviaResponse(answer="42") - result = loop.postprocess_response(response) + mock_prompt = MagicMock() + session = Session() + result = loop.postprocess_response(response, mock_prompt, session) assert result is response @@ -306,9 +315,16 @@ def test_postprocess_response_can_be_overridden( fake_mailboxes: TriviaMailboxes, ) -> None: """Test that postprocess_response can be overridden in subclass.""" + from weakincentives import Prompt + from weakincentives.runtime import Session class CustomLoop(TriviaAgentLoop): - def postprocess_response(self, response: TriviaResponse) -> TriviaResponse: + def postprocess_response( + self, + response: TriviaResponse, + prompt: Prompt[TriviaResponse], + session: Session, + ) -> TriviaResponse: return TriviaResponse(answer=f"Answer: {response.answer}") mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() @@ -319,7 +335,9 @@ def postprocess_response(self, response: TriviaResponse) -> TriviaResponse: ) response = TriviaResponse(answer="42") - result = loop.postprocess_response(response) + mock_prompt = MagicMock() + session = Session() + result = loop.postprocess_response(response, mock_prompt, session) assert result.answer == "Answer: 42" @@ -327,15 +345,18 @@ def postprocess_response(self, response: TriviaResponse) -> TriviaResponse: class TestTriviaAgentLoopExecute: """Tests for TriviaAgentLoop.execute() with preprocessing/postprocessing.""" - def test_execute_calls_preprocess_request( + def test_execute_calls_preprocess_request_via_prepare( self, fake_mailboxes: TriviaMailboxes, ) -> None: - """Test that execute() calls preprocess_request.""" - from weakincentives.runtime import MainLoop + """Test that prepare() calls preprocess_request during execution.""" + from weakincentives.runtime import Session + + preprocess_calls: list[TriviaRequest] = [] class CustomLoop(TriviaAgentLoop): - def preprocess_request(self, request: TriviaRequest) -> TriviaRequest: + def preprocess_request(self, request: TriviaRequest, session: Session) -> TriviaRequest: + preprocess_calls.append(request) return TriviaRequest(question=request.question.strip()) mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() @@ -345,32 +366,33 @@ def preprocess_request(self, request: TriviaRequest) -> TriviaRequest: requests=fake_mailboxes.requests, ) - mock_response = MagicMock() - mock_response.output = TriviaResponse(answer="42") - mock_session = MagicMock() - - captured_requests: list[TriviaRequest] = [] - - def capture_execute(self_arg, request, **kwargs): - captured_requests.append(request) - return (mock_response, mock_session) - - with patch.object(MainLoop, "execute", capture_execute): - request = TriviaRequest(question=" What is the answer? ") - loop.execute(request) + request = TriviaRequest(question=" What is the answer? ") + prompt, session = loop.prepare(request) - assert len(captured_requests) == 1 - assert captured_requests[0].question == "What is the answer?" + assert len(preprocess_calls) == 1 + assert preprocess_calls[0].question == " What is the answer? " + # The prompt should have the preprocessed question bound + rendered = str(prompt.render()) + assert "What is the answer?" in rendered def test_execute_calls_postprocess_response( self, fake_mailboxes: TriviaMailboxes, ) -> None: """Test that execute() calls postprocess_response.""" - from weakincentives.runtime import MainLoop + from weakincentives import Prompt + from weakincentives.runtime import MainLoop, Session + + postprocess_calls: list[tuple[TriviaResponse, bool, bool]] = [] class CustomLoop(TriviaAgentLoop): - def postprocess_response(self, response: TriviaResponse) -> TriviaResponse: + def postprocess_response( + self, + response: TriviaResponse, + prompt: Prompt[TriviaResponse], + session: Session, + ) -> TriviaResponse: + postprocess_calls.append((response, prompt is not None, session is not None)) return TriviaResponse(answer=response.answer.strip()) mock_adapter: ProviderAdapter[TriviaResponse] = MagicMock() @@ -385,13 +407,18 @@ def postprocess_response(self, response: TriviaResponse) -> TriviaResponse: mock_response.update = MagicMock(return_value=mock_response) mock_session = MagicMock() + # Call prepare() first to set up _last_prompt/_last_session + request = TriviaRequest(question="What is the answer?") + loop.prepare(request) + with patch.object(MainLoop, "execute", return_value=(mock_response, mock_session)): - request = TriviaRequest(question="What is the answer?") loop.execute(request) + assert len(postprocess_calls) == 1 + assert postprocess_calls[0][0].answer == " 42 " + assert postprocess_calls[0][1] is True # prompt was passed + assert postprocess_calls[0][2] is True # session was passed mock_response.update.assert_called_once() - call_kwargs = mock_response.update.call_args.kwargs - assert call_kwargs["output"].answer == "42" def test_execute_skips_postprocessing_when_output_is_none( self,