-
Notifications
You must be signed in to change notification settings - Fork 129
add mlu rms_norm; optimize fa2 call;rename platform to ; #698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @Wq-dd, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing the performance and clarity of MLU (Cambricon MLU) operations within the system. It achieves this by integrating a dedicated MLU-optimized RMS normalization, fine-tuning Flash Attention 2 calls for better efficiency, and standardizing the platform's naming convention. These changes collectively contribute to a more optimized and maintainable codebase for MLU-accelerated workloads. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for MLU RMS norm, optimizes FlashAttention calls for MLU, and renames the mlu platform to cambricon_mlu. The changes are generally moving in the right direction. However, I've identified a few areas for improvement. There's a significant code duplication issue with the new norm_template.py file, which should be addressed by reusing the existing common template. Additionally, I've pointed out several instances of duplicated logic that can be refactored to enhance code clarity and maintainability. Please see the detailed comments for specific suggestions.
| return DTYPE_MAP[RUNNING_FLAG] | ||
|
|
||
|
|
||
| class RMSWeightTemplate(metaclass=ABCMeta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This RMSWeightTemplate class appears to be a duplicate of the one defined in lightx2v/common/ops/norm/rms_norm_weight.py. To improve maintainability and avoid code duplication, this new file should be removed. The MluRmsNormWeight class in lightx2v_platform/ops/norm/cambricon_mlu/mlu_rms_norm.py should instead import the template from the common location, for example: from lightx2v.common.ops.norm.rms_norm_weight import RMSWeightTemplate.
| if self.config.get("sf_config", False): | ||
| self.attn_rms_type = "self_forcing" | ||
| self.attn_rms_type = self.config.get("rms_type", "self_forcing") | ||
| else: | ||
| self.attn_rms_type = "sgl-kernel" | ||
| self.attn_rms_type = self.config.get("rms_type", "sgl-kernel") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for setting self.attn_rms_type is duplicated across two branches of the if/else statement. This can be simplified to improve readability and reduce redundancy.
default_rms_type = "self_forcing" if self.config.get("sf_config", False) else "sgl-kernel"
self.attn_rms_type = self.config.get("rms_type", default_rms_type)| if self.config.get("sf_config", False): | ||
| self.attn_rms_type = "self_forcing" | ||
| self.attn_rms_type = self.config.get("rms_type", "self_forcing") | ||
| else: | ||
| self.attn_rms_type = "sgl-kernel" | ||
| self.attn_rms_type = self.config.get("rms_type", "sgl-kernel") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for setting self.attn_rms_type is duplicated in both the if and else blocks. This can be refactored to be more concise and easier to maintain by determining the default value first and then calling self.config.get once.
default_rms_type = "self_forcing" if self.config.get("sf_config", False) else "sgl-kernel"
self.attn_rms_type = self.config.get("rms_type", default_rms_type)| if self.config.get("sf_config", False): | ||
| self.attn_rms_type = "self_forcing" | ||
| self.attn_rms_type = self.config.get("rms_type", "self_forcing") | ||
| else: | ||
| self.attn_rms_type = "sgl-kernel" | ||
| self.attn_rms_type = self.config.get("rms_type", "sgl-kernel") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if/else block contains duplicated logic for getting the rms_type from the configuration. You can simplify this by defining the default value based on sf_config and then calling config.get only once. This will make the code cleaner and more maintainable.
default_rms_type = "self_forcing" if self.config.get("sf_config", False) else "sgl-kernel"
self.attn_rms_type = self.config.get("rms_type", default_rms_type)| def apply(self, input_tensor): | ||
| if self.sensitive_layer_dtype != self.infer_dtype: | ||
| output = tmo.fused_rms_norm( | ||
| input_tensor.float(), | ||
| residual=None, | ||
| gamma=self.weight, | ||
| beta=None, | ||
| bias=None, | ||
| eps=self.eps, | ||
| store_output_before_norm=False, | ||
| quant_scale=None, | ||
| out=None, | ||
| ).to(self.infer_dtype) | ||
| else: | ||
| output = tmo.fused_rms_norm( | ||
| input_tensor, | ||
| residual=None, | ||
| gamma=self.weight, | ||
| beta=None, | ||
| bias=None, | ||
| eps=self.eps, | ||
| store_output_before_norm=False, | ||
| quant_scale=None, | ||
| out=None, | ||
| ) | ||
| return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The apply method has significant code duplication between the if and else blocks. The call to tmo.fused_rms_norm is almost identical in both cases. This can be refactored to remove redundancy, making the code cleaner and easier to maintain.
def apply(self, input_tensor):
input_data = input_tensor
if self.sensitive_layer_dtype != self.infer_dtype:
input_data = input_tensor.float()
output = tmo.fused_rms_norm(
input_data,
residual=None,
gamma=self.weight,
beta=None,
bias=None,
eps=self.eps,
store_output_before_norm=False,
quant_scale=None,
out=None,
)
if self.sensitive_layer_dtype != self.infer_dtype:
output = output.to(self.infer_dtype)
return output
No description provided.