diff --git a/src/tools/registry.py b/src/tools/registry.py new file mode 100644 index 0000000..2cdf1b0 --- /dev/null +++ b/src/tools/registry.py @@ -0,0 +1,109 @@ +"""Generic tool registry — dispatcher for typed agent tools. + +Each tool is a function that takes a single ``StrictModel`` input and returns +a single ``StrictModel`` output. The registry holds a mapping from tool name +to ``(input_schema, callable)`` and provides: + +- ``register(name, input_schema)`` — decorator to register a callable. +- ``dispatch(name, raw_input)`` — validate the dict-shaped ``raw_input`` + against the input schema, call the + tool, return the typed output. + +Layer-wise the registry sits below ``agent`` / ``api`` / ``eval`` (it doesn't +import from them) and above ``models``. Verified by the import-linter +contract in ``pyproject.toml``. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from src.models._base import StrictModel + +ToolFn = Callable[[StrictModel], StrictModel] + + +class UnknownToolError(KeyError): + """Raised when ``dispatch`` is called with an unregistered tool name.""" + + +class Registry: + """Maps a tool name to its input schema and callable implementation.""" + + def __init__(self) -> None: + self._tools: dict[str, tuple[type[StrictModel], ToolFn]] = {} + + def register( + self, + name: str, + input_schema: type[StrictModel], + ) -> Callable[[ToolFn], ToolFn]: + """Register a tool implementation. + + Returns a decorator so callers can use either of: + + @registry.register("echo", EchoToolInput) + def echo_tool(payload: EchoToolInput) -> EchoToolOutput: ... + + registry.register("echo", EchoToolInput)(echo_tool) + """ + + def decorator(fn: ToolFn) -> ToolFn: + if name in self._tools: + msg = f"Tool {name!r} is already registered." + raise ValueError(msg) + self._tools[name] = (input_schema, fn) + return fn + + return decorator + + def dispatch(self, name: str, raw_input: dict[str, Any]) -> StrictModel: + """Validate ``raw_input`` and call the tool. + + Raises ``UnknownToolError`` when *name* isn't registered. Pydantic's + ``ValidationError`` propagates when ``raw_input`` doesn't match the + registered input schema. + """ + if name not in self._tools: + registered = sorted(self._tools) + msg = f"Unknown tool {name!r}. Registered: {registered}" + raise UnknownToolError(msg) + input_schema, fn = self._tools[name] + payload = input_schema.model_validate(raw_input) + return fn(payload) + + def names(self) -> list[str]: + """Return the sorted list of registered tool names.""" + return sorted(self._tools) + + +# Module-global singleton — agent / eval consumers import this directly so +# tools self-register at module load via the decorator below. +registry = Registry() + + +# --------------------------------------------------------------------------- +# Example tool: echo — exercises the layer + demonstrates the contract shape. +# --------------------------------------------------------------------------- + + +class EchoToolInput(StrictModel, strict=True): + """Input contract for the example echo tool.""" + + msg: str + + +class EchoToolOutput(StrictModel, strict=True): + """Output contract for the example echo tool.""" + + echoed: str + + +@registry.register("echo", EchoToolInput) +def echo_tool(payload: StrictModel) -> StrictModel: + """Return the input string wrapped in ``EchoToolOutput``.""" + if not isinstance(payload, EchoToolInput): # pragma: no cover — defensive + msg = f"echo_tool got unexpected payload type: {type(payload)!r}" + raise TypeError(msg) + return EchoToolOutput(echoed=payload.msg) diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 0000000..1994f60 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,60 @@ +"""Tests for ``src.tools.registry`` — happy path, unknown tool, bad input.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from src.tools.registry import ( + EchoToolInput, + EchoToolOutput, + Registry, + UnknownToolError, + echo_tool, + registry, +) + + +def test_module_registry_resolves_echo() -> None: + """The module-global registry has the echo tool wired at import.""" + assert "echo" in registry.names() + + +def test_dispatch_happy_path() -> None: + output = registry.dispatch("echo", {"msg": "hello"}) + assert isinstance(output, EchoToolOutput) + assert output.echoed == "hello" + + +def test_dispatch_unknown_tool_raises() -> None: + with pytest.raises(UnknownToolError, match="Unknown tool 'nope'"): + registry.dispatch("nope", {}) + + +def test_dispatch_rejects_bad_input() -> None: + """Wrong-typed payload triggers Pydantic ValidationError, not a runtime crash.""" + with pytest.raises(ValidationError): + registry.dispatch("echo", {"msg": 123}) # msg must be str + + +def test_dispatch_rejects_unknown_keys() -> None: + """StrictModel input schema rejects unknown keys — extra='forbid' propagates.""" + with pytest.raises(ValidationError): + registry.dispatch("echo", {"msg": "hi", "boom": True}) + + +def test_register_rejects_duplicate_names() -> None: + """Registering twice under the same name is a programmer error.""" + local = Registry() + local.register("echo", EchoToolInput)(echo_tool) + with pytest.raises(ValueError, match="already registered"): + local.register("echo", EchoToolInput)(echo_tool) + + +def test_local_registry_isolation() -> None: + """Multiple registries don't share state.""" + a = Registry() + b = Registry() + a.register("echo", EchoToolInput)(echo_tool) + assert a.names() == ["echo"] + assert b.names() == []