Skip to content

Validation Perplexity Skewed by Unweighted Batch Averaging + attention_mask Overlap Bugs #36

Description

@david-thrower

Validation Perplexity Skewed by Unweighted Batch Averaging + attention_mask Overlap Bugs

Labels: bug, training, evaluation, dataset
Priority: High


Summary

Validation perplexity is not computed as a true corpus-level metric. Three related bugs cause the model to (1) average batch losses equally regardless of valid token count, (2) feed incorrect attention_mask values to the recurrent layers for overlap-masked chunks, and (3) build divergent masks in the streaming/prechunked path. These issues can inflate or destabilize reported validation PPL and create hidden-state inconsistency during evaluation.


Bug 1: evaluate() Uses Unweighted Batch-Mean Average

File: trainer.py (evaluate(), ~lines 583–629)

Problem:
outputs["loss"] is already a mean over valid tokens within each batch. The current code sums these means and divides by num_batches, giving every batch equal weight even if one batch is 90% padding and another is 10% padding. This is not corpus-level perplexity.

Current (incorrect):

total_loss = 0.0
num_batches = 0
for batch in pbar:
    ...
    loss = outputs["loss"]
    if not (torch.isnan(loss) or torch.isinf(loss)):
        total_loss += loss.item()
        num_batches += 1
        avg = total_loss / max(num_batches, 1)
        pbar.set_postfix({"loss": f"{avg:.4f}", "ppl": f"{compute_perplexity(avg):.2f}"})

avg_loss = total_loss / max(num_batches, 1)
return {"loss": avg_loss, "perplexity": compute_perplexity(avg_loss)}

Fix: Convert each batch mean back to a sum, weight by valid token count, and divide by total tokens:

total_loss = 0.0
total_tokens = 0
for batch in pbar:
    ...
    loss = outputs["loss"]
    if not (torch.isnan(loss) or torch.isinf(loss)):
        valid_tokens = (labels != -100).sum().item()
        total_loss += loss.item() * valid_tokens
        total_tokens += valid_tokens

        avg = total_loss / max(total_tokens, 1)
        pbar.set_postfix({
            "loss": f"{avg:.4f}",
            "ppl": f"{compute_perplexity(avg):.2f}",
            "tok": f"{total_tokens:,}",
        })

avg_loss = total_loss / max(total_tokens, 1)
return {"loss": avg_loss, "perplexity": compute_perplexity(avg_loss)}

Bug 2: HelixDataset._make_sample() Only Masks Trailing Padding

File: dataset.py (_make_sample())

Problem:
The method counts trailing -100 values to build attention_mask, but HelixDataset.__getitem__() also masks head overlap in sliding-window chunks (labels[:warmup_len] = -100). The mask tells the recurrent model to attend to those overlap tokens even though the loss ignores them, creating hidden-state inconsistency.

Current (incorrect):

def _make_sample(self, chunk, labels, is_natural_stop):
    ...
    pad_len = sum(1 for tok in reversed(labels_t.tolist()) if tok == -100)
    attention_mask = torch.cat([
        torch.ones(self.seq_len - pad_len, dtype=torch.long),
        torch.zeros(pad_len, dtype=torch.long),
    ])
    ...

Fix: Derive attention_mask directly from labels so any -100 position is masked:

def _make_sample(self, chunk, labels, is_natural_stop):
    input_ids = torch.tensor(chunk[:self.seq_len], dtype=torch.long)
    labels_t = torch.tensor(labels[:self.seq_len], dtype=torch.long)
    attention_mask = (labels_t != -100).long()
    return {
        "input_ids": input_ids,
        "labels": labels_t,
        "attention_mask": attention_mask,
        "is_natural_stop": torch.tensor(is_natural_stop, dtype=torch.bool),
    }

Bug 3: _process_and_shard_batch() Omits Overlap from attention_mask

File: dataset.py (_process_and_shard_batch() / _chunk_batch)

Problem:
The streaming/prechunked path sets attn_mask = [1] * seq_len for sliding-window chunks, ignoring the head overlap that is masked to -100 in labels. This causes the streaming validation path to diverge from the List[str] map-style path.

Current (incorrect):

# For sliding window chunks:
pad_len = 0
attn_mask = [1] * seq_len

Fix: Compute and apply overlap masking consistently:

overlap_mask = 0
if start > 0 and stride < seq_len:
    overlap_mask = seq_len - stride

attn_mask = [1] * seq_len
if overlap_mask > 0:
    attn_mask[:overlap_mask] = [0] * overlap_mask
if pad_len > 0:
    attn_mask[-pad_len:] = [0] * pad_len

Impact

  • Validation PPL may be systematically off (inflated or deflated) depending on document-length distribution and batch padding ratios.
  • Hidden states for sliding-window chunks are computed over tokens that do not participate in loss, which can subtly corrupt perplexity and generation quality.
  • Streaming vs. List[str] evaluation may produce numerically different results for the same underlying data.

Reproduction / Verification

  1. Train a checkpoint that reports validation PPL ≈ 63.
  2. Apply the three fixes above.
  3. Re-run evaluation on the same checkpoint and data.
  4. Compare PPL before/after. If the gap does not close, inspect model.py to confirm recurrent() actually consumes attention_mask at every layer.

Checklist

  • Fix evaluate() to use token-weighted averaging (trainer.py)
  • Fix _make_sample() to derive mask from all -100 positions (dataset.py)
  • Fix _process_and_shard_batch() to include overlap in attn_mask (dataset.py)
  • Re-run validation on affected checkpoint and report delta
  • (Optional) Apply same token-weighted logic to train_epoch() progress bar for consistency

Context

These bugs were identified during a code review of validation path divergence between streaming and List[str] data sources. The model forward pass (nn.CrossEntropyLoss(ignore_index=-100)) and label shifting logic are confirmed correct and not part of this issue. The root cause is in metric aggregation and mask construction, not in the loss function itself.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions