Skip to content

auto-count-iterable-dataset #33

Description

@david-thrower

Feature Request patch from #32 : Replace estimated_total_tokens with Accurate Streaming Dataset Length Computation

Current Behavior

The Trainer class accepts an estimated_total_tokens parameter when working with iterable datasets (streaming mode). This estimate is used to calculate the total number of training steps for the learning rate scheduler.

trainer = Trainer(
    model=model, cfg=cfg, tokenizer=tokenizer,
    train_loader=train_ds,  # Trainer auto-detects iterable
    warmup_ratio=0.1,       # ratio mode — no len() needed
    estimated_total_tokens=2_000_000_000,  # Manual estimate
)

Problems with this approach:

  1. Manual estimates are error-prone — Users may over/under-estimate, causing incorrect warmup schedules and learning rate decay
  2. Hidden misconfiguration — Bad estimates silently produce suboptimal training dynamics
  3. User burden — Requires pre-calculating or guessing dataset statistics
  4. Non-deterministic behavior — Different estimates → different LR curves → different results

Desired Behavior

Replace the manual estimated_total_tokens parameter with an internal, efficient streaming count that:

  • Iterates through the dataset once (or a representative sample)
  • Does not materialize the full dataset in memory
  • Returns an accurate step count for the scheduler
  • Caches the result for subsequent epochs

Proposed API Change

Option 1: Automatic Silent Count (Recommended)

trainer = Trainer(
    model=model, cfg=cfg, tokenizer=tokenizer,
    train_loader=train_ds,
    warmup_ratio=0.1,
    # No estimated_total_tokens needed — computed automatically
)

Implementation Sketch

def _count_iterable_dataset(self, iterable_ds):
    """
    Efficiently count batches in an iterable dataset without materializing.
    Returns the total count and resets the iterator.
    """
    count = 0
    # Iterate without storing results
    for _ in iterable_ds:
        count += 1
    return count

def _compute_total_steps(self):
    if self._is_iterable_dataset(self.train_loader):
        # Cache the count after first computation
        if not hasattr(self, '_cached_dataset_length'):
            self._cached_dataset_length = self._count_iterable_dataset(
                self.train_loader.dataset
            )
            # Reset epoch/shuffle state if needed
        total_steps = self._cached_dataset_length // self.cfg.gradient_accumulation_steps
    else:
        total_steps = len(self.train_loader)
    return total_steps

Edge Cases to Consider

  1. Resume from checkpoint: Cached count should be serialized/deserialized
  2. Shuffling between epochs: Count remains constant, only order changes
  3. Dynamic datasets: Some streaming sources may change size between epochs; document this limitation
  4. Performance: Counting pass adds overhead at start; could be parallelized or sampled for very large corpora

Backward Compatibility

  • Remove the parameter estimated_total_tokens. Replace with internal iteration over what the final chunked, padded, collated, ... data to determine the steps deterministically, without assigning the steps to memory.

Related Code

  • helix_lm/trainer.py: Trainer class, _is_iterable_dataset(), scheduler setup
  • helix_lm/dataset.py: HelixIterableDataset, create_helix_streaming_loader()

Discussion

The current estimated_total_tokens approach was likely implemented as a quick workaround for iterable datasets not having a __len__(). However, in production training scenarios, an accurate step count is critical for:

  • Correct warmup scheduling
  • Learning rate decay timing
  • Training completion estimates
  • Reproducibility

A counting pass at initialization is a small price to pay for correctness and eliminates a footgun for users.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions