Skip to content

Feat/megatron/support sft native#701

Draft
wenxie-amd wants to merge 36 commits intomainfrom
feat/megatron/support-sft-native
Draft

Feat/megatron/support sft native#701
wenxie-amd wants to merge 36 commits intomainfrom
feat/megatron/support-sft-native

Conversation

@wenxie-amd
Copy link
Copy Markdown
Collaborator

No description provided.

Copilot AI and others added 30 commits January 30, 2026 04:38
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
…nd field extraction

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
…r registration

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Add load_jsonl_file() function to load data from JSONL files
- Update SFTDataset to detect and load local files (JSONL/JSON)
- Support both .jsonl and .json file formats
- Maintain backward compatibility with HuggingFace datasets
- Update documentation with offline dataset examples
- Add comprehensive unit tests for offline loading

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Add comprehensive offline dataset guide
- Add script to convert HuggingFace datasets to JSONL
- Add CSV to JSONL conversion support
- Provide examples and troubleshooting tips

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Add OpenAIMessagesFormatter class for multi-turn conversations
- Support messages field with role-content pairs (system/user/assistant)
- Implement specialized loss masking for multi-turn (only on assistant)
- Update SFTDataset to detect and handle messages format
- Add formatter options: "openai" and "messages"
- Create comprehensive tests for messages format
- Update documentation with examples and usage

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
…ions

- Create MULTI_TURN_CONVERSATIONS.md guide with examples and best practices
- Add example multi-turn conversation data file
- Create example config for multi-turn training
- Include conversion scripts and troubleshooting tips
- Document loss masking behavior for multi-turn dialogues

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Add _tokenize_text() helper method to handle different tokenizer interfaces
- Support tokenizers that return IDs directly from tokenize() (Megatron style)
- Support standard tokenize() + convert_tokens_to_ids() pattern
- Support encode() method as fallback
- Update _tokenize_and_mask() and _tokenize_and_mask_messages() to use new helper
- Fixes AttributeError: '_HuggingFaceTokenizer' object has no attribute 'convert_tokens_to_ids'

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create TOKENIZER_FIX.md with detailed explanation
- Update IMPLEMENTATION_SUMMARY.md with fix information
- Document the problem, solution, and implementation details
- Explain backward compatibility and testing

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Add position_ids generation before model forward call
- Create position_ids tensor with shape [batch_size, seq_len]
- Use torch.arange to create standard position encoding [0, 1, 2, ..., seq_len-1]
- Ensure position_ids is on same device as tokens
- Pass position_ids as positional argument to model
- Fixes TypeError: GPTModel.forward() missing 1 required positional argument: 'position_ids'

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create POSITION_IDS_FIX.md with detailed explanation
- Update IMPLEMENTATION_SUMMARY.md with latest fix
- Document the problem, solution, and implementation details
- Explain position_ids structure and why it's needed

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
…call

- Remove labels parameter from model forward call to get logits instead of loss
- Model now returns logits tensor which loss_func can process correctly
- Custom loss_func applies proper masking to compute SFT loss
- Fixes type mismatch error where loss_func expected logits but got scalar loss
- Aligns with standard Megatron SFT pattern of separating forward pass and loss computation

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create LOSS_COMPUTATION_FIX.md with detailed explanation
- Update IMPLEMENTATION_SUMMARY.md with latest fix
- Document the problem, solution, and benefits
- Explain why removing labels parameter is correct

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
…l for data parallel ops

