Skip to content

Commit 37ef0e8

Browse files
fix: update code for renamed model classes
- Filter UnknownAgentTool in agent update (context.py) - Update examples for renamed classes: - BatchRequest, UserMessage: import from models module - ClassifierTrainingParametersIn -> ClassifierTrainingParameters - ClassifierJobOut -> ClassifierFineTuningJob - Add type narrowing for ClassifierFineTuningJobDetails
1 parent c74b056 commit 37ef0e8

File tree

7 files changed

+22
-11
lines changed

7 files changed

+22
-11
lines changed

examples/mistral/audio/chat_streaming.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import os
44

5-
from mistralai.client import Mistral, File
5+
from mistralai.client import Mistral
6+
from mistralai.client.models import File
67
from mistralai.client.models import UserMessage
78

89

examples/mistral/audio/transcription_async.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import os
44
import asyncio
5-
from mistralai.client import Mistral, File
5+
from mistralai.client import Mistral
6+
from mistralai.client.models import File
67

78

89
async def main():

examples/mistral/audio/transcription_diarize_async.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import os
44
import asyncio
55
import pathlib
6-
from mistralai.client import Mistral, File
6+
from mistralai.client import Mistral
7+
from mistralai.client.models import File
78

89
fixture_dir = pathlib.Path(__file__).parents[2] / "fixtures"
910

examples/mistral/audio/transcription_stream_async.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import asyncio
33
import os
44

5-
from mistralai.client import Mistral, File
5+
from mistralai.client import Mistral
6+
from mistralai.client.models import File
67

78

89
async def main():

examples/mistral/classifier/async_classifier.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from pprint import pprint
44
import asyncio
5-
from mistralai.client import Mistral, TrainingFile, ClassifierTrainingParametersIn
6-
from mistralai.client.models import ClassifierJobOut
5+
from mistralai.client import Mistral
6+
from mistralai.client.models import ClassifierFineTuningJob, ClassifierFineTuningJobDetails, ClassifierTrainingParameters, TrainingFile
77

88
import os
99

@@ -36,12 +36,12 @@ async def train_classifier(client: Mistral, training_file_ids: list[str]) -> str
3636
TrainingFile(file_id=training_file_id)
3737
for training_file_id in training_file_ids
3838
],
39-
hyperparameters=ClassifierTrainingParametersIn(
39+
hyperparameters=ClassifierTrainingParameters(
4040
learning_rate=0.0001,
4141
),
4242
auto_start=True,
4343
)
44-
if not isinstance(job, ClassifierJobOut):
44+
if not isinstance(job, ClassifierFineTuningJob):
4545
print("Unexpected job type returned")
4646
return None
4747

@@ -51,6 +51,8 @@ async def train_classifier(client: Mistral, training_file_ids: list[str]) -> str
5151
while True:
5252
await asyncio.sleep(10)
5353
detailed_job = await client.fine_tuning.jobs.get_async(job_id=job.id)
54+
if not isinstance(detailed_job, ClassifierFineTuningJobDetails):
55+
raise Exception(f"Unexpected job type: {type(detailed_job)}")
5456
if detailed_job.status not in [
5557
"QUEUED",
5658
"STARTED",

examples/mistral/jobs/async_batch_job_chat_completion_inline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from mistralai.client import Mistral, BatchRequest, UserMessage
1+
from mistralai.client import Mistral
2+
from mistralai.client.models import BatchRequest, UserMessage
23
import os
34
import asyncio
45

src/mistralai/extra/run/context.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
create_tool_call,
2323
)
2424
from mistralai.client.models import (
25-
AgentTool,
2625
CompletionArgs,
2726
CompletionArgsTypedDict,
2827
ConversationInputs,
@@ -35,6 +34,8 @@
3534
InputEntries,
3635
MessageInputEntry,
3736
ResponseFormat,
37+
UnknownAgentTool,
38+
UpdateAgentRequestTool,
3839
)
3940
from mistralai.client.types.basemodel import BaseModel, OptionalNullable, UNSET
4041

@@ -187,8 +188,11 @@ async def prepare_agent_request(self, beta_client: "Beta") -> AgentRequestKwargs
187188
)
188189
agent = await beta_client.agents.get_async(agent_id=self.agent_id)
189190
agent_tools = agent.tools or []
190-
updated_tools: list[AgentTool] = []
191+
updated_tools: list[UpdateAgentRequestTool] = []
191192
for tool in agent_tools:
193+
if isinstance(tool, UnknownAgentTool):
194+
# Skip unknown tools - can't include them in update request
195+
continue
192196
if not isinstance(tool, FunctionTool):
193197
updated_tools.append(tool)
194198
elif tool.function.name in self._callable_tools:

0 commit comments

Comments
 (0)