Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ dataset = load_dataset("trl-lib/tldr", split="train").select(range(128))

# Initialize trainer and start training
trainer = MAGRPOTrainer(
model="Qwen/Qwen2.5-0.5B",
agent_model="Qwen/Qwen2.5-0.5B",
num_agents=2,
tokenizer=tokenizer,
train_dataset=dataset,
Expand Down
18 changes: 16 additions & 2 deletions comlrl/trainers/actor_critic/ac_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,22 @@ def _resolve_turn_prompt(
modified_item, agent_idx, external_prompts=external_prompt
)

def _encode_prompt(self, prompt: str) -> Dict[str, torch.Tensor]:
encoded = self.tokenizer(
def _get_tokenizer(self, agent_idx: Optional[int] = None):
tokenizers = getattr(self, "tokenizers", None)
if isinstance(tokenizers, list) and tokenizers:
if agent_idx is None:
return tokenizers[0]
return tokenizers[agent_idx]
return self.tokenizer

def _encode_prompt(
self,
prompt: str,
agent_idx: Optional[int] = None,
tokenizer: Optional[Any] = None,
) -> Dict[str, torch.Tensor]:
tokenizer = tokenizer or self._get_tokenizer(agent_idx)
encoded = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
Expand Down
124 changes: 73 additions & 51 deletions comlrl/trainers/actor_critic/iac.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from comlrl.utils.formatters import build_formatters
from comlrl.utils.model_loading import resolve_model_sources
from comlrl.utils.reward_utils import call_reward_function, normalize_reward_lengths
from comlrl.utils.tokenizer_utils import apply_tokenizer_specials, resolve_tokenizer
from comlrl.utils.tokenizer_utils import apply_tokenizer_specials, resolve_tokenizers
from .ac_base import ActorCriticTrainerBase
import wandb

Expand Down Expand Up @@ -109,8 +109,11 @@ class IACTrainer(ActorCriticTrainerBase):

def __init__(
self,
model: Optional[Union[str, PreTrainedModel]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
agent_model: Optional[Union[str, PreTrainedModel]] = None,
critic_model: Optional[Union[str, PreTrainedModel]] = None,
tokenizer: Optional[
Union[PreTrainedTokenizerBase, Sequence[PreTrainedTokenizerBase]]
] = None,
reward_func: Optional[RewardFunc] = None,
reward_processor: Optional[Callable[[float], float]] = None,
formatters: Optional[Union[Formatter, Sequence[Formatter]]] = None,
Expand All @@ -131,19 +134,21 @@ def __init__(
self.args = args if args is not None else IACConfig()
if reward_func is None or not callable(reward_func):
raise ValueError("reward_func must be a callable.")
if model is None and agents is None:
raise ValueError("Either model or agents must be provided.")
if not self.args.use_separate_critic and critics is not None:
if agent_model is None and agents is None:
raise ValueError("Either agent_model or agents must be provided.")
if not self.args.use_separate_critic and (
critics is not None or critic_model is not None
):
raise ValueError(
"critics can only be provided when use_separate_critic=True."
)
if (
agents is None
and self.args.num_agents > 1
and isinstance(model, PreTrainedModel)
and isinstance(agent_model, PreTrainedModel)
):
raise ValueError(
"Multi-agent IAC requires `model` to be a pretrained identifier string."
"Multi-agent IAC requires `agent_model` to be a pretrained identifier string."
)
if agents is not None and tokenizer is None:
raise ValueError("Tokenizer must be provided when using agents.")
Expand All @@ -159,22 +164,29 @@ def __init__(

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.agent_models: List[CausalLMWithValueHead] = []
self.critic_models: List[Optional[CausalLMWithValueHead]] = []
self.agents: List[CausalLMWithValueHead] = []
self.critics: List[CausalLMWithValueHead] = []

self.tokenizer = resolve_tokenizer(model, tokenizer, agents)
tokenizers = resolve_tokenizers(agent_model, tokenizer, agents)
if isinstance(tokenizers, list):
self.tokenizers = tokenizers
self.tokenizer = tokenizers[0] if tokenizers else None
else:
self.tokenizers = [tokenizers] * self.args.num_agents
self.tokenizer = tokenizers
self.formatters = build_formatters(formatters, self.args.num_agents)
try:
self._reward_signature = inspect.signature(reward_func)
except (TypeError, ValueError):
self._reward_signature = None
self.external_transition = external_transition

actor_sources, self.agent_model_name = resolve_model_sources(
actor_sources, _agent_name = resolve_model_sources(
kind="agents",
model=model,
model=agent_model,
models=agents,
expected_count=self.args.num_agents,
model_label="agent_model",
)
for actor_source in actor_sources:
if actor_source is None:
Expand Down Expand Up @@ -208,19 +220,19 @@ def __init__(
attach_value_head=attach_value,
)
agent_model.to(self.device)
self.agent_models.append(agent_model)
self.agents.append(agent_model)

self.critic_model_name = None
if self.args.use_separate_critic:
if critics is None:
if critics is None and critic_model is None:
raise ValueError(
"critics must be provided when use_separate_critic=True."
"Either critic_model or critics must be provided when use_separate_critic=True."
)
critic_sources, self.critic_model_name = resolve_model_sources(
critic_sources, _critic_name = resolve_model_sources(
kind="critics",
model=None,
model=critic_model,
models=critics,
expected_count=self.args.num_agents,
model_label="critic_model",
)
for critic_source in critic_sources:
if isinstance(critic_source, CausalLMWithValueHead):
Expand Down Expand Up @@ -250,32 +262,35 @@ def __init__(
attach_value_head=True,
)
critic_model.to(self.device)
self.critic_models.append(critic_model)
self.critics.append(critic_model)
else:
self.critic_models = [None] * self.args.num_agents

apply_tokenizer_specials(
self.tokenizer, [*self.agent_models, *self.critic_models]
)
self.critics = []

if self.tokenizers and len(self.tokenizers) == len(self.agents):
for idx, tok in enumerate(self.tokenizers):
models = [self.agents[idx]]
if idx < len(self.critics):
models.append(self.critics[idx])
apply_tokenizer_specials(tok, models)
else:
apply_tokenizer_specials(self.tokenizer, [*self.agents, *self.critics])
self.agent_optimizers = []
self.critic_optimizers = []

for agent_model in self.agent_models:
for agent_model in self.agents:
optimizer = torch.optim.AdamW(
agent_model.parameters(),
lr=self.args.agent_learning_rate,
)
self.agent_optimizers.append(optimizer)

if self.args.use_separate_critic:
if any(critic is None for critic in self.critic_models):
raise RuntimeError("Critic model expected but missing.")
critic_model = self.critic_models[-1]
optimizer = torch.optim.AdamW(
critic_model.parameters(),
lr=self.args.critic_learning_rate,
)
self.critic_optimizers.append(optimizer)
for critic_model in self.critics:
optimizer = torch.optim.AdamW(
critic_model.parameters(),
lr=self.args.critic_learning_rate,
)
self.critic_optimizers.append(optimizer)

self.env_step = 0
self.rollout_buffers = [[] for _ in range(self.args.num_agents)]
Expand Down Expand Up @@ -331,15 +346,13 @@ def _init_wandb(self) -> None:
)
if isinstance(sections, dict):
dataset_section = sections.get("dataset") or {}
model_section = sections.get("model") or {}
output_section = sections.get("output") or {}
external_section = sections.get("external") or {}
trainer_section = sections.get("trainer") or {}

config_dict.update(
{
"dataset": dataset_section,
"model": model_section,
"output": output_section,
"external": external_section,
"trainer": trainer_section,
Expand Down Expand Up @@ -403,7 +416,7 @@ def _generate_rollout(
agent_idx: int,
num_ret: int,
) -> Dict[str, Any]:
encoded_prompt = self._encode_prompt(prompt)
encoded_prompt = self._encode_prompt(prompt, agent_idx=agent_idx)
prompt_input_ids = encoded_prompt["input_ids"]
prompt_attention_mask = encoded_prompt["attention_mask"]
prompt_len = prompt_input_ids.size(1)
Expand All @@ -426,7 +439,8 @@ def _generate_rollout(
raise RuntimeError("Model produced an empty completion during rollout.")

response_tokens = sequences[:, prompt_len:]
pad_id = self.tokenizer.pad_token_id
tokenizer = self._get_tokenizer(agent_idx)
pad_id = tokenizer.pad_token_id
response_lens: List[int] = []
completion_texts: List[str] = []
for seq in response_tokens:
Expand All @@ -436,14 +450,14 @@ def _generate_rollout(
)
response_lens.append(resp_len)
completion_texts.append(
self.tokenizer.decode(seq[:resp_len], skip_special_tokens=True)
tokenizer.decode(seq[:resp_len], skip_special_tokens=True)
)

full_attention_mask = torch.ones_like(sequences, device=self.device)

with torch.no_grad():
if self.args.use_separate_critic:
critic_model = self.critic_models[agent_idx]
critic_model = self.critics[agent_idx]
if critic_model is None:
raise RuntimeError("Critic model missing for agent.")
value = self._value_for_critic_type(
Expand Down Expand Up @@ -508,7 +522,7 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]:
rollout_data: List[Dict[str, Any]] = []
num_ret = int(getattr(self.args, "num_generations", 1))

for agent_idx, agent_model in enumerate(self.agent_models):
for agent_idx, agent_model in enumerate(self.agents):
prompt = self._resolve_turn_prompt(item, agent_idx)
gen = self._generate_rollout(agent_model, prompt, agent_idx, num_ret)
completions_per_agent.append(gen["completions"])
Expand Down Expand Up @@ -551,7 +565,7 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]:
RolloutSample(
agent_idx=agent_idx,
prompt=data["prompt"],
completion=self.tokenizer.decode(
completion=self._get_tokenizer(agent_idx).decode(
seq[data["prompt_len"] : data["prompt_len"] + resp_len],
skip_special_tokens=True,
),
Expand Down Expand Up @@ -621,7 +635,7 @@ def _collect_rollouts_multi_turn(

completions_per_agent: List[List[str]] = []
rollout_data: List[Dict[str, Any]] = []
for agent_idx, agent_model in enumerate(self.agent_models):
for agent_idx, agent_model in enumerate(self.agents):
prompt = turn_prompts[agent_idx]
gen = self._generate_rollout(agent_model, prompt, agent_idx, num_ret=1)
completions_per_agent.append(gen["completions"])
Expand Down Expand Up @@ -854,9 +868,9 @@ def _compute_sequence_stats(

# Actor-Critic update logic
def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, float]:
agent_model = self.agent_models[agent_idx]
agent_model = self.agents[agent_idx]
critic_model = (
self.critic_models[agent_idx] if self.args.use_separate_critic else None
self.critics[agent_idx] if self.args.use_separate_critic else None
)
agent_optimizer = self.agent_optimizers[agent_idx]
critic_optimizer = (
Expand Down Expand Up @@ -1013,14 +1027,14 @@ def _update(
def save_model(self, output_dir: str) -> None:
os.makedirs(output_dir, exist_ok=True)
if self.args.num_agents == 1:
actor = self.agent_models[0]
actor = self.agents[0]
actor.model.save_pretrained(output_dir)
if actor.value_head is not None:
torch.save(
actor.value_head.state_dict(),
os.path.join(output_dir, "value_head.pt"),
)
critic = self.critic_models[0]
critic = self.critics[0] if self.critics else None
if critic is not None:
critic_dir = os.path.join(output_dir, "critic")
os.makedirs(critic_dir, exist_ok=True)
Expand All @@ -1031,7 +1045,7 @@ def save_model(self, output_dir: str) -> None:
os.path.join(critic_dir, "value_head.pt"),
)
else:
for agent_idx, actor in enumerate(self.agent_models):
for agent_idx, actor in enumerate(self.agents):
agent_dir = os.path.join(output_dir, f"agent_{agent_idx}")
os.makedirs(agent_dir, exist_ok=True)
actor.model.save_pretrained(agent_dir)
Expand All @@ -1040,9 +1054,9 @@ def save_model(self, output_dir: str) -> None:
actor.value_head.state_dict(),
os.path.join(agent_dir, "value_head.pt"),
)
critic = self.critic_models[agent_idx]
if critic is None:
if not self.critics or agent_idx >= len(self.critics):
continue
critic = self.critics[agent_idx]
critic_dir = os.path.join(agent_dir, "critic")
os.makedirs(critic_dir, exist_ok=True)
critic.model.save_pretrained(critic_dir)
Expand All @@ -1052,5 +1066,13 @@ def save_model(self, output_dir: str) -> None:
os.path.join(critic_dir, "value_head.pt"),
)

if self.tokenizer is not None:
if self.tokenizers:
if self.args.num_agents == 1:
self.tokenizers[0].save_pretrained(output_dir)
else:
for idx, tok in enumerate(self.tokenizers):
agent_dir = os.path.join(output_dir, f"agent_{idx}")
os.makedirs(agent_dir, exist_ok=True)
tok.save_pretrained(agent_dir)
elif self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
Loading