Feat/megatron/support sft native#701
Draft
wenxie-amd wants to merge 36 commits intomainfrom
Draft
Conversation
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
…nd field extraction Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
…r registration Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Add load_jsonl_file() function to load data from JSONL files - Update SFTDataset to detect and load local files (JSONL/JSON) - Support both .jsonl and .json file formats - Maintain backward compatibility with HuggingFace datasets - Update documentation with offline dataset examples - Add comprehensive unit tests for offline loading Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Add comprehensive offline dataset guide - Add script to convert HuggingFace datasets to JSONL - Add CSV to JSONL conversion support - Provide examples and troubleshooting tips Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Add OpenAIMessagesFormatter class for multi-turn conversations - Support messages field with role-content pairs (system/user/assistant) - Implement specialized loss masking for multi-turn (only on assistant) - Update SFTDataset to detect and handle messages format - Add formatter options: "openai" and "messages" - Create comprehensive tests for messages format - Update documentation with examples and usage Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
…ions - Create MULTI_TURN_CONVERSATIONS.md guide with examples and best practices - Add example multi-turn conversation data file - Create example config for multi-turn training - Include conversion scripts and troubleshooting tips - Document loss masking behavior for multi-turn dialogues Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Add _tokenize_text() helper method to handle different tokenizer interfaces - Support tokenizers that return IDs directly from tokenize() (Megatron style) - Support standard tokenize() + convert_tokens_to_ids() pattern - Support encode() method as fallback - Update _tokenize_and_mask() and _tokenize_and_mask_messages() to use new helper - Fixes AttributeError: '_HuggingFaceTokenizer' object has no attribute 'convert_tokens_to_ids' Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create TOKENIZER_FIX.md with detailed explanation - Update IMPLEMENTATION_SUMMARY.md with fix information - Document the problem, solution, and implementation details - Explain backward compatibility and testing Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Add position_ids generation before model forward call - Create position_ids tensor with shape [batch_size, seq_len] - Use torch.arange to create standard position encoding [0, 1, 2, ..., seq_len-1] - Ensure position_ids is on same device as tokens - Pass position_ids as positional argument to model - Fixes TypeError: GPTModel.forward() missing 1 required positional argument: 'position_ids' Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create POSITION_IDS_FIX.md with detailed explanation - Update IMPLEMENTATION_SUMMARY.md with latest fix - Document the problem, solution, and implementation details - Explain position_ids structure and why it's needed Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
…call - Remove labels parameter from model forward call to get logits instead of loss - Model now returns logits tensor which loss_func can process correctly - Custom loss_func applies proper masking to compute SFT loss - Fixes type mismatch error where loss_func expected logits but got scalar loss - Aligns with standard Megatron SFT pattern of separating forward pass and loss computation Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create LOSS_COMPUTATION_FIX.md with detailed explanation - Update IMPLEMENTATION_SUMMARY.md with latest fix - Document the problem, solution, and benefits - Explain why removing labels parameter is correct Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
…l for data parallel ops - Change import from tensor_parallel to parallel_state at line 159 - Update get_data_parallel_world_size() call to use parallel_state - Update get_data_parallel_group() call to use parallel_state - Remove unused tensor_parallel import from loss_func - Fixes RuntimeError: module 'megatron.core.tensor_parallel' has no attribute 'get_data_parallel_world_size' Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create PARALLEL_STATE_FIX.md with detailed explanation - Update IMPLEMENTATION_SUMMARY.md with latest fix - Document correct Megatron-LM API usage - Explain tensor_parallel vs parallel_state modules Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create FORWARD_STEP_DESIGN_DECISION.md with comprehensive explanation - Explain why we don't use megatron.bridge.training.vlm_step.forward_step - Document original requirement: no Megatron-Bridge dependency - List benefits of custom implementation vs external dependency - Compare SFT-specific needs vs general VLM forward_step - Provide guidance on when to use each approach - Update README_SFT.md to reference design decision - Update IMPLEMENTATION_SUMMARY.md with design note Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Port best practices from Megatron-Bridge forward_step patterns - Add comprehensive error handling for data iteration and batch extraction - Add proper shape validation and assertion checks - Track num_tokens for accurate logging - Add causal attention mask comment (handled internally by model) - Fix data parallel loss averaging (divide by world_size after all_reduce) - Add ignore_index=-100 to cross_entropy for standard padding handling - Improve docstrings with detailed parameter and return value descriptions - Add comments explaining each step of the process - Maintain SFT-specific loss masking on response tokens only Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create FORWARD_STEP_ENHANCEMENTS.md with comprehensive documentation - Explain what was ported from Megatron-Bridge patterns - Document all enhancements: error handling, validation, loss averaging fix - Highlight critical data parallel loss averaging bug fix - Show before/after code comparisons - Document benefits of each enhancement - Update IMPLEMENTATION_SUMMARY.md to reference enhancements document - Provide testing guidance and future enhancement suggestions Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
- Create FINAL_SUMMARY.md with complete project overview - Document all requirements, implementations, and achievements - List all features: direct integration, datasets, multi-turn, robustness - Detail all 5 critical fixes applied during development - Show implementation statistics and code metrics - Document design evolution through 5 phases - Provide documentation structure and usage quick start - Highlight what makes this implementation special - List future enhancement possibilities - Complete project ready for production use Co-authored-by: Xiaoming-AMD <198007710+Xiaoming-AMD@users.noreply.github.com>
…or issue - Create new sft_forward_step.py module with create_sft_forward_step() factory function - Extract forward_step implementation from megatron_sft_trainer.py for better modularity - Add None check for data_iterator to handle evaluation without validation dataset - Add comprehensive documentation and type hints - Improve code organization and maintainability
Fix AssertionError when loading HuggingFace checkpoints converted to Megatron format. Problem: - Megatron-Bridge's AutoBridge.import_ckpt() creates iteration 0 in metadata - Megatron-LM requires iteration > 0 OR 'release' string - This caused: AssertionError: error parsing metadata file Solution: - Automatically detect and fix metadata files with iteration 0 - Replace '0' with 'release' to indicate pretrained checkpoint - Fix applied in checkpoint conversion hook Changes: - Added automatic metadata fix in 01_convert_checkpoints.py - Creates docs/CHECKPOINT_METADATA_FIX.md explaining the issue - Manual fix already applied to existing checkpoint This allows SFT training to proceed without assertion errors.
Extend previous fix to handle both metadata AND directory structure.
Problem:
- Previous fix changed metadata from '0' to 'release' ✓
- But checkpoint was still in 'iter_0000000/' directory ✗
- Megatron expects 'release/' directory when metadata says 'release'
- This caused: FileNotFoundError: .../release
Root Cause:
Megatron's get_checkpoint_name() function:
if release:
directory = 'release' # <-- Must match metadata
else:
directory = 'iter_{:07d}'.format(iteration)
Solution:
1. Update metadata: '0' -> 'release' (already done)
2. Rename directory: 'iter_0000000/' -> 'release/' (NEW)
Changes:
- Enhanced 01_convert_checkpoints.py to rename directory
- Updated CHECKPOINT_METADATA_FIX.md with complete explanation
- Manual fix applied to existing checkpoint
After Fix:
checkpoint/
├── latest_checkpointed_iteration.txt (contains 'release')
└── release/ (renamed from iter_0000000/)
└── ... (checkpoint files)
This ensures metadata and directory structure are consistent.
- Add 'args' to common.pt for Megatron-LM checkpoint loading - Patch load_checkpoint to handle missing 'model' key in torch_dist format - HuggingFace converted checkpoints use TP=1, PP=1 defaults
- Port PEFT module from Megatron-Bridge for LoRA support - Add LoRA configuration options to sft_trainer.yaml - Integrate LoRA into MegatronSFTTrainer with model provider wrapper - Add detailed logging for trainable/frozen parameter statistics - Support configurable: rank, alpha, dropout, target_modules, init methods New files: - primus/backends/megatron/peft/ - Complete PEFT module - examples/megatron/configs/MI355X/llama3_8B-BF16-lora-sft.yaml Usage: Set lora_enabled: true in config to enable LoRA fine-tuning
- Move forward_step.py to sft/ directory - Add sft/__init__.py - Add gpt_sft_chat_dataset.py to sft/ - Simplify docstrings in megatron_sft_trainer.py
Preserve the main-branch runtime/backend adapter flow while keeping stage-aware trainer registration for the current backend integrations.
Modularize the native Megatron SFT flow into dedicated dataset, formatting, runtime, and forward-step components while keeping the legacy dataset import path compatible.
| return (loss, num_tokens, {"lm loss": reporting_loss}) | ||
|
|
||
| if getattr(args, "use_legacy_models", False): | ||
| output_tensor = model(tokens, position_ids, attention_mask, labels=labels) |
Comment on lines
+251
to
+259
| # if tensor_parallel_size > 1: | ||
| # kwargs["tensor_parallel_group"] = parallel_state.get_tensor_model_parallel_group() | ||
| # if isinstance(self.to_wrap, (te.Linear, te.LayerNormLinear)): | ||
| # kwargs["tensor_parallel_mode"] = self.to_wrap.parallel_mode | ||
| # kwargs["sequence_parallel"] = self.to_wrap.sequence_parallel | ||
| # if kwargs["tensor_parallel_mode"] == "row": | ||
| # kwargs["in_features"] *= tensor_parallel_size | ||
| # elif kwargs["tensor_parallel_mode"] == "column": | ||
| # kwargs["out_features"] *= tensor_parallel_size |
| seq_lengths = [] | ||
| loss_ratios = [] | ||
| context_lengths = [] | ||
| answer_lengths = [] |
| from primus.backends.megatron.peft.walk_utils import walk | ||
|
|
||
|
|
||
| logger: logging.Logger = logging.getLogger(__name__) |
| class HasBool(Protocol): | ||
| """Protocol for objects that can be evaluated as boolean.""" | ||
|
|
||
| def __bool__(self) -> bool: ... |
Contributor
There was a problem hiding this comment.
Pull request overview
Adds Megatron-native SFT (supervised fine-tuning) support to Primus, including stage-based trainer registration, dataset/formatter abstractions (offline + multi-turn), and runner hooks for HF→Megatron checkpoint conversion.
Changes:
- Introduce a stage-based trainer class registry and update backend adapters to load trainers by
(backend, stage). - Add Megatron-native SFT stack (trainer + runtime wiring + dataset/formatters/preprocessing/forward_step), plus PEFT (LoRA) utilities.
- Add runner hooks, example configs/scripts, and tests/docs for offline JSON/JSONL and OpenAI-messages multi-turn SFT.
Reviewed changes
Copilot reviewed 60 out of 60 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit_tests/backends/megatron/test_sft_dataset_offline.py | Unit tests for offline JSON/JSONL dataset loading |
| tests/unit_tests/backends/megatron/test_sft_abstractions.py | Tests for SFT schema/formatting/tokenization + forward_step behaviors |
| tests/unit_tests/backends/megatron/test_messages_format.py | Tests for OpenAI messages format + dataset integration |
| tests/unit_tests/backends/megatron/test_megatron_sft_trainer.py | Tests for MegatronSFTTrainer wiring/runtime behavior |
| tests/unit_tests/backends/megatron/test_megatron_registration.py | Update registration test for stage-based trainers |
| tests/unit_tests/backends/megatron/test_megatron_adapter.py | Adapter tests updated to support stage-based trainer loading |
| runner/helpers/hooks/train/posttrain/megatron/01_convert_checkpoints.py | HF→Megatron conversion hook + checkpoint layout normalization |
| runner/helpers/hooks/train/posttrain/megatron/00_install_requirements.sh | Install deps for conversion hook |
| run.sh | Convenience invocation snippet for SFT run |
| primus/core/backend/backend_registry.py | Add (backend, stage) trainer registry APIs |
| primus/core/backend/backend_adapter.py | Doc update for stage-based trainer lookup |
| primus/configs/modules/megatron/sft_trainer.yaml | Base Megatron SFT module config |
| primus/configs/models/megatron/llama3_8B.yaml | Switch tokenizer type for Llama3 Megatron config |
| primus/backends/torchtitan/torchtitan_adapter.py | Use BackendRegistry trainer lookup by stage |
| primus/backends/torchtitan/init.py | Register torchtitan pretrain trainer class |
| primus/backends/megatron_bridge/megatron_bridge_adapter.py | Use BackendRegistry trainer lookup by stage |
| primus/backends/megatron_bridge/init.py | Register megatron_bridge SFT trainer class |
| primus/backends/megatron/sft/schema.py | Normalized SFT sample/message schema |
| primus/backends/megatron/sft/runtime.py | Datasets provider + Megatron pretrain entrypoint compatibility |
| primus/backends/megatron/sft/preprocessing.py | Local record loading + tokenization + loss mask construction |
| primus/backends/megatron/sft/forward_step.py | SFT forward_step with masked loss + legacy/new Megatron support |
| primus/backends/megatron/sft/formatters.py | Alpaca/ChatML/OpenAI-messages formatters |
| primus/backends/megatron/sft/dataset.py | SFTDataset supporting HF hub + local files + formatter-driven masks |
| primus/backends/megatron/sft/init.py | Public exports for Megatron SFT package |
| primus/backends/megatron/peft/walk_utils.py | Module walking utilities (ported) for PEFT transforms |
| primus/backends/megatron/peft/recompute.py | Recompute-related patching for adapter-only training |
| primus/backends/megatron/peft/module_matcher.py | Module matching for LoRA targeting |
| primus/backends/megatron/peft/lora.py | LoRA implementation/patching |
| primus/backends/megatron/peft/import_utils.py | Safe import helpers for optional deps |
| primus/backends/megatron/peft/base.py | PEFT base class and application mechanics |
| primus/backends/megatron/peft/adapter_wrapper.py | Adapter wrapper helpers for state dict behaviors |
| primus/backends/megatron/peft/init.py | PEFT package exports |
| primus/backends/megatron/patches/checkpoint_patches.py | Patch for distributed checkpoint loading edge-case |
| primus/backends/megatron/megatron_sft_trainer.py | New MegatronSFTTrainer implementation |
| primus/backends/megatron/megatron_adapter.py | Load trainer via stage-based BackendRegistry lookup |
| primus/backends/megatron/core/datasets/sft_dataset.py | Compatibility shim re-exporting new SFT APIs |
| primus/backends/megatron/init.py | Register megatron pretrain + sft trainers |
| primus/backends/megatron/README_SFT.md | Megatron SFT user documentation |
| examples/megatron/convert_to_jsonl.py | Utility to convert datasets/CSV to JSONL |
| examples/megatron/configs/MI355X/llama3_8B-BF16-sft.yaml | Example single-turn SFT experiment config |
| examples/megatron/configs/MI355X/llama3_8B-BF16-multiturn-sft.yaml | Example multi-turn (OpenAI messages) SFT config |
| examples/megatron/configs/MI355X/llama3_8B-BF16-lora-sft.yaml | Example LoRA SFT config |
| docs/OFFLINE_DATASET_GUIDE.md | Offline dataset usage guide |
| docs/MULTI_TURN_CONVERSATIONS.md | Multi-turn messages format usage guide |
| docs/MEGATRON_BRIDGE_LOSS_PORT.md | Notes on loss function porting alignment |
| docs/CHECKPOINT_METADATA_FIX.md | Notes on checkpoint metadata/layout fixes |
| TOKENIZER_FIX.md | Tokenizer interface fix notes |
| REBASE_SUMMARY.md | Summary of stage-based registry rebase |
| POSITION_IDS_FIX.md | Position IDs fix notes |
| PARALLEL_STATE_FIX.md | Parallel state API fix notes |
| OFFLINE_DATASET_IMPLEMENTATION.md | Offline dataset implementation notes |
| MULTI_TURN_IMPLEMENTATION.md | Multi-turn implementation notes |
| LOSS_COMPUTATION_FIX.md | Loss computation fix notes |
| FORWARD_STEP_ENHANCEMENTS.md | Forward-step enhancements notes |
| FORWARD_STEP_DESIGN_DECISION.md | Design decision notes for forward_step implementation |
| FINAL_SUMMARY.md | Overall implementation summary |
Comment on lines
+79
to
+86
| raw_messages = sample.get("messages") | ||
| if isinstance(raw_messages, Sequence) and not isinstance(raw_messages, (str, bytes)): | ||
| messages = tuple( | ||
| SFTMessage.from_mapping(message) | ||
| for message in raw_messages | ||
| if isinstance(message, Mapping) | ||
| ) | ||
| return cls(messages=messages) |
Comment on lines
+123
to
+135
| loss_mask = np.zeros(len(token_ids), dtype=np.int64) | ||
| prefix_text = "" | ||
| prefix_token_count = 0 | ||
| for segment in formatted_sample.segments: | ||
| start = prefix_token_count | ||
| prefix_text += segment.text | ||
| prefix_token_count = len(tokenize_text(tokenizer, prefix_text)) | ||
| end = prefix_token_count | ||
|
|
||
| if segment.supervise and start < len(token_ids): | ||
| loss_mask[start:min(end, len(token_ids))] = 1 | ||
| if start >= len(token_ids): | ||
| break |
Comment on lines
+40
to
+83
| auto_detect_ckpt_format: true | ||
| ckpt_format: torch_dist | ||
| fully_parallel_save: true | ||
|
|
||
| # Learning rate (lower than pretraining for fine-tuning) | ||
| lr: 1.0e-5 | ||
| min_lr: 0.0 | ||
| lr_decay_iters: null | ||
| lr_warmup_iters: 50 | ||
| lr_decay_style: cosine | ||
|
|
||
| # Optimizer settings | ||
| weight_decay: 0.1 | ||
| adam_beta1: 0.9 | ||
| adam_beta2: 0.95 | ||
| override_opt_param_scheduler: true | ||
|
|
||
| # Attention and distributed settings | ||
| use_flash_attn: true | ||
| distributed_timeout_minutes: 60 | ||
| use_distributed_optimizer: true | ||
|
|
||
| # SFT-specific settings | ||
| # Note: These are typically overridden in the experiment config | ||
| sft_dataset_name: "tatsu-lab/alpaca" | ||
| sft_conversation_format: "alpaca" | ||
|
|
||
| # LoRA (Low-Rank Adaptation) settings | ||
| lora: | ||
| enabled: false # Set true for parameter-efficient fine-tuning | ||
| dim: 32 # Low-rank dimension (8-64 typical) | ||
| alpha: 32 # Scaling factor (usually = dim) | ||
| dropout: 0.0 # Dropout rate (0.0-0.1) | ||
| dropout_position: pre # "pre" or "post" | ||
| lora_A_init_method: xavier # "xavier" or "kaiming" | ||
| lora_B_init_method: zero # "zero" (standard) | ||
| target_modules: # Modules to apply LoRA | ||
| - linear_qkv | ||
| - linear_proj | ||
| - linear_fc1 | ||
| - linear_fc2 | ||
|
|
||
| ckpt_format: torch_dist | ||
| fully_parallel_save: true |
Comment on lines
+145
to
+149
| AutoBridge.import_ckpt( | ||
| hf_model_id=hf_path, | ||
| megatron_path=megatron_path, | ||
| trust_remote_code=True, | ||
| ) |
Comment on lines
+154
to
+163
| def wait_for_conversion(done_file: Path, lock_file: Path, timeout: int = 600): | ||
| """Wait for rank 0 to complete checkpoint conversion.""" | ||
| elapsed = 0 | ||
| while not done_file.exists() and elapsed < timeout: | ||
| if not lock_file.exists() and not done_file.exists(): | ||
| time.sleep(2) | ||
| else: | ||
| time.sleep(5) | ||
| elapsed += 5 | ||
|
|
Comment on lines
+64
to
+76
| """Applies a function to a PyTorch module or a collection of modules. | ||
|
|
||
| This function can be used to modify modules in place, such as changing their attributes, | ||
| applying normalization, or any other custom transformations. It supports individual modules, | ||
| lists of modules, and dictionaries of modules. The function can be applied selectively to | ||
| modules that do not have parameters if `leaf_only` is set to True. | ||
|
|
||
| Args: | ||
| module: The module or collection of modules to which the function will be applied. | ||
| func: A callable that takes a module (and optionally additional keyword arguments) and | ||
| returns a transformed module. The signature should be `func(module, **kwargs)`. | ||
| leaf_only: If True, the function will only be applied to modules that | ||
| do not have any parameters. Defaults to False. |
Comment on lines
+1
to
+4
|
|
||
|
|
||
|
|
||
| PRIMUS_TRAIN_RUNTIME=core ./primus-cli direct -- train posttrain --config ./examples/megatron/configs/MI355X/llama3_8B-BF16-sft.yaml No newline at end of file |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.