- Change import from tensor_parallel to parallel_state at line 159
- Update get_data_parallel_world_size() call to use parallel_state
- Update get_data_parallel_group() call to use parallel_state
- Remove unused tensor_parallel import from loss_func
- Fixes RuntimeError: module 'megatron.core.tensor_parallel' has no attribute 'get_data_parallel_world_size'

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create PARALLEL_STATE_FIX.md with detailed explanation
- Update IMPLEMENTATION_SUMMARY.md with latest fix
- Document correct Megatron-LM API usage
- Explain tensor_parallel vs parallel_state modules

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create FORWARD_STEP_DESIGN_DECISION.md with comprehensive explanation
- Explain why we don't use megatron.bridge.training.vlm_step.forward_step
- Document original requirement: no Megatron-Bridge dependency
- List benefits of custom implementation vs external dependency
- Compare SFT-specific needs vs general VLM forward_step
- Provide guidance on when to use each approach
- Update README_SFT.md to reference design decision
- Update IMPLEMENTATION_SUMMARY.md with design note

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Port best practices from Megatron-Bridge forward_step patterns
- Add comprehensive error handling for data iteration and batch extraction
- Add proper shape validation and assertion checks
- Track num_tokens for accurate logging
- Add causal attention mask comment (handled internally by model)
- Fix data parallel loss averaging (divide by world_size after all_reduce)
- Add ignore_index=-100 to cross_entropy for standard padding handling
- Improve docstrings with detailed parameter and return value descriptions
- Add comments explaining each step of the process
- Maintain SFT-specific loss masking on response tokens only

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create FORWARD_STEP_ENHANCEMENTS.md with comprehensive documentation
- Explain what was ported from Megatron-Bridge patterns
- Document all enhancements: error handling, validation, loss averaging fix
- Highlight critical data parallel loss averaging bug fix
- Show before/after code comparisons
- Document benefits of each enhancement
- Update IMPLEMENTATION_SUMMARY.md to reference enhancements document
- Provide testing guidance and future enhancement suggestions

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create FINAL_SUMMARY.md with complete project overview
- Document all requirements, implementations, and achievements
- List all features: direct integration, datasets, multi-turn, robustness
- Detail all 5 critical fixes applied during development
- Show implementation statistics and code metrics
- Document design evolution through 5 phases
- Provide documentation structure and usage quick start
- Highlight what makes this implementation special
- List future enhancement possibilities
- Complete project ready for production use

Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
…or issue

- Create new sft_forward_step.py module with create_sft_forward_step() factory function
- Extract forward_step implementation from megatron_sft_trainer.py for better modularity
- Add None check for data_iterator to handle evaluation without validation dataset
- Add comprehensive documentation and type hints
- Improve code organization and maintainability
Fix AssertionError when loading HuggingFace checkpoints converted to Megatron format.

Problem:
- Megatron-Bridge's AutoBridge.import_ckpt() creates iteration 0 in metadata
- Megatron-LM requires iteration > 0 OR 'release' string
- This caused: AssertionError: error parsing metadata file

Solution:
- Automatically detect and fix metadata files with iteration 0
- Replace '0' with 'release' to indicate pretrained checkpoint
- Fix applied in checkpoint conversion hook

Changes:
- Added automatic metadata fix in 01_convert_checkpoints.py
- Creates docs/CHECKPOINT_METADATA_FIX.md explaining the issue
- Manual fix already applied to existing checkpoint

This allows SFT training to proceed without assertion errors.
Extend previous fix to handle both metadata AND directory structure.

Problem:
- Previous fix changed metadata from '0' to 'release' ✓
- But checkpoint was still in 'iter_0000000/' directory ✗
- Megatron expects 'release/' directory when metadata says 'release'
- This caused: FileNotFoundError: .../release

Root Cause:
Megatron's get_checkpoint_name() function:
  if release:
      directory = 'release'  # <-- Must match metadata
  else:
      directory = 'iter_{:07d}'.format(iteration)

Solution:
1. Update metadata: '0' -> 'release' (already done)
2. Rename directory: 'iter_0000000/' -> 'release/' (NEW)

Changes:
- Enhanced 01_convert_checkpoints.py to rename directory
- Updated CHECKPOINT_METADATA_FIX.md with complete explanation
- Manual fix applied to existing checkpoint

After Fix:
checkpoint/
├── latest_checkpointed_iteration.txt  (contains 'release')
└── release/                           (renamed from iter_0000000/)
    └── ... (checkpoint files)

This ensures metadata and directory structure are consistent.
Xiaoming and others added 6 commits February 2, 2026 03:15
- Add 'args' to common.pt for Megatron-LM checkpoint loading
- Patch load_checkpoint to handle missing 'model' key in torch_dist format
- HuggingFace converted checkpoints use TP=1, PP=1 defaults
- Port PEFT module from Megatron-Bridge for LoRA support
- Add LoRA configuration options to sft_trainer.yaml
- Integrate LoRA into MegatronSFTTrainer with model provider wrapper
- Add detailed logging for trainable/frozen parameter statistics
- Support configurable: rank, alpha, dropout, target_modules, init methods

New files:
- primus/backends/megatron/peft/ - Complete PEFT module
- examples/megatron/configs/MI355X/llama3_8B-BF16-lora-sft.yaml

Usage: Set lora_enabled: true in config to enable LoRA fine-tuning
- Move forward_step.py to sft/ directory
- Add sft/__init__.py
- Add gpt_sft_chat_dataset.py to sft/
- Simplify docstrings in megatron_sft_trainer.py
Preserve the main-branch runtime/backend adapter flow while keeping stage-aware trainer registration for the current backend integrations.
Modularize the native Megatron SFT flow into dedicated dataset, formatting,
runtime, and forward-step components while keeping the legacy dataset import
path compatible.
Copilot AI review requested due to automatic review settings April 29, 2026 06:47
return (loss, num_tokens, {"lm loss": reporting_loss})

if getattr(args, "use_legacy_models", False):
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
Comment on lines +251 to +259
# if tensor_parallel_size > 1:
# kwargs["tensor_parallel_group"] = parallel_state.get_tensor_model_parallel_group()
# if isinstance(self.to_wrap, (te.Linear, te.LayerNormLinear)):
# kwargs["tensor_parallel_mode"] = self.to_wrap.parallel_mode
# kwargs["sequence_parallel"] = self.to_wrap.sequence_parallel
# if kwargs["tensor_parallel_mode"] == "row":
# kwargs["in_features"] *= tensor_parallel_size
# elif kwargs["tensor_parallel_mode"] == "column":
# kwargs["out_features"] *= tensor_parallel_size
seq_lengths = []
loss_ratios = []
context_lengths = []
answer_lengths = []
from primus.backends.megatron.peft.walk_utils import walk


logger: logging.Logger = logging.getLogger(__name__)
class HasBool(Protocol):
"""Protocol for objects that can be evaluated as boolean."""

def __bool__(self) -> bool: ...
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Megatron-native SFT (supervised fine-tuning) support to Primus, including stage-based trainer registration, dataset/formatter abstractions (offline + multi-turn), and runner hooks for HF→Megatron checkpoint conversion.

Changes:

  • Introduce a stage-based trainer class registry and update backend adapters to load trainers by (backend, stage).
  • Add Megatron-native SFT stack (trainer + runtime wiring + dataset/formatters/preprocessing/forward_step), plus PEFT (LoRA) utilities.
  • Add runner hooks, example configs/scripts, and tests/docs for offline JSON/JSONL and OpenAI-messages multi-turn SFT.

Reviewed changes

