Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "toolguard"
version = "0.2.13"
version = "0.2.14"
description = "Policy adherence code generation for guarding AI agent tools"
readme = "README.md"

Expand All @@ -20,7 +20,6 @@ dependencies = [
"langchain-core>=0.3.72",
"litellm<=1.82.6", # https://github.com/BerriAI/litellm/issues/24512
"markdown>=3.7",
"mellea<0.4.0", # mellea 0.4.0 requires python >=3.11
"pydantic>=2.11.0",
"pytest>=8.3.3",
"pytest-asyncio>=1.3.0",
Expand Down
9 changes: 1 addition & 8 deletions src/toolguard/buildtime/gen_py/gen_toolguards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from pathlib import Path
from typing import Callable, List, Optional

import mellea
from loguru import logger

from toolguard.buildtime.gen_py.domain_from_funcs import generate_domain_from_functions
from toolguard.buildtime.gen_py.domain_from_openapi import generate_domain_from_openapi
from toolguard.buildtime.gen_py.mellea_simple import SimpleBackend
from toolguard.buildtime.gen_py.tool_guard_generator import ToolGuardGenerator
from toolguard.buildtime.llm import I_TG_LLM
from toolguard.buildtime.utils import py, pyright, pytest
Expand Down Expand Up @@ -85,13 +83,8 @@ async def generate_toolguards_from_domain(
if len(spec.policy_items) > 0
]

# mellea_workaround = {"model_options": {"reasoning_effort": "medium"}}#FIXME https://github.com/generative-computing/mellea/issues/270
# kw_args = llm.kw_args
# kw_args.update(mellea_workaround)
mellea_backend = SimpleBackend(llm)
m = mellea.MelleaSession(mellea_backend)
tools_generator = [
ToolGuardGenerator(app_name, tool_policy, py_root, domain, m)
ToolGuardGenerator(app_name, tool_policy, py_root, domain, llm)
for tool_policy in not_empty_specs
]
with py.temp_python_path(py_root):
Expand Down
50 changes: 0 additions & 50 deletions src/toolguard/buildtime/gen_py/mellea_simple.py

This file was deleted.

65 changes: 65 additions & 0 deletions src/toolguard/buildtime/gen_py/prompt_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Lightweight replacement for mellea's @generative decorator.

Builds a prompt from a function's name, signature, docstring, and bound
keyword arguments, then sends it to an I_TG_LLM backend and returns the
raw text response.
"""

import inspect
from typing import Any, Callable, Dict, List

from toolguard.buildtime.llm import I_TG_LLM


def _format_arg(func: Callable, key: str, val: Any) -> str:
"""Format a single argument line like mellea's Arguments component."""
sig = inspect.signature(func)
param = sig.parameters.get(key)
if param and param.annotation is not inspect.Parameter.empty:
param_type = param.annotation
else:
param_type = type(val)

if param_type is str:
display_val = f'"{val!s}"'
else:
display_val = str(val)

return f"- {key}: {display_val} (type: {param_type})"


def build_prompt(func: Callable, **kwargs: Any) -> str:
"""Build the same prompt that mellea's GenerativeSlot + TemplateFormatter produces."""
sig_str = str(inspect.signature(func))
docstring = inspect.getdoc(func) or "No documentation provided."

lines = [
"Your task is to imitate the output of the following function for the given arguments.",
"Reply Nothing else but the output of the function.",
"",
"Function:",
f"def {func.__name__}{sig_str}:",
f' """{docstring}"""',
]

if kwargs:
arg_lines = [_format_arg(func, k, v) for k, v in kwargs.items()]
lines.append("")
lines.append("Arguments:")
lines.extend(arg_lines)

return "\n".join(lines)


async def run_prompt(
llm: I_TG_LLM,
func: Callable,
**kwargs: Any,
) -> str:
"""Build a prompt from *func*'s metadata + *kwargs*, send it to *llm*, return the response."""
prompt = build_prompt(func, **kwargs)
msg: Dict = {
"role": "user",
"content": [{"type": "text", "text": prompt}],
}
return await llm.generate([msg])
48 changes: 42 additions & 6 deletions src/toolguard/buildtime/gen_py/prompts/gen_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

from typing import List

from mellea import generative

from toolguard.buildtime.gen_py.prompt_runner import run_prompt
from toolguard.buildtime.llm import I_TG_LLM
from toolguard.runtime.data_types import Domain, FileTwin, ToolGuardSpecItem


@generative
async def generate_init_tests(
async def _generate_init_tests_template(
fn_src: FileTwin,
policy_item: ToolGuardSpecItem,
domain: Domain,
Expand Down Expand Up @@ -113,8 +112,7 @@ async def test_violation_book_room_in_the_past():
...


@generative
async def improve_tests(
async def _improve_tests_template(
prev_impl: str,
domain: Domain,
policy_item: ToolGuardSpecItem,
Expand All @@ -139,3 +137,41 @@ async def improve_tests(
- You can add import statements, but dont remove them.
"""
...


async def generate_init_tests(
llm: I_TG_LLM,
*,
fn_src: FileTwin,
policy_item: ToolGuardSpecItem,
domain: Domain,
dependent_tool_names: List[str],
) -> str:
return await run_prompt(
llm,
_generate_init_tests_template,
fn_src=fn_src,
policy_item=policy_item,
domain=domain,
dependent_tool_names=dependent_tool_names,
)


async def improve_tests(
llm: I_TG_LLM,
*,
prev_impl: str,
domain: Domain,
policy_item: ToolGuardSpecItem,
review_comments: List[str],
dependent_tool_names: List[str],
) -> str:
return await run_prompt(
llm,
_improve_tests_template,
prev_impl=prev_impl,
domain=domain,
policy_item=policy_item,
review_comments=review_comments,
dependent_tool_names=dependent_tool_names,
)
29 changes: 25 additions & 4 deletions src/toolguard/buildtime/gen_py/prompts/improve_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

from typing import List

from mellea import generative

from toolguard.buildtime.gen_py.prompt_runner import run_prompt
from toolguard.buildtime.llm import I_TG_LLM
from toolguard.runtime.data_types import FileTwin


@generative
async def improve_tool_guard(
async def _improve_tool_guard_template(
policy_txt: str,
dependent_tool_names: List[str],
prev_impl: str,
Expand Down Expand Up @@ -168,3 +167,25 @@ async def airline_cancelled():
```
"""
...


async def improve_tool_guard(
llm: I_TG_LLM,
*,
policy_txt: str,
dependent_tool_names: List[str],
prev_impl: str,
review_comments: List[str],
api: FileTwin,
data_types: FileTwin,
) -> str:
return await run_prompt(
llm,
_improve_tool_guard_template,
policy_txt=policy_txt,
dependent_tool_names=dependent_tool_names,
prev_impl=prev_impl,
review_comments=review_comments,
api=api,
data_types=data_types,
)
26 changes: 22 additions & 4 deletions src/toolguard/buildtime/gen_py/prompts/pseudo_code.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# mypy: ignore-errors


from mellea import generative

from toolguard.buildtime.gen_py.prompt_runner import run_prompt
from toolguard.buildtime.llm import I_TG_LLM
from toolguard.runtime.data_types import FileTwin


@generative
async def tool_policy_pseudo_code(
async def _pseudo_code_template(
policy_txt: str, fn_to_analyze: str, data_types: FileTwin, api: FileTwin
) -> str:
"""
Expand Down Expand Up @@ -169,3 +168,22 @@ def are_relatives(self, person1_id: str, person2_id: str) -> bool: pass
```
"""
...


async def tool_policy_pseudo_code(
llm: I_TG_LLM,
*,
policy_txt: str,
fn_to_analyze: str,
data_types: FileTwin,
api: FileTwin,
model_options: dict | None = None,
) -> str:
return await run_prompt(
llm,
_pseudo_code_template,
policy_txt=policy_txt,
fn_to_analyze=fn_to_analyze,
data_types=data_types,
api=api,
)
9 changes: 4 additions & 5 deletions src/toolguard/buildtime/gen_py/tool_dependencies.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import re
from typing import Any, Dict, Set

from mellea import MelleaSession

from toolguard.buildtime.gen_py import prompts
from toolguard.buildtime.llm import I_TG_LLM
from toolguard.runtime.data_types import Domain

MAX_TRIALS = 3
Expand All @@ -13,12 +12,12 @@ async def tool_dependencies(
policy_txt: str,
tool_signature: str,
domain: Domain,
m: MelleaSession,
llm: I_TG_LLM,
trial=0,
) -> Set[str]:
model_options: Dict[str, Any] = {} # {ModelOption.TEMPERATURE: 0.8}
pseudo_code = await prompts.tool_policy_pseudo_code(
m,
llm,
policy_txt=policy_txt,
fn_to_analyze=tool_signature,
data_types=domain.app_types,
Expand All @@ -30,7 +29,7 @@ async def tool_dependencies(
return fn_names
if trial <= MAX_TRIALS:
# as tool_policy_pseudo_code has some temerature, we retry hoping next time the pseudo code will be correct
return await tool_dependencies(policy_txt, tool_signature, domain, m, trial + 1)
return await tool_dependencies(policy_txt, tool_signature, domain, llm, trial + 1)
raise Exception("Failed to analyze api dependencies")


Expand Down
Loading