Skip to content

Improve RHT + LoRA decode#360

Open
CC-Yeh wants to merge 4 commits intomainfrom
improve_QLoRA_rebased
Open

Improve RHT + LoRA decode#360
CC-Yeh wants to merge 4 commits intomainfrom
improve_QLoRA_rebased

Conversation

@CC-Yeh
Copy link
Copy Markdown
Contributor

@CC-Yeh CC-Yeh commented Apr 21, 2026

3.3% faster decode on LFM2.5-1.2B-RHT-QLoRA

  • Offline math trick: A_down' = A_down · H at load time
  • Fused A_down' into RMSNorm kernel (decode)
  • Fused A_up SG0-tail into QmvFast kernel (decode)
  • Per-rank dispatch (fused at r=16, unfused fallback otherwise)
  • One-line prefill recovery
  • LORA_RANK plumbed as kernel VARIANT
  • CPU LoRA reference + cross-backend tests
  • Shared adapter_up buffer (−22 MB)

@CC-Yeh CC-Yeh force-pushed the improve_QLoRA_rebased branch from 339ae1c to a855456 Compare April 21, 2026 16:08
@CC-Yeh CC-Yeh requested review from eugenebokhan and uuuvn April 21, 2026 16:09
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a855456cd4

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +80 to +83
uint tid = simd_group * 32 + simd_lane;
if (tid < LORA_RANK) {
h_lora[tid] = static_cast<float>(h_input[tid]);
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Offset LoRA intermediate by batch in Metal QmvFast

The fused LoRA path loads h_input into threadgroup memory without applying a batch offset, so every batch_idx > 0 reuses batch 0’s h vector. This produces incorrect LoRA deltas whenever fused QmvFast is used with batch_size > 1 (for example small prefill batches that stay on the matrix-vector path), causing wrong outputs for all nonzero batches.

Useful? React with 👍 / 👎.

Comment on lines +67 to 70
let (down_projection, down_input_hadamard_factors, _) = <dyn Linear<B>>::new_extracting_input_fusions(
&dense_config.linear_config,
false,
hidden_dimension,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid fusing A_down for MLP down projection

This call switches the MLP down projection to new_extracting_input_fusions but discards the returned LoRA fusion payload, while new_extracting_input_fusions still enables rms_norm_fuses_a_down for RHT+QLoRA linears. In that mode QLoRALinearWrapper::encode skips computing adapter_down and reads state.common_aux.lora_intermediate instead, but down projection has no preceding RMSNorm fusion site to populate h from MlpHidden, so it can consume stale h and apply an incorrect LoRA update.

Useful? React with 👍 / 👎.

Comment on lines +49 to +51
pub h_buffer: Option<&'a B::Buffer>,
pub adapter_up: Option<&'a B::Buffer>,
pub lora_scale: f32,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these always used together/not used together? If so it should be another struct and one option of that structure. As much invariants (like h_buffer.is_some() == adapter_up.is_some(), lora_scale is not needed at not used when lora not enabled, etc) as possible should be expressed via type system (like via having one top option and inner struct).

input_array_id: ArrayId,
output_array_id: ArrayId,
) -> Result<(Box<dyn Linear<B>>, Option<B::Buffer>), LinearBlockError<B>> {
) -> Result<(Box<dyn Linear<B>>, Option<B::Buffer>, Option<LoraFusion<B>>), LinearBlockError<B>> {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big scary tuple, let's make extracted fusions a struct. And struct can have nice helpers for extracting fusions and erroring if something was silently dropped. We can also do the same with declaring fusion capabilities (a struct of bools of what extracted fusions we can handle)

Comment on lines 27 to -60
@@ -46,34 +63,13 @@ pub struct QLoRALinearWrapper<B: Backend> {
base_linear: QuantizedLinear<B>,
adapter_kernel: RefCell<<B::Kernels as ManualKernels>::MatmulKernel>,
adapter_down: B::Buffer,
adapter_up: B::Buffer,
input_dim: usize,
output_dim: usize,
lora_rank: usize,
lora_scale: f32,
input_array_id: ArrayId,
output_array_id: ArrayId,
}

// TODO: figure out how to make this generic over QLoRAWrapperError::InvalidTensor or make one global "Invalid Tensor" error and make this a common helper
fn validate_tensor<'file, 'context, 'leaf, B: Backend>(
weights_leaf: &ParameterLeaf<'file, 'context, 'leaf, B::Context>,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The move ate the todo

@ry2009
Copy link
Copy Markdown

ry2009 commented Apr 22, 2026

Noticed a potential issue with the sign/Hadamard ordering in compose_rotated_adapter_down.

The precomposition needs A_down' @ x == A_down @ H_rht(x), which requires applying H_rht^T to the rows of A_down, not H_rht.

Since H_rht = H @ diag(s) (signs first, then butterfly -- matching simdgroup_random_hadamard_transform), the transpose is H_rht^T = diag(s) @ H (butterfly first, then signs).

But compose_rotated_adapter_down calls hadamard_kernel.encode() which applies the standard H @ diag(s) (signs first)... this gives the wrong result.

Quick repro (JAX, same math):

  • Apply H @ diag(s) to rows -> max error vs ground truth: 112.9
  • Apply diag(s) @ H to rows -> max error vs ground truth: 0.033 (f32 noise)

Fix: apply Hadamard butterfly to A_down rows first, then multiply by signs -- instead of the current order... lmk if this is intended though

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Apr 22, 2026

Noticed a potential issue with the sign/Hadamard ordering in compose_rotated_adapter_down.

The precomposition needs A_down' @ x == A_down @ H_rht(x), which requires applying H_rht^T to the rows of A_down, not H_rht.

Since H_rht = H @ diag(s) (signs first, then butterfly -- matching simdgroup_random_hadamard_transform), the transpose is H_rht^T = diag(s) @ H (butterfly first, then signs).

But compose_rotated_adapter_down calls hadamard_kernel.encode() which applies the standard H @ diag(s) (signs first)... this gives the wrong result.

Quick repro (JAX, same math):

  • Apply H @ diag(s) to rows -> max error vs ground truth: 112.9
  • Apply diag(s) @ H to rows -> max error vs ground truth: 0.033 (f32 noise)

Fix: apply Hadamard butterfly to A_down rows first, then multiply by signs -- instead of the current order... lmk if this is intended though

Not intended at all, thanks for catching that!

@uuuvn
Copy link
Copy Markdown
Contributor

uuuvn commented May 8, 2026

@CC-Yeh what's the status of this pr? We definitely want it after the review fixes

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented May 8, 2026

@CC-Yeh what's the status of this pr? We definitely want it after the review fixes

Forgot to compute A_down for those layers can't be fused (not RMSNorm), the performance is the same as main after patching that, still trying to figure out a way to speed this up. Will spend 1-2 days more on this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants