Conversation
339ae1c to
a855456
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a855456cd4
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| uint tid = simd_group * 32 + simd_lane; | ||
| if (tid < LORA_RANK) { | ||
| h_lora[tid] = static_cast<float>(h_input[tid]); | ||
| } |
There was a problem hiding this comment.
Offset LoRA intermediate by batch in Metal QmvFast
The fused LoRA path loads h_input into threadgroup memory without applying a batch offset, so every batch_idx > 0 reuses batch 0’s h vector. This produces incorrect LoRA deltas whenever fused QmvFast is used with batch_size > 1 (for example small prefill batches that stay on the matrix-vector path), causing wrong outputs for all nonzero batches.
Useful? React with 👍 / 👎.
| let (down_projection, down_input_hadamard_factors, _) = <dyn Linear<B>>::new_extracting_input_fusions( | ||
| &dense_config.linear_config, | ||
| false, | ||
| hidden_dimension, |
There was a problem hiding this comment.
Avoid fusing A_down for MLP down projection
This call switches the MLP down projection to new_extracting_input_fusions but discards the returned LoRA fusion payload, while new_extracting_input_fusions still enables rms_norm_fuses_a_down for RHT+QLoRA linears. In that mode QLoRALinearWrapper::encode skips computing adapter_down and reads state.common_aux.lora_intermediate instead, but down projection has no preceding RMSNorm fusion site to populate h from MlpHidden, so it can consume stale h and apply an incorrect LoRA update.
Useful? React with 👍 / 👎.
| pub h_buffer: Option<&'a B::Buffer>, | ||
| pub adapter_up: Option<&'a B::Buffer>, | ||
| pub lora_scale: f32, |
There was a problem hiding this comment.
Are these always used together/not used together? If so it should be another struct and one option of that structure. As much invariants (like h_buffer.is_some() == adapter_up.is_some(), lora_scale is not needed at not used when lora not enabled, etc) as possible should be expressed via type system (like via having one top option and inner struct).
| input_array_id: ArrayId, | ||
| output_array_id: ArrayId, | ||
| ) -> Result<(Box<dyn Linear<B>>, Option<B::Buffer>), LinearBlockError<B>> { | ||
| ) -> Result<(Box<dyn Linear<B>>, Option<B::Buffer>, Option<LoraFusion<B>>), LinearBlockError<B>> { |
There was a problem hiding this comment.
Big scary tuple, let's make extracted fusions a struct. And struct can have nice helpers for extracting fusions and erroring if something was silently dropped. We can also do the same with declaring fusion capabilities (a struct of bools of what extracted fusions we can handle)
| @@ -46,34 +63,13 @@ pub struct QLoRALinearWrapper<B: Backend> { | |||
| base_linear: QuantizedLinear<B>, | |||
| adapter_kernel: RefCell<<B::Kernels as ManualKernels>::MatmulKernel>, | |||
| adapter_down: B::Buffer, | |||
| adapter_up: B::Buffer, | |||
| input_dim: usize, | |||
| output_dim: usize, | |||
| lora_rank: usize, | |||
| lora_scale: f32, | |||
| input_array_id: ArrayId, | |||
| output_array_id: ArrayId, | |||
| } | |||
|
|
|||
| // TODO: figure out how to make this generic over QLoRAWrapperError::InvalidTensor or make one global "Invalid Tensor" error and make this a common helper | |||
| fn validate_tensor<'file, 'context, 'leaf, B: Backend>( | |||
| weights_leaf: &ParameterLeaf<'file, 'context, 'leaf, B::Context>, | |||
|
Noticed a potential issue with the sign/Hadamard ordering in The precomposition needs Since But Quick repro (JAX, same math):
Fix: apply Hadamard butterfly to A_down rows first, then multiply by signs -- instead of the current order... lmk if this is intended though |
Not intended at all, thanks for catching that! |
|
@CC-Yeh what's the status of this pr? We definitely want it after the review fixes |
Forgot to compute A_down for those layers can't be fused (not RMSNorm), the performance is the same as main after patching that, still trying to figure out a way to speed this up. Will spend 1-2 days more on this |
3.3% faster decode on LFM2.5-1.2B-RHT-QLoRA