Copilot reviewed 60 out of 60 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
tests/unit_tests/backends/megatron/test_sft_dataset_offline.py Unit tests for offline JSON/JSONL dataset loading
tests/unit_tests/backends/megatron/test_sft_abstractions.py Tests for SFT schema/formatting/tokenization + forward_step behaviors
tests/unit_tests/backends/megatron/test_messages_format.py Tests for OpenAI messages format + dataset integration
tests/unit_tests/backends/megatron/test_megatron_sft_trainer.py Tests for MegatronSFTTrainer wiring/runtime behavior
tests/unit_tests/backends/megatron/test_megatron_registration.py Update registration test for stage-based trainers
tests/unit_tests/backends/megatron/test_megatron_adapter.py Adapter tests updated to support stage-based trainer loading
runner/helpers/hooks/train/posttrain/megatron/01_convert_checkpoints.py HF→Megatron conversion hook + checkpoint layout normalization
runner/helpers/hooks/train/posttrain/megatron/00_install_requirements.sh Install deps for conversion hook
run.sh Convenience invocation snippet for SFT run
primus/core/backend/backend_registry.py Add (backend, stage) trainer registry APIs
primus/core/backend/backend_adapter.py Doc update for stage-based trainer lookup
primus/configs/modules/megatron/sft_trainer.yaml Base Megatron SFT module config
primus/configs/models/megatron/llama3_8B.yaml Switch tokenizer type for Llama3 Megatron config
primus/backends/torchtitan/torchtitan_adapter.py Use BackendRegistry trainer lookup by stage
primus/backends/torchtitan/init.py Register torchtitan pretrain trainer class
primus/backends/megatron_bridge/megatron_bridge_adapter.py Use BackendRegistry trainer lookup by stage
primus/backends/megatron_bridge/init.py Register megatron_bridge SFT trainer class
primus/backends/megatron/sft/schema.py Normalized SFT sample/message schema
primus/backends/megatron/sft/runtime.py Datasets provider + Megatron pretrain entrypoint compatibility
primus/backends/megatron/sft/preprocessing.py Local record loading + tokenization + loss mask construction
primus/backends/megatron/sft/forward_step.py SFT forward_step with masked loss + legacy/new Megatron support
primus/backends/megatron/sft/formatters.py Alpaca/ChatML/OpenAI-messages formatters
primus/backends/megatron/sft/dataset.py SFTDataset supporting HF hub + local files + formatter-driven masks
primus/backends/megatron/sft/init.py Public exports for Megatron SFT package
primus/backends/megatron/peft/walk_utils.py Module walking utilities (ported) for PEFT transforms
primus/backends/megatron/peft/recompute.py Recompute-related patching for adapter-only training
primus/backends/megatron/peft/module_matcher.py Module matching for LoRA targeting
primus/backends/megatron/peft/lora.py LoRA implementation/patching
primus/backends/megatron/peft/import_utils.py Safe import helpers for optional deps
primus/backends/megatron/peft/base.py PEFT base class and application mechanics
primus/backends/megatron/peft/adapter_wrapper.py Adapter wrapper helpers for state dict behaviors
primus/backends/megatron/peft/init.py PEFT package exports
primus/backends/megatron/patches/checkpoint_patches.py Patch for distributed checkpoint loading edge-case
primus/backends/megatron/megatron_sft_trainer.py New MegatronSFTTrainer implementation
primus/backends/megatron/megatron_adapter.py Load trainer via stage-based BackendRegistry lookup
primus/backends/megatron/core/datasets/sft_dataset.py Compatibility shim re-exporting new SFT APIs
primus/backends/megatron/init.py Register megatron pretrain + sft trainers
primus/backends/megatron/README_SFT.md Megatron SFT user documentation
examples/megatron/convert_to_jsonl.py Utility to convert datasets/CSV to JSONL
examples/megatron/configs/MI355X/llama3_8B-BF16-sft.yaml Example single-turn SFT experiment config
examples/megatron/configs/MI355X/llama3_8B-BF16-multiturn-sft.yaml Example multi-turn (OpenAI messages) SFT config
examples/megatron/configs/MI355X/llama3_8B-BF16-lora-sft.yaml Example LoRA SFT config
docs/OFFLINE_DATASET_GUIDE.md Offline dataset usage guide
docs/MULTI_TURN_CONVERSATIONS.md Multi-turn messages format usage guide
docs/MEGATRON_BRIDGE_LOSS_PORT.md Notes on loss function porting alignment
docs/CHECKPOINT_METADATA_FIX.md Notes on checkpoint metadata/layout fixes
TOKENIZER_FIX.md Tokenizer interface fix notes
REBASE_SUMMARY.md Summary of stage-based registry rebase
POSITION_IDS_FIX.md Position IDs fix notes
PARALLEL_STATE_FIX.md Parallel state API fix notes
OFFLINE_DATASET_IMPLEMENTATION.md Offline dataset implementation notes
MULTI_TURN_IMPLEMENTATION.md Multi-turn implementation notes
LOSS_COMPUTATION_FIX.md Loss computation fix notes
FORWARD_STEP_ENHANCEMENTS.md Forward-step enhancements notes
FORWARD_STEP_DESIGN_DECISION.md Design decision notes for forward_step implementation
FINAL_SUMMARY.md Overall implementation summary

