Feature Request patch from #32 : Replace estimated_total_tokens with Accurate Streaming Dataset Length Computation
Current Behavior
The Trainer class accepts an estimated_total_tokens parameter when working with iterable datasets (streaming mode). This estimate is used to calculate the total number of training steps for the learning rate scheduler.
trainer = Trainer(
model=model, cfg=cfg, tokenizer=tokenizer,
train_loader=train_ds, # Trainer auto-detects iterable
warmup_ratio=0.1, # ratio mode — no len() needed
estimated_total_tokens=2_000_000_000, # Manual estimate
)
Problems with this approach:
- Manual estimates are error-prone — Users may over/under-estimate, causing incorrect warmup schedules and learning rate decay
- Hidden misconfiguration — Bad estimates silently produce suboptimal training dynamics
- User burden — Requires pre-calculating or guessing dataset statistics
- Non-deterministic behavior — Different estimates → different LR curves → different results
Desired Behavior
Replace the manual estimated_total_tokens parameter with an internal, efficient streaming count that:
- Iterates through the dataset once (or a representative sample)
- Does not materialize the full dataset in memory
- Returns an accurate step count for the scheduler
- Caches the result for subsequent epochs
Proposed API Change
Option 1: Automatic Silent Count (Recommended)
trainer = Trainer(
model=model, cfg=cfg, tokenizer=tokenizer,
train_loader=train_ds,
warmup_ratio=0.1,
# No estimated_total_tokens needed — computed automatically
)
Implementation Sketch
def _count_iterable_dataset(self, iterable_ds):
"""
Efficiently count batches in an iterable dataset without materializing.
Returns the total count and resets the iterator.
"""
count = 0
# Iterate without storing results
for _ in iterable_ds:
count += 1
return count
def _compute_total_steps(self):
if self._is_iterable_dataset(self.train_loader):
# Cache the count after first computation
if not hasattr(self, '_cached_dataset_length'):
self._cached_dataset_length = self._count_iterable_dataset(
self.train_loader.dataset
)
# Reset epoch/shuffle state if needed
total_steps = self._cached_dataset_length // self.cfg.gradient_accumulation_steps
else:
total_steps = len(self.train_loader)
return total_steps
Edge Cases to Consider
- Resume from checkpoint: Cached count should be serialized/deserialized
- Shuffling between epochs: Count remains constant, only order changes
- Dynamic datasets: Some streaming sources may change size between epochs; document this limitation
- Performance: Counting pass adds overhead at start; could be parallelized or sampled for very large corpora
Backward Compatibility
- Remove the parameter
estimated_total_tokens. Replace with internal iteration over what the final chunked, padded, collated, ... data to determine the steps deterministically, without assigning the steps to memory.
Related Code
helix_lm/trainer.py: Trainer class, _is_iterable_dataset(), scheduler setup
helix_lm/dataset.py: HelixIterableDataset, create_helix_streaming_loader()
Discussion
The current estimated_total_tokens approach was likely implemented as a quick workaround for iterable datasets not having a __len__(). However, in production training scenarios, an accurate step count is critical for:
- Correct warmup scheduling
- Learning rate decay timing
- Training completion estimates
- Reproducibility
A counting pass at initialization is a small price to pay for correctness and eliminates a footgun for users.
Feature Request patch from #32 : Replace
estimated_total_tokenswith Accurate Streaming Dataset Length ComputationCurrent Behavior
The
Trainerclass accepts anestimated_total_tokensparameter when working with iterable datasets (streaming mode). This estimate is used to calculate the total number of training steps for the learning rate scheduler.Problems with this approach:
Desired Behavior
Replace the manual
estimated_total_tokensparameter with an internal, efficient streaming count that:Proposed API Change
Option 1: Automatic Silent Count (Recommended)
Implementation Sketch
Edge Cases to Consider
Backward Compatibility
estimated_total_tokens. Replace with internal iteration over what the final chunked, padded, collated, ... data to determine the steps deterministically, without assigning the steps to memory.Related Code
helix_lm/trainer.py:Trainerclass,_is_iterable_dataset(), scheduler setuphelix_lm/dataset.py:HelixIterableDataset,create_helix_streaming_loader()Discussion
The current
estimated_total_tokensapproach was likely implemented as a quick workaround for iterable datasets not having a__len__(). However, in production training scenarios, an accurate step count is critical for:A counting pass at initialization is a small price to pay for correctness and eliminates a footgun for users.