Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 22 additions & 98 deletions comlrl/trainers/actor_critic/ac_base.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,15 @@
from __future__ import annotations

from collections import defaultdict
import inspect
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional

import torch
import wandb
from tqdm import tqdm # type: ignore

Formatter = Callable[[Dict[str, Any]], str]
from torch.utils.data import DataLoader


class ActorCriticTrainerBase:
"""Shared training utilities for actor-critic style trainers."""

def _infer_model_name(self, source: Any) -> Optional[str]:
if source is None:
return None
if isinstance(source, str):
return source
base = getattr(source, "model", source)
config = getattr(base, "config", None)
if config is not None:
name = getattr(config, "_name_or_path", None) or getattr(
config, "model_type", None
)
if name:
return str(name)
return base.__class__.__name__

def _resolve_model_sources(
self,
*,
kind: str,
model: Optional[Any],
models: Optional[Sequence[Any]],
expected_count: int,
expected_label: Optional[str] = None,
) -> Tuple[List[Any], Optional[str]]:
if model is not None and models is not None:
is_name_list = (
isinstance(models, Sequence)
and not isinstance(models, (str, bytes))
and all(isinstance(src, str) for src in models)
)
if not is_name_list or len(models) != expected_count:
label = expected_label or f"num_agents ({expected_count})"
raise ValueError(
f"Cannot provide both model and {kind} unless {kind} is a list of {label} model names."
)
if model is None and models is None:
raise ValueError(f"Either model or {kind} must be provided.")
if expected_count < 1:
raise ValueError("expected_count must be >= 1.")

if models is not None:
if isinstance(models, (str, bytes)) or not isinstance(models, Sequence):
raise ValueError(f"{kind} must be a non-empty sequence.")
sources = list(models)
if len(sources) != expected_count:
label = expected_label or f"num_agents ({expected_count})"
raise ValueError(f"{kind} length ({len(sources)}) must match {label}.")
else:
sources = [model] * expected_count

if any(src is None for src in sources):
raise ValueError(f"{kind} entries must be non-null.")

model_name = self._infer_model_name(sources[0]) if sources else None
return sources, model_name

def _filter_model_kwargs(self, cfg: Optional[Dict[str, Any]]) -> Dict[str, Any]:
torch_dtype = None
if isinstance(cfg, dict):
Expand All @@ -80,42 +20,6 @@ def _filter_model_kwargs(self, cfg: Optional[Dict[str, Any]]) -> Dict[str, Any]:
torch_dtype = model_cfg.get("torch_dtype") or model_cfg.get("dtype")
return {"torch_dtype": torch_dtype} if torch_dtype is not None else {}

def _setup_formatters(
self, formatters: Optional[Union[Formatter, Sequence[Formatter]]]
) -> List[Formatter]:
def _default_formatter(item: Dict[str, Any], external_prompts=None) -> str:
if external_prompts is not None:
return external_prompts
return item.get("prompt", "")

def _wrap_formatter(fmt: Formatter) -> Formatter:
try:
sig = inspect.signature(fmt)
if "external_prompts" in sig.parameters:
return lambda x, external_prompts=None, f=fmt: f(
x, external_prompts=external_prompts
)
except (TypeError, ValueError):
pass
return lambda x, external_prompts=None, f=fmt: f(x)

num_agents = int(self.args.num_agents)
if formatters is None:
return [_default_formatter for _ in range(num_agents)]
if callable(formatters):
return [_wrap_formatter(formatters) for _ in range(num_agents)]
if isinstance(formatters, Sequence) and not isinstance(
formatters, (str, bytes)
):
if len(formatters) != num_agents:
raise ValueError(
"Number of formatters must match num_agents when providing a sequence."
)
return [_wrap_formatter(f) for f in list(formatters)]
raise ValueError(
"formatters must be None, a callable, or a sequence of callables."
)

def _format_prompt(
self,
item: Dict[str, Any],
Expand Down Expand Up @@ -325,6 +229,26 @@ def _flush_buffers(self, epoch_metrics: Dict[str, List[float]]) -> None:
continue
self._process_buffer(agent_idx, buffer, epoch_metrics)

def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:
raise ValueError("Training requires a dataset.")
return DataLoader(
self.train_dataset,
batch_size=1,
shuffle=False,
collate_fn=lambda batch: batch,
)

def get_eval_dataloader(self) -> Optional[DataLoader]:
if self.eval_dataset is None:
return None
return DataLoader(
self.eval_dataset,
batch_size=self.args.eval_batch_size,
shuffle=False,
collate_fn=lambda batch: batch,
)

def evaluate(self) -> Dict[str, float]:
if self.eval_dataset is None:
return {}
Expand Down
Loading