Validation Perplexity Skewed by Unweighted Batch Averaging + attention_mask Overlap Bugs
Labels: bug, training, evaluation, dataset
Priority: High
Summary
Validation perplexity is not computed as a true corpus-level metric. Three related bugs cause the model to (1) average batch losses equally regardless of valid token count, (2) feed incorrect attention_mask values to the recurrent layers for overlap-masked chunks, and (3) build divergent masks in the streaming/prechunked path. These issues can inflate or destabilize reported validation PPL and create hidden-state inconsistency during evaluation.
Bug 1: evaluate() Uses Unweighted Batch-Mean Average
File: trainer.py (evaluate(), ~lines 583–629)
Problem:
outputs["loss"] is already a mean over valid tokens within each batch. The current code sums these means and divides by num_batches, giving every batch equal weight even if one batch is 90% padding and another is 10% padding. This is not corpus-level perplexity.
Current (incorrect):
total_loss = 0.0
num_batches = 0
for batch in pbar:
...
loss = outputs["loss"]
if not (torch.isnan(loss) or torch.isinf(loss)):
total_loss += loss.item()
num_batches += 1
avg = total_loss / max(num_batches, 1)
pbar.set_postfix({"loss": f"{avg:.4f}", "ppl": f"{compute_perplexity(avg):.2f}"})
avg_loss = total_loss / max(num_batches, 1)
return {"loss": avg_loss, "perplexity": compute_perplexity(avg_loss)}
Fix: Convert each batch mean back to a sum, weight by valid token count, and divide by total tokens:
total_loss = 0.0
total_tokens = 0
for batch in pbar:
...
loss = outputs["loss"]
if not (torch.isnan(loss) or torch.isinf(loss)):
valid_tokens = (labels != -100).sum().item()
total_loss += loss.item() * valid_tokens
total_tokens += valid_tokens
avg = total_loss / max(total_tokens, 1)
pbar.set_postfix({
"loss": f"{avg:.4f}",
"ppl": f"{compute_perplexity(avg):.2f}",
"tok": f"{total_tokens:,}",
})
avg_loss = total_loss / max(total_tokens, 1)
return {"loss": avg_loss, "perplexity": compute_perplexity(avg_loss)}
Bug 2: HelixDataset._make_sample() Only Masks Trailing Padding
File: dataset.py (_make_sample())
Problem:
The method counts trailing -100 values to build attention_mask, but HelixDataset.__getitem__() also masks head overlap in sliding-window chunks (labels[:warmup_len] = -100). The mask tells the recurrent model to attend to those overlap tokens even though the loss ignores them, creating hidden-state inconsistency.
Current (incorrect):
def _make_sample(self, chunk, labels, is_natural_stop):
...
pad_len = sum(1 for tok in reversed(labels_t.tolist()) if tok == -100)
attention_mask = torch.cat([
torch.ones(self.seq_len - pad_len, dtype=torch.long),
torch.zeros(pad_len, dtype=torch.long),
])
...
Fix: Derive attention_mask directly from labels so any -100 position is masked:
def _make_sample(self, chunk, labels, is_natural_stop):
input_ids = torch.tensor(chunk[:self.seq_len], dtype=torch.long)
labels_t = torch.tensor(labels[:self.seq_len], dtype=torch.long)
attention_mask = (labels_t != -100).long()
return {
"input_ids": input_ids,
"labels": labels_t,
"attention_mask": attention_mask,
"is_natural_stop": torch.tensor(is_natural_stop, dtype=torch.bool),
}
Bug 3: _process_and_shard_batch() Omits Overlap from attention_mask
File: dataset.py (_process_and_shard_batch() / _chunk_batch)
Problem:
The streaming/prechunked path sets attn_mask = [1] * seq_len for sliding-window chunks, ignoring the head overlap that is masked to -100 in labels. This causes the streaming validation path to diverge from the List[str] map-style path.
Current (incorrect):
# For sliding window chunks:
pad_len = 0
attn_mask = [1] * seq_len
Fix: Compute and apply overlap masking consistently:
overlap_mask = 0
if start > 0 and stride < seq_len:
overlap_mask = seq_len - stride
attn_mask = [1] * seq_len
if overlap_mask > 0:
attn_mask[:overlap_mask] = [0] * overlap_mask
if pad_len > 0:
attn_mask[-pad_len:] = [0] * pad_len
Impact
- Validation PPL may be systematically off (inflated or deflated) depending on document-length distribution and batch padding ratios.
- Hidden states for sliding-window chunks are computed over tokens that do not participate in loss, which can subtly corrupt perplexity and generation quality.
- Streaming vs.
List[str] evaluation may produce numerically different results for the same underlying data.
Reproduction / Verification
- Train a checkpoint that reports validation PPL ≈ 63.
- Apply the three fixes above.
- Re-run evaluation on the same checkpoint and data.
- Compare PPL before/after. If the gap does not close, inspect
model.py to confirm recurrent() actually consumes attention_mask at every layer.
Checklist
Context
These bugs were identified during a code review of validation path divergence between streaming and List[str] data sources. The model forward pass (nn.CrossEntropyLoss(ignore_index=-100)) and label shifting logic are confirmed correct and not part of this issue. The root cause is in metric aggregation and mask construction, not in the loss function itself.
Validation Perplexity Skewed by Unweighted Batch Averaging +
attention_maskOverlap BugsLabels:
bug,training,evaluation,datasetPriority: High
Summary
Validation perplexity is not computed as a true corpus-level metric. Three related bugs cause the model to (1) average batch losses equally regardless of valid token count, (2) feed incorrect
attention_maskvalues to the recurrent layers for overlap-masked chunks, and (3) build divergent masks in the streaming/prechunked path. These issues can inflate or destabilize reported validation PPL and create hidden-state inconsistency during evaluation.Bug 1:
evaluate()Uses Unweighted Batch-Mean AverageFile:
trainer.py(evaluate(), ~lines 583–629)Problem:
outputs["loss"]is already a mean over valid tokens within each batch. The current code sums these means and divides bynum_batches, giving every batch equal weight even if one batch is 90% padding and another is 10% padding. This is not corpus-level perplexity.Current (incorrect):
Fix: Convert each batch mean back to a sum, weight by valid token count, and divide by total tokens:
Bug 2:
HelixDataset._make_sample()Only Masks Trailing PaddingFile:
dataset.py(_make_sample())Problem:
The method counts trailing
-100values to buildattention_mask, butHelixDataset.__getitem__()also masks head overlap in sliding-window chunks (labels[:warmup_len] = -100). The mask tells the recurrent model to attend to those overlap tokens even though the loss ignores them, creating hidden-state inconsistency.Current (incorrect):
Fix: Derive
attention_maskdirectly fromlabelsso any-100position is masked:Bug 3:
_process_and_shard_batch()Omits Overlap fromattention_maskFile:
dataset.py(_process_and_shard_batch()/_chunk_batch)Problem:
The streaming/prechunked path sets
attn_mask = [1] * seq_lenfor sliding-window chunks, ignoring the head overlap that is masked to-100inlabels. This causes the streaming validation path to diverge from theList[str]map-style path.Current (incorrect):
Fix: Compute and apply overlap masking consistently:
Impact
List[str]evaluation may produce numerically different results for the same underlying data.Reproduction / Verification
model.pyto confirmrecurrent()actually consumesattention_maskat every layer.Checklist
evaluate()to use token-weighted averaging (trainer.py)_make_sample()to derive mask from all-100positions (dataset.py)_process_and_shard_batch()to include overlap inattn_mask(dataset.py)train_epoch()progress bar for consistencyContext
These bugs were identified during a code review of validation path divergence between streaming and
List[str]data sources. The model forward pass (nn.CrossEntropyLoss(ignore_index=-100)) and label shifting logic are confirmed correct and not part of this issue. The root cause is in metric aggregation and mask construction, not in the loss function itself.