Skip to content

Refactor-DocumentAwareDataset-for-scale #32

Description

@david-thrower

Refactor DocumentAwareDataset for scale

TLDR: Add 2 proposed Dataset and dataloader setups for less overhead and faster streaming.

Work done thus far:

1. helix_lm/dataset.py — Three dataset strategies for different scales

HelixIterableDataset (NEW — for 2B+ token corpora)

  • True streaming torch.utils.data.IterableDataset subclass
  • Tokenizes and chunks on-the-fly during iteration — never materializes the full corpus
  • __len__() raises TypeError so HF Trainer detects it as iterable (no random sampler, no length-based step estimation)
  • set_epoch() for reshuffling between epochs
  • Reservoir shuffle with configurable buffer size
  • Preserves all HelixLM invariants identically to DocumentAwareDataset:
    • is_natural_stop per document boundary
    • -100 masking at overlap head positions (when stride < seq_len)
    • -100 masking at trailing padding positions (from exact pad_len)
    • attention_mask built from pad_len, never from pad_token_id comparison
    • No cross-document boundary crossings

HelixPrechunkedDataset (NEW — for fast repeatable training)

  • Map-style dataset from pre-chunked HF Dataset
  • preprocess() classmethod: chunks once with Dataset.map(..., batched=True, num_proc=...) using parallel processing
  • Arrow-backed, memory-mapped, fast random access
  • from_disk() / save_to_disk() for persisting preprocessed data

preprocess_to_shards() (NEW — two-stage pipeline utility)

  • Offline preprocessing into sharded files for streaming during training
  • Implements the recommended "remux" pattern: pre-chunk offline → stream shards during training

helix_data_collator() (NEW — unified collator)

  • Single collator function that works with all dataset types
  • Stacks input_ids, labels, attention_mask, is_natural_stop into batched tensors

Factory functions: create_helix_streaming_loader(), create_helix_prechunked_loader()

All existing classes preserved unchanged: HelixDataset, DocumentAwareDataset, HelixDatasetFromTokens, HelixHFDataset, create_helix_dataloader(), create_document_loader()

2. helix_lm/trainer.py — Dual-mode Trainer

Iterable dataset auto-detection

  • _is_iterable_dataset() checks for __iter__ without __getitem__, or __len__ raising TypeError
  • Automatically detected at Trainer initialization

Scheduler modes

  • Map-style (existing behavior): len(train_loader) used for exact step count
  • Iterable with warmup_ratio: specify warmup as ratio of total steps (e.g., warmup_ratio=0.1 for 10% warmup). Step count estimated from estimated_total_tokens or conservative default
  • Iterable with warmup_steps: falls back to step-based warmup with estimated total steps

Epoch iteration

  • set_epoch() called on iterable datasets between epochs for reshuffling
  • Gradient accumulation: for iterables, steps every grad_accum_steps batches unconditionally (no len() dependency)
  • tqdm works correctly with iterables (no total, counts as it goes)

New constructor args:

  • warmup_ratio: override cfg.warmup_steps as ratio (recommended for iterables)
  • estimated_total_tokens: provide token count for better step estimation

3. helix_lm/__init__.py — Updated exports

All new classes and functions exported.


Usage Examples

Streaming (for 2B+ tokens)

from datasets import load_dataset
from helix_lm import HelixForCausalLM, HelixConfig, Trainer
from helix_lm.dataset import create_helix_streaming_loader

# Stream from HF Hub — never materializes full corpus
ds = load_dataset("your_dataset", split="train", streaming=True)
train_ds = create_helix_streaming_loader(
    ds, tokenizer, seq_len=512, batch_size=32,
    shuffle_buffer_size=10_000,  # reservoir shuffle
)

trainer = Trainer(
    model=model, cfg=cfg, tokenizer=tokenizer,
    train_loader=train_ds,  # Trainer auto-detects iterable
    warmup_ratio=0.1,       # ratio mode — no len() needed
    estimated_total_tokens=2_000_000_000,
)
trainer.train()

Pre-chunked (for fast repeatable training)

from helix_lm.dataset import create_helix_prechunked_loader

