Skip to content

Commit 793c3ac

Browse files
authored
fix(extra): return ResponseFormatTypedDict from response_format_from_pydantic_model (#438)
* fix(extra): return ResponseFormatTypedDict from response_format_from_pydantic_model response_format_from_pydantic_model now returns a ResponseFormatTypedDict (plain dict) instead of a ResponseFormat model instance. This fixes the type mismatch when passing the result to the Azure SDK, which expects its own ResponseFormat class, and avoids the schema alias data loss issue. Fixes AIR-143 / GitHub #367 * fix(examples): use ChatCompletionRequestTools1 type for tools list Fix mypy list invariance error introduced by SDK 2.1.0 regen where list[Tool] no longer satisfies the broadened tools union type.
1 parent 86e592f commit 793c3ac

5 files changed

Lines changed: 36 additions & 28 deletions

File tree

examples/mistral/chat/function_calling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mistralai.client.models import (
88
AssistantMessage,
99
ChatCompletionRequestMessage,
10+
ChatCompletionRequestTools1,
1011
Function,
1112
Tool,
1213
ToolMessage,
@@ -48,7 +49,7 @@ def retrieve_payment_date(data: dict[str, list[Any]], transaction_id: str) -> st
4849
"retrieve_payment_date": functools.partial(retrieve_payment_date, data=data),
4950
}
5051

51-
tools: list[Tool] = [
52+
tools: list[ChatCompletionRequestTools1] = [
5253
Tool(
5354
function=Function(
5455
name="retrieve_payment_status",

examples/mistral/jobs/async_jobs_ocr_batch_annotation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ def create_ocr_batch_request(custom_id: str, document_url: str) -> dict:
2929
"custom_id": custom_id,
3030
"body": {
3131
"document": {"type": "document_url", "document_url": document_url},
32-
"document_annotation_format": response_format.model_dump(
33-
by_alias=True, exclude_none=True
34-
),
32+
"document_annotation_format": response_format,
3533
"pages": [0, 1, 2, 3, 4, 5, 6, 7],
3634
"include_image_base64": False,
3735
},

src/mistralai/extra/run/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ async def prepare_model_request(
243243
def response_format(self) -> ResponseFormat:
244244
if not self.output_format:
245245
raise RunException("No response format exist for the current RunContext.")
246-
return response_format_from_pydantic_model(self.output_format)
246+
return ResponseFormat.model_validate(
247+
response_format_from_pydantic_model(self.output_format)
248+
)
247249

248250

249251
async def _validate_run(

src/mistralai/extra/tests/test_utils.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
)
66
from pydantic import BaseModel, Field, ValidationError
77

8-
from mistralai.client.models import ResponseFormat, JSONSchema
9-
from mistralai.client.types.basemodel import Unset
10-
118
import unittest
129

1310

@@ -55,15 +52,14 @@ class MathDemonstration(BaseModel):
5552
mathdemo_strict_schema["$defs"]["Explanation"]["additionalProperties"] = False # type: ignore
5653
mathdemo_strict_schema["additionalProperties"] = False
5754

58-
mathdemo_response_format = ResponseFormat(
59-
type="json_schema",
60-
json_schema=JSONSchema(
61-
name="MathDemonstration",
62-
schema_definition=mathdemo_strict_schema,
63-
description=Unset(),
64-
strict=True,
65-
),
66-
)
55+
mathdemo_response_format = {
56+
"type": "json_schema",
57+
"json_schema": {
58+
"name": "MathDemonstration",
59+
"schema": mathdemo_strict_schema,
60+
"strict": True,
61+
},
62+
}
6763

6864

6965
class TestResponseFormat(unittest.TestCase):
@@ -220,10 +216,10 @@ class ModelWithConstraints(BaseModel):
220216
# Should not raise ValueError
221217
result = response_format_from_pydantic_model(ModelWithConstraints)
222218

223-
# Verify it returns a valid ResponseFormat
224-
self.assertIsInstance(result, ResponseFormat)
225-
self.assertEqual(result.type, "json_schema")
226-
self.assertIsNotNone(result.json_schema)
219+
# Verify it returns a valid response format dict
220+
self.assertIsInstance(result, dict)
221+
self.assertEqual(result.get("type"), "json_schema")
222+
self.assertIsNotNone(result.get("json_schema"))
227223

228224
def test_rec_strict_json_schema_with_invalid_type(self):
229225
"""Test that rec_strict_json_schema raises ValueError for truly invalid types."""

src/mistralai/extra/utils/response_format.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
1-
from typing import Any, TypeVar
1+
from typing import Any, TypeVar, cast
22

33
from pydantic import BaseModel
4-
from mistralai.client.models import JSONSchema, ResponseFormat
4+
from mistralai.client.models import ResponseFormatTypedDict
55
from ._pydantic_helper import rec_strict_json_schema
66

77
CustomPydanticModel = TypeVar("CustomPydanticModel", bound=BaseModel)
88

99

1010
def response_format_from_pydantic_model(
1111
model: type[CustomPydanticModel],
12-
) -> ResponseFormat:
13-
"""Generate a strict JSON schema from a pydantic model."""
12+
) -> ResponseFormatTypedDict:
13+
"""Generate a strict JSON schema response format from a pydantic model.
14+
15+
Returns a TypedDict compatible with both the main SDK's and Azure SDK's
16+
ResponseFormat / ResponseFormatTypedDict.
17+
"""
1418
model_schema = rec_strict_json_schema(model.model_json_schema())
15-
json_schema = JSONSchema.model_validate(
16-
{"name": model.__name__, "schema": model_schema, "strict": True}
19+
return cast(
20+
ResponseFormatTypedDict,
21+
{
22+
"type": "json_schema",
23+
"json_schema": {
24+
"name": model.__name__,
25+
"schema": model_schema,
26+
"strict": True,
27+
},
28+
},
1729
)
18-
return ResponseFormat(type="json_schema", json_schema=json_schema)
1930

2031

2132
def pydantic_model_from_json(

0 commit comments

Comments
 (0)