From 225616e5ff62235bb5ead16683328b86044ccdd8 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Tue, 10 Mar 2026 17:40:25 +0000 Subject: [PATCH 01/20] Add draft evaluator --- pyproject.toml | 1 + src/core/datasets/base_dataset.py | 4 + src/core/evaluation/__init__.py | 0 .../evaluation/continuous_batch_generator.py | 218 ++++++++++++++ src/core/evaluation/evaluator.py | 174 +++++++++++ .../evaluation/multi_checkpoint_evaluator.py | 133 +++++++++ tests/__init__.py | 0 tests/test_continuous_batch_generator.py | 275 ++++++++++++++++++ uv.lock | 36 +++ 9 files changed, 841 insertions(+) create mode 100644 src/core/evaluation/__init__.py create mode 100644 src/core/evaluation/continuous_batch_generator.py create mode 100644 src/core/evaluation/evaluator.py create mode 100644 src/core/evaluation/multi_checkpoint_evaluator.py create mode 100644 tests/__init__.py create mode 100644 tests/test_continuous_batch_generator.py diff --git a/pyproject.toml b/pyproject.toml index cb842a0..6d3aa61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,4 +39,5 @@ packages = [ dev = [ "huggingface-hub[cli]>=0.29.3", "ipykernel>=6.29.5", + "pytest>=9.0.2", ] diff --git a/src/core/datasets/base_dataset.py b/src/core/datasets/base_dataset.py index 4521d04..8352df8 100644 --- a/src/core/datasets/base_dataset.py +++ b/src/core/datasets/base_dataset.py @@ -21,3 +21,7 @@ def user_prompt(self, row: dict) -> str: ... @abstractmethod def row_id(self, row: dict) -> str: ... + + @property + def dataset_id(self) -> str: + return self.__class__.__name__ diff --git a/src/core/evaluation/__init__.py b/src/core/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/evaluation/continuous_batch_generator.py b/src/core/evaluation/continuous_batch_generator.py new file mode 100644 index 0000000..0f9fce8 --- /dev/null +++ b/src/core/evaluation/continuous_batch_generator.py @@ -0,0 +1,218 @@ +from collections import deque +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F +from transformers import DynamicCache, PreTrainedModel, PreTrainedTokenizer + + +@dataclass +class _Slot: + index: int + prompt_len: int + generated_ids: list[int] = field(default_factory=list) + cache: DynamicCache = field(default_factory=DynamicCache) + seq_position: int = 0 # total tokens seen = prompt_len + len(generated_ids) + + +def _right_pad_cache(cache: DynamicCache, target_len: int) -> DynamicCache: + """Pad KV cache tensors along seq dim (dim=-2) with zeros to target_len.""" + current_len = cache.get_seq_length() + if current_len == target_len: + return cache + pad_len = target_len - current_len + padded = DynamicCache() + for layer_idx in range(len(cache)): + k = cache.key_cache[layer_idx] # [1, H, T, D] + v = cache.value_cache[layer_idx] + k_pad = F.pad(k, (0, 0, 0, pad_len)) # pad dim=-2 + v_pad = F.pad(v, (0, 0, 0, pad_len)) + padded.update(k_pad, v_pad, layer_idx) + return padded + + +def _trim_cache(cache: DynamicCache, valid_len: int) -> DynamicCache: + """Trim KV cache to only the first valid_len entries along seq dim.""" + trimmed = DynamicCache() + for layer_idx in range(len(cache)): + k = cache.key_cache[layer_idx][:, :, :valid_len, :] + v = cache.value_cache[layer_idx][:, :, :valid_len, :] + trimmed.update(k, v, layer_idx) + return trimmed + + +class ContinuousBatchGenerator: + """Token-by-token generation with continuous batching via model.forward(). + + Maintains a pool of active slots. Empty slots are filled from a queue of + pending prompts. Each slot holds its own DynamicCache (batch_size=1). + + Prefill runs individually per prompt (no padding waste). Decode batches + all active slots into a single forward() call by padding KV caches to + equal length and using an attention mask to ignore padded positions. + """ + + def __init__( + self, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + max_new_tokens: int, + max_batch_size: int = 8, + temperature: float = 0.0, + top_p: float = 1.0, + top_k: int = -1, + ): + self.model = model + self.tokenizer = tokenizer + self.max_new_tokens = max_new_tokens + self.max_batch_size = max_batch_size + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + + self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + self.eos_token_id = tokenizer.eos_token_id + + @torch.no_grad() + def generate(self, prompts: list[list[int]]) -> list[list[int]]: + """Generate responses for a list of prompts using continuous batching. + + Args: + prompts: List of token ID sequences (one per sample). + + Returns: + List of generated token ID sequences (excluding prompt), same order + as input. + """ + results: list[list[int] | None] = [None] * len(prompts) + queue: deque[tuple[int, list[int]]] = deque((i, p) for i, p in enumerate(prompts)) + active_slots: list[_Slot | None] = [None] * self.max_batch_size + + while queue or any(s is not None for s in active_slots): + # FILL: prefill empty slots with new prompts (batch_size=1 each) + for slot_idx in range(self.max_batch_size): + if active_slots[slot_idx] is not None or not queue: + continue + prompt_idx, prompt_ids = queue.popleft() + active_slots[slot_idx] = self._prefill(prompt_idx, prompt_ids) + + # Collect occupied slots + occupied = [(i, s) for i, s in enumerate(active_slots) if s is not None] + if not occupied: + break + + # BATCHED DECODE: single forward() call for all active slots + slots_only = [s for _, s in occupied] + self._batched_decode(slots_only) + + # RETIRE: check for completed sequences + for slot_idx, slot in occupied: + last_token = slot.generated_ids[-1] + if last_token == self.eos_token_id or len(slot.generated_ids) >= self.max_new_tokens: + results[slot.index] = slot.generated_ids + active_slots[slot_idx] = None + + return [r if r is not None else [] for r in results] + + def _prefill(self, prompt_idx: int, prompt_ids: list[int]) -> _Slot: + """Run the prefill forward pass to build KV cache and sample the first token.""" + device = self.model.device + input_ids = torch.tensor([prompt_ids], device=device) + seq_len = len(prompt_ids) + position_ids = torch.arange(seq_len, device=device).unsqueeze(0) + cache_position = torch.arange(seq_len, device=device) + cache = DynamicCache() + + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + cache_position=cache_position, + past_key_values=cache, + use_cache=True, + ) + + next_token = self._sample_token(outputs.logits[:, -1, :]) + + return _Slot( + index=prompt_idx, + prompt_len=seq_len, + generated_ids=[next_token.item()], + cache=outputs.past_key_values, + seq_position=seq_len + 1, + ) + + def _batched_decode(self, slots: list[_Slot]) -> None: + """Run a single batched decode step for all active slots.""" + device = self.model.device + num_slots = len(slots) + + # Cache lengths before padding (needed for trim after forward) + slot_cache_lens = [s.cache.get_seq_length() for s in slots] + max_cache_len = max(slot_cache_lens) + + # Pad each slot's KV cache to max_cache_len, then merge into batched cache + padded_caches = [_right_pad_cache(s.cache, max_cache_len) for s in slots] + batched_cache = DynamicCache.from_batch_splits(padded_caches) + + # input_ids: last generated token per slot [num_slots, 1] + input_ids = torch.tensor([[s.generated_ids[-1]] for s in slots], device=device) + + # attention_mask: [num_slots, max_cache_len + 1] (+1 for the new token) + attn_mask = torch.zeros(num_slots, max_cache_len + 1, dtype=torch.long, device=device) + for i, slot in enumerate(slots): + attn_mask[i, :slot_cache_lens[i]] = 1 # valid cached positions + attn_mask[i, max_cache_len] = 1 # the new token position (appended at end) + + # cache_position: shared across batch, points to where the new KV is appended + cache_position = torch.tensor([max_cache_len], device=device) + + # position_ids: each slot's actual position (prompt_len + generated so far) + position_ids = torch.tensor([[s.seq_position] for s in slots], device=device) + + # Single forward() call + outputs = self.model( + input_ids=input_ids, + attention_mask=attn_mask, + position_ids=position_ids, + cache_position=cache_position, + past_key_values=batched_cache, + use_cache=True, + ) + + # Sample next token per slot + for i, slot in enumerate(slots): + next_token = self._sample_token(outputs.logits[i : i + 1, -1, :]) + slot.generated_ids.append(next_token.item()) + slot.seq_position += 1 + + # Split updated cache back to per-slot, trim padding + updated_splits = outputs.past_key_values.batch_split(num_slots, split_size=1) + for i, slot in enumerate(slots): + valid_len = slot_cache_lens[i] + 1 # original cache len + 1 new token + slot.cache = _trim_cache(updated_splits[i], valid_len) + + def _sample_token(self, logits: torch.Tensor) -> torch.Tensor: + """Sample a single token from logits of shape [1, vocab_size].""" + if self.temperature == 0.0: + return logits.argmax(dim=-1).squeeze(0) + + logits = logits / self.temperature + + if self.top_k > 0: + top_k = min(self.top_k, logits.size(-1)) + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits = logits.masked_fill(indices_to_remove, float("-inf")) + + if self.top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > self.top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices_to_remove.scatter( + dim=-1, index=sorted_indices, src=sorted_indices_to_remove + ) + logits = logits.masked_fill(indices_to_remove, float("-inf")) + + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).squeeze(-1).squeeze(0) diff --git a/src/core/evaluation/evaluator.py b/src/core/evaluation/evaluator.py new file mode 100644 index 0000000..a39ed35 --- /dev/null +++ b/src/core/evaluation/evaluator.py @@ -0,0 +1,174 @@ +import json +from pathlib import Path + +import pandas as pd +from pydantic import BaseModel +from pydraconf import PydraConfig +from transformers import AutoModelForCausalLM, AutoTokenizer + +from core.datasets.qa_dataset import QADataset +from core.datasets.qa_dataset_adapter import QADatasetAdapter +from core.evaluation.continuous_batch_generator import ContinuousBatchGenerator +from core.utils.logger import logger + + +class GenerationConfig(BaseModel): + max_new_tokens: int = 4096 + max_batch_size: int = 8 + temperature: float = 0.0 + top_p: float = 1.0 + top_k: int = -1 + + +class EvaluatorConfig(PydraConfig): + model_path: str + eval_dataset: QADatasetAdapter | list[QADatasetAdapter] + out_path: str | None = None + generation: GenerationConfig = GenerationConfig() + + +class EvaluationResult(BaseModel): + accuracy: float + total: int + correct: int + + +class Evaluator: + def __init__(self, config: EvaluatorConfig): + self.config = config + + @property + def _datasets(self) -> list[QADatasetAdapter]: + if isinstance(self.config.eval_dataset, list): + return self.config.eval_dataset + return [self.config.eval_dataset] + + def evaluate(self) -> list[EvaluationResult]: + model, tokenizer = self._load_model() + model.eval() + + results: list[EvaluationResult] = [] + for eval_dataset in self._datasets: + result = self._evaluate_single(eval_dataset, model, tokenizer) + results.append(result) + + return results + + def _evaluate_single(self, eval_dataset: QADatasetAdapter, model, tokenizer) -> EvaluationResult: + ds = eval_dataset.process_dataset() + + prompts = [row["input_ids"] for row in ds] + logger.info(f"Evaluating {len(prompts)} samples with model from {self.config.model_path}") + + generator = ContinuousBatchGenerator( + model=model, + tokenizer=tokenizer, + max_new_tokens=self.config.generation.max_new_tokens, + max_batch_size=self.config.generation.max_batch_size, + temperature=self.config.generation.temperature, + top_p=self.config.generation.top_p, + top_k=self.config.generation.top_k, + ) + + generated = generator.generate(prompts) + + correct = 0 + total = len(prompts) + all_results: list[dict] = [] + + qa_dataset: QADataset = eval_dataset.dataset + + for i, gen_ids in enumerate(generated): + row = ds[i] + response = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() + + try: + parsed_answer, is_correct = qa_dataset.verify_assistant_response(row, response) + except Exception as ex: + logger.warning(f"Error verifying row {row['row_id']}: {ex}") + parsed_answer = response + is_correct = False + + if is_correct: + correct += 1 + + all_results.append( + { + "row_id": row["row_id"], + "response": response, + "parsed_answer": parsed_answer, + "is_correct": is_correct, + } + ) + + accuracy = correct / total if total > 0 else 0.0 + result = EvaluationResult(accuracy=accuracy, total=total, correct=correct) + + logger.info(f"Evaluation complete: accuracy={accuracy:.4f} ({correct}/{total})") + + self._save_results(eval_dataset, result, all_results) + + return result + + def _out_path_for(self, eval_dataset: QADatasetAdapter) -> Path: + dataset_id = eval_dataset.dataset.dataset_id + if self.config.out_path: + return Path(self.config.out_path) / dataset_id + + model_path = Path(self.config.model_path) + if not model_path.is_dir(): + raise ValueError(f"out_path must be set when model_path is not a local directory: {self.config.model_path}") + return model_path / "evals" / dataset_id + + def _eval_results_path_for(self, eval_dataset: QADatasetAdapter) -> Path: + return self._out_path_for(eval_dataset) / "results.json" + + def _load_model(self): + model_path = Path(self.config.model_path) + + if model_path.is_dir(): + adapter_config = model_path / "adapter_config.json" + if adapter_config.exists(): + return self._load_lora_model(model_path, adapter_config) + + logger.info(f"Loading model from {self.config.model_path}") + model = AutoModelForCausalLM.from_pretrained(self.config.model_path) + tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) + return model, tokenizer + + def _load_lora_model(self, model_path: Path, adapter_config: Path): + from peft import PeftModel + + with open(adapter_config) as f: + config = json.load(f) + + base_model_id = config.get("base_model_name_or_path") + if not base_model_id: + raise ValueError(f"adapter_config.json at {adapter_config} missing 'base_model_name_or_path'") + + logger.info(f"Loading LoRA model: base={base_model_id}, adapter={model_path}") + base_model = AutoModelForCausalLM.from_pretrained(base_model_id) + model = PeftModel.from_pretrained(base_model, str(model_path)) + tokenizer = AutoTokenizer.from_pretrained(base_model_id) + return model, tokenizer + + def _save_results(self, eval_dataset: QADatasetAdapter, result: EvaluationResult, all_results: list[dict]) -> None: + out_path = self._out_path_for(eval_dataset) + out_path.mkdir(parents=True, exist_ok=True) + + # Save summary + summary_path = self._eval_results_path_for(eval_dataset) + summary = { + "accuracy": result.accuracy, + "total": result.total, + "correct": result.correct, + "model_path": self.config.model_path, + } + with open(summary_path, "w") as f: + json.dump(summary, f, indent=2) + logger.info(f"Summary saved to {summary_path}") + + # Save per-row results + results_path = out_path / "responses.parquet" + pd.DataFrame(all_results).to_parquet(results_path, index=False) + logger.info(f"Per-row results saved to {results_path}") diff --git a/src/core/evaluation/multi_checkpoint_evaluator.py b/src/core/evaluation/multi_checkpoint_evaluator.py new file mode 100644 index 0000000..9989e8a --- /dev/null +++ b/src/core/evaluation/multi_checkpoint_evaluator.py @@ -0,0 +1,133 @@ +import gc +import json +from pathlib import Path + +import torch +from pydraconf import PydraConfig + +from core.datasets.qa_dataset_adapter import QADatasetAdapter +from core.evaluation.evaluator import EvaluationResult, Evaluator, EvaluatorConfig, GenerationConfig +from core.utils.logger import logger + + +class MultiCheckpointEvaluatorConfig(PydraConfig): + checkpoints_dir: str + eval_dataset: QADatasetAdapter | list[QADatasetAdapter] + base_model_id: str | None = None + out_path: str | None = None + generation: GenerationConfig = GenerationConfig() + + +class MultiCheckpointEvaluator: + """Evaluates all checkpoints in a directory sequentially. + + For each checkpoint, loads the model, runs evaluation, then unloads and + frees GPU memory before proceeding to the next. + """ + + def __init__(self, config: MultiCheckpointEvaluatorConfig): + self.config = config + + def _normalize_datasets(self) -> list[QADatasetAdapter]: + if isinstance(self.config.eval_dataset, list): + return self.config.eval_dataset + return [self.config.eval_dataset] + + def evaluate_all(self) -> list[tuple[str, list[EvaluationResult], float | None]]: + checkpoints_dir = Path(self.config.checkpoints_dir) + if not checkpoints_dir.is_dir(): + raise NotADirectoryError(f"{checkpoints_dir} is not a directory") + + checkpoint_dirs = sorted( + checkpoints_dir.glob("checkpoint-*"), + key=lambda p: int(p.name.split("-")[1]), + ) + + if not checkpoint_dirs: + logger.warning(f"No checkpoint-* dirs found in {checkpoints_dir}") + return [] + + logger.info(f"Found {len(checkpoint_dirs)} checkpoints in {checkpoints_dir}") + + results: list[tuple[str, list[EvaluationResult], float | None]] = [] + + if self.config.base_model_id: + logger.info(f"Evaluating base model {self.config.base_model_id} as epoch 0...") + base_out_path = str(Path(self.config.out_path) / "base_model") if self.config.out_path else None + base_config = EvaluatorConfig( + model_path=self.config.base_model_id, + eval_dataset=self.config.eval_dataset, + out_path=base_out_path, + generation=self.config.generation, + ) + base_results = Evaluator(base_config).evaluate() + results.append((self.config.base_model_id, base_results, 0.0)) + + for r in base_results: + logger.info(f"base_model: accuracy={r.accuracy:.4f} ({r.correct}/{r.total})") + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + for ckpt_dir in checkpoint_dirs: + ckpt_name = ckpt_dir.name + ckpt_out_path = str(Path(self.config.out_path) / ckpt_name) + + logger.info(f"Evaluating {ckpt_name}...") + + config = EvaluatorConfig( + model_path=str(ckpt_dir), + eval_dataset=self.config.eval_dataset, + out_path=ckpt_out_path, + generation=self.config.generation, + ) + + eval_results = Evaluator(config).evaluate() + epoch = self._read_epoch(ckpt_dir) + results.append((ckpt_name, eval_results, epoch)) + + for r in eval_results: + logger.info(f"{ckpt_name}: accuracy={r.accuracy:.4f} ({r.correct}/{r.total})") + + # Free GPU memory between checkpoints + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + self._save_summary(results) + return results + + def _read_epoch(self, ckpt_dir: Path) -> float | None: + state_file = ckpt_dir / "trainer_state.json" + if not state_file.exists(): + return None + with open(state_file) as f: + state = json.load(f) + return state.get("epoch") + + def _save_summary(self, results: list[tuple[str, list[EvaluationResult], float | None]]) -> None: + out_path = Path(self.config.out_path) + out_path.mkdir(parents=True, exist_ok=True) + + datasets = self._normalize_datasets() + + summary: dict[str, list[dict]] = {} + for ds_idx, ds_adapter in enumerate(datasets): + dataset_id = ds_adapter.dataset.dataset_id + summary[dataset_id] = [ + { + "epoch": epoch, + "checkpoint": ckpt_name, + "accuracy": eval_results[ds_idx].accuracy, + "total": eval_results[ds_idx].total, + "correct": eval_results[ds_idx].correct, + } + for ckpt_name, eval_results, epoch in results + ] + + summary_path = out_path / "summary.json" + with open(summary_path, "w") as f: + json.dump(summary, f, indent=2) + + logger.info(f"Summary saved to {summary_path}") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_continuous_batch_generator.py b/tests/test_continuous_batch_generator.py new file mode 100644 index 0000000..7c3574e --- /dev/null +++ b/tests/test_continuous_batch_generator.py @@ -0,0 +1,275 @@ +"""Tests for ContinuousBatchGenerator with proper batched decode. + +Uses sshleifer/tiny-gpt2 (~500KB) — a minimal GPT-2 model that supports +KV cache and all the HF generation APIs we rely on. + +Correctness tests compare batched continuous generation against HF's +model.generate() under greedy decoding (temperature=0). + +Performance tests measure wall-clock time: batched vs sequential decode. +""" + +import time + +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from core.evaluation.continuous_batch_generator import ContinuousBatchGenerator + +MODEL_ID = "sshleifer/tiny-gpt2" +MAX_NEW_TOKENS = 20 + + +@pytest.fixture(scope="module") +def model_and_tokenizer(): + model = AutoModelForCausalLM.from_pretrained(MODEL_ID) + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + model.eval() + return model, tokenizer + + +def _hf_generate_greedy(model, tokenizer, prompt_ids: list[int], max_new_tokens: int) -> list[int]: + """Reference: HF model.generate() with greedy decoding for a single prompt.""" + input_ids = torch.tensor([prompt_ids], device=model.device) + with torch.no_grad(): + output = model.generate( + input_ids, + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + ) + # Strip prompt from output + return output[0, len(prompt_ids) :].tolist() + + +# ────────────────────────────────────────────────────────────────────── +# Correctness tests +# ────────────────────────────────────────────────────────────────────── + + +class TestCorrectnessVsHFGenerate: + """Batched continuous generation must match HF generate() under greedy decoding.""" + + def test_single_prompt(self, model_and_tokenizer): + model, tokenizer = model_and_tokenizer + prompt = tokenizer.encode("Hello world") + + gen = ContinuousBatchGenerator( + model=model, + tokenizer=tokenizer, + max_new_tokens=MAX_NEW_TOKENS, + max_batch_size=1, + ) + [result] = gen.generate([prompt]) + + expected = _hf_generate_greedy(model, tokenizer, prompt, MAX_NEW_TOKENS) + assert result == expected, f"Mismatch:\n got: {result}\n expected: {expected}" + + def test_multiple_same_length_prompts(self, model_and_tokenizer): + model, tokenizer = model_and_tokenizer + prompts = [ + tokenizer.encode("The cat sat on"), + tokenizer.encode("A dog ran to"), + ] + + gen = ContinuousBatchGenerator( + model=model, + tokenizer=tokenizer, + max_new_tokens=MAX_NEW_TOKENS, + max_batch_size=4, + ) + results = gen.generate(prompts) + + for i, (prompt, result) in enumerate(zip(prompts, results)): + expected = _hf_generate_greedy(model, tokenizer, prompt, MAX_NEW_TOKENS) + assert result == expected, f"Prompt {i} mismatch:\n got: {result}\n expected: {expected}" + + def test_variable_length_prompts(self, model_and_tokenizer): + """Prompts of very different lengths in the same batch.""" + model, tokenizer = model_and_tokenizer + prompts = [ + tokenizer.encode("Hi"), # short + tokenizer.encode("The quick brown fox jumps over the lazy dog and then"), # long + tokenizer.encode("Once"), # short + ] + + gen = ContinuousBatchGenerator( + model=model, + tokenizer=tokenizer, + max_new_tokens=MAX_NEW_TOKENS, + max_batch_size=4, + ) + results = gen.generate(prompts) + + for i, (prompt, result) in enumerate(zip(prompts, results)): + expected = _hf_generate_greedy(model, tokenizer, prompt, MAX_NEW_TOKENS) + assert result == expected, f"Prompt {i} (len={len(prompt)}) mismatch:\n got: {result}\n expected: {expected}" + + def test_more_prompts_than_batch_size(self, model_and_tokenizer): + """Tests continuous batching: prompts > max_batch_size forces queuing.""" + model, tokenizer = model_and_tokenizer + prompts = [ + tokenizer.encode("One"), + tokenizer.encode("Two two"), + tokenizer.encode("Three three three"), + tokenizer.encode("Four"), + tokenizer.encode("Five five"), + ] + + gen = ContinuousBatchGenerator( + model=model, + tokenizer=tokenizer, + max_new_tokens=MAX_NEW_TOKENS, + max_batch_size=2, # only 2 slots, 5 prompts → continuous batching + ) + results = gen.generate(prompts) + + for i, (prompt, result) in enumerate(zip(prompts, results)): + expected = _hf_generate_greedy(model, tokenizer, prompt, MAX_NEW_TOKENS) + assert result == expected, f"Prompt {i} mismatch:\n got: {result}\n expected: {expected}" + + def test_max_new_tokens_cutoff(self, model_and_tokenizer): + model, tokenizer = model_and_tokenizer + prompt = tokenizer.encode("Hello") + max_tokens = 5 + + gen = ContinuousBatchGenerator( + model=model, + tokenizer=tokenizer, + max_new_tokens=max_tokens, + max_batch_size=1, + ) + [result] = gen.generate([prompt]) + + assert len(result) <= max_tokens, f"Generated {len(result)} tokens, max was {max_tokens}" + + def test_result_order_preserved(self, model_and_tokenizer): + """Results must be returned in the same order as input prompts.""" + model, tokenizer = model_and_tokenizer + prompts = [ + tokenizer.encode("Alpha"), + tokenizer.encode("Beta beta beta beta"), + tokenizer.encode("Gamma"), + ] + + gen = ContinuousBatchGenerator( + model=model, + tokenizer=tokenizer, + max_new_tokens=MAX_NEW_TOKENS, + max_batch_size=2, + ) + results = gen.generate(prompts) + + assert len(results) == len(prompts) + for i, (prompt, result) in enumerate(zip(prompts, results)): + expected = _hf_generate_greedy(model, tokenizer, prompt, MAX_NEW_TOKENS) + assert result == expected, f"Prompt {i} order/content mismatch" + + def test_batch_size_1_same_as_sequential(self, model_and_tokenizer): + """batch_size=1 should produce identical results (no batching, just prefill+decode).""" + model, tokenizer = model_and_tokenizer + prompts = [ + tokenizer.encode("First prompt here"), + tokenizer.encode("Second"), + ] + + gen = ContinuousBatchGenerator( + model=model, + tokenizer=tokenizer, + max_new_tokens=MAX_NEW_TOKENS, + max_batch_size=1, + ) + results = gen.generate(prompts) + + for i, (prompt, result) in enumerate(zip(prompts, results)): + expected = _hf_generate_greedy(model, tokenizer, prompt, MAX_NEW_TOKENS) + assert result == expected + + +# ────────────────────────────────────────────────────────────────────── +# Performance tests +# ────────────────────────────────────────────────────────────────────── + + +class TestPerformance: + """Compare batched vs sequential (batch_size=1) wall-clock time. + + These tests use a tiny model so the speedup may not be dramatic, + but we verify batching is not slower than sequential. + """ + + @pytest.fixture + def many_prompts(self, model_and_tokenizer): + _, tokenizer = model_and_tokenizer + return [ + tokenizer.encode("The quick brown fox jumps over the lazy dog"), + tokenizer.encode("Once upon a time in a land far away"), + tokenizer.encode("To be or not to be that is the question"), + tokenizer.encode("In the beginning there was nothing"), + tokenizer.encode("A long time ago in a galaxy far far away"), + tokenizer.encode("It was the best of times it was the worst"), + tokenizer.encode("Call me Ishmael some years ago never mind how long"), + tokenizer.encode("All happy families are alike each unhappy family"), + ] + + def test_batched_vs_sequential_time(self, model_and_tokenizer, many_prompts): + model, tokenizer = model_and_tokenizer + max_tokens = 30 + + # Sequential: batch_size=1 + gen_seq = ContinuousBatchGenerator( + model=model, + tokenizer=tokenizer, + max_new_tokens=max_tokens, + max_batch_size=1, + ) + start = time.perf_counter() + results_seq = gen_seq.generate(many_prompts) + time_sequential = time.perf_counter() - start + + # Batched: batch_size=8 + gen_batch = ContinuousBatchGenerator( + model=model, + tokenizer=tokenizer, + max_new_tokens=max_tokens, + max_batch_size=8, + ) + start = time.perf_counter() + results_batch = gen_batch.generate(many_prompts) + time_batched = time.perf_counter() - start + + # Verify correctness: both must produce same results + assert results_seq == results_batch, "Batched and sequential results differ!" + + # Report timing + speedup = time_sequential / time_batched if time_batched > 0 else float("inf") + print(f"\n Sequential (bs=1): {time_sequential:.3f}s") + print(f" Batched (bs=8): {time_batched:.3f}s") + print(f" Speedup: {speedup:.2f}x") + + def test_varying_batch_sizes(self, model_and_tokenizer, many_prompts): + model, tokenizer = model_and_tokenizer + max_tokens = 20 + + timings = {} + for bs in [1, 2, 4, 8]: + gen = ContinuousBatchGenerator( + model=model, + tokenizer=tokenizer, + max_new_tokens=max_tokens, + max_batch_size=bs, + ) + start = time.perf_counter() + gen.generate(many_prompts) + elapsed = time.perf_counter() - start + timings[bs] = elapsed + + print("\n Batch size -> Time:") + for bs, t in timings.items(): + print(f" bs={bs}: {t:.3f}s") + + # Sanity: all should complete without error (no assertion on speed + # since tiny-gpt2 on CPU may not show clear scaling) diff --git a/uv.lock b/uv.lock index 8fe065d..ccdd795 100644 --- a/uv.lock +++ b/uv.lock @@ -308,6 +308,7 @@ dependencies = [ dev = [ { name = "huggingface-hub", extra = ["cli"] }, { name = "ipykernel" }, + { name = "pytest" }, ] [package.metadata] @@ -337,6 +338,7 @@ requires-dist = [ dev = [ { name = "huggingface-hub", extras = ["cli"], specifier = ">=0.29.3" }, { name = "ipykernel", specifier = ">=6.29.5" }, + { name = "pytest", specifier = ">=9.0.2" }, ] [[package]] @@ -767,6 +769,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" }, ] +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + [[package]] name = "inquirerpy" version = "0.3.4" @@ -1815,6 +1826,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/31/05e764397056194206169869b50cf2fee4dbbbc71b344705b9c0d878d4d8/platformdirs-4.9.2-py3-none-any.whl", hash = "sha256:9170634f126f8efdae22fb58ae8a0eaa86f38365bc57897a6c4f781d1f5875bd", size = 21168, upload-time = "2026-02-16T03:56:08.891Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "prompt-toolkit" version = "3.0.52" @@ -2127,6 +2147,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/bd/c038d7cc38edc1aa5bf91ab8068b63d4308c66c4c8bb3cbba7dfbc049f9c/pyparsing-3.3.2-py3-none-any.whl", hash = "sha256:850ba148bd908d7e2411587e247a1e4f0327839c40e2e5e6d05a007ecc69911d", size = 122781, upload-time = "2026-01-21T03:57:55.912Z" }, ] +[[package]] +name = "pytest" +version = "9.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" From 0b483077a3f7b601e3994b758a8d4f464d0ed816 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Tue, 10 Mar 2026 21:36:50 +0000 Subject: [PATCH 02/20] Minor path tweaks --- .../evaluation/multi_checkpoint_evaluator.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/core/evaluation/multi_checkpoint_evaluator.py b/src/core/evaluation/multi_checkpoint_evaluator.py index 9989e8a..08cfc87 100644 --- a/src/core/evaluation/multi_checkpoint_evaluator.py +++ b/src/core/evaluation/multi_checkpoint_evaluator.py @@ -53,7 +53,7 @@ def evaluate_all(self) -> list[tuple[str, list[EvaluationResult], float | None]] if self.config.base_model_id: logger.info(f"Evaluating base model {self.config.base_model_id} as epoch 0...") - base_out_path = str(Path(self.config.out_path) / "base_model") if self.config.out_path else None + base_out_path = str(self._out_path / "base_model") base_config = EvaluatorConfig( model_path=self.config.base_model_id, eval_dataset=self.config.eval_dataset, @@ -72,7 +72,7 @@ def evaluate_all(self) -> list[tuple[str, list[EvaluationResult], float | None]] for ckpt_dir in checkpoint_dirs: ckpt_name = ckpt_dir.name - ckpt_out_path = str(Path(self.config.out_path) / ckpt_name) + ckpt_out_path = str(self._out_path / ckpt_name) logger.info(f"Evaluating {ckpt_name}...") @@ -107,8 +107,7 @@ def _read_epoch(self, ckpt_dir: Path) -> float | None: return state.get("epoch") def _save_summary(self, results: list[tuple[str, list[EvaluationResult], float | None]]) -> None: - out_path = Path(self.config.out_path) - out_path.mkdir(parents=True, exist_ok=True) + self._out_path.mkdir(parents=True, exist_ok=True) datasets = self._normalize_datasets() @@ -126,8 +125,15 @@ def _save_summary(self, results: list[tuple[str, list[EvaluationResult], float | for ckpt_name, eval_results, epoch in results ] - summary_path = out_path / "summary.json" + summary_path = self._out_path / "summary.json" with open(summary_path, "w") as f: json.dump(summary, f, indent=2) logger.info(f"Summary saved to {summary_path}") + + @property + def _out_path(self) -> Path: + if self.config.out_path: + return Path(self.config.out_path) + + return Path(self.config.checkpoints_dir) From 46c793068dab2b61022dfba5d72904289357d887 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Tue, 10 Mar 2026 22:32:55 +0000 Subject: [PATCH 03/20] Update sft by complexity split --- src/core/evaluation/evaluator.py | 43 +++++++--- .../evaluation/multi_checkpoint_evaluator.py | 3 +- src/core/training/base_trainer.py | 10 +++ .../sft_by_complexity_splits/mmlu/llama_3b.py | 81 +++++++++++++------ 4 files changed, 101 insertions(+), 36 deletions(-) diff --git a/src/core/evaluation/evaluator.py b/src/core/evaluation/evaluator.py index a39ed35..7110747 100644 --- a/src/core/evaluation/evaluator.py +++ b/src/core/evaluation/evaluator.py @@ -4,7 +4,7 @@ import pandas as pd from pydantic import BaseModel from pydraconf import PydraConfig -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer from core.datasets.qa_dataset import QADataset from core.datasets.qa_dataset_adapter import QADatasetAdapter @@ -13,8 +13,8 @@ class GenerationConfig(BaseModel): - max_new_tokens: int = 4096 - max_batch_size: int = 8 + max_new_tokens: int + max_batch_size: int temperature: float = 0.0 top_p: float = 1.0 top_k: int = -1 @@ -24,7 +24,7 @@ class EvaluatorConfig(PydraConfig): model_path: str eval_dataset: QADatasetAdapter | list[QADatasetAdapter] out_path: str | None = None - generation: GenerationConfig = GenerationConfig() + generation: GenerationConfig class EvaluationResult(BaseModel): @@ -34,8 +34,9 @@ class EvaluationResult(BaseModel): class Evaluator: - def __init__(self, config: EvaluatorConfig): + def __init__(self, config: EvaluatorConfig, tokenizer: PreTrainedTokenizer | None = None): self.config = config + self.tokenizer = tokenizer @property def _datasets(self) -> list[QADatasetAdapter]: @@ -44,13 +45,20 @@ def _datasets(self) -> list[QADatasetAdapter]: return [self.config.eval_dataset] def evaluate(self) -> list[EvaluationResult]: + cached_results: list[EvaluationResult | None] = [self._load_cached_result(ds) for ds in self._datasets] + + if all(r is not None for r in cached_results): + return cached_results # type: ignore[return-value] + model, tokenizer = self._load_model() model.eval() results: list[EvaluationResult] = [] - for eval_dataset in self._datasets: - result = self._evaluate_single(eval_dataset, model, tokenizer) - results.append(result) + for ds, cached in zip(self._datasets, cached_results): + if cached is not None: + results.append(cached) + else: + results.append(self._evaluate_single(ds, model, tokenizer)) return results @@ -110,6 +118,15 @@ def _evaluate_single(self, eval_dataset: QADatasetAdapter, model, tokenizer) -> return result + def _load_cached_result(self, eval_dataset: QADatasetAdapter) -> EvaluationResult | None: + results_path = self._eval_results_path_for(eval_dataset) + if not results_path.exists(): + return None + with open(results_path) as f: + data = json.load(f) + logger.info(f"Found cached results at {results_path}, skipping evaluation") + return EvaluationResult(accuracy=data["accuracy"], total=data["total"], correct=data["correct"]) + def _out_path_for(self, eval_dataset: QADatasetAdapter) -> Path: dataset_id = eval_dataset.dataset.dataset_id if self.config.out_path: @@ -133,8 +150,9 @@ def _load_model(self): logger.info(f"Loading model from {self.config.model_path}") model = AutoModelForCausalLM.from_pretrained(self.config.model_path) - tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) - return model, tokenizer + if not self.tokenizer: + self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) + return model, self.tokenizer def _load_lora_model(self, model_path: Path, adapter_config: Path): from peft import PeftModel @@ -149,8 +167,9 @@ def _load_lora_model(self, model_path: Path, adapter_config: Path): logger.info(f"Loading LoRA model: base={base_model_id}, adapter={model_path}") base_model = AutoModelForCausalLM.from_pretrained(base_model_id) model = PeftModel.from_pretrained(base_model, str(model_path)) - tokenizer = AutoTokenizer.from_pretrained(base_model_id) - return model, tokenizer + if not self.tokenizer: + self.tokenizer = AutoTokenizer.from_pretrained(base_model_id) + return model, self.tokenizer def _save_results(self, eval_dataset: QADatasetAdapter, result: EvaluationResult, all_results: list[dict]) -> None: out_path = self._out_path_for(eval_dataset) diff --git a/src/core/evaluation/multi_checkpoint_evaluator.py b/src/core/evaluation/multi_checkpoint_evaluator.py index 08cfc87..4df85c1 100644 --- a/src/core/evaluation/multi_checkpoint_evaluator.py +++ b/src/core/evaluation/multi_checkpoint_evaluator.py @@ -4,6 +4,7 @@ import torch from pydraconf import PydraConfig +from pydantic import Field from core.datasets.qa_dataset_adapter import QADatasetAdapter from core.evaluation.evaluator import EvaluationResult, Evaluator, EvaluatorConfig, GenerationConfig @@ -15,7 +16,7 @@ class MultiCheckpointEvaluatorConfig(PydraConfig): eval_dataset: QADatasetAdapter | list[QADatasetAdapter] base_model_id: str | None = None out_path: str | None = None - generation: GenerationConfig = GenerationConfig() + generation: GenerationConfig class MultiCheckpointEvaluator: diff --git a/src/core/training/base_trainer.py b/src/core/training/base_trainer.py index 8a62fed..79e6033 100644 --- a/src/core/training/base_trainer.py +++ b/src/core/training/base_trainer.py @@ -1,8 +1,11 @@ +import gc import json import subprocess from pathlib import Path from typing import Any +import torch + from pydantic import BaseModel from pydraconf import PydraConfig from transformers import ( @@ -138,6 +141,13 @@ def _run_training(self, train_ds): logger.info(f"Has checkpoint: {has_checkpoint}") trainer.train(resume_from_checkpoint=has_checkpoint) + def unload(self): + self._model = None + self._tokenizer = None + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + def _directory_is_empty(self, directory: str, expected_epochs: int) -> bool: p = Path(directory) if not p.exists(): diff --git a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py index 9c77b0b..7f8e508 100644 --- a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py +++ b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py @@ -4,34 +4,69 @@ from core.datasets.causal_dataset_adapter import CausalDatasetAdapter from core.datasets.mmlu.mmlu_single_token_response_dataset import MMLUSingleTokenResponseDataset, QADatasetConfig +from core.datasets.qa_dataset_adapter import QADatasetAdapter +from core.evaluation.multi_checkpoint_evaluator import ( + GenerationConfig, + MultiCheckpointEvaluator, + MultiCheckpointEvaluatorConfig, +) from core.training.lora_trainer import LoRATrainer, LoRATrainerConfig, LoRATrainingArgs +from core.utils.logger import logger MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token -trainer = LoRATrainer( - config=LoRATrainerConfig( - out_path=Path(__file__) - .parent.joinpath("../../../../artifacts/sft_by_complexity_splits/mmlu/llama_3b/group0") - .as_posix(), - model_id=MODEL_NAME, - train_dataset=CausalDatasetAdapter( - dataset=MMLUSingleTokenResponseDataset( - config=QADatasetConfig( - path=Path(__file__) - .parent.joinpath( - "../../../../data/out/splits/single_token_entropy/mmlu/llama_3b/group0_train.parquet" - ) - .as_posix() - ), - tokenizer=tokenizer, - ) +for group in range(6): + logger.info(f"Training on group {group}...") + + trainer = LoRATrainer( + config=LoRATrainerConfig( + out_path=Path(__file__) + .parent.joinpath(f"../../../../artifacts/sft_by_complexity_splits/mmlu/llama_3b/group{group}") + .as_posix(), + model_id=MODEL_NAME, + train_dataset=CausalDatasetAdapter( + dataset=MMLUSingleTokenResponseDataset( + config=QADatasetConfig( + path=Path(__file__) + .parent.joinpath( + f"../../../../data/out/splits/single_token_entropy/mmlu/llama_3b/group{group}_train.parquet" + ) + .as_posix() + ), + tokenizer=tokenizer, + ) + ), + training_args=LoRATrainingArgs(num_train_epochs=20, per_device_train_batch_size=32), + save_schedule=[1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20], ), - training_args=LoRATrainingArgs(num_train_epochs=20, per_device_train_batch_size=32), - save_schedule=[1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20], - ), - tokenizer=tokenizer, -) -trainer.train() + tokenizer=tokenizer, + ) + trainer.train() + trainer.unload() + + single_token_evaluator = MultiCheckpointEvaluator( + config=MultiCheckpointEvaluatorConfig( + checkpoints_dir=trainer.config.out_path, + eval_dataset=[ + QADatasetAdapter( + dataset=MMLUSingleTokenResponseDataset( + config=QADatasetConfig( + path=Path(__file__) + .parent.joinpath( + f"../../../../data/out/splits/single_token_entropy/mmlu/llama_3b/group{j}_test.parquet" + ) + .as_posix() + ), + tokenizer=tokenizer, + ) + ) + for j in range(6) + ], + base_model_id=MODEL_NAME, + generation=GenerationConfig(max_new_tokens=1, max_batch_size=64), + ) + ) + single_token_evaluator.evaluate_all() From 9526bf0e2457172e75a95c1b3ef62fe245217569 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Tue, 10 Mar 2026 22:40:36 +0000 Subject: [PATCH 04/20] Add progress bars --- src/core/evaluation/continuous_batch_generator.py | 6 +++++- src/core/evaluation/evaluator.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/core/evaluation/continuous_batch_generator.py b/src/core/evaluation/continuous_batch_generator.py index 0f9fce8..65167f8 100644 --- a/src/core/evaluation/continuous_batch_generator.py +++ b/src/core/evaluation/continuous_batch_generator.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F +from tqdm import tqdm from transformers import DynamicCache, PreTrainedModel, PreTrainedTokenizer @@ -87,6 +88,7 @@ def generate(self, prompts: list[list[int]]) -> list[list[int]]: results: list[list[int] | None] = [None] * len(prompts) queue: deque[tuple[int, list[int]]] = deque((i, p) for i, p in enumerate(prompts)) active_slots: list[_Slot | None] = [None] * self.max_batch_size + pbar = tqdm(total=len(prompts), desc="Generating") while queue or any(s is not None for s in active_slots): # FILL: prefill empty slots with new prompts (batch_size=1 each) @@ -111,7 +113,9 @@ def generate(self, prompts: list[list[int]]) -> list[list[int]]: if last_token == self.eos_token_id or len(slot.generated_ids) >= self.max_new_tokens: results[slot.index] = slot.generated_ids active_slots[slot_idx] = None + pbar.update(1) + pbar.close() return [r if r is not None else [] for r in results] def _prefill(self, prompt_idx: int, prompt_ids: list[int]) -> _Slot: @@ -160,7 +164,7 @@ def _batched_decode(self, slots: list[_Slot]) -> None: # attention_mask: [num_slots, max_cache_len + 1] (+1 for the new token) attn_mask = torch.zeros(num_slots, max_cache_len + 1, dtype=torch.long, device=device) for i, slot in enumerate(slots): - attn_mask[i, :slot_cache_lens[i]] = 1 # valid cached positions + attn_mask[i, : slot_cache_lens[i]] = 1 # valid cached positions attn_mask[i, max_cache_len] = 1 # the new token position (appended at end) # cache_position: shared across batch, points to where the new KV is appended diff --git a/src/core/evaluation/evaluator.py b/src/core/evaluation/evaluator.py index 7110747..562037b 100644 --- a/src/core/evaluation/evaluator.py +++ b/src/core/evaluation/evaluator.py @@ -4,6 +4,7 @@ import pandas as pd from pydantic import BaseModel from pydraconf import PydraConfig +from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer from core.datasets.qa_dataset import QADataset @@ -54,7 +55,7 @@ def evaluate(self) -> list[EvaluationResult]: model.eval() results: list[EvaluationResult] = [] - for ds, cached in zip(self._datasets, cached_results): + for ds, cached in tqdm(zip(self._datasets, cached_results), total=len(self._datasets), desc="Datasets"): if cached is not None: results.append(cached) else: From c0f54a5aa4c5b085ed9567d587aa87861d4bcb6a Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Wed, 11 Mar 2026 15:00:31 +0300 Subject: [PATCH 05/20] Add CoT evals --- .../mmlu/mmlu_cot_response_dataset.py | 46 +++++++++++++++++++ .../evaluation/multi_checkpoint_evaluator.py | 29 +++++++----- .../sft_by_complexity_splits/mmlu/llama_3b.py | 29 +++++++++++- 3 files changed, 91 insertions(+), 13 deletions(-) create mode 100644 src/core/datasets/mmlu/mmlu_cot_response_dataset.py diff --git a/src/core/datasets/mmlu/mmlu_cot_response_dataset.py b/src/core/datasets/mmlu/mmlu_cot_response_dataset.py new file mode 100644 index 0000000..947cf57 --- /dev/null +++ b/src/core/datasets/mmlu/mmlu_cot_response_dataset.py @@ -0,0 +1,46 @@ +from typing import override + +from transformers import PreTrainedTokenizer + +from core.datasets.mmlu.mmlu_single_token_response_dataset import MMLUSingleTokenResponseDataset +from core.datasets.qa_dataset import QADatasetConfig + + +class MMLUCoTResponseDataset(MMLUSingleTokenResponseDataset): + def __init__(self, tokenizer: PreTrainedTokenizer, config: QADatasetConfig): + super().__init__(tokenizer, config) + + self.answer_marker = ("[[", "]]") + + @override + def system_prompt(self, row: dict) -> str: + subject = row["base_cluster"] + return f"The following are multiple choice questions about {subject}. Explain your thinking process step-by-step. At the end, choose a correct option letter by strictly following this format: {self.answer_marker[0]}correct_option{self.answer_marker[1]}." + + @override + def assistant_response(self, row: dict) -> str: + raise NotImplementedError( + "MMLUCoTResponseDataset does not implement assistant_response since it is not used for training. Use MMLUReasoningResponseDataset for evaluation instead." + ) + + @override + def verify_assistant_response(self, row: dict, assistant_response: str) -> tuple[str, bool]: + answer_start_token_position = assistant_response.find(self.answer_marker[0]) + answer_end_token_position = assistant_response.find(self.answer_marker[1]) + if ( + answer_start_token_position == -1 + or answer_end_token_position == -1 + or answer_end_token_position < answer_start_token_position + ): + return "", False + + extracted_answer = ( + assistant_response[answer_start_token_position + len(self.answer_marker[0]) : answer_end_token_position] + .strip() + .lower() + ) + + try: + return extracted_answer, self.assistant_response(row) == extracted_answer + except: + return extracted_answer, False diff --git a/src/core/evaluation/multi_checkpoint_evaluator.py b/src/core/evaluation/multi_checkpoint_evaluator.py index 4df85c1..aa1da2b 100644 --- a/src/core/evaluation/multi_checkpoint_evaluator.py +++ b/src/core/evaluation/multi_checkpoint_evaluator.py @@ -4,7 +4,7 @@ import torch from pydraconf import PydraConfig -from pydantic import Field +from transformers import PreTrainedTokenizer from core.datasets.qa_dataset_adapter import QADatasetAdapter from core.evaluation.evaluator import EvaluationResult, Evaluator, EvaluatorConfig, GenerationConfig @@ -16,6 +16,7 @@ class MultiCheckpointEvaluatorConfig(PydraConfig): eval_dataset: QADatasetAdapter | list[QADatasetAdapter] base_model_id: str | None = None out_path: str | None = None + summary_filename: str = "summary.json" generation: GenerationConfig @@ -26,8 +27,9 @@ class MultiCheckpointEvaluator: frees GPU memory before proceeding to the next. """ - def __init__(self, config: MultiCheckpointEvaluatorConfig): + def __init__(self, config: MultiCheckpointEvaluatorConfig, tokenizer: PreTrainedTokenizer | None = None): self.config = config + self.tokenizer = tokenizer def _normalize_datasets(self) -> list[QADatasetAdapter]: if isinstance(self.config.eval_dataset, list): @@ -61,15 +63,13 @@ def evaluate_all(self) -> list[tuple[str, list[EvaluationResult], float | None]] out_path=base_out_path, generation=self.config.generation, ) - base_results = Evaluator(base_config).evaluate() + base_results = Evaluator(base_config, self.tokenizer).evaluate() results.append((self.config.base_model_id, base_results, 0.0)) for r in base_results: logger.info(f"base_model: accuracy={r.accuracy:.4f} ({r.correct}/{r.total})") - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + self._free_vram() for ckpt_dir in checkpoint_dirs: ckpt_name = ckpt_dir.name @@ -84,19 +84,19 @@ def evaluate_all(self) -> list[tuple[str, list[EvaluationResult], float | None]] generation=self.config.generation, ) - eval_results = Evaluator(config).evaluate() + eval_results = Evaluator(config, self.tokenizer).evaluate() epoch = self._read_epoch(ckpt_dir) results.append((ckpt_name, eval_results, epoch)) for r in eval_results: logger.info(f"{ckpt_name}: accuracy={r.accuracy:.4f} ({r.correct}/{r.total})") - # Free GPU memory between checkpoints - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + self._free_vram() self._save_summary(results) + + self._free_vram() + return results def _read_epoch(self, ckpt_dir: Path) -> float | None: @@ -126,7 +126,7 @@ def _save_summary(self, results: list[tuple[str, list[EvaluationResult], float | for ckpt_name, eval_results, epoch in results ] - summary_path = self._out_path / "summary.json" + summary_path = self._out_path / self.config.summary_filename with open(summary_path, "w") as f: json.dump(summary, f, indent=2) @@ -138,3 +138,8 @@ def _out_path(self) -> Path: return Path(self.config.out_path) return Path(self.config.checkpoints_dir) + + def _free_vram(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py index 7f8e508..72e5278 100644 --- a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py +++ b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py @@ -3,6 +3,7 @@ from transformers import AutoTokenizer from core.datasets.causal_dataset_adapter import CausalDatasetAdapter +from core.datasets.mmlu.mmlu_cot_response_dataset import MMLUCoTResponseDataset from core.datasets.mmlu.mmlu_single_token_response_dataset import MMLUSingleTokenResponseDataset, QADatasetConfig from core.datasets.qa_dataset_adapter import QADatasetAdapter from core.evaluation.multi_checkpoint_evaluator import ( @@ -67,6 +68,32 @@ ], base_model_id=MODEL_NAME, generation=GenerationConfig(max_new_tokens=1, max_batch_size=64), - ) + ), + tokenizer=tokenizer, ) single_token_evaluator.evaluate_all() + + cot_evaluator = MultiCheckpointEvaluator( + config=MultiCheckpointEvaluatorConfig( + checkpoints_dir=trainer.config.out_path, + eval_dataset=[ + QADatasetAdapter( + dataset=MMLUCoTResponseDataset( + config=QADatasetConfig( + path=Path(__file__) + .parent.joinpath( + f"../../../../data/out/splits/single_token_entropy/mmlu/llama_3b/group{j}_test.parquet" + ) + .as_posix() + ), + tokenizer=tokenizer, + ) + ) + for j in range(6) + ], + base_model_id=MODEL_NAME, + generation=GenerationConfig(max_new_tokens=4096, max_batch_size=4), + ), + tokenizer=tokenizer, + ) + cot_evaluator.evaluate_all() From 32a641d310dcbfd7fed33068def96dc64a62f211 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Wed, 11 Mar 2026 12:09:27 +0000 Subject: [PATCH 06/20] Add max_len_seq warning --- .../evaluation/continuous_batch_generator.py | 18 +++++++++++---- src/core/evaluation/evaluator.py | 23 ++++++++++++++++--- .../evaluation/multi_checkpoint_evaluator.py | 11 +++++++++ 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/src/core/evaluation/continuous_batch_generator.py b/src/core/evaluation/continuous_batch_generator.py index 65167f8..24fb0ac 100644 --- a/src/core/evaluation/continuous_batch_generator.py +++ b/src/core/evaluation/continuous_batch_generator.py @@ -7,6 +7,13 @@ from transformers import DynamicCache, PreTrainedModel, PreTrainedTokenizer +@dataclass +class GenerationResult: + sequences: list[list[int]] + num_truncated: int + total: int + + @dataclass class _Slot: index: int @@ -75,19 +82,19 @@ def __init__( self.eos_token_id = tokenizer.eos_token_id @torch.no_grad() - def generate(self, prompts: list[list[int]]) -> list[list[int]]: + def generate(self, prompts: list[list[int]]) -> GenerationResult: """Generate responses for a list of prompts using continuous batching. Args: prompts: List of token ID sequences (one per sample). Returns: - List of generated token ID sequences (excluding prompt), same order - as input. + GenerationResult with generated sequences and truncation stats. """ results: list[list[int] | None] = [None] * len(prompts) queue: deque[tuple[int, list[int]]] = deque((i, p) for i, p in enumerate(prompts)) active_slots: list[_Slot | None] = [None] * self.max_batch_size + num_truncated = 0 pbar = tqdm(total=len(prompts), desc="Generating") while queue or any(s is not None for s in active_slots): @@ -111,12 +118,15 @@ def generate(self, prompts: list[list[int]]) -> list[list[int]]: for slot_idx, slot in occupied: last_token = slot.generated_ids[-1] if last_token == self.eos_token_id or len(slot.generated_ids) >= self.max_new_tokens: + if last_token != self.eos_token_id: + num_truncated += 1 results[slot.index] = slot.generated_ids active_slots[slot_idx] = None pbar.update(1) pbar.close() - return [r if r is not None else [] for r in results] + sequences = [r if r is not None else [] for r in results] + return GenerationResult(sequences=sequences, num_truncated=num_truncated, total=len(prompts)) def _prefill(self, prompt_idx: int, prompt_ids: list[int]) -> _Slot: """Run the prefill forward pass to build KV cache and sample the first token.""" diff --git a/src/core/evaluation/evaluator.py b/src/core/evaluation/evaluator.py index 562037b..47b97e0 100644 --- a/src/core/evaluation/evaluator.py +++ b/src/core/evaluation/evaluator.py @@ -32,6 +32,7 @@ class EvaluationResult(BaseModel): accuracy: float total: int correct: int + num_truncated: int = 0 class Evaluator: @@ -79,7 +80,15 @@ def _evaluate_single(self, eval_dataset: QADatasetAdapter, model, tokenizer) -> top_k=self.config.generation.top_k, ) - generated = generator.generate(prompts) + gen_result = generator.generate(prompts) + generated = gen_result.sequences + + if gen_result.num_truncated > 0: + pct = gen_result.num_truncated / gen_result.total * 100 + logger.warning( + f"Generation reached max_new_tokens ({self.config.generation.max_new_tokens}) " + f"for {gen_result.num_truncated}/{gen_result.total} sequences ({pct:.1f}%)" + ) correct = 0 total = len(prompts) @@ -111,7 +120,9 @@ def _evaluate_single(self, eval_dataset: QADatasetAdapter, model, tokenizer) -> ) accuracy = correct / total if total > 0 else 0.0 - result = EvaluationResult(accuracy=accuracy, total=total, correct=correct) + result = EvaluationResult( + accuracy=accuracy, total=total, correct=correct, num_truncated=gen_result.num_truncated + ) logger.info(f"Evaluation complete: accuracy={accuracy:.4f} ({correct}/{total})") @@ -126,7 +137,12 @@ def _load_cached_result(self, eval_dataset: QADatasetAdapter) -> EvaluationResul with open(results_path) as f: data = json.load(f) logger.info(f"Found cached results at {results_path}, skipping evaluation") - return EvaluationResult(accuracy=data["accuracy"], total=data["total"], correct=data["correct"]) + return EvaluationResult( + accuracy=data["accuracy"], + total=data["total"], + correct=data["correct"], + num_truncated=data.get("num_truncated", 0), + ) def _out_path_for(self, eval_dataset: QADatasetAdapter) -> Path: dataset_id = eval_dataset.dataset.dataset_id @@ -182,6 +198,7 @@ def _save_results(self, eval_dataset: QADatasetAdapter, result: EvaluationResult "accuracy": result.accuracy, "total": result.total, "correct": result.correct, + "num_truncated": result.num_truncated, "model_path": self.config.model_path, } with open(summary_path, "w") as f: diff --git a/src/core/evaluation/multi_checkpoint_evaluator.py b/src/core/evaluation/multi_checkpoint_evaluator.py index aa1da2b..6d6e0f4 100644 --- a/src/core/evaluation/multi_checkpoint_evaluator.py +++ b/src/core/evaluation/multi_checkpoint_evaluator.py @@ -68,6 +68,11 @@ def evaluate_all(self) -> list[tuple[str, list[EvaluationResult], float | None]] for r in base_results: logger.info(f"base_model: accuracy={r.accuracy:.4f} ({r.correct}/{r.total})") + if r.num_truncated > 0: + pct = r.num_truncated / r.total * 100 + logger.warning( + f"base_model: {r.num_truncated}/{r.total} ({pct:.1f}%) sequences reached max_new_tokens" + ) self._free_vram() @@ -90,6 +95,11 @@ def evaluate_all(self) -> list[tuple[str, list[EvaluationResult], float | None]] for r in eval_results: logger.info(f"{ckpt_name}: accuracy={r.accuracy:.4f} ({r.correct}/{r.total})") + if r.num_truncated > 0: + pct = r.num_truncated / r.total * 100 + logger.warning( + f"{ckpt_name}: {r.num_truncated}/{r.total} ({pct:.1f}%) sequences reached max_new_tokens" + ) self._free_vram() @@ -122,6 +132,7 @@ def _save_summary(self, results: list[tuple[str, list[EvaluationResult], float | "accuracy": eval_results[ds_idx].accuracy, "total": eval_results[ds_idx].total, "correct": eval_results[ds_idx].correct, + "num_truncated": eval_results[ds_idx].num_truncated, } for ckpt_name, eval_results, epoch in results ] From 0635e419965e483f4c9facc3d89a058abcdc699e Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Wed, 11 Mar 2026 15:12:16 +0300 Subject: [PATCH 07/20] Upd llama mmlu sft by complexity splits --- src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py index 72e5278..7cdcce3 100644 --- a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py +++ b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py @@ -68,6 +68,7 @@ ], base_model_id=MODEL_NAME, generation=GenerationConfig(max_new_tokens=1, max_batch_size=64), + summary_filename="summary_cot.json", ), tokenizer=tokenizer, ) @@ -93,6 +94,7 @@ ], base_model_id=MODEL_NAME, generation=GenerationConfig(max_new_tokens=4096, max_batch_size=4), + summary_filename="summary_cot.json", ), tokenizer=tokenizer, ) From 0f24cd5db00fc40995c78741342477b950b27c9c Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Wed, 11 Mar 2026 12:23:14 +0000 Subject: [PATCH 08/20] Add device mapping --- src/core/evaluation/evaluator.py | 5 +++-- src/core/evaluation/multi_checkpoint_evaluator.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/core/evaluation/evaluator.py b/src/core/evaluation/evaluator.py index 47b97e0..69c5102 100644 --- a/src/core/evaluation/evaluator.py +++ b/src/core/evaluation/evaluator.py @@ -10,6 +10,7 @@ from core.datasets.qa_dataset import QADataset from core.datasets.qa_dataset_adapter import QADatasetAdapter from core.evaluation.continuous_batch_generator import ContinuousBatchGenerator +from core.utils.device import DEVICE_MAP from core.utils.logger import logger @@ -166,7 +167,7 @@ def _load_model(self): return self._load_lora_model(model_path, adapter_config) logger.info(f"Loading model from {self.config.model_path}") - model = AutoModelForCausalLM.from_pretrained(self.config.model_path) + model = AutoModelForCausalLM.from_pretrained(self.config.model_path, device_map=DEVICE_MAP) if not self.tokenizer: self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) return model, self.tokenizer @@ -182,7 +183,7 @@ def _load_lora_model(self, model_path: Path, adapter_config: Path): raise ValueError(f"adapter_config.json at {adapter_config} missing 'base_model_name_or_path'") logger.info(f"Loading LoRA model: base={base_model_id}, adapter={model_path}") - base_model = AutoModelForCausalLM.from_pretrained(base_model_id) + base_model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map=DEVICE_MAP) model = PeftModel.from_pretrained(base_model, str(model_path)) if not self.tokenizer: self.tokenizer = AutoTokenizer.from_pretrained(base_model_id) diff --git a/src/core/evaluation/multi_checkpoint_evaluator.py b/src/core/evaluation/multi_checkpoint_evaluator.py index 6d6e0f4..3912ea1 100644 --- a/src/core/evaluation/multi_checkpoint_evaluator.py +++ b/src/core/evaluation/multi_checkpoint_evaluator.py @@ -154,3 +154,5 @@ def _free_vram(self): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() + if torch.mps.is_available(): + torch.mps.empty_cache() From 8ccaabbf34f0aa9f0e28cb9695a3e32b07deb3ff Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Wed, 11 Mar 2026 12:30:29 +0000 Subject: [PATCH 09/20] Add dataset ID --- src/core/datasets/base_dataset.py | 3 ++- .../sft_by_complexity_splits/mmlu/llama_3b.py | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/core/datasets/base_dataset.py b/src/core/datasets/base_dataset.py index 8352df8..1c3b4d0 100644 --- a/src/core/datasets/base_dataset.py +++ b/src/core/datasets/base_dataset.py @@ -6,6 +6,7 @@ class BaseDatasetConfig(PydraConfig): path: str + dataset_id: str class BaseDataset[C: BaseDatasetConfig](ABC): @@ -24,4 +25,4 @@ def row_id(self, row: dict) -> str: ... @property def dataset_id(self) -> str: - return self.__class__.__name__ + return self.config.dataset_id diff --git a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py index 7cdcce3..315beff 100644 --- a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py +++ b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py @@ -35,7 +35,8 @@ .parent.joinpath( f"../../../../data/out/splits/single_token_entropy/mmlu/llama_3b/group{group}_train.parquet" ) - .as_posix() + .as_posix(), + dataset_id=f"mmlu_single_token_response_group{group}_train", ), tokenizer=tokenizer, ) @@ -59,7 +60,8 @@ .parent.joinpath( f"../../../../data/out/splits/single_token_entropy/mmlu/llama_3b/group{j}_test.parquet" ) - .as_posix() + .as_posix(), + dataset_id=f"mmlu_single_token_response_group{j}_test", ), tokenizer=tokenizer, ) @@ -67,7 +69,7 @@ for j in range(6) ], base_model_id=MODEL_NAME, - generation=GenerationConfig(max_new_tokens=1, max_batch_size=64), + generation=GenerationConfig(max_new_tokens=1, max_batch_size=32), summary_filename="summary_cot.json", ), tokenizer=tokenizer, @@ -85,7 +87,8 @@ .parent.joinpath( f"../../../../data/out/splits/single_token_entropy/mmlu/llama_3b/group{j}_test.parquet" ) - .as_posix() + .as_posix(), + dataset_id=f"mmlu_cot_response_group{j}_test", ), tokenizer=tokenizer, ) From 664f293b96ebc02bbbf0df45d4a4b37694b129a8 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Wed, 11 Mar 2026 15:36:44 +0300 Subject: [PATCH 10/20] Fix paths --- src/core/evaluation/multi_checkpoint_evaluator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/evaluation/multi_checkpoint_evaluator.py b/src/core/evaluation/multi_checkpoint_evaluator.py index 3912ea1..63495aa 100644 --- a/src/core/evaluation/multi_checkpoint_evaluator.py +++ b/src/core/evaluation/multi_checkpoint_evaluator.py @@ -56,7 +56,7 @@ def evaluate_all(self) -> list[tuple[str, list[EvaluationResult], float | None]] if self.config.base_model_id: logger.info(f"Evaluating base model {self.config.base_model_id} as epoch 0...") - base_out_path = str(self._out_path / "base_model") + base_out_path = str(self._out_path / "base_model" / "evals") base_config = EvaluatorConfig( model_path=self.config.base_model_id, eval_dataset=self.config.eval_dataset, @@ -78,7 +78,7 @@ def evaluate_all(self) -> list[tuple[str, list[EvaluationResult], float | None]] for ckpt_dir in checkpoint_dirs: ckpt_name = ckpt_dir.name - ckpt_out_path = str(self._out_path / ckpt_name) + ckpt_out_path = str(self._out_path / ckpt_name / "evals") logger.info(f"Evaluating {ckpt_name}...") From 1f8d294d93ce1d320dc0866a6bb80640df0fec9f Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Wed, 11 Mar 2026 12:58:50 +0000 Subject: [PATCH 11/20] Add torch compile --- .../evaluation/continuous_batch_generator.py | 18 +++++++++++++++++- src/core/evaluation/evaluator.py | 9 +++++++++ tests/test_continuous_batch_generator.py | 18 +++++++++--------- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/core/evaluation/continuous_batch_generator.py b/src/core/evaluation/continuous_batch_generator.py index 24fb0ac..a06cd13 100644 --- a/src/core/evaluation/continuous_batch_generator.py +++ b/src/core/evaluation/continuous_batch_generator.py @@ -80,6 +80,7 @@ def __init__( self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id self.eos_token_id = tokenizer.eos_token_id + self._is_compiled = isinstance(model, torch._dynamo.eval_frame.OptimizedModule) @torch.no_grad() def generate(self, prompts: list[list[int]]) -> GenerationResult: @@ -183,6 +184,20 @@ def _batched_decode(self, slots: list[_Slot]) -> None: # position_ids: each slot's actual position (prompt_len + generated so far) position_ids = torch.tensor([[s.seq_position] for s in slots], device=device) + # Pad batch dimension to max_batch_size to avoid recompilation when compiled + pad_batch = self.max_batch_size - num_slots if self._is_compiled else 0 + if pad_batch > 0: + input_ids = F.pad(input_ids, (0, 0, 0, pad_batch), value=float(self.pad_token_id)) + attn_mask = F.pad(attn_mask, (0, 0, 0, pad_batch), value=0) + position_ids = F.pad(position_ids, (0, 0, 0, pad_batch), value=0) + for layer_idx in range(len(batched_cache)): + batched_cache.key_cache[layer_idx] = F.pad( + batched_cache.key_cache[layer_idx], (0, 0, 0, 0, 0, 0, 0, pad_batch) + ) + batched_cache.value_cache[layer_idx] = F.pad( + batched_cache.value_cache[layer_idx], (0, 0, 0, 0, 0, 0, 0, pad_batch) + ) + # Single forward() call outputs = self.model( input_ids=input_ids, @@ -200,7 +215,8 @@ def _batched_decode(self, slots: list[_Slot]) -> None: slot.seq_position += 1 # Split updated cache back to per-slot, trim padding - updated_splits = outputs.past_key_values.batch_split(num_slots, split_size=1) + total_batch = num_slots + pad_batch + updated_splits = outputs.past_key_values.batch_split(total_batch, split_size=1) for i, slot in enumerate(slots): valid_len = slot_cache_lens[i] + 1 # original cache len + 1 new token slot.cache = _trim_cache(updated_splits[i], valid_len) diff --git a/src/core/evaluation/evaluator.py b/src/core/evaluation/evaluator.py index 69c5102..3ccb50f 100644 --- a/src/core/evaluation/evaluator.py +++ b/src/core/evaluation/evaluator.py @@ -2,6 +2,7 @@ from pathlib import Path import pandas as pd +import torch from pydantic import BaseModel from pydraconf import PydraConfig from tqdm import tqdm @@ -20,6 +21,7 @@ class GenerationConfig(BaseModel): temperature: float = 0.0 top_p: float = 1.0 top_k: int = -1 + torch_compile: bool = True class EvaluatorConfig(PydraConfig): @@ -56,6 +58,13 @@ def evaluate(self) -> list[EvaluationResult]: model, tokenizer = self._load_model() model.eval() + if self.config.generation.torch_compile: + if not torch.cuda.is_available(): + logger.warning("torch_compile=True but CUDA not available — skipping compilation.") + else: + logger.info("Compiling model with torch.compile(dynamic=True)... First forward call will be slow.") + model = torch.compile(model, dynamic=True) + results: list[EvaluationResult] = [] for ds, cached in tqdm(zip(self._datasets, cached_results), total=len(self._datasets), desc="Datasets"): if cached is not None: diff --git a/tests/test_continuous_batch_generator.py b/tests/test_continuous_batch_generator.py index 7c3574e..83e488e 100644 --- a/tests/test_continuous_batch_generator.py +++ b/tests/test_continuous_batch_generator.py @@ -63,7 +63,7 @@ def test_single_prompt(self, model_and_tokenizer): max_new_tokens=MAX_NEW_TOKENS, max_batch_size=1, ) - [result] = gen.generate([prompt]) + [result] = gen.generate([prompt]).sequences expected = _hf_generate_greedy(model, tokenizer, prompt, MAX_NEW_TOKENS) assert result == expected, f"Mismatch:\n got: {result}\n expected: {expected}" @@ -81,7 +81,7 @@ def test_multiple_same_length_prompts(self, model_and_tokenizer): max_new_tokens=MAX_NEW_TOKENS, max_batch_size=4, ) - results = gen.generate(prompts) + results = gen.generate(prompts).sequences for i, (prompt, result) in enumerate(zip(prompts, results)): expected = _hf_generate_greedy(model, tokenizer, prompt, MAX_NEW_TOKENS) @@ -102,7 +102,7 @@ def test_variable_length_prompts(self, model_and_tokenizer): max_new_tokens=MAX_NEW_TOKENS, max_batch_size=4, ) - results = gen.generate(prompts) + results = gen.generate(prompts).sequences for i, (prompt, result) in enumerate(zip(prompts, results)): expected = _hf_generate_greedy(model, tokenizer, prompt, MAX_NEW_TOKENS) @@ -125,7 +125,7 @@ def test_more_prompts_than_batch_size(self, model_and_tokenizer): max_new_tokens=MAX_NEW_TOKENS, max_batch_size=2, # only 2 slots, 5 prompts → continuous batching ) - results = gen.generate(prompts) + results = gen.generate(prompts).sequences for i, (prompt, result) in enumerate(zip(prompts, results)): expected = _hf_generate_greedy(model, tokenizer, prompt, MAX_NEW_TOKENS) @@ -142,7 +142,7 @@ def test_max_new_tokens_cutoff(self, model_and_tokenizer): max_new_tokens=max_tokens, max_batch_size=1, ) - [result] = gen.generate([prompt]) + [result] = gen.generate([prompt]).sequences assert len(result) <= max_tokens, f"Generated {len(result)} tokens, max was {max_tokens}" @@ -161,7 +161,7 @@ def test_result_order_preserved(self, model_and_tokenizer): max_new_tokens=MAX_NEW_TOKENS, max_batch_size=2, ) - results = gen.generate(prompts) + results = gen.generate(prompts).sequences assert len(results) == len(prompts) for i, (prompt, result) in enumerate(zip(prompts, results)): @@ -182,7 +182,7 @@ def test_batch_size_1_same_as_sequential(self, model_and_tokenizer): max_new_tokens=MAX_NEW_TOKENS, max_batch_size=1, ) - results = gen.generate(prompts) + results = gen.generate(prompts).sequences for i, (prompt, result) in enumerate(zip(prompts, results)): expected = _hf_generate_greedy(model, tokenizer, prompt, MAX_NEW_TOKENS) @@ -227,7 +227,7 @@ def test_batched_vs_sequential_time(self, model_and_tokenizer, many_prompts): max_batch_size=1, ) start = time.perf_counter() - results_seq = gen_seq.generate(many_prompts) + results_seq = gen_seq.generate(many_prompts).sequences time_sequential = time.perf_counter() - start # Batched: batch_size=8 @@ -238,7 +238,7 @@ def test_batched_vs_sequential_time(self, model_and_tokenizer, many_prompts): max_batch_size=8, ) start = time.perf_counter() - results_batch = gen_batch.generate(many_prompts) + results_batch = gen_batch.generate(many_prompts).sequences time_batched = time.perf_counter() - start # Verify correctness: both must produce same results From 408053e0c5f41bad73fcd7cd504cfc6587ce9293 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Wed, 11 Mar 2026 13:05:33 +0000 Subject: [PATCH 12/20] Speed up matmuls --- src/core/evaluation/evaluator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/core/evaluation/evaluator.py b/src/core/evaluation/evaluator.py index 3ccb50f..3a43d0b 100644 --- a/src/core/evaluation/evaluator.py +++ b/src/core/evaluation/evaluator.py @@ -63,6 +63,7 @@ def evaluate(self) -> list[EvaluationResult]: logger.warning("torch_compile=True but CUDA not available — skipping compilation.") else: logger.info("Compiling model with torch.compile(dynamic=True)... First forward call will be slow.") + torch.set_float32_matmul_precision("high") model = torch.compile(model, dynamic=True) results: list[EvaluationResult] = [] From 0c79a8208ac91d9f4a51742c86bbe5b88c6b0a5b Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Thu, 12 Mar 2026 14:27:21 +0300 Subject: [PATCH 13/20] Upd hyperparameters --- .../sft_by_complexity_splits/mmlu/llama_3b.py | 4 +- .../sft_by_complexity_splits/mmlu/phi4mini.py | 103 ++++++++++++++++++ .../sft_by_complexity_splits/mmlu/qwen_3b.py | 103 ++++++++++++++++++ 3 files changed, 208 insertions(+), 2 deletions(-) create mode 100644 src/experiments/sft_by_complexity_splits/mmlu/phi4mini.py create mode 100644 src/experiments/sft_by_complexity_splits/mmlu/qwen_3b.py diff --git a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py index 315beff..05ef309 100644 --- a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py +++ b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py @@ -41,7 +41,7 @@ tokenizer=tokenizer, ) ), - training_args=LoRATrainingArgs(num_train_epochs=20, per_device_train_batch_size=32), + training_args=LoRATrainingArgs(num_train_epochs=20, per_device_train_batch_size=8), save_schedule=[1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20], ), tokenizer=tokenizer, @@ -70,7 +70,7 @@ ], base_model_id=MODEL_NAME, generation=GenerationConfig(max_new_tokens=1, max_batch_size=32), - summary_filename="summary_cot.json", + summary_filename="summary_single_token.json", ), tokenizer=tokenizer, ) diff --git a/src/experiments/sft_by_complexity_splits/mmlu/phi4mini.py b/src/experiments/sft_by_complexity_splits/mmlu/phi4mini.py new file mode 100644 index 0000000..4bffed0 --- /dev/null +++ b/src/experiments/sft_by_complexity_splits/mmlu/phi4mini.py @@ -0,0 +1,103 @@ +from pathlib import Path + +from transformers import AutoTokenizer + +from core.datasets.causal_dataset_adapter import CausalDatasetAdapter +from core.datasets.mmlu.mmlu_cot_response_dataset import MMLUCoTResponseDataset +from core.datasets.mmlu.mmlu_single_token_response_dataset import MMLUSingleTokenResponseDataset, QADatasetConfig +from core.datasets.qa_dataset_adapter import QADatasetAdapter +from core.evaluation.multi_checkpoint_evaluator import ( + GenerationConfig, + MultiCheckpointEvaluator, + MultiCheckpointEvaluatorConfig, +) +from core.training.lora_trainer import LoRATrainer, LoRATrainerConfig, LoRATrainingArgs +from core.utils.logger import logger + +MODEL_NAME = "microsoft/Phi-4-mini-instruct" + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + +for group in range(6): + logger.info(f"Training on group {group}...") + + trainer = LoRATrainer( + config=LoRATrainerConfig( + out_path=Path(__file__) + .parent.joinpath(f"../../../../artifacts/sft_by_complexity_splits/mmlu/phi4mini/group{group}") + .as_posix(), + model_id=MODEL_NAME, + train_dataset=CausalDatasetAdapter( + dataset=MMLUSingleTokenResponseDataset( + config=QADatasetConfig( + path=Path(__file__) + .parent.joinpath( + f"../../../../data/out/splits/single_token_entropy/mmlu/phi4mini/group{group}_train.parquet" + ) + .as_posix(), + dataset_id=f"mmlu_single_token_response_group{group}_train", + ), + tokenizer=tokenizer, + ) + ), + training_args=LoRATrainingArgs(num_train_epochs=20, per_device_train_batch_size=32), + save_schedule=[1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20], + ), + tokenizer=tokenizer, + ) + trainer.train() + trainer.unload() + + single_token_evaluator = MultiCheckpointEvaluator( + config=MultiCheckpointEvaluatorConfig( + checkpoints_dir=trainer.config.out_path, + eval_dataset=[ + QADatasetAdapter( + dataset=MMLUSingleTokenResponseDataset( + config=QADatasetConfig( + path=Path(__file__) + .parent.joinpath( + f"../../../../data/out/splits/single_token_entropy/mmlu/phi4mini/group{j}_test.parquet" + ) + .as_posix(), + dataset_id=f"mmlu_single_token_response_group{j}_test", + ), + tokenizer=tokenizer, + ) + ) + for j in range(6) + ], + base_model_id=MODEL_NAME, + generation=GenerationConfig(max_new_tokens=1, max_batch_size=64), + summary_filename="summary_single_token.json", + ), + tokenizer=tokenizer, + ) + single_token_evaluator.evaluate_all() + + cot_evaluator = MultiCheckpointEvaluator( + config=MultiCheckpointEvaluatorConfig( + checkpoints_dir=trainer.config.out_path, + eval_dataset=[ + QADatasetAdapter( + dataset=MMLUCoTResponseDataset( + config=QADatasetConfig( + path=Path(__file__) + .parent.joinpath( + f"../../../../data/out/splits/single_token_entropy/mmlu/phi4mini/group{j}_test.parquet" + ) + .as_posix(), + dataset_id=f"mmlu_cot_response_group{j}_test", + ), + tokenizer=tokenizer, + ) + ) + for j in range(6) + ], + base_model_id=MODEL_NAME, + generation=GenerationConfig(max_new_tokens=4096, max_batch_size=8), + summary_filename="summary_cot.json", + ), + tokenizer=tokenizer, + ) + cot_evaluator.evaluate_all() diff --git a/src/experiments/sft_by_complexity_splits/mmlu/qwen_3b.py b/src/experiments/sft_by_complexity_splits/mmlu/qwen_3b.py new file mode 100644 index 0000000..5968d74 --- /dev/null +++ b/src/experiments/sft_by_complexity_splits/mmlu/qwen_3b.py @@ -0,0 +1,103 @@ +from pathlib import Path + +from transformers import AutoTokenizer + +from core.datasets.causal_dataset_adapter import CausalDatasetAdapter +from core.datasets.mmlu.mmlu_cot_response_dataset import MMLUCoTResponseDataset +from core.datasets.mmlu.mmlu_single_token_response_dataset import MMLUSingleTokenResponseDataset, QADatasetConfig +from core.datasets.qa_dataset_adapter import QADatasetAdapter +from core.evaluation.multi_checkpoint_evaluator import ( + GenerationConfig, + MultiCheckpointEvaluator, + MultiCheckpointEvaluatorConfig, +) +from core.training.lora_trainer import LoRATrainer, LoRATrainerConfig, LoRATrainingArgs +from core.utils.logger import logger + +MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct" + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + +for group in range(6): + logger.info(f"Training on group {group}...") + + trainer = LoRATrainer( + config=LoRATrainerConfig( + out_path=Path(__file__) + .parent.joinpath(f"../../../../artifacts/sft_by_complexity_splits/mmlu/qwen_3b/group{group}") + .as_posix(), + model_id=MODEL_NAME, + train_dataset=CausalDatasetAdapter( + dataset=MMLUSingleTokenResponseDataset( + config=QADatasetConfig( + path=Path(__file__) + .parent.joinpath( + f"../../../../data/out/splits/single_token_entropy/mmlu/qwen_3b/group{group}_train.parquet" + ) + .as_posix(), + dataset_id=f"mmlu_single_token_response_group{group}_train", + ), + tokenizer=tokenizer, + ) + ), + training_args=LoRATrainingArgs(num_train_epochs=20, per_device_train_batch_size=32), + save_schedule=[1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20], + ), + tokenizer=tokenizer, + ) + trainer.train() + trainer.unload() + + single_token_evaluator = MultiCheckpointEvaluator( + config=MultiCheckpointEvaluatorConfig( + checkpoints_dir=trainer.config.out_path, + eval_dataset=[ + QADatasetAdapter( + dataset=MMLUSingleTokenResponseDataset( + config=QADatasetConfig( + path=Path(__file__) + .parent.joinpath( + f"../../../../data/out/splits/single_token_entropy/mmlu/qwen_3b/group{j}_test.parquet" + ) + .as_posix(), + dataset_id=f"mmlu_single_token_response_group{j}_test", + ), + tokenizer=tokenizer, + ) + ) + for j in range(6) + ], + base_model_id=MODEL_NAME, + generation=GenerationConfig(max_new_tokens=1, max_batch_size=64), + summary_filename="summary_single_token.json", + ), + tokenizer=tokenizer, + ) + single_token_evaluator.evaluate_all() + + cot_evaluator = MultiCheckpointEvaluator( + config=MultiCheckpointEvaluatorConfig( + checkpoints_dir=trainer.config.out_path, + eval_dataset=[ + QADatasetAdapter( + dataset=MMLUCoTResponseDataset( + config=QADatasetConfig( + path=Path(__file__) + .parent.joinpath( + f"../../../../data/out/splits/single_token_entropy/mmlu/qwen_3b/group{j}_test.parquet" + ) + .as_posix(), + dataset_id=f"mmlu_cot_response_group{j}_test", + ), + tokenizer=tokenizer, + ) + ) + for j in range(6) + ], + base_model_id=MODEL_NAME, + generation=GenerationConfig(max_new_tokens=4096, max_batch_size=16), + summary_filename="summary_cot.json", + ), + tokenizer=tokenizer, + ) + cot_evaluator.evaluate_all() From 96931ed2ebc6035bd76787d8ffbd7a909a97badf Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Thu, 12 Mar 2026 12:18:37 +0000 Subject: [PATCH 14/20] Fix repetitive generation --- .../evaluation/continuous_batch_generator.py | 26 +++++++++--------- tests/test_continuous_batch_generator.py | 27 +++++++++++++++++++ 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/src/core/evaluation/continuous_batch_generator.py b/src/core/evaluation/continuous_batch_generator.py index a06cd13..4be0a51 100644 --- a/src/core/evaluation/continuous_batch_generator.py +++ b/src/core/evaluation/continuous_batch_generator.py @@ -39,16 +39,6 @@ def _right_pad_cache(cache: DynamicCache, target_len: int) -> DynamicCache: return padded -def _trim_cache(cache: DynamicCache, valid_len: int) -> DynamicCache: - """Trim KV cache to only the first valid_len entries along seq dim.""" - trimmed = DynamicCache() - for layer_idx in range(len(cache)): - k = cache.key_cache[layer_idx][:, :, :valid_len, :] - v = cache.value_cache[layer_idx][:, :, :valid_len, :] - trimmed.update(k, v, layer_idx) - return trimmed - - class ContinuousBatchGenerator: """Token-by-token generation with continuous batching via model.forward(). @@ -214,12 +204,22 @@ def _batched_decode(self, slots: list[_Slot]) -> None: slot.generated_ids.append(next_token.item()) slot.seq_position += 1 - # Split updated cache back to per-slot, trim padding + # Split updated cache back to per-slot, keeping original valid entries + new token total_batch = num_slots + pad_batch updated_splits = outputs.past_key_values.batch_split(total_batch, split_size=1) for i, slot in enumerate(slots): - valid_len = slot_cache_lens[i] + 1 # original cache len + 1 new token - slot.cache = _trim_cache(updated_splits[i], valid_len) + split_cache = updated_splits[i] + rebuilt = DynamicCache() + for layer_idx in range(len(split_cache)): + k = split_cache.key_cache[layer_idx] + v = split_cache.value_cache[layer_idx] + # DynamicCache.update() appends new KV at the end (position max_cache_len), + # not at slot's actual cache length. Keep [0:slot_cache_len] (valid original) + # and [max_cache_len:max_cache_len+1] (the new token), skip padding in between. + k_new = torch.cat([k[:, :, :slot_cache_lens[i], :], k[:, :, max_cache_len:max_cache_len + 1, :]], dim=2) + v_new = torch.cat([v[:, :, :slot_cache_lens[i], :], v[:, :, max_cache_len:max_cache_len + 1, :]], dim=2) + rebuilt.update(k_new, v_new, layer_idx) + slot.cache = rebuilt def _sample_token(self, logits: torch.Tensor) -> torch.Tensor: """Sample a single token from logits of shape [1, vocab_size].""" diff --git a/tests/test_continuous_batch_generator.py b/tests/test_continuous_batch_generator.py index 83e488e..288af6b 100644 --- a/tests/test_continuous_batch_generator.py +++ b/tests/test_continuous_batch_generator.py @@ -108,6 +108,33 @@ def test_variable_length_prompts(self, model_and_tokenizer): expected = _hf_generate_greedy(model, tokenizer, prompt, MAX_NEW_TOKENS) assert result == expected, f"Prompt {i} (len={len(prompt)}) mismatch:\n got: {result}\n expected: {expected}" + def test_variable_length_prompts_batched_pairwise(self, model_and_tokenizer): + """Prompts with very different lengths forced into the same batch (batch_size=2). + + This specifically tests that the KV cache trim logic correctly preserves + the new token's KV entry for the shorter slot when slots have different + cache lengths in the same batch. + """ + model, tokenizer = model_and_tokenizer + prompts = [ + tokenizer.encode("Hi"), # very short + tokenizer.encode("The quick brown fox jumps over the lazy dog and then keeps running"), # much longer + ] + + gen = ContinuousBatchGenerator( + model=model, + tokenizer=tokenizer, + max_new_tokens=MAX_NEW_TOKENS, + max_batch_size=2, # forces both into the same batch + ) + results = gen.generate(prompts).sequences + + for i, (prompt, result) in enumerate(zip(prompts, results)): + expected = _hf_generate_greedy(model, tokenizer, prompt, MAX_NEW_TOKENS) + assert result == expected, ( + f"Prompt {i} (len={len(prompt)}) mismatch:\n got: {result}\n expected: {expected}" + ) + def test_more_prompts_than_batch_size(self, model_and_tokenizer): """Tests continuous batching: prompts > max_batch_size forces queuing.""" model, tokenizer = model_and_tokenizer From 699ede78f48de9e7555450a0070ada42f6951548 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Thu, 12 Mar 2026 21:47:24 +0300 Subject: [PATCH 15/20] Fix CoT resposne MMLU dataset --- src/core/datasets/mmlu/mmlu_cot_response_dataset.py | 3 ++- src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/core/datasets/mmlu/mmlu_cot_response_dataset.py b/src/core/datasets/mmlu/mmlu_cot_response_dataset.py index 947cf57..3e23ee8 100644 --- a/src/core/datasets/mmlu/mmlu_cot_response_dataset.py +++ b/src/core/datasets/mmlu/mmlu_cot_response_dataset.py @@ -40,7 +40,8 @@ def verify_assistant_response(self, row: dict, assistant_response: str) -> tuple .lower() ) + correct_answer = str(row["answer"]).strip().lower() try: - return extracted_answer, self.assistant_response(row) == extracted_answer + return extracted_answer, correct_answer == extracted_answer except: return extracted_answer, False diff --git a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py index 05ef309..57510f3 100644 --- a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py +++ b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py @@ -41,7 +41,7 @@ tokenizer=tokenizer, ) ), - training_args=LoRATrainingArgs(num_train_epochs=20, per_device_train_batch_size=8), + training_args=LoRATrainingArgs(num_train_epochs=20, per_device_train_batch_size=32), save_schedule=[1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20], ), tokenizer=tokenizer, @@ -69,7 +69,7 @@ for j in range(6) ], base_model_id=MODEL_NAME, - generation=GenerationConfig(max_new_tokens=1, max_batch_size=32), + generation=GenerationConfig(max_new_tokens=1, max_batch_size=64), summary_filename="summary_single_token.json", ), tokenizer=tokenizer, @@ -96,7 +96,7 @@ for j in range(6) ], base_model_id=MODEL_NAME, - generation=GenerationConfig(max_new_tokens=4096, max_batch_size=4), + generation=GenerationConfig(max_new_tokens=4096, max_batch_size=16), summary_filename="summary_cot.json", ), tokenizer=tokenizer, From 88a5ee0e411a31bf3e32ed14bc9067e205da4abe Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Fri, 13 Mar 2026 00:28:04 +0300 Subject: [PATCH 16/20] Upd hyperparameters --- src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py | 2 +- src/experiments/sft_by_complexity_splits/mmlu/phi4mini.py | 2 +- src/experiments/sft_by_complexity_splits/mmlu/qwen_3b.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py index 57510f3..ae9cbc0 100644 --- a/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py +++ b/src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py @@ -96,7 +96,7 @@ for j in range(6) ], base_model_id=MODEL_NAME, - generation=GenerationConfig(max_new_tokens=4096, max_batch_size=16), + generation=GenerationConfig(max_new_tokens=6000, max_batch_size=12), summary_filename="summary_cot.json", ), tokenizer=tokenizer, diff --git a/src/experiments/sft_by_complexity_splits/mmlu/phi4mini.py b/src/experiments/sft_by_complexity_splits/mmlu/phi4mini.py index 4bffed0..aefe74c 100644 --- a/src/experiments/sft_by_complexity_splits/mmlu/phi4mini.py +++ b/src/experiments/sft_by_complexity_splits/mmlu/phi4mini.py @@ -95,7 +95,7 @@ for j in range(6) ], base_model_id=MODEL_NAME, - generation=GenerationConfig(max_new_tokens=4096, max_batch_size=8), + generation=GenerationConfig(max_new_tokens=6000, max_batch_size=8), summary_filename="summary_cot.json", ), tokenizer=tokenizer, diff --git a/src/experiments/sft_by_complexity_splits/mmlu/qwen_3b.py b/src/experiments/sft_by_complexity_splits/mmlu/qwen_3b.py index 5968d74..e8365e4 100644 --- a/src/experiments/sft_by_complexity_splits/mmlu/qwen_3b.py +++ b/src/experiments/sft_by_complexity_splits/mmlu/qwen_3b.py @@ -68,7 +68,7 @@ for j in range(6) ], base_model_id=MODEL_NAME, - generation=GenerationConfig(max_new_tokens=1, max_batch_size=64), + generation=GenerationConfig(max_new_tokens=1, max_batch_size=16), summary_filename="summary_single_token.json", ), tokenizer=tokenizer, @@ -95,7 +95,7 @@ for j in range(6) ], base_model_id=MODEL_NAME, - generation=GenerationConfig(max_new_tokens=4096, max_batch_size=16), + generation=GenerationConfig(max_new_tokens=6000, max_batch_size=4), summary_filename="summary_cot.json", ), tokenizer=tokenizer, From c2444f386cfb3f1b7ebeae389846e8b63d4440fe Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Thu, 12 Mar 2026 21:30:02 +0000 Subject: [PATCH 17/20] Enhance continuous batching --- .../evaluation/continuous_batch_generator.py | 246 ++++++++++++------ 1 file changed, 160 insertions(+), 86 deletions(-) diff --git a/src/core/evaluation/continuous_batch_generator.py b/src/core/evaluation/continuous_batch_generator.py index 4be0a51..fabfbb8 100644 --- a/src/core/evaluation/continuous_batch_generator.py +++ b/src/core/evaluation/continuous_batch_generator.py @@ -1,10 +1,11 @@ from collections import deque from dataclasses import dataclass, field +from typing import Optional import torch -import torch.nn.functional as F from tqdm import tqdm from transformers import DynamicCache, PreTrainedModel, PreTrainedTokenizer +from transformers.cache_utils import _static_cache_update @dataclass @@ -18,36 +19,87 @@ class GenerationResult: class _Slot: index: int prompt_len: int + batch_idx: int generated_ids: list[int] = field(default_factory=list) - cache: DynamicCache = field(default_factory=DynamicCache) seq_position: int = 0 # total tokens seen = prompt_len + len(generated_ids) -def _right_pad_cache(cache: DynamicCache, target_len: int) -> DynamicCache: - """Pad KV cache tensors along seq dim (dim=-2) with zeros to target_len.""" - current_len = cache.get_seq_length() - if current_len == target_len: - return cache - pad_len = target_len - current_len - padded = DynamicCache() - for layer_idx in range(len(cache)): - k = cache.key_cache[layer_idx] # [1, H, T, D] - v = cache.value_cache[layer_idx] - k_pad = F.pad(k, (0, 0, 0, pad_len)) # pad dim=-2 - v_pad = F.pad(v, (0, 0, 0, pad_len)) - padded.update(k_pad, v_pad, layer_idx) - return padded +class _PreAllocatedBatchCache(DynamicCache): + """Pre-allocated KV cache that updates in-place via index_copy_. + + Subclasses DynamicCache so models take the DynamicCache code path in + _update_causal_mask (target_length = attention_mask.shape[-1]). + """ + + def __init__( + self, + num_layers: int, + num_kv_heads: int, + head_dim: int, + max_batch_size: int, + max_cache_len: int, + device: torch.device, + dtype: torch.dtype, + ) -> None: + super().__init__() + self._seen_tokens = 0 + self._active_seq_len = 0 + cache_shape = (max_batch_size, num_kv_heads, max_cache_len, head_dim) + for _ in range(num_layers): + k = torch.zeros(cache_shape, dtype=dtype, device=device) + v = torch.zeros(cache_shape, dtype=dtype, device=device) + torch._dynamo.mark_static_address(k) + torch._dynamo.mark_static_address(v) + self.key_cache.append(k) + self.value_cache.append(v) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict] = None, + ): + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None + _static_cache_update( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + key_states, + value_states, + cache_position, + ) + # Return a view trimmed to _active_seq_len + 1 so attention only sees + # valid positions, not the full pre-allocated length. + seq_end = self._active_seq_len + 1 + return ( + self.key_cache[layer_idx][:, :, :seq_end, :], + self.value_cache[layer_idx][:, :, :seq_end, :], + ) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + return self._active_seq_len + + def get_max_cache_shape(self): + return None + + def reset_slot(self, slot_idx: int) -> None: + for layer_idx in range(len(self)): + self.key_cache[layer_idx][slot_idx].zero_() + self.value_cache[layer_idx][slot_idx].zero_() class ContinuousBatchGenerator: """Token-by-token generation with continuous batching via model.forward(). - Maintains a pool of active slots. Empty slots are filled from a queue of - pending prompts. Each slot holds its own DynamicCache (batch_size=1). + Maintains a pool of active slots backed by a single pre-allocated KV cache. + Empty slots are filled from a queue of pending prompts. Prefill runs individually per prompt (no padding waste). Decode batches - all active slots into a single forward() call by padding KV caches to - equal length and using an attention mask to ignore padded positions. + all active slots into a single forward() call using an attention mask to + ignore padded positions. The KV cache is updated in-place with zero + tensor allocations per decode step. """ def __init__( @@ -70,7 +122,23 @@ def __init__( self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id self.eos_token_id = tokenizer.eos_token_id - self._is_compiled = isinstance(model, torch._dynamo.eval_frame.OptimizedModule) + + def _init_cache(self, max_seq_len: int) -> _PreAllocatedBatchCache: + config = self.model.config + num_layers = config.num_hidden_layers + num_kv_heads = getattr(config, "num_key_value_heads", None) or config.num_attention_heads + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + device = self.model.device + dtype = self.model.dtype + return _PreAllocatedBatchCache( + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + max_batch_size=self.max_batch_size, + max_cache_len=max_seq_len, + device=device, + dtype=dtype, + ) @torch.no_grad() def generate(self, prompts: list[list[int]]) -> GenerationResult: @@ -88,13 +156,19 @@ def generate(self, prompts: list[list[int]]) -> GenerationResult: num_truncated = 0 pbar = tqdm(total=len(prompts), desc="Generating") + # Pre-allocate the batched KV cache + max_prompt_len = max(len(p) for p in prompts) + max_seq_len = max_prompt_len + self.max_new_tokens + self._cache = self._init_cache(max_seq_len) + self._valid_lens = [0] * self.max_batch_size + while queue or any(s is not None for s in active_slots): # FILL: prefill empty slots with new prompts (batch_size=1 each) for slot_idx in range(self.max_batch_size): if active_slots[slot_idx] is not None or not queue: continue prompt_idx, prompt_ids = queue.popleft() - active_slots[slot_idx] = self._prefill(prompt_idx, prompt_ids) + active_slots[slot_idx] = self._prefill(prompt_idx, prompt_ids, slot_idx) # Collect occupied slots occupied = [(i, s) for i, s in enumerate(active_slots) if s is not None] @@ -119,107 +193,107 @@ def generate(self, prompts: list[list[int]]) -> GenerationResult: sequences = [r if r is not None else [] for r in results] return GenerationResult(sequences=sequences, num_truncated=num_truncated, total=len(prompts)) - def _prefill(self, prompt_idx: int, prompt_ids: list[int]) -> _Slot: + def _prefill(self, prompt_idx: int, prompt_ids: list[int], batch_idx: int) -> _Slot: """Run the prefill forward pass to build KV cache and sample the first token.""" device = self.model.device input_ids = torch.tensor([prompt_ids], device=device) seq_len = len(prompt_ids) position_ids = torch.arange(seq_len, device=device).unsqueeze(0) cache_position = torch.arange(seq_len, device=device) - cache = DynamicCache() + # Use a temporary DynamicCache for prefill (runs once per prompt) + tmp_cache = DynamicCache() outputs = self.model( input_ids=input_ids, position_ids=position_ids, cache_position=cache_position, - past_key_values=cache, + past_key_values=tmp_cache, use_cache=True, ) + # Copy prefill KV into the pre-allocated cache at batch_idx + prefill_cache = outputs.past_key_values + for layer_idx in range(len(prefill_cache)): + k = prefill_cache.key_cache[layer_idx] # [1, H, seq_len, D] + v = prefill_cache.value_cache[layer_idx] + self._cache.key_cache[layer_idx][batch_idx, :, :seq_len, :] = k[0] + self._cache.value_cache[layer_idx][batch_idx, :, :seq_len, :] = v[0] + self._valid_lens[batch_idx] = seq_len + next_token = self._sample_token(outputs.logits[:, -1, :]) return _Slot( index=prompt_idx, prompt_len=seq_len, + batch_idx=batch_idx, generated_ids=[next_token.item()], - cache=outputs.past_key_values, seq_position=seq_len + 1, ) def _batched_decode(self, slots: list[_Slot]) -> None: - """Run a single batched decode step for all active slots.""" + """Run a single batched decode step for all active slots. + + Uses the pre-allocated cache — zero tensor allocations per step. + """ device = self.model.device - num_slots = len(slots) - - # Cache lengths before padding (needed for trim after forward) - slot_cache_lens = [s.cache.get_seq_length() for s in slots] - max_cache_len = max(slot_cache_lens) - - # Pad each slot's KV cache to max_cache_len, then merge into batched cache - padded_caches = [_right_pad_cache(s.cache, max_cache_len) for s in slots] - batched_cache = DynamicCache.from_batch_splits(padded_caches) - - # input_ids: last generated token per slot [num_slots, 1] - input_ids = torch.tensor([[s.generated_ids[-1]] for s in slots], device=device) - - # attention_mask: [num_slots, max_cache_len + 1] (+1 for the new token) - attn_mask = torch.zeros(num_slots, max_cache_len + 1, dtype=torch.long, device=device) - for i, slot in enumerate(slots): - attn_mask[i, : slot_cache_lens[i]] = 1 # valid cached positions - attn_mask[i, max_cache_len] = 1 # the new token position (appended at end) - - # cache_position: shared across batch, points to where the new KV is appended - cache_position = torch.tensor([max_cache_len], device=device) - - # position_ids: each slot's actual position (prompt_len + generated so far) - position_ids = torch.tensor([[s.seq_position] for s in slots], device=device) - - # Pad batch dimension to max_batch_size to avoid recompilation when compiled - pad_batch = self.max_batch_size - num_slots if self._is_compiled else 0 - if pad_batch > 0: - input_ids = F.pad(input_ids, (0, 0, 0, pad_batch), value=float(self.pad_token_id)) - attn_mask = F.pad(attn_mask, (0, 0, 0, pad_batch), value=0) - position_ids = F.pad(position_ids, (0, 0, 0, pad_batch), value=0) - for layer_idx in range(len(batched_cache)): - batched_cache.key_cache[layer_idx] = F.pad( - batched_cache.key_cache[layer_idx], (0, 0, 0, 0, 0, 0, 0, pad_batch) - ) - batched_cache.value_cache[layer_idx] = F.pad( - batched_cache.value_cache[layer_idx], (0, 0, 0, 0, 0, 0, 0, pad_batch) - ) - - # Single forward() call + + # Determine shared write position + max_active_len = max(self._valid_lens[s.batch_idx] for s in slots) + self._cache._active_seq_len = max_active_len + + # Build input_ids [max_batch_size, 1] + input_ids = torch.full( + (self.max_batch_size, 1), self.pad_token_id, dtype=torch.long, device=device + ) + for slot in slots: + input_ids[slot.batch_idx, 0] = slot.generated_ids[-1] + + # Build attention_mask [max_batch_size, max_active_len + 1] + attn_mask = torch.zeros( + self.max_batch_size, max_active_len + 1, dtype=torch.long, device=device + ) + for slot in slots: + valid_len = self._valid_lens[slot.batch_idx] + attn_mask[slot.batch_idx, :valid_len] = 1 # valid cached positions + attn_mask[slot.batch_idx, max_active_len] = 1 # the new token position + + # cache_position: shared write position + cache_position = torch.tensor([max_active_len], device=device) + + # position_ids [max_batch_size, 1] + position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long, device=device) + for slot in slots: + position_ids[slot.batch_idx, 0] = slot.seq_position + + # Single forward() call — model writes new KV in-place at cache_position outputs = self.model( input_ids=input_ids, attention_mask=attn_mask, position_ids=position_ids, cache_position=cache_position, - past_key_values=batched_cache, + past_key_values=self._cache, use_cache=True, ) # Sample next token per slot - for i, slot in enumerate(slots): - next_token = self._sample_token(outputs.logits[i : i + 1, -1, :]) + for slot in slots: + next_token = self._sample_token(outputs.logits[slot.batch_idx : slot.batch_idx + 1, -1, :]) slot.generated_ids.append(next_token.item()) slot.seq_position += 1 - # Split updated cache back to per-slot, keeping original valid entries + new token - total_batch = num_slots + pad_batch - updated_splits = outputs.past_key_values.batch_split(total_batch, split_size=1) - for i, slot in enumerate(slots): - split_cache = updated_splits[i] - rebuilt = DynamicCache() - for layer_idx in range(len(split_cache)): - k = split_cache.key_cache[layer_idx] - v = split_cache.value_cache[layer_idx] - # DynamicCache.update() appends new KV at the end (position max_cache_len), - # not at slot's actual cache length. Keep [0:slot_cache_len] (valid original) - # and [max_cache_len:max_cache_len+1] (the new token), skip padding in between. - k_new = torch.cat([k[:, :, :slot_cache_lens[i], :], k[:, :, max_cache_len:max_cache_len + 1, :]], dim=2) - v_new = torch.cat([v[:, :, :slot_cache_lens[i], :], v[:, :, max_cache_len:max_cache_len + 1, :]], dim=2) - rebuilt.update(k_new, v_new, layer_idx) - slot.cache = rebuilt + # Compact: move new KV from shared write position to each slot's actual position + num_layers = len(self._cache) + for slot in slots: + valid_len = self._valid_lens[slot.batch_idx] + if valid_len < max_active_len: + for layer_idx in range(num_layers): + self._cache.key_cache[layer_idx][slot.batch_idx, :, valid_len, :] = ( + self._cache.key_cache[layer_idx][slot.batch_idx, :, max_active_len, :] + ) + self._cache.value_cache[layer_idx][slot.batch_idx, :, valid_len, :] = ( + self._cache.value_cache[layer_idx][slot.batch_idx, :, max_active_len, :] + ) + self._valid_lens[slot.batch_idx] = valid_len + 1 def _sample_token(self, logits: torch.Tensor) -> torch.Tensor: """Sample a single token from logits of shape [1, vocab_size].""" From 9c0b0f6b5e65ffe211df13a5394cc27dbc88b7f8 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Thu, 12 Mar 2026 21:31:14 +0000 Subject: [PATCH 18/20] Add logs --- src/core/evaluation/evaluator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/core/evaluation/evaluator.py b/src/core/evaluation/evaluator.py index 3a43d0b..a5b2329 100644 --- a/src/core/evaluation/evaluator.py +++ b/src/core/evaluation/evaluator.py @@ -79,7 +79,9 @@ def _evaluate_single(self, eval_dataset: QADatasetAdapter, model, tokenizer) -> ds = eval_dataset.process_dataset() prompts = [row["input_ids"] for row in ds] - logger.info(f"Evaluating {len(prompts)} samples with model from {self.config.model_path}") + logger.info( + f"Evaluating {len(prompts)} samples with model from {self.config.model_path} for dataset {eval_dataset.dataset.dataset_id}..." + ) generator = ContinuousBatchGenerator( model=model, From b68608b686f68b0200f18de4dd05cc4b8583f547 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Thu, 12 Mar 2026 21:42:31 +0000 Subject: [PATCH 19/20] Fix pre-fill --- .../evaluation/continuous_batch_generator.py | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/core/evaluation/continuous_batch_generator.py b/src/core/evaluation/continuous_batch_generator.py index fabfbb8..a6fc942 100644 --- a/src/core/evaluation/continuous_batch_generator.py +++ b/src/core/evaluation/continuous_batch_generator.py @@ -194,30 +194,39 @@ def generate(self, prompts: list[list[int]]) -> GenerationResult: return GenerationResult(sequences=sequences, num_truncated=num_truncated, total=len(prompts)) def _prefill(self, prompt_idx: int, prompt_ids: list[int], batch_idx: int) -> _Slot: - """Run the prefill forward pass to build KV cache and sample the first token.""" + """Run the prefill forward pass to build KV cache and sample the first token. + + Prefills directly into a slice of the pre-allocated cache to avoid + allocating a temporary DynamicCache. + """ device = self.model.device input_ids = torch.tensor([prompt_ids], device=device) seq_len = len(prompt_ids) position_ids = torch.arange(seq_len, device=device).unsqueeze(0) cache_position = torch.arange(seq_len, device=device) - # Use a temporary DynamicCache for prefill (runs once per prompt) - tmp_cache = DynamicCache() + # Build a single-slot view into the pre-allocated cache so the model + # writes KV directly into the right batch row. + prefill_cache = _PreAllocatedBatchCache.__new__(_PreAllocatedBatchCache) + DynamicCache.__init__(prefill_cache) + prefill_cache._seen_tokens = 0 + prefill_cache._active_seq_len = seq_len - 1 # update() will see seq_end = seq_len + prefill_cache.key_cache = [ + self._cache.key_cache[l][batch_idx : batch_idx + 1] + for l in range(len(self._cache)) + ] + prefill_cache.value_cache = [ + self._cache.value_cache[l][batch_idx : batch_idx + 1] + for l in range(len(self._cache)) + ] + outputs = self.model( input_ids=input_ids, position_ids=position_ids, cache_position=cache_position, - past_key_values=tmp_cache, + past_key_values=prefill_cache, use_cache=True, ) - - # Copy prefill KV into the pre-allocated cache at batch_idx - prefill_cache = outputs.past_key_values - for layer_idx in range(len(prefill_cache)): - k = prefill_cache.key_cache[layer_idx] # [1, H, seq_len, D] - v = prefill_cache.value_cache[layer_idx] - self._cache.key_cache[layer_idx][batch_idx, :, :seq_len, :] = k[0] - self._cache.value_cache[layer_idx][batch_idx, :, :seq_len, :] = v[0] self._valid_lens[batch_idx] = seq_len next_token = self._sample_token(outputs.logits[:, -1, :]) From 3be2a91209248abf7c8a3da7481a733d11196332 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Thu, 12 Mar 2026 23:04:03 +0000 Subject: [PATCH 20/20] IMprove perf --- .../evaluation/continuous_batch_generator.py | 64 +++++++++++-------- src/core/evaluation/evaluator.py | 15 ++++- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/src/core/evaluation/continuous_batch_generator.py b/src/core/evaluation/continuous_batch_generator.py index a6fc942..21727b1 100644 --- a/src/core/evaluation/continuous_batch_generator.py +++ b/src/core/evaluation/continuous_batch_generator.py @@ -44,6 +44,8 @@ def __init__( super().__init__() self._seen_tokens = 0 self._active_seq_len = 0 + self._per_row_cache_positions: Optional[torch.Tensor] = None + self._batch_indices = torch.arange(max_batch_size, device=device) cache_shape = (max_batch_size, num_kv_heads, max_cache_len, head_dim) for _ in range(num_layers): k = torch.zeros(cache_shape, dtype=dtype, device=device) @@ -62,14 +64,24 @@ def update( ): if layer_idx == 0: self._seen_tokens += key_states.shape[-2] - cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None - _static_cache_update( - self.key_cache[layer_idx], - self.value_cache[layer_idx], - key_states, - value_states, - cache_position, - ) + + if self._per_row_cache_positions is not None: + # Per-slot decode: each batch row writes KV to its own position. + k_cache = self.key_cache[layer_idx] + v_cache = self.value_cache[layer_idx] + k_cache[self._batch_indices, :, self._per_row_cache_positions, :] = key_states[:, :, 0, :] + v_cache[self._batch_indices, :, self._per_row_cache_positions, :] = value_states[:, :, 0, :] + else: + # Standard path (prefill): use shared cache_position. + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None + _static_cache_update( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + key_states, + value_states, + cache_position, + ) + # Return a view trimmed to _active_seq_len + 1 so attention only sees # valid positions, not the full pre-allocated length. seq_end = self._active_seq_len + 1 @@ -211,6 +223,7 @@ def _prefill(self, prompt_idx: int, prompt_ids: list[int], batch_idx: int) -> _S DynamicCache.__init__(prefill_cache) prefill_cache._seen_tokens = 0 prefill_cache._active_seq_len = seq_len - 1 # update() will see seq_end = seq_len + prefill_cache._per_row_cache_positions = None prefill_cache.key_cache = [ self._cache.key_cache[l][batch_idx : batch_idx + 1] for l in range(len(self._cache)) @@ -242,14 +255,21 @@ def _prefill(self, prompt_idx: int, prompt_ids: list[int], batch_idx: int) -> _S def _batched_decode(self, slots: list[_Slot]) -> None: """Run a single batched decode step for all active slots. + Each slot writes KV to its own cache position (no shared write + compact). Uses the pre-allocated cache — zero tensor allocations per step. """ device = self.model.device - # Determine shared write position + # max_active_len determines attention mask width and cache view size. max_active_len = max(self._valid_lens[s.batch_idx] for s in slots) self._cache._active_seq_len = max_active_len + # Set per-row cache positions so each slot writes KV at its own valid_len. + per_row_positions = torch.zeros(self.max_batch_size, dtype=torch.long, device=device) + for slot in slots: + per_row_positions[slot.batch_idx] = self._valid_lens[slot.batch_idx] + self._cache._per_row_cache_positions = per_row_positions + # Build input_ids [max_batch_size, 1] input_ids = torch.full( (self.max_batch_size, 1), self.pad_token_id, dtype=torch.long, device=device @@ -264,9 +284,9 @@ def _batched_decode(self, slots: list[_Slot]) -> None: for slot in slots: valid_len = self._valid_lens[slot.batch_idx] attn_mask[slot.batch_idx, :valid_len] = 1 # valid cached positions - attn_mask[slot.batch_idx, max_active_len] = 1 # the new token position + attn_mask[slot.batch_idx, valid_len] = 1 # new token at slot's own position - # cache_position: shared write position + # cache_position: max_active_len for correct causal mask sizing cache_position = torch.tensor([max_active_len], device=device) # position_ids [max_batch_size, 1] @@ -274,7 +294,7 @@ def _batched_decode(self, slots: list[_Slot]) -> None: for slot in slots: position_ids[slot.batch_idx, 0] = slot.seq_position - # Single forward() call — model writes new KV in-place at cache_position + # Single forward() call — model writes new KV via per-row cache positions outputs = self.model( input_ids=input_ids, attention_mask=attn_mask, @@ -284,25 +304,15 @@ def _batched_decode(self, slots: list[_Slot]) -> None: use_cache=True, ) - # Sample next token per slot + # Clear per-row positions (revert to standard path for prefill) + self._cache._per_row_cache_positions = None + + # Sample next token per slot and advance valid_lens for slot in slots: next_token = self._sample_token(outputs.logits[slot.batch_idx : slot.batch_idx + 1, -1, :]) slot.generated_ids.append(next_token.item()) slot.seq_position += 1 - - # Compact: move new KV from shared write position to each slot's actual position - num_layers = len(self._cache) - for slot in slots: - valid_len = self._valid_lens[slot.batch_idx] - if valid_len < max_active_len: - for layer_idx in range(num_layers): - self._cache.key_cache[layer_idx][slot.batch_idx, :, valid_len, :] = ( - self._cache.key_cache[layer_idx][slot.batch_idx, :, max_active_len, :] - ) - self._cache.value_cache[layer_idx][slot.batch_idx, :, valid_len, :] = ( - self._cache.value_cache[layer_idx][slot.batch_idx, :, max_active_len, :] - ) - self._valid_lens[slot.batch_idx] = valid_len + 1 + self._valid_lens[slot.batch_idx] += 1 def _sample_token(self, logits: torch.Tensor) -> torch.Tensor: """Sample a single token from logits of shape [1, vocab_size].""" diff --git a/src/core/evaluation/evaluator.py b/src/core/evaluation/evaluator.py index a5b2329..3df891e 100644 --- a/src/core/evaluation/evaluator.py +++ b/src/core/evaluation/evaluator.py @@ -22,6 +22,7 @@ class GenerationConfig(BaseModel): top_p: float = 1.0 top_k: int = -1 torch_compile: bool = True + attn_implementation: str | None = "flash_attention_2" class EvaluatorConfig(PydraConfig): @@ -179,7 +180,12 @@ def _load_model(self): return self._load_lora_model(model_path, adapter_config) logger.info(f"Loading model from {self.config.model_path}") - model = AutoModelForCausalLM.from_pretrained(self.config.model_path, device_map=DEVICE_MAP) + model = AutoModelForCausalLM.from_pretrained( + self.config.model_path, + device_map=DEVICE_MAP, + torch_dtype=torch.bfloat16, + attn_implementation=self.config.generation.attn_implementation, + ) if not self.tokenizer: self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) return model, self.tokenizer @@ -195,7 +201,12 @@ def _load_lora_model(self, model_path: Path, adapter_config: Path): raise ValueError(f"adapter_config.json at {adapter_config} missing 'base_model_name_or_path'") logger.info(f"Loading LoRA model: base={base_model_id}, adapter={model_path}") - base_model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map=DEVICE_MAP) + base_model = AutoModelForCausalLM.from_pretrained( + base_model_id, + device_map=DEVICE_MAP, + torch_dtype=torch.bfloat16, + attn_implementation=self.config.generation.attn_implementation, + ) model = PeftModel.from_pretrained(base_model, str(model_path)) if not self.tokenizer: self.tokenizer = AutoTokenizer.from_pretrained(base_model_id)