-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathuser_input
More file actions
130 lines (103 loc) · 3.95 KB
/
user_input
File metadata and controls
130 lines (103 loc) · 3.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os, getpass
from langchain_openai import ChatOpenAI
from pprint import pprint
from langchain_core.messages import SystemMessage, HumanMessage, RemoveMessage
from IPython.display import Image, display
from langgraph.graph import MessagesState
from langgraph.graph import StateGraph, START, END
import sqlite3
from IPython.display import Image, display
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, START
from dotenv import load_dotenv
from langgraph.prebuilt import tools_condition, ToolNode
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ValueError("API key not found. Ensure your .env file is set correctly.")
llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, openai_api_key=openai_api_key)
#Breakpoints
#add tools
def multiply(a: int, b: int) -> int:
"""Multiply a and b.
Args:
a: first int
b: second int
"""
return a * b
# This will be a tool
def add(a: int, b: int) -> int:
"""Adds a and b.
Args:
a: first int
b: second int
"""
return a + b
tools = [multiply, add]
llm_with_tools = llm.bind_tools(tools)
sys_msg = SystemMessage(content="You are a helpful assistant tasked with performing arithmetic on a set of inputs.")
#node
def assistant(state: MessagesState):
return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
# If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
# If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
tools_condition,
)
builder.add_edge("tools", "assistant")
memory = MemorySaver()
#graph = builder.compile(interrupt_before=["tools"], checkpointer=memory)
#Changing the interrupt before to assistant
graph = builder.compile(interrupt_before=["assistant"], checkpointer=memory)
initial_input = {"messages": HumanMessage(content="Multiply 2 and 3")}
thread = {"configurable": {"thread_id": "1"}}
# for event in graph.stream(initial_input, thread, stream_mode="values"):
# event['messages'][-1].pretty_print()
# state = graph.get_state(thread)
# print(state.next)
# for event in graph.stream(None, thread, stream_mode="values"):
# event['messages'][-1].pretty_print()
thread2 = {"configurable": {"thread_id":"2"}}
# Run the graph until the first interruption
for event in graph.stream(initial_input, thread2, stream_mode="values"):
event['messages'][-1].pretty_print()
#user_approval = input("Do you want to call the tool?")
#Check approval
# if user_approval.lower() == "yes":
# for event in graph.stream(None, thread2, stream_mode="values"):
# event['messages'][-1].pretty_print()
# else:
# print("Operation cancelled by user.")
#Interupting before assitant and edditing message
graph.update_state(thread2, {"messages": [HumanMessage(content="No, actually multiply 3 and 3!")]})
new_state = graph.get_state(thread2).values
#printing the human messages
for m in new_state['messages']:
m.pretty_print()
#Printing the output
for event in graph.stream(None, thread2, stream_mode="values"):
event["messages"][-1].pretty_print()
#Replaying states
all_states = [s for s in graph.get_state_history(thread2)]
print("length:", len(all_states))
to_replay = all_states[-1]
print(to_replay.values)
print(to_replay.next)
for event in graph.stream(None, to_replay.config, stream_mode="values"):
event['messages'][-1].pretty_print()
#forking
to_fork = all_states[-1]
to_fork.values["messages"]
print(to_fork.config)
#modifying the state at this checkpoint
fork_config = graph.update_state(
to_fork.config,
{"messages": [HumanMessage(content='Multiply 2 and 3',
id=to_fork.values["messages"][0].id)]},
)
print(fork_config)