Comment on lines +79 to +86
raw_messages = sample.get("messages")
if isinstance(raw_messages, Sequence) and not isinstance(raw_messages, (str, bytes)):
messages = tuple(
SFTMessage.from_mapping(message)
for message in raw_messages
if isinstance(message, Mapping)
)
return cls(messages=messages)
Comment on lines +123 to +135
loss_mask = np.zeros(len(token_ids), dtype=np.int64)
prefix_text = ""
prefix_token_count = 0
for segment in formatted_sample.segments:
start = prefix_token_count
prefix_text += segment.text
prefix_token_count = len(tokenize_text(tokenizer, prefix_text))
end = prefix_token_count

if segment.supervise and start < len(token_ids):
loss_mask[start:min(end, len(token_ids))] = 1
if start >= len(token_ids):
break
Comment on lines +40 to +83
auto_detect_ckpt_format: true
ckpt_format: torch_dist
fully_parallel_save: true

# Learning rate (lower than pretraining for fine-tuning)
lr: 1.0e-5
min_lr: 0.0
lr_decay_iters: null
lr_warmup_iters: 50
lr_decay_style: cosine

# Optimizer settings
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
override_opt_param_scheduler: true

# Attention and distributed settings
use_flash_attn: true
distributed_timeout_minutes: 60
use_distributed_optimizer: true

# SFT-specific settings
# Note: These are typically overridden in the experiment config
sft_dataset_name: "tatsu-lab/alpaca"
sft_conversation_format: "alpaca"

# LoRA (Low-Rank Adaptation) settings
lora:
enabled: false # Set true for parameter-efficient fine-tuning
dim: 32 # Low-rank dimension (8-64 typical)
alpha: 32 # Scaling factor (usually = dim)
dropout: 0.0 # Dropout rate (0.0-0.1)
dropout_position: pre # "pre" or "post"
lora_A_init_method: xavier # "xavier" or "kaiming"
lora_B_init_method: zero # "zero" (standard)
target_modules: # Modules to apply LoRA
- linear_qkv
- linear_proj
- linear_fc1
- linear_fc2

ckpt_format: torch_dist
fully_parallel_save: true
Comment on lines +145 to +149
AutoBridge.import_ckpt(
hf_model_id=hf_path,
megatron_path=megatron_path,
trust_remote_code=True,
)
Comment on lines +154 to +163
def wait_for_conversion(done_file: Path, lock_file: Path, timeout: int = 600):
"""Wait for rank 0 to complete checkpoint conversion."""
elapsed = 0
while not done_file.exists() and elapsed < timeout:
if not lock_file.exists() and not done_file.exists():
time.sleep(2)
else:
time.sleep(5)
elapsed += 5

Comment on lines +64 to +76
"""Applies a function to a PyTorch module or a collection of modules.

This function can be used to modify modules in place, such as changing their attributes,
applying normalization, or any other custom transformations. It supports individual modules,
lists of modules, and dictionaries of modules. The function can be applied selectively to
modules that do not have parameters if `leaf_only` is set to True.

Args:
module: The module or collection of modules to which the function will be applied.
func: A callable that takes a module (and optionally additional keyword arguments) and
returns a transformed module. The signature should be `func(module, **kwargs)`.
leaf_only: If True, the function will only be applied to modules that
do not have any parameters. Defaults to False.
Comment thread run.sh
Comment on lines +1 to +4



PRIMUS_TRAIN_RUNTIME=core ./primus-cli direct -- train posttrain --config ./examples/megatron/configs/MI355X/llama3_8B-BF16-sft.yaml No newline at end of file
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants