Skip to content

Conversation

@HAOCHENYE
Copy link
Collaborator

  • Extract model input preparation logic into _prepare_model_input method
  • Move loss_log update logic from trainer to train_engine
  • Simplify _log_step method signature by using instance variables
  • Fix type hints: consumed_tokens and consumed_img_tokens should be int
  • Adjust consumed_samples calculation position for better logic flow

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the trainer fit loop to improve code organization by extracting model input preparation logic, relocating loss_log update logic, simplifying method signatures, and fixing type hints.

  • Extracted model input preparation into a dedicated _prepare_model_input method for better code modularity
  • Moved loss_log update logic from trainer to train_engine for better separation of concerns
  • Simplified _log_step method signature by using instance variables instead of passing them as parameters
  • Fixed type hints for consumed_tokens and consumed_img_tokens from float to int with appropriate conversions

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
xtuner/v1/train/trainer.py Refactored fit loop by extracting _prepare_model_input method, removed loss_log update logic (moved to engine), simplified _log_step signature, adjusted consumed_samples calculation timing, updated _reduce_number_across_rank type hints, and removed unused ModelForwardExtraLogInfo import
xtuner/v1/engine/train_engine.py Updated type hints for consumed_tokens and consumed_img_tokens to int, added loss_log update logic (moved from trainer), and added int conversion for consumed_tokens
xtuner/v1/engine/vision_compose_train_engine.py Added int conversions for consumed_tokens and consumed_img_tokens to match updated type hints

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

other_log["extra_info"] = train_engine_extra_info # type: ignore[assignment]
other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item()
other_log["consumed_img_tokens"] = step_consumed_img_tokens
other_log["consumed_img_tokens"] = int(step_consumed_img_tokens)
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable step_consumed_img_tokens is initialized as a float (0.0) on line 148 and may contain a fractional value after division on line 163. Converting to int here will truncate any fractional part. Consider using 0 instead of 0.0 on line 148 and using integer division (//) on line 163 if integer values are required, or document that truncation is intentional.

Copilot uses AI. Check for mistakes.
- Extract model input preparation logic into _prepare_model_input method
- Move loss_log update logic from trainer to train_engine
- Simplify _log_step method signature by using instance variables
- Fix type hints: consumed_tokens and consumed_img_tokens should be int
- Adjust consumed_samples calculation position for better logic flow
@HAOCHENYE HAOCHENYE force-pushed the yehc/beautify-trainer-fit branch from 1ef1c72 to 4f6412f Compare December 24, 2025 08:37
else:
extra_info_updated = ModelForwardExtraLogInfo(extra_info)
extra_info_dict = extra_info_updated.get()
loss_log.update(extra_info_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不更新extra_info的话,sft/pretrain应该就不打印了每张卡的loss了,这个是符合预期的不

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the logic has been moved to 'TrainEngine', and 'Trainer' should not be aware of this part of the logic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Collaborator

@YanhuiDua YanhuiDua left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HAOCHENYE
Copy link
Collaborator Author

@gemini-code-assist

@HAOCHENYE
Copy link
Collaborator Author

/gemini review

Copy link
Collaborator

@jayhenry jayhenry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

consumed_tokens: float
consumed_img_tokens: NotRequired[float]
consumed_tokens: int
consumed_img_tokens: NotRequired[int]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在下面的PR已经rename为step_consumed_tokens,在rebase时需要注意下:
参考 rename的PR
统计变量前缀规则是:

  • 空间上(dp rank还是reduce求和),rank的用 local_,默认reduced无前缀。
  • 时间上(step还是累积),用step_和total_。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants