-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreact_code_web_search.py
More file actions
304 lines (257 loc) · 12.4 KB
/
react_code_web_search.py
File metadata and controls
304 lines (257 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import argparse
import json
import logging
import os
import subprocess
import asyncio
from datetime import datetime
from typing import List, Dict, Any, Optional
from openai import AsyncOpenAI, OpenAI
from httpx import Client
# --- External Dependencies ---
# Make sure to install: pip install tavily-python
from tavily import TavilyClient
# --- Your existing imports (Adjust paths as necessary) ---
from scripts.async_llm import LLMsConfig
from scripts.async_llm import create_llm_instance
# We don't need 'operator' anymore because we are manually handling the chat loop below
# import workspace.InverseProb.workflows.template.operator as operator
# =============================================================================
# SECTION 1: TOOL DEFINITIONS
# =============================================================================
class ToolRegistry:
def __init__(self, working_file_path, command, working_folder, tavily_api_key):
self.working_file_path = working_file_path
self.command = command
self.working_folder = working_folder
self.tavily_client = TavilyClient(api_key=tavily_api_key) if tavily_api_key else None
async def code_writer_and_tester(self, code_to_write: str) -> str:
"""
Writes Python code to a file and runs the test command.
Returns the execution logs and status.
"""
logging.info(f"TOOL USE: Writing code to {self.working_file_path}...")
try:
# 1. Write the code
with open(self.working_file_path, 'w', encoding='utf-8') as f:
f.write(code_to_write)
# 2. Run the command
logging.info(f"TOOL USE: Running command: {' '.join(self.command)}")
proc = await asyncio.create_subprocess_exec(
*self.command,
cwd=self.working_folder,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=300)
except asyncio.TimeoutError:
proc.kill()
return json.dumps({"status": "FAIL", "error": "Execution timed out after 300s."})
stdout_str = stdout.decode('utf-8', errors='replace')
stderr_str = stderr.decode('utf-8', errors='replace')
# 3. Parse Result
if proc.returncode != 0:
result = {"status": "FAIL", "error": stderr_str, "stdout": stdout_str}
elif "✅ PASS" in stdout_str: # Adjust this string based on your actual test runner output
result = {"status": "SUCCESS", "output": stdout_str}
else:
# Return code 0 but no explicit PASS often means logical failure in some test runners
result = {"status": "FAIL", "error": f"Command finished but 'PASS' not found.\nSTDOUT:\n{stdout_str}\nSTDERR:\n{stderr_str}"}
return json.dumps(result)
except Exception as e:
return json.dumps({"status": "FAIL", "error": f"Tool execution exception: {str(e)}"})
async def web_search(self, query: str) -> str:
"""
Performs a web search using Tavily.
"""
if not self.tavily_client:
return json.dumps({"error": "Tavily API key not configured."})
logging.info(f"TOOL USE: Searching web for: {query}")
try:
# Run in thread to avoid blocking async loop
response = await asyncio.to_thread(
self.tavily_client.search,
query=query,
search_depth="advanced",
max_results=5
)
return json.dumps(response)
except Exception as e:
return json.dumps({"error": f"Search failed: {str(e)}"})
# =============================================================================
# SECTION 2: JSON SCHEMAS FOR LLM
# =============================================================================
TOOLS_SCHEMA = [
{
"type": "function",
"function": {
"name": "code_writer_and_tester",
"description": "Write Python code to solve the problem and immediately test it. Returns SUCCESS or FAIL.",
"parameters": {
"type": "object",
"properties": {
"code_to_write": {
"type": "string",
"description": "The complete python code to be written to the file."
}
},
"required": ["code_to_write"]
}
}
},
{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the internet for documentation, algorithms, or library usage examples.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query."
}
},
"required": ["query"]
}
}
}
]
# =============================================================================
# SECTION 3: UTILS & SETUP
# =============================================================================
def load_data(problem_description_path_json):
with open(problem_description_path_json, mode="r", encoding="utf-8") as file:
return json.load(file)
def parse_args():
parser = argparse.ArgumentParser(description="react_agent_integrated")
parser.add_argument("--working_file_path", type=str, required=True)
parser.add_argument("--command", nargs='*', default=["python", "test_script.py"])
parser.add_argument("--working_folder_location", type=str, default="./")
parser.add_argument("--problem_description_path_json", type=str, default="problems.json")
parser.add_argument("--max_rounds", type=int, default=20)
parser.add_argument("--model_name", type=str, default="deepseek-r1-250528")
parser.add_argument("--log_dir", type=str, default="./logs")
parser.add_argument("--log_name", type=str, default="react_run")
parser.add_argument("--tavily_api_key", type=str, default=os.getenv("TAVILY_API_KEY"), help="API Key for Tavily Search")
return parser.parse_args()
def setup_logging(log_dir, log_name):
os.makedirs(log_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file_path = os.path.join(log_dir, f"{log_name}_{timestamp}.log")
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(log_file_path)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
return logger
# =============================================================================
# SECTION 4: MAIN REACT LOOP
# =============================================================================
async def run_agent_loop(llm_client, model_name,tools_registry, question, max_rounds):
# System Prompt
system_prompt = """You are an expert Computational Imaging Assistant.
You have access to two tools:
1. 'web_search': Use this to find documentation, math formulas, or existing implementations.
2. 'code_writer_and_tester': Use this to write code and run the local test environment.
Your Process:
- Analyze the problem.
- If you are unsure about the algorithm or library, use 'web_search'.
- It is recommended to first search concerning papers on the problem and learn from their methods.
- Once you have a plan, use 'code_writer_and_tester' to implement it.
- If the test FAILS, analyze the error message, fix the code, and try again.
- If the same error are happening repeatedly, you should use 'web_search' for certain solution to resolve the error rather than aimlessly trial.
- If the test SUCCESS, you are done.
Always output the tool call strictly in the format requested by the system."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Solve this problem: {question}"}
]
for i in range(max_rounds):
logging.info(f"--- Round {i+1} ---")
logging.info("the model input")
logging.info(messages)
# 1. Call LLM
try:
# Note: The specific syntax here depends on your 'create_llm_instance' wrapper.
# I am assuming standard OpenAI-like client usage based on your previous requests.
response = llm_client.chat.completions.create(
model=model_name, # Or passed from args
messages=messages,
tools=TOOLS_SCHEMA,
tool_choice="auto"
)
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
except Exception as e:
logging.error(f"LLM Call Failed: {e}")
return
# Add assistant response to history
messages.append(response_message)
# 2. Check for Tool Calls
if tool_calls:
for tool_call in tool_calls:
func_name = tool_call.function.name
func_args = json.loads(tool_call.function.arguments)
logging.info(f"Model requested tool: {func_name}")
tool_output = ""
# Dispatch to Tool Registry
if func_name == "web_search":
tool_output = await tools_registry.web_search(func_args.get("query"))
elif func_name == "code_writer_and_tester":
tool_output = await tools_registry.code_writer_and_tester(func_args.get("code_to_write"))
# Check if we solved it
tool_output_json = json.loads(tool_output)
if tool_output_json.get("status") == "SUCCESS":
logging.info("Problem Solved Successfully!")
return # EXIT LOOP
# 3. Feed output back to LLM
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"name": func_name,
"content": tool_output
})
else:
# Model didn't call a tool, just talked.
logging.info(f"Model message: {response_message.content}")
# If the model thinks it's done but didn't trigger success via tool, you might want to prompt it to test.
messages.append({
"role": "user",
"content": "Please continue. If you have written code, make sure to use the 'code_writer_and_tester' tool to verify it."
})
logging.warning("Max rounds reached without success.")
if __name__ == "__main__":
args = parse_args()
setup_logging(args.log_dir, args.log_name)
# Initialize Tools
registry = ToolRegistry(
working_file_path=args.working_file_path,
command=args.command,
working_folder=args.working_folder_location,
tavily_api_key=args.tavily_api_key
)
# Initialize LLM
# Assuming create_llm_instance returns an OpenAI-compatible AsyncClient
models_config = LLMsConfig.default()
config_entry = models_config.get(args.model_name)
# client = create_llm_instance(config_entry)
client = OpenAI(api_key=config_entry.key, base_url=config_entry.base_url)
# Load Data
problem_description = load_data(args.problem_description_path_json)
# Run Loop
for question_id, question in enumerate(problem_description):
asyncio.run(run_agent_loop(client, args.model_name,registry, problem_description[question], args.max_rounds))
# if isinstance(problem_description, dict):
# # If json is a dict of questions
# for q_id, question in problem_description.items():
# asyncio.run(run_agent_loop(client, registry, question, args.max_rounds))
# elif isinstance(problem_description, list):
# # If json is a list of questions
# for i, question in enumerate(problem_description):
# asyncio.run(run_agent_loop(client, args.model_name,registry, question, args.max_rounds))