-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreact_code_langchain.py
More file actions
454 lines (374 loc) · 17 KB
/
react_code_langchain.py
File metadata and controls
454 lines (374 loc) · 17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
import logging
import asyncio
import os
import subprocess
import json
import argparse
from typing import Type, Dict, Any, List
from pydantic import BaseModel, Field
from datetime import datetime
# --- LangChain Imports ---
from langchain.agents import AgentExecutor, create_react_agent
from langchain.prompts import ChatPromptTemplate
from langchain_core.tools import BaseTool
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_openai import ChatOpenAI
# NEW - Expects JSON format
from langchain.agents import create_json_chat_agent
def load_data(problem_description_path_json, specific_indices = None):
data = []
with open(problem_description_path_json, mode="r", encoding="utf-8") as file:
data = json.load(file)
if specific_indices is not None:
filtered_data = [data[i] for i in specific_indices if i < len(data)]
return filtered_data
return data
# --- Setup Logging ---
# Basic logging setup. In a real app, you might use a file handler.
def setup_logging(log_dir, log_name):
"""Configures the logging system."""
# Ensure the log directory exists
os.makedirs(log_dir, exist_ok=True)
# Create a timestamped log filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file_name = f"{log_name}_{timestamp}.log"
log_file_path = os.path.join(log_dir, log_file_name)
# Get the root logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the minimum level to log
# Create a formatter
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
# Create File Handler (to write to file)
file_handler = logging.FileHandler(log_file_path)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# Create Console Handler (to print to console)
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
logger.info(f"Logging initialized. Log file: {log_file_path}")
return logger
# --- Tool Definition ---
# First, we must define the *inputs* for our custom tool using Pydantic.
# This tells the agent WHAT arguments it can pass to the tool.
class CodeTestInput(BaseModel):
code: str = Field(description="The complete Python code to be written to a file and tested.")
# --- 2. Create the Tool Class ---
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.agents import AgentAction, AgentFinish
class AgentLoggingCallback(BaseCallbackHandler):
"""
Custom Callback to pipe Agent actions and tool outputs
to the standard Python logging module.
"""
def __init__(self, logger):
self.logger = logger
def on_agent_action(self, action: AgentAction, **kwargs):
"""Run on agent action (before the tool is executed)."""
log_msg = f"\n[AGENT ACTION]\nTool: {action.tool}\nInput: {action.tool_input}\nLog: {action.log}\n"
self.logger.info(log_msg)
def on_tool_end(self, output: str, **kwargs):
"""Run when tool ends running."""
# Clean up output if it's too long for logs
clean_output = str(output)[:1000] + "..." if len(str(output)) > 1000 else str(output)
log_msg = f"\n[TOOL OUTPUT]\n{clean_output}\n"
self.logger.info(log_msg)
def on_agent_finish(self, finish: AgentFinish, **kwargs):
"""Run on agent end."""
log_msg = f"\n[AGENT FINISH]\nReturn: {finish.return_values}\nLog: {finish.log}\n"
self.logger.info(log_msg)
def on_llm_error(self, error: BaseException, **kwargs):
"""Run when LLM errors."""
self.logger.error(f"[LLM ERROR] {error}")
def on_tool_error(self, error: BaseException, **kwargs):
"""Run when tool errors."""
self.logger.error(f"[TOOL ERROR] {error}")
class CodeWriterAndTesterTool(BaseTool):
"""
A tool that writes Python code to a specified file and then executes
a shell command (e.g., a test runner like pytest) in a specified directory.
It returns the execution status ('SUCCESS' or 'FAIL') and any errors.
"""
name: str = "code_writer_and_tester"
description: str = (
"Writes Python code to a file and runs a test command. "
"Use this to write code and immediately test its execution. "
"Returns a dictionary with 'status' ('SUCCESS' or 'FAIL') and 'error' (any error message)."
)
args_schema: Type[BaseModel] = CodeTestInput
# These are the arguments your original function needed
working_file_path: str
command: List[str]
working_folder_location: str
def _parse_exec_result(self, exec_result: subprocess.CompletedProcess) -> Dict[str, Any]:
"""Helper method to parse the result from subprocess.run."""
# 1. Check if the script itself crashed (e.g., syntax error)
if exec_result.returncode != 0:
result = {"status": "FAIL", "error": exec_result.stderr}
# 2. Check if the tests passed (based on the "✅ PASS" convention)
elif "✅ PASS" in exec_result.stdout:
result = {"status": "SUCCESS", "error": None, "output": exec_result.stdout}
# 3. If it ran (returncode 0) but didn't pass, it's a test failure
else:
# Combine stdout and stderr for a complete failure context
error_output = (
f"STDOUT:\n{exec_result.stdout}\n\nSTDERR:\n{exec_result.stderr}"
)
result = {"status": "FAIL", "error": error_output}
return result
def _run(
self,
code_to_write: str,
) -> Dict[str, Any]:
"""Use this to execute the tool synchronously."""
try:
# Write the code to the specified file
with open(self.working_file_path, 'w', encoding='utf-8') as f:
f.write(code_to_write)
# Execute the command
exec_result = subprocess.run(
self.command,
timeout=300,
capture_output=True,
text=True,
encoding='utf-8',
cwd=self.working_folder_location,
)
# Parse the result
result = self._parse_exec_result(exec_result)
except Exception as e:
result = {"status": "FAIL", "error": f"Tool execution failed: {str(e)}"}
finally:
# Clean up the file after execution
try:
with open(self.working_file_path, 'w', encoding='utf-8') as f:
f.write('') # Clear content
except Exception:
pass # Ignore cleanup errors if file is already gone, etc.
return result
async def _arun(
self,
code_to_write: str,
) -> Dict[str, Any]:
"""Use this to execute the tool asynchronously."""
try:
# Asynchronously write the file
await asyncio.to_thread(self._write_file_sync, self.working_file_path, code_to_write)
# Asynchronously run the subprocess
proc = await asyncio.create_subprocess_exec(
*self.command,
cwd=self.working_folder_location,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(proc.communicate(), timeout=300)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
return {"status": "FAIL", "error": "Execution timed out after 300 seconds."}
# Create a CompletedProcess-like object for the parser
class AsyncExecResult:
def __init__(self, returncode, stdout, stderr):
self.returncode = returncode
self.stdout = stdout.decode('utf-8')
self.stderr = stderr.decode('utf-8')
exec_result = AsyncExecResult(proc.returncode, stdout_bytes, stderr_bytes)
# Parse the result
result = self._parse_exec_result(exec_result)
except Exception as e:
result = {"status": "FAIL", "error": f"Tool execution failed: {str(e)}"}
finally:
# Asynchronously clean up the file
try:
await asyncio.to_thread(self._write_file_sync, working_file_path, '') # Clear content
except Exception:
pass # Ignore cleanup errors
return result
def _write_file_sync(self, file_path: str, content: str):
"""Synchronous helper for file writing, to be used with asyncio.to_thread."""
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
# --- Agent Setup ---
def setup_agent(args: argparse.Namespace, logger: logging.Logger) -> AgentExecutor:
"""
Initializes and returns the LangChain ReAct Agent and Executor.
"""
# 1. Set API Keys (CRITICAL)
# The agent needs these environment variables to be set.
# We'll check for them here.
llm = ChatOpenAI(
model="Qwen/Qwen3-32B",
temperature=0,
base_url="http://172.31.51.159:8001/v1",
api_key="not-needed"
)
# 3. Initialize the Tools
# The agent will have two tools:
# Tool A: The web search tool you wanted to add.
search_wrapper = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search_wrapper, max_results=5)
# Tool B: Our custom code testing tool.
code_tool = CodeWriterAndTesterTool(
working_file_path=args.working_file_path,
command=args.command, # Split command string into a list
working_folder_location=args.working_folder_location
)
tools = [tavily_tool, code_tool]
# 4. Create the Agent Prompt
# This prompt is the "brain" of the agent. It tells the agent
# its persona, what tools it has, and how to reason.
# We are using the standard ReAct JSON prompt structure.
# Notice the placeholders: {tools}, {tool_names}, {input}, {agent_scratchpad}
# These are filled in by the AgentExecutor automatically.
CI_AGENT_PROMPT = """
You are an expert in computational imaging and a world-class Python programmer.
Your goal is to solve a coding task.
Answer the user's request. You have access to the following tools with these names:
{tool_names}
You have access to the following tools:
{tools}
To use a tool, use the following JSON format:
Thought:
I need to use a tool to do something.
```json
{{
"action": "tool_name",
"action_input": "input to the tool"
}}
```
Observation:
The tool will return some output.
... (this Thought/Action/Observation loop can repeat N times)
When you have gathered enough information and are ready to write the final,
runnable code, you MUST use the 'code_writer_and_tester' tool.
Your reasoning process should be:
1. **Analyze Request:** Understand the problem.
2. **Search (if needed):** If the problem involves unfamiliar libraries (like 'ehtim')
or specific scientific concepts, use the 'tavily_search_results_json' tool
to find documentation, examples, and formulas.
3. **Plan Code:** Formulate a plan for the Python code.
4. **Write & Test Code:** Use the 'code_writer_and_tester' tool to submit
your *complete* Python script.
5. **Analyze Results:**
- If the 'Observation' (tool output) shows 'status: "SUCCESS"',
your job is done. Respond with your final answer.
- If the 'Observation' shows 'status: "FAILED"', analyze the 'error'
and 'output' fields.
6. **Debug & Retry:** Formulate a new 'Thought' based on the error.
Decide if you need to search again (e.g., to fix the error) or
if you can just modify the code. Then, use 'code_writer_and_tester'
again with the *revised* code.
**IMPORTANT:** The 'code_writer_and_tester' tool expects the *entire*
runnable Python script as its input. Do not send partial code.
Begin!
Problem:
{input}
{agent_scratchpad}
"""
prompt = ChatPromptTemplate.from_template(CI_AGENT_PROMPT)
# 5. Create the Agent
# This binds the LLM, tools, and prompt together.
agent = create_json_chat_agent(llm, tools, prompt)
logging_callback = AgentLoggingCallback(logger)
# 6. Create the Agent Executor
# This is the runtime that actually executes the agent's decisions
# and manages the Thought/Action/Observation loop.
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True, # Set to True to see the agent's thoughts!
max_iterations=args.max_rounds,
handle_parsing_errors=True, # Helps with malformed JSON from LLM
return_intermediate_steps=True, # Good for debugging
callbacks=[logging_callback]
)
return agent_executor
# --- Main Execution ---
def main():
# --- Mocking User's Args ---
# We'll create a mock 'args' object to simulate your script's input
# In your real script, you'd use argparse.ArgumentParser
parser = argparse.ArgumentParser()
parser.add_argument("--max_rounds", type=int, default=3)
parser.add_argument("--working_file_path", type=str, default="./temp_work/agent_code.py")
# This command will test our mock file.
# In a real scenario, this would be your 'pytest' or 'bash test.sh' command.
parser.add_argument("--command",
nargs='*', # <--- This is the correct way
default=[
"python",
"/fs-computility-new/UPDZ02_sunhe/chensiyi.p/AFlow/code_development/obs_arg/eval_v3.py",
"--npix", "32",
"--obspath", "/fs-computility-new/UPDZ02_sunhe/shared/eht_imaging/DPI/dataset/interferometry1/obs.uvfits"
],
help="the command used to run the test env")
parser.add_argument("--working_folder_location", type=str, default=".")
parser.add_argument(
"--problem_description_path_json",
type=str,
default="workspace",
help="Optimized result save path",
)
parser.add_argument(
"--log_dir",
type=str,
default="./logs",
help="Directory to save log files."
)
parser.add_argument(
"--log_name",
type=str,
default="react_run",
help="Base name for the log file. A timestamp will be appended."
)
args = parser.parse_args()
logger = setup_logging(args.log_dir, args.log_name)
# --- Mocking User's Problem ---
# This simulates your `problem_description[question]`
# --- PROBLEM 1: A problem that requires searching (ehtim) ---
problem_description = load_data(args.problem_description_path_json)
# --- PROBLEM 2: A problem that fails first, then succeeds ---
# problem_description = (
# "Write a Python script that calculates the area of a circle "
# "with a radius of 5. It must print 'Success!' at the end."
# "Note: The test environment *only* succeeds if it sees 'print(\'Success!\')'. "
# "A common mistake is to forget the 'import math' statement."
# )
logging.info(f"--- Starting new problem ---")
logging.info(f"Problem Description: {problem_description}")
# 1. Setup the agent
try:
agent_executor = setup_agent(args, logger)
except EnvironmentError as e:
logging.error(f"Failed to setup agent: {e}")
logging.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment.")
return
# 2. Define the initial input for the agent
# This is your *entire* manual loop, now a single function call.
initial_input = {
"input": problem_description
}
# 3. Invoke the agent and let it run
logging.info("--- Invoking Agent Executor ---")
try:
# We use asyncio.run to call the async `ainvoke` method
# This is the modern way to run LangChain agents.
response = asyncio.run(agent_executor.ainvoke(initial_input))
logging.info("--- Agent run finished ---")
logging.info(f"Final Output: {response.get('output')}")
# You can also inspect the intermediate steps (the full loop)
# for debugging and logging:
# logging.debug(f"Intermediate Steps: {response.get('intermediate_steps')}")
except Exception as e:
logging.error(f"An error occurred while running the agent: {e}")
if __name__ == "__main__":
# To run this:
# 1. pip install langchain langchain-openai langchain-community tavily-python
# 2. Set your API keys in your terminal:
# export OPENAI_API_KEY="sk-..."
# export TAVILY_API_KEY="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"
# 3. Run the script:
# python langchain_ci_agent.py
main()