Skip to content

from_pretrained distributed refactor (FSDP2 + TP)#44996

Merged
3outeille merged 13 commits intofsdp-core-model-loadingfrom
distributed-refactor
Mar 26, 2026
Merged

from_pretrained distributed refactor (FSDP2 + TP)#44996
3outeille merged 13 commits intofsdp-core-model-loadingfrom
distributed-refactor

Conversation

@3outeille
Copy link
Copy Markdown
Member

@3outeille 3outeille commented Mar 25, 2026

  • Introduce DistributedConfig
    • DistributedConfig(tp_size=2, fsdp_size=2) # plans default to "auto" replaces passing separate tp_plan, tp_size, fsdp_plan kwargs. Sizes auto-fill (specify one, the other defaults to 1). Plans default to "auto" when a size is given.
  • init_device_mesh + distribute_model
    • init_device_mesh(distributed_config) → builds a 2D DeviceMesh("fsdp", "tp"), inits process group if needed
    • distribute_model(model, distributed_config, device_mesh) → attaches TP hooks before weight loading
    • apply_fsdp2(model, mesh, plan) → wraps with FSDP2 after weight loading
  • Made sure accelerate and transformers.distributed dont overlap. Separate codepath
distributed_config path accelerate path
Device setup init_device_mesh() check_and_set_device_map()
Pre-loading distribute_model() (TP hooks) distribute_model() + _get_device_map()
Post-loading apply_fsdp2() accelerate_dispatch()

Note: apply_fsdp2 will be moved to pre-loading soon as we dont want to to apply it full weights.

We can now train like this:

# torchrun --nproc_per_node=4 train_fsdp_tp.py

import os
import torch
import torch.distributed.checkpoint as dcp
from torch.utils.data import DataLoader, DistributedSampler
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.distributed import DistributedConfig

if __name__ == "__main__":

    model_name = "meta-llama/Llama-3.2-1B"
    num_steps, lr = 50, 3e-4
    save_dir = "./checkpoints"

    torch.distributed.init_process_group(backend="nccl")
    rank, local_rank = int(os.environ["RANK"]), int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
    tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token

    dataset = load_dataset("roneneldan/TinyStories", split="train[:1000]")
    dataset = dataset.map(lambda x: tokenizer(x["text"], truncation=True, padding="max_length", max_length=512), batched=True, remove_columns=dataset.column_names)
    dataset.set_format("torch")
    dataloader = DataLoader(dataset, batch_size=4, sampler=DistributedSampler(dataset))

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        distributed_config=DistributedConfig(tp_size=2, tp_plan="auto", fsdp_size=2, fsdp_plan="auto"),
        torch_dtype=torch.bfloat16,
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    model.train()
    for step, batch in enumerate(dataloader):
        if step >= num_steps:
            break
        input_ids = batch["input_ids"].to(f"cuda:{local_rank}")
        labels = input_ids.clone()
        labels[labels == tokenizer.pad_token_id] = -100

        loss = model(input_ids, labels=labels).loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if rank == 0 and step % 10 == 0:
            print(f"Step {step:>4d} | Loss: {loss.item():.4f}")

    # Save model (HF format) and optimizer (DCP)
    model.save_pretrained(save_dir)
    dcp.save({"optimizer": optimizer.state_dict()}, checkpoint_id=os.path.join(save_dir, "optimizer"))

    if rank == 0:
        tokenizer.save_pretrained(save_dir)
        print(f"Saved to {save_dir}")

    torch.distributed.destroy_process_group()

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@3outeille 3outeille force-pushed the distributed-refactor branch from 1436e31 to 4972eae Compare March 25, 2026 15:25
@3outeille 3outeille changed the base branch from fsdp-vs-ddp to fsdp-core-model-loading March 25, 2026 15:32
@3outeille 3outeille force-pushed the distributed-refactor branch from 4972eae to 091eaa2 Compare March 25, 2026 15:33
@3outeille 3outeille marked this pull request as ready for review March 25, 2026 15:46
…dConfig)

- Expand DistributedConfig with tp_size, tp_plan, fsdp_size, fsdp_plan
- Add init_device_mesh() for building 2D DeviceMesh from DistributedConfig
- Reuse apply_fsdp2() from PR #44083 for FSDP2 fully_shard wrapping
- Rewire from_pretrained with two clean separated paths:
  1. distributed_config → native torch.distributed (no accelerate)
  2. Everything else → accelerate (unchanged)
- Export DistributedConfig from top-level transformers package
- Add unit tests for DistributedConfig
@3outeille 3outeille force-pushed the distributed-refactor branch from 502a09f to c758c59 Compare March 25, 2026 15:53
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: clap, deit

@huggingface huggingface deleted a comment from amitmodi Mar 26, 2026
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44996&sha=0b8b77

@3outeille 3outeille merged commit 187ee5d into fsdp-core-model-loading Mar 26, 2026
29 of 31 checks passed
@3outeille 3outeille deleted the distributed-refactor branch March 26, 2026 15:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants