Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions 87M_param_production_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def main():
# Streaming-specific options
preprocess_batch_size=PREPROCESS_BATCH_SIZE,
cleanup_shards=CLEANUP_SHARDS,
num_workers=4, # DataLoader workers for prefetching
)

# Constant LR: cosine with min_lr_ratio=1.0 = flat after warmup
Expand Down
171 changes: 124 additions & 47 deletions helix_lm/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,29 @@
Compatible with both eager and lazy loading.
"""
import random
import threading
from typing import List, Optional, Iterator, Dict, Any, Union, Tuple
from collections import OrderedDict

import torch
from torch.utils.data import IterableDataset, Dataset, DataLoader
from tqdm import tqdm


def _collate_batch(batch):
"""Module-level collate function for pickling with multiprocessing."""
input_ids = torch.stack([b["input_ids"] for b in batch])
labels = torch.stack([b["labels"] for b in batch])
attention_mask = torch.stack([b["attention_mask"] for b in batch])
is_natural_stop = torch.stack([b["is_natural_stop"] for b in batch])
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"is_natural_stop": is_natural_stop,
}


class HelixDataset(Dataset):
"""
Index-based Dataset with rolling chunking for language model pretraining.
Expand Down Expand Up @@ -481,17 +497,8 @@ def create_helix_dataloader(
) -> torch.utils.data.DataLoader:
dataset = HelixDataset(texts, tokenizer, seq_len, stride, lazy=lazy, **kwargs)

def collate_fn(batch):
input_ids = torch.stack([b["input_ids"] for b in batch])
labels = torch.stack([b["labels"] for b in batch])
attention_mask = torch.stack([b["attention_mask"] for b in batch])
is_natural_stop = torch.stack([b["is_natural_stop"] for b in batch])
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"is_natural_stop": is_natural_stop,
}
# Use module-level collate_fn for pickling with multiprocessing
collate_fn = _collate_batch

return torch.utils.data.DataLoader(
dataset,
Expand Down Expand Up @@ -531,17 +538,8 @@ def create_document_loader(
min_tail_len=min_tail_len, add_eos=add_eos, lazy=lazy, stride=stride,
)

def collate_fn(batch):
input_ids = torch.stack([b["input_ids"] for b in batch])
labels = torch.stack([b["labels"] for b in batch])
attention_mask = torch.stack([b["attention_mask"] for b in batch])
is_natural_stop = torch.stack([b["is_natural_stop"] for b in batch])
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"is_natural_stop": is_natural_stop,
}
# Use module-level collate_fn for pickling with multiprocessing
collate_fn = _collate_batch

# Use torch.Generator for deterministic shuffling with given seed
if shuffle:
Expand Down Expand Up @@ -733,17 +731,22 @@ class HelixShardedDataset(Dataset):

The shuffle implementation uses a Torch Generator (like List[str] path) for deterministic
reproducibility and identical behavior to non-streaming datasets.

