Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions src/tools/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Generic tool registry — dispatcher for typed agent tools.

Each tool is a function that takes a single ``StrictModel`` input and returns
a single ``StrictModel`` output. The registry holds a mapping from tool name
to ``(input_schema, callable)`` and provides:

- ``register(name, input_schema)`` — decorator to register a callable.
- ``dispatch(name, raw_input)`` — validate the dict-shaped ``raw_input``
against the input schema, call the
tool, return the typed output.

Layer-wise the registry sits below ``agent`` / ``api`` / ``eval`` (it doesn't
import from them) and above ``models``. Verified by the import-linter
contract in ``pyproject.toml``.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import Any

from src.models._base import StrictModel

ToolFn = Callable[[StrictModel], StrictModel]


class UnknownToolError(KeyError):
"""Raised when ``dispatch`` is called with an unregistered tool name."""


class Registry:
"""Maps a tool name to its input schema and callable implementation."""

def __init__(self) -> None:
self._tools: dict[str, tuple[type[StrictModel], ToolFn]] = {}

def register(
self,
name: str,
input_schema: type[StrictModel],
) -> Callable[[ToolFn], ToolFn]:
"""Register a tool implementation.

Returns a decorator so callers can use either of:

@registry.register("echo", EchoToolInput)
def echo_tool(payload: EchoToolInput) -> EchoToolOutput: ...

registry.register("echo", EchoToolInput)(echo_tool)
"""

def decorator(fn: ToolFn) -> ToolFn:
if name in self._tools:
msg = f"Tool {name!r} is already registered."
raise ValueError(msg)
self._tools[name] = (input_schema, fn)
return fn

return decorator

def dispatch(self, name: str, raw_input: dict[str, Any]) -> StrictModel:
"""Validate ``raw_input`` and call the tool.

Raises ``UnknownToolError`` when *name* isn't registered. Pydantic's
``ValidationError`` propagates when ``raw_input`` doesn't match the
registered input schema.
"""
if name not in self._tools:
registered = sorted(self._tools)
msg = f"Unknown tool {name!r}. Registered: {registered}"
raise UnknownToolError(msg)
input_schema, fn = self._tools[name]
payload = input_schema.model_validate(raw_input)
return fn(payload)

def names(self) -> list[str]:
"""Return the sorted list of registered tool names."""
return sorted(self._tools)


# Module-global singleton — agent / eval consumers import this directly so
# tools self-register at module load via the decorator below.
registry = Registry()


# ---------------------------------------------------------------------------
# Example tool: echo — exercises the layer + demonstrates the contract shape.
# ---------------------------------------------------------------------------


class EchoToolInput(StrictModel, strict=True):
"""Input contract for the example echo tool."""

msg: str


class EchoToolOutput(StrictModel, strict=True):
"""Output contract for the example echo tool."""

echoed: str


@registry.register("echo", EchoToolInput)
def echo_tool(payload: StrictModel) -> StrictModel:
"""Return the input string wrapped in ``EchoToolOutput``."""
if not isinstance(payload, EchoToolInput): # pragma: no cover — defensive
msg = f"echo_tool got unexpected payload type: {type(payload)!r}"
raise TypeError(msg)
return EchoToolOutput(echoed=payload.msg)
60 changes: 60 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Tests for ``src.tools.registry`` — happy path, unknown tool, bad input."""

from __future__ import annotations

import pytest
from pydantic import ValidationError

from src.tools.registry import (
EchoToolInput,
EchoToolOutput,
Registry,
UnknownToolError,
echo_tool,
registry,
)


def test_module_registry_resolves_echo() -> None:
"""The module-global registry has the echo tool wired at import."""
assert "echo" in registry.names()


def test_dispatch_happy_path() -> None:
output = registry.dispatch("echo", {"msg": "hello"})
assert isinstance(output, EchoToolOutput)
assert output.echoed == "hello"


def test_dispatch_unknown_tool_raises() -> None:
with pytest.raises(UnknownToolError, match="Unknown tool 'nope'"):
registry.dispatch("nope", {})


def test_dispatch_rejects_bad_input() -> None:
"""Wrong-typed payload triggers Pydantic ValidationError, not a runtime crash."""
with pytest.raises(ValidationError):
registry.dispatch("echo", {"msg": 123}) # msg must be str


def test_dispatch_rejects_unknown_keys() -> None:
"""StrictModel input schema rejects unknown keys — extra='forbid' propagates."""
with pytest.raises(ValidationError):
registry.dispatch("echo", {"msg": "hi", "boom": True})


def test_register_rejects_duplicate_names() -> None:
"""Registering twice under the same name is a programmer error."""
local = Registry()
local.register("echo", EchoToolInput)(echo_tool)
with pytest.raises(ValueError, match="already registered"):
local.register("echo", EchoToolInput)(echo_tool)


def test_local_registry_isolation() -> None:
"""Multiple registries don't share state."""
a = Registry()
b = Registry()
a.register("echo", EchoToolInput)(echo_tool)
assert a.names() == ["echo"]
assert b.names() == []
Loading