Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions akd_ext/agents/_base/pydantic_ai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class MyAgent(PydanticAIBaseAgent[MyIn, MyOut]):
from akd._base.structures import RunContext as AKDRunContext
from akd.agents._base import BaseAgentConfig

from ._capabilities import make_ratio_trimmer
from ._event_translator import pai_event_to_akd_event
from ._tool_adapter import akd_to_pai_tool

Expand All @@ -76,6 +77,24 @@ class PydanticAIBaseAgentConfig(BaseAgentConfig):
"Merged with any capabilities auto-derived from AKD scalar fields."
),
)
history_processors: list[Any] = Field(
default_factory=list,
description=(
"Pydantic AI history processors (callables that shape the message "
"history before each model request). Merged with any processors "
"auto-derived from AKD scalar fields."
),
)

# Override ``BaseAgentConfig.enable_trimming`` (default ``True``) to ``False``:
# the naive ratio-based trimmer we ship in ``_capabilities.make_ratio_trimmer``
# drops arbitrary message-history slices, which can strand a
# ``ToolReturnPart`` that was paired with an earlier ``ToolCallPart`` —
# OpenAI's API rejects histories shaped that way. Consumers who know their
# history doesn't contain tool calls (or who supply their own pairing-aware
# processor via ``history_processors``) can opt in explicitly with
# ``enable_trimming=True``.
enable_trimming: bool = Field(default=False)

# -- Silence AKD-core's litellm-based config validators --------------
# The following validator help for lookups that expect
Expand Down Expand Up @@ -169,6 +188,10 @@ def __init__(self, config: PydanticAIBaseAgentConfig | None = None) -> None:
*self._build_capabilities_from_scalars(),
*self.config.capabilities,
],
history_processors=[
*self._build_history_processors_from_scalars(),
*self.config.history_processors,
],
**extra_kwargs,
)

Expand Down Expand Up @@ -338,6 +361,27 @@ def _build_capabilities_from_scalars(self) -> list[AbstractCapability]:
"""
return []

def _build_history_processors_from_scalars(self) -> list:
"""Derive history processors from AKD scalar config fields.

Current wiring:

- ``enable_trimming`` + ``trim_ratio`` → ``make_ratio_trimmer`` —
a stateless history processor that drops the oldest
``1 - trim_ratio`` fraction of messages on every invocation.
``trim_ratio`` is AKD's *retention* ratio (``0.75`` = keep 75%)
so we invert it for the trimmer's *drop* fraction. Disabled by
default via the ``enable_trimming=False`` override on
``PydanticAIBaseAgentConfig`` — see that field's docstring for
the caveat on tool-call pairing.

Subclasses override to append their own; call ``super()`` first.
"""
procs: list = []
if self.config.enable_trimming:
procs.append(make_ratio_trimmer(1 - self.config.trim_ratio))
return procs

# ── Tool adaptation ──────────────────────────────────────────────────

def _adapt_tools(self, tools: list) -> list:
Expand Down
52 changes: 52 additions & 0 deletions akd_ext/agents/_base/pydantic_ai/_capabilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Custom pydantic_ai capabilities and history processors built from AKD
scalar config fields.

Concrete implementations behind ``PydanticAIBaseAgent``'s
``_build_capabilities_from_scalars`` and
``_build_history_processors_from_scalars`` hooks. Factories here return
ready-to-use pydantic_ai objects (``Hooks`` capabilities or plain
history-processor callables); the base agent constructs them at
``__init__`` time from the relevant ``BaseAgentConfig`` fields.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import Any


def make_ratio_trimmer(trim_ratio: float) -> Callable[[list[Any]], list[Any]]:
"""Return a stateless history processor that drops the oldest
``trim_ratio`` fraction of messages on every invocation.

A ``trim_ratio`` of ``0.3`` drops the oldest 30% of the message list;
``0.0`` is a no-op; ratios ≥ ``1.0`` are rejected (would drop
everything). The first message is preserved so a system-prompt-style
preamble survives trimming.

.. warning::

