From 1715b271f3e0802efc0bbedfe02ace690b9c20ed Mon Sep 17 00:00:00 2001 From: Sanjog Thapa Date: Thu, 23 Apr 2026 14:03:49 -0500 Subject: [PATCH] add trimming capability on PydanticAIBaseAgent --- akd_ext/agents/_base/pydantic_ai/_base.py | 44 ++++++++++ .../agents/_base/pydantic_ai/_capabilities.py | 52 +++++++++++ tests/agents/test_base_pydantic.py | 87 +++++++++++++++++++ 3 files changed, 183 insertions(+) create mode 100644 akd_ext/agents/_base/pydantic_ai/_capabilities.py diff --git a/akd_ext/agents/_base/pydantic_ai/_base.py b/akd_ext/agents/_base/pydantic_ai/_base.py index b5f7728..0574c91 100644 --- a/akd_ext/agents/_base/pydantic_ai/_base.py +++ b/akd_ext/agents/_base/pydantic_ai/_base.py @@ -50,6 +50,7 @@ class MyAgent(PydanticAIBaseAgent[MyIn, MyOut]): from akd._base.structures import RunContext as AKDRunContext from akd.agents._base import BaseAgentConfig +from ._capabilities import make_ratio_trimmer from ._event_translator import pai_event_to_akd_event from ._tool_adapter import akd_to_pai_tool @@ -76,6 +77,24 @@ class PydanticAIBaseAgentConfig(BaseAgentConfig): "Merged with any capabilities auto-derived from AKD scalar fields." ), ) + history_processors: list[Any] = Field( + default_factory=list, + description=( + "Pydantic AI history processors (callables that shape the message " + "history before each model request). Merged with any processors " + "auto-derived from AKD scalar fields." + ), + ) + + # Override ``BaseAgentConfig.enable_trimming`` (default ``True``) to ``False``: + # the naive ratio-based trimmer we ship in ``_capabilities.make_ratio_trimmer`` + # drops arbitrary message-history slices, which can strand a + # ``ToolReturnPart`` that was paired with an earlier ``ToolCallPart`` — + # OpenAI's API rejects histories shaped that way. Consumers who know their + # history doesn't contain tool calls (or who supply their own pairing-aware + # processor via ``history_processors``) can opt in explicitly with + # ``enable_trimming=True``. + enable_trimming: bool = Field(default=False) # -- Silence AKD-core's litellm-based config validators -------------- # The following validator help for lookups that expect @@ -169,6 +188,10 @@ def __init__(self, config: PydanticAIBaseAgentConfig | None = None) -> None: *self._build_capabilities_from_scalars(), *self.config.capabilities, ], + history_processors=[ + *self._build_history_processors_from_scalars(), + *self.config.history_processors, + ], **extra_kwargs, ) @@ -338,6 +361,27 @@ def _build_capabilities_from_scalars(self) -> list[AbstractCapability]: """ return [] + def _build_history_processors_from_scalars(self) -> list: + """Derive history processors from AKD scalar config fields. + + Current wiring: + + - ``enable_trimming`` + ``trim_ratio`` → ``make_ratio_trimmer`` — + a stateless history processor that drops the oldest + ``1 - trim_ratio`` fraction of messages on every invocation. + ``trim_ratio`` is AKD's *retention* ratio (``0.75`` = keep 75%) + so we invert it for the trimmer's *drop* fraction. Disabled by + default via the ``enable_trimming=False`` override on + ``PydanticAIBaseAgentConfig`` — see that field's docstring for + the caveat on tool-call pairing. + + Subclasses override to append their own; call ``super()`` first. + """ + procs: list = [] + if self.config.enable_trimming: + procs.append(make_ratio_trimmer(1 - self.config.trim_ratio)) + return procs + # ── Tool adaptation ────────────────────────────────────────────────── def _adapt_tools(self, tools: list) -> list: diff --git a/akd_ext/agents/_base/pydantic_ai/_capabilities.py b/akd_ext/agents/_base/pydantic_ai/_capabilities.py new file mode 100644 index 0000000..c513647 --- /dev/null +++ b/akd_ext/agents/_base/pydantic_ai/_capabilities.py @@ -0,0 +1,52 @@ +"""Custom pydantic_ai capabilities and history processors built from AKD +scalar config fields. + +Concrete implementations behind ``PydanticAIBaseAgent``'s +``_build_capabilities_from_scalars`` and +``_build_history_processors_from_scalars`` hooks. Factories here return +ready-to-use pydantic_ai objects (``Hooks`` capabilities or plain +history-processor callables); the base agent constructs them at +``__init__`` time from the relevant ``BaseAgentConfig`` fields. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + + +def make_ratio_trimmer(trim_ratio: float) -> Callable[[list[Any]], list[Any]]: + """Return a stateless history processor that drops the oldest + ``trim_ratio`` fraction of messages on every invocation. + + A ``trim_ratio`` of ``0.3`` drops the oldest 30% of the message list; + ``0.0`` is a no-op; ratios ≥ ``1.0`` are rejected (would drop + everything). The first message is preserved so a system-prompt-style + preamble survives trimming. + + .. warning:: + + Naive drop-from-the-middle trimming can strand pydantic_ai + ``ToolReturnPart`` messages that were paired with an earlier + ``ToolCallPart``. OpenAI's API rejects such histories. This + trimmer is safe only when callers know their history shape + (e.g. no tool calls, or a custom trimmer hook manages pairing). + ``PydanticAIBaseAgentConfig`` defaults ``enable_trimming=False`` + for exactly that reason — callers opt in. + """ + if not 0 <= trim_ratio < 1: + raise ValueError(f"trim_ratio must be in [0, 1), got {trim_ratio!r}") + + def processor(messages: list[Any]) -> list[Any]: + if not messages or trim_ratio == 0: + return messages + head, rest = messages[:1], messages[1:] + drop_count = int(len(rest) * trim_ratio) + return [*head, *rest[drop_count:]] + + return processor + + +__all__ = [ + "make_ratio_trimmer", +] diff --git a/tests/agents/test_base_pydantic.py b/tests/agents/test_base_pydantic.py index 1499823..0bcc149 100644 --- a/tests/agents/test_base_pydantic.py +++ b/tests/agents/test_base_pydantic.py @@ -438,3 +438,90 @@ async def test_existing_akd_tool_is_pai_compatible(): assert isinstance(akd_tool, AKDTool) # And the input schema reference is intact. assert akd_tool.input_schema is DummyInputSchema + + +# --------------------------------------------------------------------------- +# make_ratio_trimmer history processor + wiring +# --------------------------------------------------------------------------- + + +def test_ratio_trimmer_is_noop_at_zero(): + """``trim_ratio=0`` returns every message unchanged — useful as an + explicit "don't trim" sentinel.""" + from akd_ext.agents._base.pydantic_ai._capabilities import make_ratio_trimmer + + trimmer = make_ratio_trimmer(0) + msgs = [f"msg-{i}" for i in range(5)] + assert trimmer(msgs) == msgs + + +def test_ratio_trimmer_drops_oldest_fraction_and_preserves_head(): + """Non-zero ratio drops the oldest ``ratio`` fraction of messages + *after* the first (which is preserved as system-prompt-like head).""" + from akd_ext.agents._base.pydantic_ai._capabilities import make_ratio_trimmer + + trimmer = make_ratio_trimmer(0.5) # drop oldest 50% of the tail + msgs = ["system"] + [f"msg-{i}" for i in range(10)] + trimmed = trimmer(msgs) + # Head preserved: + assert trimmed[0] == "system" + # Half of the 10 tail messages dropped from the oldest end: + assert trimmed[1:] == [f"msg-{i}" for i in range(5, 10)] + + +def test_ratio_trimmer_empty_input_passthrough(): + """Empty list returns empty list regardless of ratio.""" + from akd_ext.agents._base.pydantic_ai._capabilities import make_ratio_trimmer + + assert make_ratio_trimmer(0.5)([]) == [] + + +def test_ratio_trimmer_rejects_invalid_ratio(): + """Ratio must be in ``[0, 1)`` — negative or ≥1 raises at factory time.""" + from akd_ext.agents._base.pydantic_ai._capabilities import make_ratio_trimmer + + with pytest.raises(ValueError, match="trim_ratio"): + make_ratio_trimmer(-0.1) + with pytest.raises(ValueError, match="trim_ratio"): + make_ratio_trimmer(1.0) + + +def test_build_history_processors_from_scalars_off_by_default(): + """``enable_trimming`` defaults to ``False`` on + ``PydanticAIBaseAgentConfig``, so the builder emits no processors.""" + agent = _EchoAgent(_EchoConfig()) + assert agent._build_history_processors_from_scalars() == [] + + +def test_build_history_processors_from_scalars_includes_trimmer_when_enabled(): + """When ``enable_trimming=True``, the builder emits one callable — the + trimmer produced by ``make_ratio_trimmer(1 - trim_ratio)``.""" + # trim_ratio=0.5 avoids floating-point drift on (1 - ratio) evaluation. + agent = _EchoAgent(_EchoConfig(enable_trimming=True, trim_ratio=0.5)) + procs = agent._build_history_processors_from_scalars() + assert len(procs) == 1 + assert callable(procs[0]) + # Sanity: the returned trimmer should behave like make_ratio_trimmer(0.5) + # — on a 1+10 message list it drops 50% of the tail (5 messages). + trimmer = procs[0] + result = trimmer(["sys"] + [f"m-{i}" for i in range(10)]) + assert result[0] == "sys" + assert len(result) == 1 + 5 # head + 10 - int(10 * 0.5) = 5 kept + + +def test_config_history_processors_pass_through(): + """Processors the caller supplies via ``config.history_processors`` must + show up in ``_build_history_processors_from_scalars``'s merge path + (i.e. the agent sees both the scalar-derived trimmer and the + config-provided callables).""" + + def custom_processor(messages): + return messages # no-op + + agent = _EchoAgent(_EchoConfig(history_processors=[custom_processor])) + # The builder itself doesn't return custom processors — only scalar-derived + # ones — but the agent's __init__ merges them into pydantic_ai's + # history_processors kwarg. Confirm the custom one survives registration + # by inspecting the config directly (the canonical assertion available + # without reaching into pydantic_ai internals). + assert custom_processor in agent.config.history_processors