Skip to content

patch-fix-multiprocessing-deadlock-after-cuda- #44

Description

@david-thrower

Bug: Validation stalls at 0% when CUDA is enabled

Observed behavior

When training with CUDA, training completes normally, but validation stalls indefinitely after printing validation 0% ....

Root cause

This is a known PyTorch issue when using multiprocessing with fork after CUDA has been initialized. fork copies the parent process including the CUDA context, which becomes unsafe once CUDA is initialized. The child DataLoader workers inherit a corrupted CUDA state and deadlock.

Additional problems this would have caused

  • Epoch-to-epoch shuffling: Re-instantiating the model or DataLoader between epochs (as the production script does) would hit the same deadlock.
  • Any post-CUDA worker creation: Creating new Trainer instances or DataLoader workers after torch.cuda is initialized would be affected.

Changes to make:

1. helix_lm/dataset.py: Changed multiprocessing context from fork to spawn

  • Fixes potential deadlocks when:
    • Creating new Trainers across epochs (your production script does this)
    • Any scenario where DataLoader workers are created after CUDA initialization

Why spawn instead of fork?

Method Behavior CUDA safety
fork Copies parent process including CUDA context ❌ Unsafe after CUDA init
spawn Creates fresh Python processes ✅ CUDA-safe (slightly slower startup)

2. helix_lm/trainer.py: Added val_num_workers parameter

  • Added val_num_workers: Optional[int] = None (defaults to 0)
  • Updated docstring to document both num_workers (training) and val_num_workers (validation)
  • Validation defaults to 0 workers for safety and minimal overhead
  • Users can override with val_num_workers=4 (or any value) for large validation sets

3. helix_lm/dataset.py: Applied multiprocessing_context='spawn' globally

Changed multiprocessing context from fork to spawn in:

  1. create_document_loader() — for non-streaming data
  2. _handle_streaming_iterable() — for streaming data

This makes workers CUDA-safe regardless of when they're created.


Files modified

File Change
helix_lm/trainer.py Uses num_workers=0 for validation internally; adds val_num_workers parameter
helix_lm/dataset.py Uses multiprocessing_context='spawn' for all DataLoaders with workers

Issue 2: (AI Hallucination overlooked when reviewing the code):

  • In 87M_param_production_trainer.py and 400M_production_trainer.py: Errors in tokenizer instantiation:

When loading tokenizers for HelixLM, replace:

from transformers import AutoTokenizer
...
tokenizer = AutoTokenizer.from_pretrained("gpt2")

with:

from helix_lm import HelixTokenizer
...
tokenizer = HelixTokenizer("gpt2")

AutoTokenizer.from_pretrained("gpt2") does not include the extensions we added for forward-compatibility with instruct fine-tuning (e.g., the Qwen3-compatible chat template). Using HelixTokenizer ensures these templates are available in downstream training stages using the same tokenizer.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions