Skip to content

Fix sharding rank mismatch for LoRA weights and optimizer states in grpo_gemma.ipynb notebook#1180

Open
rajasekharporeddy wants to merge 1 commit intogoogle:mainfrom
rajasekharporeddy:grpo_gemma
Open

Fix sharding rank mismatch for LoRA weights and optimizer states in grpo_gemma.ipynb notebook#1180
rajasekharporeddy wants to merge 1 commit intogoogle:mainfrom
rajasekharporeddy:grpo_gemma

Conversation

@rajasekharporeddy
Copy link
Collaborator

Issue:
Applying 3D PartitionSpec annotations (inherited from the base model's scanned layers) to 2D LoRA weights and their optimizer states causes a ValueError during GRPO training.

Fix:

  • Added a dimensionality check (len(spec) <= x.ndim) in get_lora_model to safely fall back invalid specs to None for 2D LoRA weights.
  • Applied the same check to the optimizer states by patching _shard_optimizer in the PeftTrainer.

Checklist

  • I have verified that my change does not break existing code and all unit tests pass.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

@gemini-code-assist
Copy link

Note

Gemini is unable to generate a summary for this pull request due to the file types involved not being currently supported.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant