Refactor DocumentAwareDataset for scale
TLDR: Add 2 proposed Dataset and dataloader setups for less overhead and faster streaming.
Work done thus far:
1. helix_lm/dataset.py — Three dataset strategies for different scales
HelixIterableDataset (NEW — for 2B+ token corpora)
- True streaming
torch.utils.data.IterableDataset subclass
- Tokenizes and chunks on-the-fly during iteration — never materializes the full corpus
__len__() raises TypeError so HF Trainer detects it as iterable (no random sampler, no length-based step estimation)
set_epoch() for reshuffling between epochs
- Reservoir shuffle with configurable buffer size
- Preserves all HelixLM invariants identically to
DocumentAwareDataset:
is_natural_stop per document boundary
-100 masking at overlap head positions (when stride < seq_len)
-100 masking at trailing padding positions (from exact pad_len)
attention_mask built from pad_len, never from pad_token_id comparison
- No cross-document boundary crossings
HelixPrechunkedDataset (NEW — for fast repeatable training)
- Map-style dataset from pre-chunked HF Dataset
preprocess() classmethod: chunks once with Dataset.map(..., batched=True, num_proc=...) using parallel processing
- Arrow-backed, memory-mapped, fast random access
from_disk() / save_to_disk() for persisting preprocessed data
preprocess_to_shards() (NEW — two-stage pipeline utility)
- Offline preprocessing into sharded files for streaming during training
- Implements the recommended "remux" pattern: pre-chunk offline → stream shards during training
helix_data_collator() (NEW — unified collator)
- Single collator function that works with all dataset types
- Stacks
input_ids, labels, attention_mask, is_natural_stop into batched tensors
Factory functions: create_helix_streaming_loader(), create_helix_prechunked_loader()
All existing classes preserved unchanged: HelixDataset, DocumentAwareDataset, HelixDatasetFromTokens, HelixHFDataset, create_helix_dataloader(), create_document_loader()
2. helix_lm/trainer.py — Dual-mode Trainer
Iterable dataset auto-detection
_is_iterable_dataset() checks for __iter__ without __getitem__, or __len__ raising TypeError
- Automatically detected at Trainer initialization
Scheduler modes
- Map-style (existing behavior):
len(train_loader) used for exact step count
- Iterable with
warmup_ratio: specify warmup as ratio of total steps (e.g., warmup_ratio=0.1 for 10% warmup). Step count estimated from estimated_total_tokens or conservative default
- Iterable with
warmup_steps: falls back to step-based warmup with estimated total steps
Epoch iteration
set_epoch() called on iterable datasets between epochs for reshuffling
- Gradient accumulation: for iterables, steps every
grad_accum_steps batches unconditionally (no len() dependency)
tqdm works correctly with iterables (no total, counts as it goes)
New constructor args:
warmup_ratio: override cfg.warmup_steps as ratio (recommended for iterables)
estimated_total_tokens: provide token count for better step estimation
3. helix_lm/__init__.py — Updated exports
All new classes and functions exported.
Usage Examples
Streaming (for 2B+ tokens)
from datasets import load_dataset
from helix_lm import HelixForCausalLM, HelixConfig, Trainer
from helix_lm.dataset import create_helix_streaming_loader
# Stream from HF Hub — never materializes full corpus
ds = load_dataset("your_dataset", split="train", streaming=True)
train_ds = create_helix_streaming_loader(
ds, tokenizer, seq_len=512, batch_size=32,
shuffle_buffer_size=10_000, # reservoir shuffle
)
trainer = Trainer(
model=model, cfg=cfg, tokenizer=tokenizer,
train_loader=train_ds, # Trainer auto-detects iterable
warmup_ratio=0.1, # ratio mode — no len() needed
estimated_total_tokens=2_000_000_000,
)
trainer.train()
Pre-chunked (for fast repeatable training)
from helix_lm.dataset import create_helix_prechunked_loader
# One-time preprocess, then fast loading
loader = create_helix_prechunked_loader(
"your_dataset", tokenizer, seq_len=512,
output_dir="./prechunked", # saves for reuse
num_proc=8, # parallel chunking
)
trainer = Trainer(model=model, cfg=cfg, train_loader=loader, ...)
trainer.train()
Existing code (unchanged — still works)
from helix_lm import Trainer
from helix_lm.dataset import create_document_loader
# DocumentAwareDataset path — fully backward compatible
loader = create_document_loader(texts, tokenizer, seq_len=512)
trainer = Trainer(model=model, cfg=cfg, train_loader=loader, ...)
trainer.train()
Feature Request: Replace estimated_total_tokens with Accurate Streaming Dataset Length Computation
Current Behavior
The Trainer class accepts an estimated_total_tokens parameter when working with iterable datasets (streaming mode). This estimate is used to calculate the total number of training steps for the learning rate scheduler.
trainer = Trainer(
model=model, cfg=cfg, tokenizer=tokenizer,
train_loader=train_ds, # Trainer auto-detects iterable
warmup_ratio=0.1, # ratio mode — no len() needed
estimated_total_tokens=2_000_000_000, # Manual estimate
)
Problems with this approach:
- Manual estimates are error-prone — Users may over/under-estimate, causing incorrect warmup schedules and learning rate decay
- Hidden misconfiguration — Bad estimates silently produce suboptimal training dynamics
- User burden — Requires pre-calculating or guessing dataset statistics
- Non-deterministic behavior — Different estimates → different LR curves → different results
Desired Behavior
Replace the manual estimated_total_tokens parameter with an internal, efficient streaming count that:
- Iterates through the dataset once (or a representative sample)
- Does not materialize the full dataset in memory
- Returns an accurate step count for the scheduler
- Caches the result for subsequent epochs
Proposed API Change
Option 1: Automatic Silent Count (Recommended)
trainer = Trainer(
model=model, cfg=cfg, tokenizer=tokenizer,
train_loader=train_ds,
warmup_ratio=0.1,
# No estimated_total_tokens needed — computed automatically
)
Implementation Sketch
def _count_iterable_dataset(self, iterable_ds):
"""
Efficiently count batches in an iterable dataset without materializing.
Returns the total count and resets the iterator.
"""
count = 0
# Iterate without storing results
for _ in iterable_ds:
count += 1
return count
def _compute_total_steps(self):
if self._is_iterable_dataset(self.train_loader):
# Cache the count after first computation
if not hasattr(self, '_cached_dataset_length'):
self._cached_dataset_length = self._count_iterable_dataset(
self.train_loader.dataset
)
# Reset epoch/shuffle state if needed
total_steps = self._cached_dataset_length // self.cfg.gradient_accumulation_steps
else:
total_steps = len(self.train_loader)
return total_steps
Edge Cases to Consider
- Resume from checkpoint: Cached count should be serialized/deserialized
- Shuffling between epochs: Count remains constant, only order changes
- Dynamic datasets: Some streaming sources may change size between epochs; document this limitation
- Performance: Counting pass adds overhead at start; could be parallelized or sampled for very large corpora
Backward Compatibility
- Remove the parameter
estimated_total_tokens. Replace with internal iteration over what the final chunked, padded, collated, ... data to determine the steps deterministically, without assigning the steps to memory.
Related Code
helix_lm/trainer.py: Trainer class, _is_iterable_dataset(), scheduler setup
helix_lm/dataset.py: HelixIterableDataset, create_helix_streaming_loader()
Discussion
The current estimated_total_tokens approach was likely implemented as a quick workaround for iterable datasets not having a __len__(). However, in production training scenarios, an accurate step count is critical for:
- Correct warmup scheduling
- Learning rate decay timing
- Training completion estimates
- Reproducibility
A counting pass at initialization is a small price to pay for correctness and eliminates a footgun for users.
Refactor DocumentAwareDataset for scale
TLDR: Add 2 proposed Dataset and dataloader setups for less overhead and faster streaming.
Work done thus far:
1.
helix_lm/dataset.py— Three dataset strategies for different scalesHelixIterableDataset(NEW — for 2B+ token corpora)torch.utils.data.IterableDatasetsubclass__len__()raisesTypeErrorso HF Trainer detects it as iterable (no random sampler, no length-based step estimation)set_epoch()for reshuffling between epochsDocumentAwareDataset:is_natural_stopper document boundary-100masking at overlap head positions (when stride < seq_len)-100masking at trailing padding positions (from exactpad_len)attention_maskbuilt frompad_len, never frompad_token_idcomparisonHelixPrechunkedDataset(NEW — for fast repeatable training)preprocess()classmethod: chunks once withDataset.map(..., batched=True, num_proc=...)using parallel processingfrom_disk()/save_to_disk()for persisting preprocessed datapreprocess_to_shards()(NEW — two-stage pipeline utility)helix_data_collator()(NEW — unified collator)input_ids,labels,attention_mask,is_natural_stopinto batched tensorsFactory functions:
create_helix_streaming_loader(),create_helix_prechunked_loader()All existing classes preserved unchanged:
HelixDataset,DocumentAwareDataset,HelixDatasetFromTokens,HelixHFDataset,create_helix_dataloader(),create_document_loader()2.
helix_lm/trainer.py— Dual-mode TrainerIterable dataset auto-detection
_is_iterable_dataset()checks for__iter__without__getitem__, or__len__raisingTypeErrorScheduler modes
len(train_loader)used for exact step countwarmup_ratio: specify warmup as ratio of total steps (e.g.,warmup_ratio=0.1for 10% warmup). Step count estimated fromestimated_total_tokensor conservative defaultwarmup_steps: falls back to step-based warmup with estimated total stepsEpoch iteration
set_epoch()called on iterable datasets between epochs for reshufflinggrad_accum_stepsbatches unconditionally (nolen()dependency)tqdmworks correctly with iterables (no total, counts as it goes)New constructor args:
warmup_ratio: overridecfg.warmup_stepsas ratio (recommended for iterables)estimated_total_tokens: provide token count for better step estimation3.
helix_lm/__init__.py— Updated exportsAll new classes and functions exported.
Usage Examples
Streaming (for 2B+ tokens)
Pre-chunked (for fast repeatable training)
Existing code (unchanged — still works)
Feature Request: Replace
estimated_total_tokenswith Accurate Streaming Dataset Length ComputationCurrent Behavior
The
Trainerclass accepts anestimated_total_tokensparameter when working with iterable datasets (streaming mode). This estimate is used to calculate the total number of training steps for the learning rate scheduler.Problems with this approach:
Desired Behavior
Replace the manual
estimated_total_tokensparameter with an internal, efficient streaming count that:Proposed API Change
Option 1: Automatic Silent Count (Recommended)
Implementation Sketch
Edge Cases to Consider
Backward Compatibility
estimated_total_tokens. Replace with internal iteration over what the final chunked, padded, collated, ... data to determine the steps deterministically, without assigning the steps to memory.Related Code
helix_lm/trainer.py:Trainerclass,_is_iterable_dataset(), scheduler setuphelix_lm/dataset.py:HelixIterableDataset,create_helix_streaming_loader()Discussion
The current
estimated_total_tokensapproach was likely implemented as a quick workaround for iterable datasets not having a__len__(). However, in production training scenarios, an accurate step count is critical for:A counting pass at initialization is a small price to pay for correctness and eliminates a footgun for users.