From d5c7d249aa73e255702bcd2dfe5b20623a4df09b Mon Sep 17 00:00:00 2001 From: David Thrower Date: Tue, 16 Jun 2026 23:36:50 +0000 Subject: [PATCH] Optimize disk to model trnasit for streaming data path. --- 87M_param_production_trainer.py | 1 + helix_lm/dataset.py | 171 +++++++++++++++++++++++--------- helix_lm/trainer.py | 9 ++ 3 files changed, 134 insertions(+), 47 deletions(-) diff --git a/87M_param_production_trainer.py b/87M_param_production_trainer.py index 99f0c6f..1767f07 100644 --- a/87M_param_production_trainer.py +++ b/87M_param_production_trainer.py @@ -340,6 +340,7 @@ def main(): # Streaming-specific options preprocess_batch_size=PREPROCESS_BATCH_SIZE, cleanup_shards=CLEANUP_SHARDS, + num_workers=4, # DataLoader workers for prefetching ) # Constant LR: cosine with min_lr_ratio=1.0 = flat after warmup diff --git a/helix_lm/dataset.py b/helix_lm/dataset.py index b41771e..22d7b96 100644 --- a/helix_lm/dataset.py +++ b/helix_lm/dataset.py @@ -13,13 +13,29 @@ Compatible with both eager and lazy loading. """ import random +import threading from typing import List, Optional, Iterator, Dict, Any, Union, Tuple +from collections import OrderedDict import torch from torch.utils.data import IterableDataset, Dataset, DataLoader from tqdm import tqdm +def _collate_batch(batch): + """Module-level collate function for pickling with multiprocessing.""" + input_ids = torch.stack([b["input_ids"] for b in batch]) + labels = torch.stack([b["labels"] for b in batch]) + attention_mask = torch.stack([b["attention_mask"] for b in batch]) + is_natural_stop = torch.stack([b["is_natural_stop"] for b in batch]) + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "is_natural_stop": is_natural_stop, + } + + class HelixDataset(Dataset): """ Index-based Dataset with rolling chunking for language model pretraining. @@ -481,17 +497,8 @@ def create_helix_dataloader( ) -> torch.utils.data.DataLoader: dataset = HelixDataset(texts, tokenizer, seq_len, stride, lazy=lazy, **kwargs) - def collate_fn(batch): - input_ids = torch.stack([b["input_ids"] for b in batch]) - labels = torch.stack([b["labels"] for b in batch]) - attention_mask = torch.stack([b["attention_mask"] for b in batch]) - is_natural_stop = torch.stack([b["is_natural_stop"] for b in batch]) - return { - "input_ids": input_ids, - "labels": labels, - "attention_mask": attention_mask, - "is_natural_stop": is_natural_stop, - } + # Use module-level collate_fn for pickling with multiprocessing + collate_fn = _collate_batch return torch.utils.data.DataLoader( dataset, @@ -531,17 +538,8 @@ def create_document_loader( min_tail_len=min_tail_len, add_eos=add_eos, lazy=lazy, stride=stride, ) - def collate_fn(batch): - input_ids = torch.stack([b["input_ids"] for b in batch]) - labels = torch.stack([b["labels"] for b in batch]) - attention_mask = torch.stack([b["attention_mask"] for b in batch]) - is_natural_stop = torch.stack([b["is_natural_stop"] for b in batch]) - return { - "input_ids": input_ids, - "labels": labels, - "attention_mask": attention_mask, - "is_natural_stop": is_natural_stop, - } + # Use module-level collate_fn for pickling with multiprocessing + collate_fn = _collate_batch # Use torch.Generator for deterministic shuffling with given seed if shuffle: @@ -733,6 +731,8 @@ class HelixShardedDataset(Dataset): The shuffle implementation uses a Torch Generator (like List[str] path) for deterministic reproducibility and identical behavior to non-streaming datasets. + + OPTIMIZED: Multi-shard LRU cache for high-throughput training """ def __init__( self, @@ -740,10 +740,13 @@ def __init__( seq_len: int, shuffle: bool = False, seed: int = 42, + cache_size: int = 8, # Number of shards to keep in memory ): super().__init__() self.shard_paths = shard_paths self.seq_len = seq_len + self.cache_size = cache_size + self._cache_lock = threading.RLock() # Build shard index: cumulative offsets for O(1) __getitem__ self.shard_sizes = [] @@ -765,6 +768,7 @@ def __init__( self.shard_offsets.append(total) self.total_size = total + self.num_shards = len(shard_paths) # Apply deterministic shuffle to index_map (only 2*int per sample, not full data) if shuffle: @@ -775,23 +779,77 @@ def __init__( # Reorder index_map according to permutation self._index_map = [self._index_map[i] for i in perm] - # Cache for current shard to avoid repeated disk reads - self._cache_shard_idx: Optional[int] = None - self._cache_shard_data: Optional[List] = None + # Multi-shard LRU cache: OrderedDict for O(1) move_to_end + # Keys: shard_idx, Values: shard_data (List of chunks) + self._shard_cache: OrderedDict[int, List] = OrderedDict() + self._cache_hits = 0 + self._cache_misses = 0 def __len__(self) -> int: return self.total_size + def get_cache_stats(self) -> Dict[str, Any]: + """Return cache hit/miss statistics.""" + total = self._cache_hits + self._cache_misses + hit_rate = self._cache_hits / total if total > 0 else 0 + return { + "hits": self._cache_hits, + "misses": self._cache_misses, + "hit_rate": hit_rate, + "cached_shards": len(self._shard_cache), + "total_shards": self.num_shards, + } + def _load_shard(self, shard_idx: int) -> List: - """Load shard data with caching.""" - if self._cache_shard_idx == shard_idx: - return self._cache_shard_data + """Load shard data with multi-shard LRU caching.""" + with self._cache_lock: + # Check cache first + if shard_idx in self._shard_cache: + # Move to end (most recently used) + self._shard_cache.move_to_end(shard_idx) + self._cache_hits += 1 + return self._shard_cache[shard_idx] + + self._cache_misses += 1 + # Load from disk (outside lock to allow concurrent loads) import pickle with open(self.shard_paths[shard_idx], 'rb') as f: - self._cache_shard_data = pickle.load(f) - self._cache_shard_idx = shard_idx - return self._cache_shard_data + shard_data = pickle.load(f) + + # Add to cache with LRU eviction + with self._cache_lock: + # Evict oldest if at capacity + while len(self._shard_cache) >= self.cache_size: + self._shard_cache.popitem(last=False) + + self._shard_cache[shard_idx] = shard_data + self._shard_cache.move_to_end(shard_idx) + + return shard_data + + def _prefetch_shard(self, shard_idx: int): + """Background prefetch hint - loads shard if not in cache.""" + if shard_idx < 0 or shard_idx >= self.num_shards: + return + + with self._cache_lock: + if shard_idx in self._shard_cache: + return # Already cached + + # Load without blocking current access + try: + import pickle + with open(self.shard_paths[shard_idx], 'rb') as f: + shard_data = pickle.load(f) + + with self._cache_lock: + if shard_idx not in self._shard_cache: + while len(self._shard_cache) >= self.cache_size: + self._shard_cache.popitem(last=False) + self._shard_cache[shard_idx] = shard_data + except Exception: + pass # Silently fail prefetch def _item_from_chunk(self, chunk_data: Tuple) -> Dict[str, torch.Tensor]: """Convert chunk tuple to sample dict (identical to HelixPrechunkedDataset and List[str] path).""" @@ -830,6 +888,18 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: shard_data = self._load_shard(shard_idx) chunk_data = shard_data[local_idx] + # Prefetch next shard (sequential access pattern optimization) + # In shuffled mode, next access is random, but prefetch may help + next_shard_idx = shard_idx + 1 + if next_shard_idx < self.num_shards: + # Use a thread for non-blocking prefetch + try: + t = threading.Thread(target=self._prefetch_shard, args=(next_shard_idx,)) + t.daemon = True + t.start() + except Exception: + pass # Silently fail + return self._item_from_chunk(chunk_data) @@ -1072,40 +1142,47 @@ def drain_completed_futures(): # to match the List[str] path behavior exactly dataset = HelixShardedDataset(shard_paths, seq_len, shuffle=False, seed=seed) - def collate_fn(batch): - input_ids = torch.stack([b["input_ids"] for b in batch]) - labels = torch.stack([b["labels"] for b in batch]) - attention_mask = torch.stack([b["attention_mask"] for b in batch]) - is_natural_stop = torch.stack([b["is_natural_stop"] for b in batch]) - return { - "input_ids": input_ids, - "labels": labels, - "attention_mask": attention_mask, - "is_natural_stop": is_natural_stop, - } + # Use module-level collate_fn for pickling with multiprocessing + collate_fn = _collate_batch + + # DataLoader with prefetching and persistent workers for high throughput + prefetch_factor = 4 if num_workers > 0 else None + persistent_workers = num_workers > 0 - # Create DataLoader with proper shuffle if shuffle: generator = torch.Generator() generator.manual_seed(seed) - loader = DataLoader( - dataset, + loader_kwargs = dict( + dataset=dataset, batch_size=batch_size, shuffle=True, generator=generator, collate_fn=collate_fn, num_workers=num_workers, drop_last=drop_last, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, ) + # Linux optimization: use fork for faster worker spawning + import sys + if sys.platform == 'linux' and num_workers > 0: + loader_kwargs['multiprocessing_context'] = 'fork' + loader = DataLoader(**loader_kwargs) else: - loader = DataLoader( - dataset, + loader_kwargs = dict( + dataset=dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=num_workers, drop_last=drop_last, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, ) + import sys + if sys.platform == 'linux' and num_workers > 0: + loader_kwargs['multiprocessing_context'] = 'fork' + loader = DataLoader(**loader_kwargs) return loader, shard_cache_dir diff --git a/helix_lm/trainer.py b/helix_lm/trainer.py index bb5e20c..a336f39 100644 --- a/helix_lm/trainer.py +++ b/helix_lm/trainer.py @@ -92,6 +92,8 @@ def __init__( preprocess_num_proc: int = 5, preprocess_batch_size: int = 1000, cleanup_shards: bool = True, + # DataLoader performance options + num_workers: int = 4, # Number of DataLoader workers for prefetching ): """ Initialize Trainer. @@ -121,6 +123,9 @@ def __init__( preprocess_num_proc: Number of processes for preprocessing (streaming only). preprocess_batch_size: Batch size for streaming preprocessing. cleanup_shards: Whether to auto-cleanup shards after training. + num_workers: Number of DataLoader worker processes for background data loading. + Higher values enable more prefetching but use more CPU/RAM. + Set to 0 for single-threaded loading (default: 4). """ # Apply intelligent stride default based on seq_len if stride is None: @@ -181,6 +186,7 @@ def __init__( stride=stride, shuffle=True, drop_last=True, + num_workers=num_workers, min_tail_len=min_tail_len, seed=getattr(cfg, 'seed', 42), # Use cfg.seed for determinism shard_cache_dir=shard_cache_dir, @@ -200,6 +206,7 @@ def __init__( cfg.seq_len, cfg.batch_size, shuffle=True, + num_workers=num_workers, min_tail_len=min_tail_len, seed=getattr(cfg, 'seed', 42), # Use cfg.seed for determinism lazy=True, @@ -222,6 +229,7 @@ def __init__( stride=stride, shuffle=False, drop_last=False, + num_workers=num_workers, min_tail_len=min_tail_len, seed=getattr(cfg, 'seed', 42), # Use cfg.seed for determinism shard_cache_dir=shard_cache_dir, @@ -242,6 +250,7 @@ def __init__( cfg.batch_size, shuffle=False, drop_last=False, + num_workers=num_workers, min_tail_len=min_tail_len, seed=getattr(cfg, 'seed', 42), # Use cfg.seed for determinism lazy=True,