From e09588e2d2fc3e7edc89d4877a22db9548de9ca6 Mon Sep 17 00:00:00 2001 From: AymanBx Date: Wed, 22 Apr 2026 02:21:52 +0000 Subject: [PATCH] agent_langchain: Modify to accept local models. Logging doesn't work well. Once the while loop starts in Agent Executor there's no control and we get final answer only --- fairnessBench/agents/agent_langchain.py | 78 +++++++++++++++++++------ 1 file changed, 60 insertions(+), 18 deletions(-) diff --git a/fairnessBench/agents/agent_langchain.py b/fairnessBench/agents/agent_langchain.py index 439c28a..efa3828 100644 --- a/fairnessBench/agents/agent_langchain.py +++ b/fairnessBench/agents/agent_langchain.py @@ -6,12 +6,15 @@ from typing import Dict, List, Optional, Tuple, Union, Any -from langchain.agents import AgentExecutor -from langchain.agents import initialize_agent -from langchain.agents.tools import tool -from langchain_anthropic import ChatAnthropic +from langchain_classic.agents import AgentExecutor +from langchain_classic.agents import initialize_agent + +# from langchain_classic.agents.tools import tool +from langchain_core.tools import StructuredTool # Replacing tool + +# from langchain_anthropic import ChatAnthropic from langchain.chat_models.base import BaseChatModel -from langchain.schema import ( +from langchain_classic.schema import ( AgentAction, AgentFinish, AIMessage, @@ -23,17 +26,25 @@ ChatResult, ChatGeneration ) -from langchain.callbacks.manager import CallbackManagerForChainRun -from langchain.input import get_color_mapping -from langchain.callbacks import FileCallbackHandler -from langchain.agents.mrkl.output_parser import MRKLOutputParser + +import torch +from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig + +from langchain_classic.callbacks.manager import CallbackManagerForChainRun +from langchain_classic.input import get_color_mapping +from langchain_classic.callbacks import FileCallbackHandler +from langchain_classic.agents.mrkl.output_parser import MRKLOutputParser from fairnessBench.schema import Action -from fairnessBench.LLM import complete_text_fast as complete_text_crfm # AS: Changed it to complete_text_fast +from fairnessBench.LLM import complete_text_crfm from .agent import Agent class AgentExecutorWithState(AgentExecutor): - """ A modified version of the AgentExecutor class that allows us to keep track of the agent's state. """ + """ + A modified version of the AgentExecutor class that allows us to keep track of the agent's state. + An object of this class is used in the LangChainAgent run method to execute actions + """ def _call( self, @@ -92,7 +103,10 @@ def _call( class AnthropicOutputParser(MRKLOutputParser): - """ Modified version of the MRKLOutputParser that allows us to parse the output of an anthropic models. """ + """ + Modified version of the MRKLOutputParser that allows us to parse the output of an anthropic models. + An object of this class is used in the LangChainAgent run method to parse LLM response + """ def parse(self, text: str) -> Union[AgentAction, AgentFinish]: text = text.split("Thought:")[-1] return super().parse(text) @@ -103,7 +117,10 @@ def _type(self) -> str: class EnvTool: - """ A wrapper class to wrap actions as tools for the LangChain agent. """ + """ + A wrapper class to wrap actions as tools for the LangChain agent. + An object of this class is used in the LangChainAgent run method to setup the tools + """ def __init__(self, action_info, env): self.action_info = action_info self.env = env @@ -274,14 +291,39 @@ def run(self, env): agent_kwargs = {"output_parser": AnthropicOutputParser()} else: # TODO: add support for other agents - raise NotImplementedError + # Consider other options are local hf llms + match self.args.llm_name: + case 'llama': model_id = "meta-llama/Llama-3.3-70B-Instruct" + case 'qwen': model_id = "Qwen/Qwen2.5-72B-Instruct" + + tokenizer = AutoTokenizer.from_pretrained(model_id) + quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16) + hf_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config = quant_config, device_map=f"cuda:{self.args.device}",torch_dtype=torch.float16) + + # Set up stopping sequences (not attempted...) + # stop_sequence_ids = tokenizer(["Observation:"], return_token_type_ids=False, add_special_tokens=False) + + pipe = pipeline( + "text-generation", + model = hf_model, + tokenizer=tokenizer, + temperature=0.5, + max_new_tokens=2500, + do_sample=True, + # stopping_criteria = stopping_criteria, + ) + + llm = HuggingFacePipeline(pipeline=pipe) + # chat_model = ChatHuggingFace(llm=llm) # ??? + agent_kwargs = {"output_parser": AnthropicOutputParser()} + tools = [] for tool_name in self.prompt_tool_names: - tools.append(tool( - tool_name, - EnvTool(self.action_infos[tool_name], env).run, - self.construct_tool_prompt(tool_name, self.action_infos[tool_name]).replace("{", "{{").replace("}", "}}") + tools.append(StructuredTool.from_function( + name=tool_name, + func=EnvTool(self.action_infos[tool_name], env).run, + description=self.construct_tool_prompt(tool_name, self.action_infos[tool_name]).replace("{", "{{").replace("}", "}}") ) )