Naive drop-from-the-middle trimming can strand pydantic_ai
``ToolReturnPart`` messages that were paired with an earlier
``ToolCallPart``. OpenAI's API rejects such histories. This
trimmer is safe only when callers know their history shape
(e.g. no tool calls, or a custom trimmer hook manages pairing).
``PydanticAIBaseAgentConfig`` defaults ``enable_trimming=False``
for exactly that reason — callers opt in.
"""
if not 0 <= trim_ratio < 1:
raise ValueError(f"trim_ratio must be in [0, 1), got {trim_ratio!r}")

def processor(messages: list[Any]) -> list[Any]:
if not messages or trim_ratio == 0:
return messages
head, rest = messages[:1], messages[1:]
drop_count = int(len(rest) * trim_ratio)
return [*head, *rest[drop_count:]]

return processor


__all__ = [
"make_ratio_trimmer",
]
87 changes: 87 additions & 0 deletions tests/agents/test_base_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,90 @@ async def test_existing_akd_tool_is_pai_compatible():
assert isinstance(akd_tool, AKDTool)
# And the input schema reference is intact.
assert akd_tool.input_schema is DummyInputSchema


# ---------------------------------------------------------------------------
# make_ratio_trimmer history processor + wiring
# ---------------------------------------------------------------------------


def test_ratio_trimmer_is_noop_at_zero():
"""``trim_ratio=0`` returns every message unchanged — useful as an
explicit "don't trim" sentinel."""
from akd_ext.agents._base.pydantic_ai._capabilities import make_ratio_trimmer

trimmer = make_ratio_trimmer(0)
msgs = [f"msg-{i}" for i in range(5)]
assert trimmer(msgs) == msgs


def test_ratio_trimmer_drops_oldest_fraction_and_preserves_head():
"""Non-zero ratio drops the oldest ``ratio`` fraction of messages
*after* the first (which is preserved as system-prompt-like head)."""
from akd_ext.agents._base.pydantic_ai._capabilities import make_ratio_trimmer

trimmer = make_ratio_trimmer(0.5) # drop oldest 50% of the tail
msgs = ["system"] + [f"msg-{i}" for i in range(10)]
trimmed = trimmer(msgs)
# Head preserved:
assert trimmed[0] == "system"
# Half of the 10 tail messages dropped from the oldest end:
assert trimmed[1:] == [f"msg-{i}" for i in range(5, 10)]


def test_ratio_trimmer_empty_input_passthrough():
"""Empty list returns empty list regardless of ratio."""
from akd_ext.agents._base.pydantic_ai._capabilities import make_ratio_trimmer

assert make_ratio_trimmer(0.5)([]) == []


def test_ratio_trimmer_rejects_invalid_ratio():
"""Ratio must be in ``[0, 1)`` — negative or ≥1 raises at factory time."""
from akd_ext.agents._base.pydantic_ai._capabilities import make_ratio_trimmer

with pytest.raises(ValueError, match="trim_ratio"):
make_ratio_trimmer(-0.1)
with pytest.raises(ValueError, match="trim_ratio"):
make_ratio_trimmer(1.0)


def test_build_history_processors_from_scalars_off_by_default():
"""``enable_trimming`` defaults to ``False`` on
``PydanticAIBaseAgentConfig``, so the builder emits no processors."""
agent = _EchoAgent(_EchoConfig())
assert agent._build_history_processors_from_scalars() == []


def test_build_history_processors_from_scalars_includes_trimmer_when_enabled():
"""When ``enable_trimming=True``, the builder emits one callable — the
trimmer produced by ``make_ratio_trimmer(1 - trim_ratio)``."""
# trim_ratio=0.5 avoids floating-point drift on (1 - ratio) evaluation.
agent = _EchoAgent(_EchoConfig(enable_trimming=True, trim_ratio=0.5))
procs = agent._build_history_processors_from_scalars()
assert len(procs) == 1
assert callable(procs[0])
# Sanity: the returned trimmer should behave like make_ratio_trimmer(0.5)
# — on a 1+10 message list it drops 50% of the tail (5 messages).
trimmer = procs[0]
result = trimmer(["sys"] + [f"m-{i}" for i in range(10)])
assert result[0] == "sys"
assert len(result) == 1 + 5 # head + 10 - int(10 * 0.5) = 5 kept


def test_config_history_processors_pass_through():
"""Processors the caller supplies via ``config.history_processors`` must
show up in ``_build_history_processors_from_scalars``'s merge path
(i.e. the agent sees both the scalar-derived trimmer and the
config-provided callables)."""

def custom_processor(messages):
return messages # no-op

agent = _EchoAgent(_EchoConfig(history_processors=[custom_processor]))
# The builder itself doesn't return custom processors — only scalar-derived
# ones — but the agent's __init__ merges them into pydantic_ai's
# history_processors kwarg. Confirm the custom one survives registration
# by inspecting the config directly (the canonical assertion available
# without reaching into pydantic_ai internals).
assert custom_processor in agent.config.history_processors