Skip to content

Commit c9c1b31

Browse files
authored
fix: propagate parameter descriptions in create_tool_call (#472)
_get_function_parameters mutates field_info.description after FieldInfo construction. Pydantic v2 ignores this because _attributes_set is not updated. All docstring-derived parameter descriptions are silently dropped from the generated tool schema. For fresh FieldInfos, pass description to Field() at construction. For existing FieldInfos, override via Annotated stacking (public API).
1 parent eee98b9 commit c9c1b31

2 files changed

Lines changed: 307 additions & 8 deletions

File tree

src/mistralai/extra/run/tools.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import logging
55
from dataclasses import dataclass
6-
from typing import Any, Callable, ForwardRef, Sequence, cast, get_type_hints
6+
from typing import Annotated, Any, Callable, ForwardRef, Sequence, cast, get_type_hints
77

88
import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes
99
from griffe import (
@@ -17,6 +17,7 @@
1717
from opentelemetry.trace import Status, StatusCode
1818
from pydantic import Field, create_model
1919
from pydantic.fields import FieldInfo
20+
from pydantic_core import PydanticUndefined as _PYDANTIC_UNDEFINED
2021

2122
from mistralai.client.models import (
2223
Function,
@@ -95,7 +96,7 @@ def _get_function_parameters(
9596
param_annotations[param.name] = type_hints.get(param.name)
9697

9798
# resolve all params into Field and create the parameters schema
98-
fields: dict[str, tuple[type, FieldInfo]] = {}
99+
fields: dict[str, Any] = {}
99100
for p in params_from_sig:
100101
default = p.default if p.default is not inspect.Parameter.empty else ...
101102
annotation = (
@@ -127,15 +128,25 @@ def _get_function_parameters(
127128
if isinstance(annotation, ForwardRef):
128129
annotation = param_annotations[p.name]
129130

130-
# no Field
131+
description = param_descriptions[p.name] or None
132+
131133
if field_info is None:
132134
if default is ...:
133-
field_info = Field()
135+
field_info = Field(description=description)
134136
else:
135-
field_info = Field(default=default)
136-
137-
field_info.description = param_descriptions[p.name]
138-
fields[p.name] = (cast(type, annotation), field_info)
137+
field_info = Field(default=default, description=description)
138+
fields[p.name] = (cast(type, annotation), field_info)
139+
elif description:
140+
typed = Annotated[ # type: ignore[valid-type]
141+
cast(type, annotation), field_info, Field(description=description)
142+
]
143+
raw_default = field_info.default
144+
if raw_default is not _PYDANTIC_UNDEFINED:
145+
fields[p.name] = (typed, raw_default)
146+
else:
147+
fields[p.name] = (typed, ...)
148+
else:
149+
fields[p.name] = (cast(type, annotation), field_info)
139150

140151
schema = create_model("_", **fields).model_json_schema() # type: ignore[call-overload]
141152
schema.pop("title", None)
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
"""Unit tests for create_tool_call parameter description propagation.
2+
3+
Validates that parameter descriptions from docstrings and Annotated[T, Field(...)]
4+
annotations correctly appear in the JSON schema produced by create_tool_call().
5+
6+
This is a regression test for a Pydantic v2 bug where post-construction mutation
7+
of FieldInfo.description is silently ignored by model_json_schema().
8+
9+
Fixtures are defined inline so each test is self-contained.
10+
"""
11+
12+
import unittest
13+
from typing import Annotated, Optional
14+
15+
from pydantic import Field
16+
17+
from ..run.tools import create_tool_call
18+
19+
20+
def _props(func):
21+
"""Shorthand: create a tool call and return its parameter properties."""
22+
return create_tool_call(func).function.parameters["properties"]
23+
24+
25+
class TestCreateToolCallDescriptions(unittest.TestCase):
26+
"""Descriptions from docstrings must appear in the generated JSON schema."""
27+
28+
# -- Docstring descriptions (Path 3: no existing FieldInfo) ----------------
29+
30+
def test_required_param_gets_docstring_description(self):
31+
def search(query: str) -> str:
32+
"""Search the web.
33+
34+
Args:
35+
query: The search query to execute.
36+
"""
37+
return ""
38+
39+
props = _props(search)
40+
self.assertEqual(props["query"]["description"], "The search query to execute.")
41+
42+
def test_optional_param_with_default_gets_docstring_description(self):
43+
def search(query: str, limit: int = 10) -> str:
44+
"""Search the web.
45+
46+
Args:
47+
query: The search query.
48+
limit: Maximum number of results.
49+
"""
50+
return ""
51+
52+
props = _props(search)
53+
self.assertEqual(props["limit"]["description"], "Maximum number of results.")
54+
self.assertEqual(props["limit"]["default"], 10)
55+
56+
def test_multiple_params_all_get_descriptions(self):
57+
def fetch(url: str, timeout: int = 30, verbose: bool = False) -> str:
58+
"""Fetch a URL.
59+
60+
Args:
61+
url: The URL to fetch.
62+
timeout: Request timeout in seconds.
63+
verbose: Enable verbose logging.
64+
"""
65+
return ""
66+
67+
props = _props(fetch)
68+
self.assertEqual(props["url"]["description"], "The URL to fetch.")
69+
self.assertEqual(props["timeout"]["description"], "Request timeout in seconds.")
70+
self.assertEqual(props["verbose"]["description"], "Enable verbose logging.")
71+
72+
# -- Annotated + docstring (Path 2: existing FieldInfo) --------------------
73+
74+
def test_annotated_field_description_overridden_by_docstring(self):
75+
def search(query: Annotated[str, Field(description="original")]) -> str:
76+
"""Search.
77+
78+
Args:
79+
query: From docstring.
80+
"""
81+
return ""
82+
83+
props = _props(search)
84+
self.assertEqual(props["query"]["description"], "From docstring.")
85+
86+
def test_annotated_field_description_preserved_when_no_docstring_entry(self):
87+
"""When the docstring has no Args entry for a param, the Field(description=...)
88+
from Annotated must be preserved, not clobbered with empty string."""
89+
90+
def search(query: Annotated[str, Field(description="keep me")]) -> str:
91+
"""Search the web."""
92+
return ""
93+
94+
props = _props(search)
95+
self.assertEqual(props["query"]["description"], "keep me")
96+
97+
def test_annotated_field_constraints_preserved_with_docstring(self):
98+
def count(n: Annotated[int, Field(ge=0, le=100)]) -> str:
99+
"""Count items.
100+
101+
Args:
102+
n: Number of items.
103+
"""
104+
return ""
105+
106+
props = _props(count)
107+
self.assertEqual(props["n"]["description"], "Number of items.")
108+
self.assertEqual(props["n"]["minimum"], 0)
109+
self.assertEqual(props["n"]["maximum"], 100)
110+
111+
def test_annotated_field_constraints_preserved_without_docstring_entry(self):
112+
def count(
113+
n: Annotated[int, Field(ge=0, le=100, description="original")],
114+
) -> str:
115+
"""Count items."""
116+
return ""
117+
118+
props = _props(count)
119+
self.assertEqual(props["n"]["description"], "original")
120+
self.assertEqual(props["n"]["minimum"], 0)
121+
self.assertEqual(props["n"]["maximum"], 100)
122+
123+
# -- Field as default value (Path 1: isinstance(default, FieldInfo)) -------
124+
125+
def test_field_default_value_with_docstring(self):
126+
def search(query: str, limit: int = Field(default=10, ge=1)) -> str:
127+
"""Search.
128+
129+
Args:
130+
query: The query.
131+
limit: Max results.
132+
"""
133+
return ""
134+
135+
props = _props(search)
136+
self.assertEqual(props["limit"]["description"], "Max results.")
137+
self.assertEqual(props["limit"]["default"], 10)
138+
self.assertEqual(props["limit"]["minimum"], 1)
139+
140+
def test_field_default_value_without_docstring_entry(self):
141+
"""Field(default=..., ge=...) without a docstring entry should preserve
142+
constraints and not inject a spurious empty description."""
143+
144+
def search(query: str, limit: int = Field(default=10, ge=1)) -> str:
145+
"""Search.
146+
147+
Args:
148+
query: The query.
149+
"""
150+
return ""
151+
152+
props = _props(search)
153+
self.assertEqual(props["limit"]["default"], 10)
154+
self.assertEqual(props["limit"]["minimum"], 1)
155+
156+
# -- Edge cases ------------------------------------------------------------
157+
158+
def test_undocumented_param_has_no_description_key(self):
159+
"""Params without any docstring entry or Field description should not
160+
have a description key in the schema (not even an empty string)."""
161+
162+
def search(query: str) -> str:
163+
"""Search the web."""
164+
return ""
165+
166+
props = _props(search)
167+
self.assertIn("query", props)
168+
self.assertNotIn("description", props["query"])
169+
170+
def test_required_params_in_required_list(self):
171+
def search(query: str, limit: int = 10) -> str:
172+
"""Search.
173+
174+
Args:
175+
query: The query.
176+
limit: Max results.
177+
"""
178+
return ""
179+
180+
tool = create_tool_call(search)
181+
required = tool.function.parameters.get("required", [])
182+
self.assertIn("query", required)
183+
self.assertNotIn("limit", required)
184+
185+
def test_optional_type_annotation(self):
186+
def search(query: str, tag: Optional[str] = None) -> str:
187+
"""Search.
188+
189+
Args:
190+
query: The query.
191+
tag: Optional tag filter.
192+
"""
193+
return ""
194+
195+
props = _props(search)
196+
self.assertEqual(props["tag"]["description"], "Optional tag filter.")
197+
198+
def test_list_type_annotation(self):
199+
def search(queries: list[str]) -> str:
200+
"""Batch search.
201+
202+
Args:
203+
queries: List of search queries.
204+
"""
205+
return ""
206+
207+
props = _props(search)
208+
self.assertEqual(props["queries"]["description"], "List of search queries.")
209+
210+
def test_function_level_description(self):
211+
def search(query: str) -> str:
212+
"""Search the web for information.
213+
214+
Args:
215+
query: The search query.
216+
"""
217+
return ""
218+
219+
tool = create_tool_call(search)
220+
self.assertEqual(tool.function.description, "Search the web for information.")
221+
222+
def test_no_docstring_at_all(self):
223+
def search(query: str) -> str:
224+
return ""
225+
226+
tool = create_tool_call(search)
227+
self.assertIsNotNone(tool.function.parameters)
228+
self.assertIn("query", tool.function.parameters["properties"])
229+
230+
def test_shared_field_info_no_cross_contamination(self):
231+
"""Two functions sharing the same FieldInfo instance via Annotated must
232+
not cross-contaminate descriptions."""
233+
234+
shared_field = Field(ge=0)
235+
236+
def func_a(n: Annotated[int, shared_field]) -> str:
237+
"""A.
238+
239+
Args:
240+
n: Description A.
241+
"""
242+
return ""
243+
244+
def func_b(n: Annotated[int, shared_field]) -> str:
245+
"""B.
246+
247+
Args:
248+
n: Description B.
249+
"""
250+
return ""
251+
252+
props_a = _props(func_a)
253+
props_b = _props(func_b)
254+
self.assertEqual(props_a["n"]["description"], "Description A.")
255+
self.assertEqual(props_b["n"]["description"], "Description B.")
256+
# Calling func_a again after func_b must still produce "Description A."
257+
props_a_again = _props(func_a)
258+
self.assertEqual(props_a_again["n"]["description"], "Description A.")
259+
# Original shared instance must be unmodified
260+
self.assertIsNone(shared_field.description)
261+
262+
263+
class TestCreateToolCallRegressionPydanticV2(unittest.TestCase):
264+
"""Regression: post-construction FieldInfo.description mutation is broken in Pydantic v2."""
265+
266+
def test_description_appears_in_schema_not_silently_dropped(self):
267+
"""The original bug: docstring descriptions were silently dropped from the
268+
JSON schema because FieldInfo.description was mutated after construction,
269+
which Pydantic v2 ignores in model_json_schema()."""
270+
271+
def get_weather(city: str, units: str = "celsius") -> str:
272+
"""Get weather for a city.
273+
274+
Args:
275+
city: The city name.
276+
units: Temperature units.
277+
"""
278+
return ""
279+
280+
tool = create_tool_call(get_weather)
281+
props = tool.function.parameters["properties"]
282+
self.assertEqual(props["city"]["description"], "The city name.")
283+
self.assertEqual(props["units"]["description"], "Temperature units.")
284+
self.assertEqual(props["units"]["default"], "celsius")
285+
286+
287+
if __name__ == "__main__":
288+
unittest.main()

0 commit comments

Comments
 (0)