Bug: Validation stalls at 0% when CUDA is enabled
Observed behavior
When training with CUDA, training completes normally, but validation stalls indefinitely after printing validation 0% ....
Root cause
This is a known PyTorch issue when using multiprocessing with fork after CUDA has been initialized. fork copies the parent process including the CUDA context, which becomes unsafe once CUDA is initialized. The child DataLoader workers inherit a corrupted CUDA state and deadlock.
Additional problems this would have caused
- Epoch-to-epoch shuffling: Re-instantiating the model or DataLoader between epochs (as the production script does) would hit the same deadlock.
- Any post-CUDA worker creation: Creating new
Trainer instances or DataLoader workers after torch.cuda is initialized would be affected.
Changes to make:
1. helix_lm/dataset.py: Changed multiprocessing context from fork to spawn
- Fixes potential deadlocks when:
- Creating new
Trainers across epochs (your production script does this)
- Any scenario where DataLoader workers are created after CUDA initialization
Why spawn instead of fork?
| Method |
Behavior |
CUDA safety |
fork |
Copies parent process including CUDA context |
❌ Unsafe after CUDA init |
spawn |
Creates fresh Python processes |
✅ CUDA-safe (slightly slower startup) |
2. helix_lm/trainer.py: Added val_num_workers parameter
- Added
val_num_workers: Optional[int] = None (defaults to 0)
- Updated docstring to document both
num_workers (training) and val_num_workers (validation)
- Validation defaults to
0 workers for safety and minimal overhead
- Users can override with
val_num_workers=4 (or any value) for large validation sets
3. helix_lm/dataset.py: Applied multiprocessing_context='spawn' globally
Changed multiprocessing context from fork to spawn in:
create_document_loader() — for non-streaming data
_handle_streaming_iterable() — for streaming data
This makes workers CUDA-safe regardless of when they're created.
Files modified
| File |
Change |
helix_lm/trainer.py |
Uses num_workers=0 for validation internally; adds val_num_workers parameter |
helix_lm/dataset.py |
Uses multiprocessing_context='spawn' for all DataLoaders with workers |
Issue 2: (AI Hallucination overlooked when reviewing the code):
- In
87M_param_production_trainer.py and 400M_production_trainer.py: Errors in tokenizer instantiation:
When loading tokenizers for HelixLM, replace:
from transformers import AutoTokenizer
...
tokenizer = AutoTokenizer.from_pretrained("gpt2")
with:
from helix_lm import HelixTokenizer
...
tokenizer = HelixTokenizer("gpt2")
AutoTokenizer.from_pretrained("gpt2") does not include the extensions we added for forward-compatibility with instruct fine-tuning (e.g., the Qwen3-compatible chat template). Using HelixTokenizer ensures these templates are available in downstream training stages using the same tokenizer.
Bug: Validation stalls at 0% when CUDA is enabled
Observed behavior
When training with CUDA, training completes normally, but validation stalls indefinitely after printing
validation 0% ....Root cause
This is a known PyTorch issue when using multiprocessing with
forkafter CUDA has been initialized.forkcopies the parent process including the CUDA context, which becomes unsafe once CUDA is initialized. The child DataLoader workers inherit a corrupted CUDA state and deadlock.Additional problems this would have caused
Trainerinstances or DataLoader workers aftertorch.cudais initialized would be affected.Changes to make:
1.
helix_lm/dataset.py: Changed multiprocessing context fromforktospawnTrainers across epochs (your production script does this)Why
spawninstead offork?forkspawn2.
helix_lm/trainer.py: Addedval_num_workersparameterval_num_workers: Optional[int] = None(defaults to0)num_workers(training) andval_num_workers(validation)0workers for safety and minimal overheadval_num_workers=4(or any value) for large validation sets3.
helix_lm/dataset.py: Appliedmultiprocessing_context='spawn'globallyChanged multiprocessing context from
forktospawnin:create_document_loader()— for non-streaming data_handle_streaming_iterable()— for streaming dataThis makes workers CUDA-safe regardless of when they're created.
Files modified
helix_lm/trainer.pynum_workers=0for validation internally; addsval_num_workersparameterhelix_lm/dataset.pymultiprocessing_context='spawn'for all DataLoaders with workersIssue 2: (AI Hallucination overlooked when reviewing the code):
87M_param_production_trainer.pyand400M_production_trainer.py: Errors in tokenizer instantiation:When loading tokenizers for HelixLM, replace:
with:
AutoTokenizer.from_pretrained("gpt2")does not include the extensions we added for forward-compatibility with instruct fine-tuning (e.g., the Qwen3-compatible chat template). UsingHelixTokenizerensures these templates are available in downstream training stages using the same tokenizer.