OPTIMIZED: Multi-shard LRU cache for high-throughput training
"""
def __init__(
self,
shard_paths: List[str],
seq_len: int,
shuffle: bool = False,
seed: int = 42,
cache_size: int = 8, # Number of shards to keep in memory
):
super().__init__()
self.shard_paths = shard_paths
self.seq_len = seq_len
self.cache_size = cache_size
self._cache_lock = threading.RLock()

# Build shard index: cumulative offsets for O(1) __getitem__
self.shard_sizes = []
Expand All @@ -765,6 +768,7 @@ def __init__(
self.shard_offsets.append(total)

self.total_size = total
self.num_shards = len(shard_paths)

# Apply deterministic shuffle to index_map (only 2*int per sample, not full data)
if shuffle:
Expand All @@ -775,23 +779,77 @@ def __init__(
# Reorder index_map according to permutation
self._index_map = [self._index_map[i] for i in perm]

# Cache for current shard to avoid repeated disk reads
self._cache_shard_idx: Optional[int] = None
self._cache_shard_data: Optional[List] = None
# Multi-shard LRU cache: OrderedDict for O(1) move_to_end
# Keys: shard_idx, Values: shard_data (List of chunks)
self._shard_cache: OrderedDict[int, List] = OrderedDict()
self._cache_hits = 0
self._cache_misses = 0

def __len__(self) -> int:
return self.total_size

def get_cache_stats(self) -> Dict[str, Any]:
"""Return cache hit/miss statistics."""
total = self._cache_hits + self._cache_misses
hit_rate = self._cache_hits / total if total > 0 else 0
return {
"hits": self._cache_hits,
"misses": self._cache_misses,
"hit_rate": hit_rate,
"cached_shards": len(self._shard_cache),
"total_shards": self.num_shards,
}

def _load_shard(self, shard_idx: int) -> List:
"""Load shard data with caching."""
if self._cache_shard_idx == shard_idx:
return self._cache_shard_data
"""Load shard data with multi-shard LRU caching."""
with self._cache_lock:
# Check cache first
if shard_idx in self._shard_cache:
# Move to end (most recently used)
self._shard_cache.move_to_end(shard_idx)
self._cache_hits += 1
return self._shard_cache[shard_idx]

self._cache_misses += 1

# Load from disk (outside lock to allow concurrent loads)
import pickle
with open(self.shard_paths[shard_idx], 'rb') as f:
self._cache_shard_data = pickle.load(f)
self._cache_shard_idx = shard_idx
return self._cache_shard_data
shard_data = pickle.load(f)

# Add to cache with LRU eviction
with self._cache_lock:
# Evict oldest if at capacity
while len(self._shard_cache) >= self.cache_size:
self._shard_cache.popitem(last=False)

self._shard_cache[shard_idx] = shard_data
self._shard_cache.move_to_end(shard_idx)

return shard_data

def _prefetch_shard(self, shard_idx: int):
"""Background prefetch hint - loads shard if not in cache."""
if shard_idx < 0 or shard_idx >= self.num_shards:
return

with self._cache_lock:
if shard_idx in self._shard_cache:
return # Already cached

# Load without blocking current access
try:
import pickle
with open(self.shard_paths[shard_idx], 'rb') as f:
shard_data = pickle.load(f)

with self._cache_lock:
if shard_idx not in self._shard_cache:
while len(self._shard_cache) >= self.cache_size:
self._shard_cache.popitem(last=False)
self._shard_cache[shard_idx] = shard_data
except Exception:
pass # Silently fail prefetch

def _item_from_chunk(self, chunk_data: Tuple) -> Dict[str, torch.Tensor]:
"""Convert chunk tuple to sample dict (identical to HelixPrechunkedDataset and List[str] path)."""
Expand Down Expand Up @@ -830,6 +888,18 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
shard_data = self._load_shard(shard_idx)
chunk_data = shard_data[local_idx]

# Prefetch next shard (sequential access pattern optimization)
# In shuffled mode, next access is random, but prefetch may help
next_shard_idx = shard_idx + 1
if next_shard_idx < self.num_shards:
# Use a thread for non-blocking prefetch
try:
t = threading.Thread(target=self._prefetch_shard, args=(next_shard_idx,))
t.daemon = True
t.start()
except Exception:
pass # Silently fail

return self._item_from_chunk(chunk_data)


Expand Down Expand Up @@ -1072,40 +1142,47 @@ def drain_completed_futures():
# to match the List[str] path behavior exactly
dataset = HelixShardedDataset(shard_paths, seq_len, shuffle=False, seed=seed)

def collate_fn(batch):
input_ids = torch.stack([b["input_ids"] for b in batch])
labels = torch.stack([b["labels"] for b in batch])
attention_mask = torch.stack([b["attention_mask"] for b in batch])
is_natural_stop = torch.stack([b["is_natural_stop"] for b in batch])
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"is_natural_stop": is_natural_stop,
}
# Use module-level collate_fn for pickling with multiprocessing
collate_fn = _collate_batch

# DataLoader with prefetching and persistent workers for high throughput
prefetch_factor = 4 if num_workers > 0 else None
persistent_workers = num_workers > 0

# Create DataLoader with proper shuffle
if shuffle:
generator = torch.Generator()
generator.manual_seed(seed)
loader = DataLoader(
dataset,
loader_kwargs = dict(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
generator=generator,
collate_fn=collate_fn,
num_workers=num_workers,
drop_last=drop_last,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
# Linux optimization: use fork for faster worker spawning
import sys
if sys.platform == 'linux' and num_workers > 0:
loader_kwargs['multiprocessing_context'] = 'fork'
loader = DataLoader(**loader_kwargs)
else:
loader = DataLoader(
dataset,
loader_kwargs = dict(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn,
num_workers=num_workers,
drop_last=drop_last,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
import sys
if sys.platform == 'linux' and num_workers > 0:
loader_kwargs['multiprocessing_context'] = 'fork'
loader = DataLoader(**loader_kwargs)

return loader, shard_cache_dir

Expand Down
9 changes: 9 additions & 0 deletions helix_lm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __init__(
preprocess_num_proc: int = 5,
preprocess_batch_size: int = 1000,
cleanup_shards: bool = True,
# DataLoader performance options
num_workers: int = 4, # Number of DataLoader workers for prefetching
):
"""
Initialize Trainer.
Expand Down Expand Up @@ -121,6 +123,9 @@ def __init__(
preprocess_num_proc: Number of processes for preprocessing (streaming only).
preprocess_batch_size: Batch size for streaming preprocessing.
cleanup_shards: Whether to auto-cleanup shards after training.
num_workers: Number of DataLoader worker processes for background data loading.
Higher values enable more prefetching but use more CPU/RAM.
Set to 0 for single-threaded loading (default: 4).
"""
# Apply intelligent stride default based on seq_len
if stride is None:
Expand Down Expand Up @@ -181,6 +186,7 @@ def __init__(
stride=stride,
shuffle=True,
drop_last=True,
num_workers=num_workers,
min_tail_len=min_tail_len,
seed=getattr(cfg, 'seed', 42), # Use cfg.seed for determinism
shard_cache_dir=shard_cache_dir,
Expand All @@ -200,6 +206,7 @@ def __init__(
cfg.seq_len,
cfg.batch_size,
shuffle=True,
num_workers=num_workers,
min_tail_len=min_tail_len,
seed=getattr(cfg, 'seed', 42), # Use cfg.seed for determinism
lazy=True,
Expand All @@ -222,6 +229,7 @@ def __init__(
stride=stride,
shuffle=False,
drop_last=False,
num_workers=num_workers,
min_tail_len=min_tail_len,
seed=getattr(cfg, 'seed', 42), # Use cfg.seed for determinism
shard_cache_dir=shard_cache_dir,
Expand All @@ -242,6 +250,7 @@ def __init__(
cfg.batch_size,
shuffle=False,
drop_last=False,
num_workers=num_workers,
min_tail_len=min_tail_len,
seed=getattr(cfg, 'seed', 42), # Use cfg.seed for determinism
lazy=True,
Expand Down
Loading