11import functools
22import json
33import os
4- from typing import Dict , List
4+ from typing import Any
55
66from mistralai .client import Mistral
7- from mistralai .client .models .assistantmessage import AssistantMessage
8- from mistralai .client .models .function import Function
9- from mistralai .client .models .toolmessage import ToolMessage
10- from mistralai .client .models .usermessage import UserMessage
7+ from mistralai .client .models import (
8+ AssistantMessage ,
9+ ChatCompletionRequestMessage ,
10+ Function ,
11+ Tool ,
12+ ToolMessage ,
13+ UserMessage ,
14+ )
1115
1216# Assuming we have the following data
13- data = {
17+ data : dict [ str , list [ Any ]] = {
1418 "transaction_id" : ["T1001" , "T1002" , "T1003" , "T1004" , "T1005" ],
1519 "customer_id" : ["C001" , "C002" , "C003" , "C002" , "C001" ],
1620 "payment_amount" : [125.50 , 89.99 , 120.00 , 54.30 , 210.20 ],
2529}
2630
2731
28- def retrieve_payment_status (data : Dict [str , List ], transaction_id : str ) -> str :
32+ def retrieve_payment_status (data : dict [str , list [ Any ] ], transaction_id : str ) -> str :
2933 for i , r in enumerate (data ["transaction_id" ]):
3034 if r == transaction_id :
3135 return json .dumps ({"status" : data ["payment_status" ][i ]})
32- else :
33- return json .dumps ({"status" : "Error - transaction id not found" })
36+ return json .dumps ({"status" : "Error - transaction id not found" })
3437
3538
36- def retrieve_payment_date (data : Dict [str , List ], transaction_id : str ) -> str :
39+ def retrieve_payment_date (data : dict [str , list [ Any ] ], transaction_id : str ) -> str :
3740 for i , r in enumerate (data ["transaction_id" ]):
3841 if r == transaction_id :
3942 return json .dumps ({"date" : data ["payment_date" ][i ]})
40- else :
41- return json .dumps ({"status" : "Error - transaction id not found" })
43+ return json .dumps ({"status" : "Error - transaction id not found" })
4244
4345
4446names_to_functions = {
4547 "retrieve_payment_status" : functools .partial (retrieve_payment_status , data = data ),
4648 "retrieve_payment_date" : functools .partial (retrieve_payment_date , data = data ),
4749}
4850
49- tools = [
50- {
51- "type" : "function" ,
52- "function" : Function (
51+ tools : list [Tool ] = [
52+ Tool (
53+ function = Function (
5354 name = "retrieve_payment_status" ,
5455 description = "Get payment status of a transaction id" ,
5556 parameters = {
@@ -63,10 +64,9 @@ def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
6364 },
6465 },
6566 ),
66- },
67- {
68- "type" : "function" ,
69- "function" : Function (
67+ ),
68+ Tool (
69+ function = Function (
7070 name = "retrieve_payment_date" ,
7171 description = "Get payment date of a transaction id" ,
7272 parameters = {
@@ -80,36 +80,35 @@ def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
8080 },
8181 },
8282 ),
83- } ,
83+ ) ,
8484]
8585
8686api_key = os .environ ["MISTRAL_API_KEY" ]
8787model = "mistral-small-latest"
8888
8989client = Mistral (api_key = api_key )
9090
91- messages = [UserMessage (content = "What's the status of my transaction?" )]
91+ messages : list [ChatCompletionRequestMessage ] = [
92+ UserMessage (content = "What's the status of my transaction?" )
93+ ]
9294
93- response = client .chat .complete (
94- model = model , messages = messages , tools = tools , temperature = 0
95- )
95+ response = client .chat .complete (model = model , messages = messages , tools = tools , temperature = 0 )
9696
9797print (response .choices [0 ].message .content )
9898
9999messages .append (AssistantMessage (content = response .choices [0 ].message .content ))
100100messages .append (UserMessage (content = "My transaction ID is T1001." ))
101101
102- response = client .chat .complete (
103- model = model , messages = messages , tools = tools , temperature = 0
104- )
102+ response = client .chat .complete (model = model , messages = messages , tools = tools , temperature = 0 )
105103
106- tool_call = response .choices [0 ].message .tool_calls [0 ]
104+ tool_calls = response .choices [0 ].message .tool_calls
105+ if not tool_calls :
106+ raise RuntimeError ("Expected tool calls" )
107+ tool_call = tool_calls [0 ]
107108function_name = tool_call .function .name
108- function_params = json .loads (tool_call .function .arguments )
109+ function_params = json .loads (str ( tool_call .function .arguments ) )
109110
110- print (
111- f"calling function_name: { function_name } , with function_params: { function_params } "
112- )
111+ print (f"calling function_name: { function_name } , with function_params: { function_params } " )
113112
114113function_result = names_to_functions [function_name ](** function_params )
115114
@@ -128,8 +127,6 @@ def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
128127)
129128print (messages )
130129
131- response = client .chat .complete (
132- model = model , messages = messages , tools = tools , temperature = 0
133- )
130+ response = client .chat .complete (model = model , messages = messages , tools = tools , temperature = 0 )
134131
135132print (f"{ response .choices [0 ].message .content } " )
0 commit comments