# One-time preprocess, then fast loading
loader = create_helix_prechunked_loader(
    "your_dataset", tokenizer, seq_len=512,
    output_dir="./prechunked",  # saves for reuse
    num_proc=8,                 # parallel chunking
)

trainer = Trainer(model=model, cfg=cfg, train_loader=loader, ...)
trainer.train()

Existing code (unchanged — still works)

from helix_lm import Trainer
from helix_lm.dataset import create_document_loader

# DocumentAwareDataset path — fully backward compatible
loader = create_document_loader(texts, tokenizer, seq_len=512)
trainer = Trainer(model=model, cfg=cfg, train_loader=loader, ...)
trainer.train()

Feature Request: Replace estimated_total_tokens with Accurate Streaming Dataset Length Computation

Current Behavior

The Trainer class accepts an estimated_total_tokens parameter when working with iterable datasets (streaming mode). This estimate is used to calculate the total number of training steps for the learning rate scheduler.

trainer = Trainer(
    model=model, cfg=cfg, tokenizer=tokenizer,
    train_loader=train_ds,  # Trainer auto-detects iterable
    warmup_ratio=0.1,       # ratio mode — no len() needed
    estimated_total_tokens=2_000_000_000,  # Manual estimate
)

Problems with this approach:

  1. Manual estimates are error-prone — Users may over/under-estimate, causing incorrect warmup schedules and learning rate decay
  2. Hidden misconfiguration — Bad estimates silently produce suboptimal training dynamics
  3. User burden — Requires pre-calculating or guessing dataset statistics
  4. Non-deterministic behavior — Different estimates → different LR curves → different results

Desired Behavior

Replace the manual estimated_total_tokens parameter with an internal, efficient streaming count that:

  • Iterates through the dataset once (or a representative sample)
  • Does not materialize the full dataset in memory
  • Returns an accurate step count for the scheduler
  • Caches the result for subsequent epochs

Proposed API Change

Option 1: Automatic Silent Count (Recommended)

trainer = Trainer(
    model=model, cfg=cfg, tokenizer=tokenizer,
    train_loader=train_ds,
    warmup_ratio=0.1,
    # No estimated_total_tokens needed — computed automatically
)

Implementation Sketch

def _count_iterable_dataset(self, iterable_ds):
    """
    Efficiently count batches in an iterable dataset without materializing.
    Returns the total count and resets the iterator.
    """
    count = 0
    # Iterate without storing results
    for _ in iterable_ds:
        count += 1
    return count

def _compute_total_steps(self):
    if self._is_iterable_dataset(self.train_loader):
        # Cache the count after first computation
        if not hasattr(self, '_cached_dataset_length'):
            self._cached_dataset_length = self._count_iterable_dataset(
                self.train_loader.dataset
            )
            # Reset epoch/shuffle state if needed
        total_steps = self._cached_dataset_length // self.cfg.gradient_accumulation_steps
    else:
        total_steps = len(self.train_loader)
    return total_steps

Edge Cases to Consider

  1. Resume from checkpoint: Cached count should be serialized/deserialized
  2. Shuffling between epochs: Count remains constant, only order changes
  3. Dynamic datasets: Some streaming sources may change size between epochs; document this limitation
  4. Performance: Counting pass adds overhead at start; could be parallelized or sampled for very large corpora

Backward Compatibility

  • Remove the parameter estimated_total_tokens. Replace with internal iteration over what the final chunked, padded, collated, ... data to determine the steps deterministically, without assigning the steps to memory.

Related Code

  • helix_lm/trainer.py: Trainer class, _is_iterable_dataset(), scheduler setup
  • helix_lm/dataset.py: HelixIterableDataset, create_helix_streaming_loader()

Discussion

The current estimated_total_tokens approach was likely implemented as a quick workaround for iterable datasets not having a __len__(). However, in production training scenarios, an accurate step count is critical for:

  • Correct warmup scheduling
  • Learning rate decay timing
  • Training completion estimates
  • Reproducibility

A counting pass at initialization is a small price to pay for correctness and eliminates a footgun for users.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions