Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 60 additions & 18 deletions fairnessBench/agents/agent_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from typing import Dict, List, Optional, Tuple, Union, Any


from langchain.agents import AgentExecutor
from langchain.agents import initialize_agent
from langchain.agents.tools import tool
from langchain_anthropic import ChatAnthropic
from langchain_classic.agents import AgentExecutor
from langchain_classic.agents import initialize_agent

# from langchain_classic.agents.tools import tool
from langchain_core.tools import StructuredTool # Replacing tool

# from langchain_anthropic import ChatAnthropic
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
from langchain_classic.schema import (
AgentAction,
AgentFinish,
AIMessage,
Expand All @@ -23,17 +26,25 @@
ChatResult,
ChatGeneration
)
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.input import get_color_mapping
from langchain.callbacks import FileCallbackHandler
from langchain.agents.mrkl.output_parser import MRKLOutputParser

import torch
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig

from langchain_classic.callbacks.manager import CallbackManagerForChainRun
from langchain_classic.input import get_color_mapping
from langchain_classic.callbacks import FileCallbackHandler
from langchain_classic.agents.mrkl.output_parser import MRKLOutputParser
from fairnessBench.schema import Action
from fairnessBench.LLM import complete_text_fast as complete_text_crfm # AS: Changed it to complete_text_fast
from fairnessBench.LLM import complete_text_crfm
from .agent import Agent


class AgentExecutorWithState(AgentExecutor):
""" A modified version of the AgentExecutor class that allows us to keep track of the agent's state. """
"""
A modified version of the AgentExecutor class that allows us to keep track of the agent's state.
An object of this class is used in the LangChainAgent run method to execute actions
"""

def _call(
self,
Expand Down Expand Up @@ -92,7 +103,10 @@ def _call(


class AnthropicOutputParser(MRKLOutputParser):
""" Modified version of the MRKLOutputParser that allows us to parse the output of an anthropic models. """
"""
Modified version of the MRKLOutputParser that allows us to parse the output of an anthropic models.
An object of this class is used in the LangChainAgent run method to parse LLM response
"""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
text = text.split("Thought:")[-1]
return super().parse(text)
Expand All @@ -103,7 +117,10 @@ def _type(self) -> str:


class EnvTool:
""" A wrapper class to wrap actions as tools for the LangChain agent. """
"""
A wrapper class to wrap actions as tools for the LangChain agent.
An object of this class is used in the LangChainAgent run method to setup the tools
"""
def __init__(self, action_info, env):
self.action_info = action_info
self.env = env
Expand Down Expand Up @@ -274,14 +291,39 @@ def run(self, env):
agent_kwargs = {"output_parser": AnthropicOutputParser()}
else:
# TODO: add support for other agents
raise NotImplementedError
# Consider other options are local hf llms
match self.args.llm_name:
case 'llama': model_id = "meta-llama/Llama-3.3-70B-Instruct"
case 'qwen': model_id = "Qwen/Qwen2.5-72B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16)
hf_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config = quant_config, device_map=f"cuda:{self.args.device}",torch_dtype=torch.float16)

# Set up stopping sequences (not attempted...)
# stop_sequence_ids = tokenizer(["Observation:"], return_token_type_ids=False, add_special_tokens=False)

pipe = pipeline(
"text-generation",
model = hf_model,
tokenizer=tokenizer,
temperature=0.5,
max_new_tokens=2500,
do_sample=True,
# stopping_criteria = stopping_criteria,
)

llm = HuggingFacePipeline(pipeline=pipe)
# chat_model = ChatHuggingFace(llm=llm) # ???
agent_kwargs = {"output_parser": AnthropicOutputParser()}


tools = []
for tool_name in self.prompt_tool_names:
tools.append(tool(
tool_name,
EnvTool(self.action_infos[tool_name], env).run,
self.construct_tool_prompt(tool_name, self.action_infos[tool_name]).replace("{", "{{").replace("}", "}}")
tools.append(StructuredTool.from_function(
name=tool_name,
func=EnvTool(self.action_infos[tool_name], env).run,
description=self.construct_tool_prompt(tool_name, self.action_infos[tool_name]).replace("{", "{{").replace("}", "}}")
)
)

Expand Down