Describe the bug
During mixed-precision training (BF16 & FP32, FP8 is optional), a RuntimeError (size mismatch) is triggered when saving a checkpoint via get_parameter_state_dp_zero.
Based on our analysis, the hardcoded parameter reordering in _build_model_and_main_param_groups within DistributedOptimizer causes the self.model_param_group_index_map to become out of sync with the actual optimizer.param_groups.
Steps/Code to reproduce bug
- Before building the DDP model and optimizer,use
Float16Module to wrap the model for BF16 training, but manually promote certain modules (both params and inputs) to FP32.
- Train for several steps.
- Call
get_parameter_state_dp_zero in DistributedOptimizer to collect optimizer states, which triggers the size mismatch error.
Root Cause Analysis
-
Initial Map Construction: In __init__, self.model_param_group_index_map is first constructed via _build_optimizer_group_ranges. This map records the position (group_index, group_order) of param in param_groups
-
Hardcoded Reordering: Subsequently, _build_model_and_main_param_groups reorders the parameters within each group (placing native FP32 shards at the front and main parameter shards converted from FP16/BF16 at the back) and updates optimizer.param_groups accordingly:
- Index Invalidation: The
model_param_group_index_map is not updated after this reordering. Consequently, downstream functions like _get_main_param_and_optimizer_states retrieve the wrong Tensors using stale group_order indices, leading to shape mismatches during buffer copy operations.
Additional question
- What is the design motivation behind this specific reordering (grouping by DType)?
- What is the recommended best practice to fix this: disabling the reordering to maintain discovery order consistency, or explicitly updating the
model_param_group_index_map after the reordering is performed?
Describe the bug
During mixed-precision training (BF16 & FP32, FP8 is optional), a RuntimeError (size mismatch) is triggered when saving a checkpoint via
get_parameter_state_dp_zero.Based on our analysis, the hardcoded parameter reordering in
_build_model_and_main_param_groupswithinDistributedOptimizercauses theself.model_param_group_index_mapto become out of sync with the actualoptimizer.param_groups.Steps/Code to reproduce bug
Float16Moduleto wrap the model for BF16 training, but manually promote certain modules (both params and inputs) to FP32.get_parameter_state_dp_zeroinDistributedOptimizerto collect optimizer states, which triggers the size mismatch error.Root Cause Analysis
Initial Map Construction: In
__init__,self.model_param_group_index_mapis first constructed via_build_optimizer_group_ranges. This map records the position(group_index, group_order)of param in param_groupsHardcoded Reordering: Subsequently,
_build_model_and_main_param_groupsreorders the parameters within each group (placing native FP32 shards at the front and main parameter shards converted from FP16/BF16 at the back) and updatesoptimizer.param_groupsaccordingly:model_param_group_index_mapis not updated after this reordering. Consequently, downstream functions like_get_main_param_and_optimizer_statesretrieve the wrong Tensors using stale group_order indices, leading to shape mismatches during buffer copy operations.Additional question
model_param_group_index_mapafter the reordering is performed?