Skip to content

Commit 20207ae

Browse files
fix: fix example types and enable mypy in CI
- Rename azure examples from .py.py to .py - Fix message types in azure and mistral examples - Add type annotations where needed for mypy - Enable mypy for examples in lint_custom_code.sh
1 parent fb15501 commit 20207ae

File tree

8 files changed

+75
-72
lines changed

8 files changed

+75
-72
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os
2+
3+
from mistralai_azure import MistralAzure
4+
from mistralai_azure.models import ChatCompletionRequestMessages, UserMessage
5+
6+
client = MistralAzure(
7+
azure_api_key=os.environ["AZURE_API_KEY"],
8+
azure_endpoint=os.environ["AZURE_ENDPOINT"],
9+
)
10+
11+
messages: list[ChatCompletionRequestMessages] = [
12+
UserMessage(content="What is the capital of France?"),
13+
]
14+
res = client.chat.complete(messages=messages)
15+
print(res.choices[0].message.content)

examples/azure/az_chat_no_streaming.py.py

Lines changed: 0 additions & 16 deletions
This file was deleted.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os
2+
3+
from mistralai_azure import MistralAzure
4+
from mistralai_azure.models import ChatCompletionRequestMessages, UserMessage
5+
6+
client = MistralAzure(
7+
azure_api_key=os.environ["AZURE_API_KEY"],
8+
azure_endpoint=os.environ["AZURE_ENDPOINT"],
9+
)
10+
11+
messages: list[ChatCompletionRequestMessages] = [
12+
UserMessage(content="What is the capital of France?"),
13+
]
14+
res = client.chat.complete(messages=messages)
15+
print(res.choices[0].message.content)

examples/azure/chat_no_streaming.py.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

examples/mistral/chat/chatbot_with_streaming.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import readline
99
import sys
10+
from typing import Any
1011

1112
from mistralai.client import Mistral
1213
from mistralai.client.models import AssistantMessage, SystemMessage, UserMessage
@@ -21,7 +22,7 @@
2122
DEFAULT_TEMPERATURE = 0.7
2223
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
2324
# A dictionary of all commands and their arguments, used for tab completion.
24-
COMMAND_LIST = {
25+
COMMAND_LIST: dict[str, Any] = {
2526
"/new": {},
2627
"/help": {},
2728
"/model": {model: {} for model in MODEL_LIST}, # Nested completions for models
Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
import functools
22
import json
33
import os
4-
from typing import Dict, List
4+
from typing import Any
55

66
from mistralai.client import Mistral
7-
from mistralai.client.models.assistantmessage import AssistantMessage
8-
from mistralai.client.models.function import Function
9-
from mistralai.client.models.toolmessage import ToolMessage
10-
from mistralai.client.models.usermessage import UserMessage
7+
from mistralai.client.models import (
8+
AssistantMessage,
9+
ChatCompletionRequestMessage,
10+
Function,
11+
Tool,
12+
ToolMessage,
13+
UserMessage,
14+
)
1115

1216
# Assuming we have the following data
13-
data = {
17+
data: dict[str, list[Any]] = {
1418
"transaction_id": ["T1001", "T1002", "T1003", "T1004", "T1005"],
1519
"customer_id": ["C001", "C002", "C003", "C002", "C001"],
1620
"payment_amount": [125.50, 89.99, 120.00, 54.30, 210.20],
@@ -25,31 +29,28 @@
2529
}
2630

2731

28-
def retrieve_payment_status(data: Dict[str, List], transaction_id: str) -> str:
32+
def retrieve_payment_status(data: dict[str, list[Any]], transaction_id: str) -> str:
2933
for i, r in enumerate(data["transaction_id"]):
3034
if r == transaction_id:
3135
return json.dumps({"status": data["payment_status"][i]})
32-
else:
33-
return json.dumps({"status": "Error - transaction id not found"})
36+
return json.dumps({"status": "Error - transaction id not found"})
3437

3538

36-
def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
39+
def retrieve_payment_date(data: dict[str, list[Any]], transaction_id: str) -> str:
3740
for i, r in enumerate(data["transaction_id"]):
3841
if r == transaction_id:
3942
return json.dumps({"date": data["payment_date"][i]})
40-
else:
41-
return json.dumps({"status": "Error - transaction id not found"})
43+
return json.dumps({"status": "Error - transaction id not found"})
4244

4345

4446
names_to_functions = {
4547
"retrieve_payment_status": functools.partial(retrieve_payment_status, data=data),
4648
"retrieve_payment_date": functools.partial(retrieve_payment_date, data=data),
4749
}
4850

49-
tools = [
50-
{
51-
"type": "function",
52-
"function": Function(
51+
tools: list[Tool] = [
52+
Tool(
53+
function=Function(
5354
name="retrieve_payment_status",
5455
description="Get payment status of a transaction id",
5556
parameters={
@@ -63,10 +64,9 @@ def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
6364
},
6465
},
6566
),
66-
},
67-
{
68-
"type": "function",
69-
"function": Function(
67+
),
68+
Tool(
69+
function=Function(
7070
name="retrieve_payment_date",
7171
description="Get payment date of a transaction id",
7272
parameters={
@@ -80,36 +80,35 @@ def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
8080
},
8181
},
8282
),
83-
},
83+
),
8484
]
8585

8686
api_key = os.environ["MISTRAL_API_KEY"]
8787
model = "mistral-small-latest"
8888

8989
client = Mistral(api_key=api_key)
9090

91-
messages = [UserMessage(content="What's the status of my transaction?")]
91+
messages: list[ChatCompletionRequestMessage] = [
92+
UserMessage(content="What's the status of my transaction?")
93+
]
9294

93-
response = client.chat.complete(
94-
model=model, messages=messages, tools=tools, temperature=0
95-
)
95+
response = client.chat.complete(model=model, messages=messages, tools=tools, temperature=0)
9696

9797
print(response.choices[0].message.content)
9898

9999
messages.append(AssistantMessage(content=response.choices[0].message.content))
100100
messages.append(UserMessage(content="My transaction ID is T1001."))
101101

102-
response = client.chat.complete(
103-
model=model, messages=messages, tools=tools, temperature=0
104-
)
102+
response = client.chat.complete(model=model, messages=messages, tools=tools, temperature=0)
105103

106-
tool_call = response.choices[0].message.tool_calls[0]
104+
tool_calls = response.choices[0].message.tool_calls
105+
if not tool_calls:
106+
raise RuntimeError("Expected tool calls")
107+
tool_call = tool_calls[0]
107108
function_name = tool_call.function.name
108-
function_params = json.loads(tool_call.function.arguments)
109+
function_params = json.loads(str(tool_call.function.arguments))
109110

110-
print(
111-
f"calling function_name: {function_name}, with function_params: {function_params}"
112-
)
111+
print(f"calling function_name: {function_name}, with function_params: {function_params}")
113112

114113
function_result = names_to_functions[function_name](**function_params)
115114

@@ -128,8 +127,6 @@ def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
128127
)
129128
print(messages)
130129

131-
response = client.chat.complete(
132-
model=model, messages=messages, tools=tools, temperature=0
133-
)
130+
response = client.chat.complete(model=model, messages=messages, tools=tools, temperature=0)
134131

135132
print(f"{response.choices[0].message.content}")

examples/mistral/classifier/async_classifier.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pprint import pprint
44
import asyncio
55
from mistralai.client import Mistral, TrainingFile, ClassifierTrainingParametersIn
6+
from mistralai.client.models import ClassifierJobOut
67

78
import os
89

@@ -26,7 +27,7 @@ async def upload_files(client: Mistral, file_names: list[str]) -> list[str]:
2627
return file_ids
2728

2829

29-
async def train_classifier(client: Mistral,training_file_ids: list[str]) -> str:
30+
async def train_classifier(client: Mistral, training_file_ids: list[str]) -> str | None:
3031
print("Creating job...")
3132
job = await client.fine_tuning.jobs.create_async(
3233
model="ministral-3b-latest",
@@ -40,6 +41,9 @@ async def train_classifier(client: Mistral,training_file_ids: list[str]) -> str:
4041
),
4142
auto_start=True,
4243
)
44+
if not isinstance(job, ClassifierJobOut):
45+
print("Unexpected job type returned")
46+
return None
4347

4448
print(f"Job created ({job.id})")
4549

@@ -62,6 +66,9 @@ async def train_classifier(client: Mistral,training_file_ids: list[str]) -> str:
6266
print("Training failed")
6367
raise Exception(f"Job failed {detailed_job.status}")
6468

69+
if not detailed_job.fine_tuned_model:
70+
print("No fine-tuned model returned")
71+
return None
6572
print(f"Training succeed: {detailed_job.fine_tuned_model}")
6673

6774
return detailed_job.fine_tuned_model

scripts/lint_custom_code.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ else
1111
fi
1212

1313
echo "Running mypy..."
14-
# TODO: Uncomment once the examples are fixed
15-
# uv run mypy examples/ || ERRORS=1
14+
echo "-> running on examples"
15+
uv run mypy examples/ --ignore-missing-imports || ERRORS=1
1616
echo "-> running on extra"
1717
uv run mypy src/mistralai/extra/ || ERRORS=1
1818
echo "-> running on hooks"

0 commit comments

Comments
 (0)