From 03b7f58896541d02203cde0c418ec64f5c8a8211 Mon Sep 17 00:00:00 2001 From: ArthurLiu Date: Wed, 22 Apr 2026 01:50:50 -0500 Subject: [PATCH] Fix FMHA codegne group mode dispatch --- example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 12 +++++++++--- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 4 ++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 7105f1aa5c..5e7e2a2ffd 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -827,14 +827,20 @@ def dvcheck(self) -> str: @property def max_seq_q_cond(self) -> str: if self.tile.max_seq_q != 0: - return f" && (t.seqlen_q <= {self.tile.max_seq_q})" + if self.mode == "group": + return f" && (t.max_seqlen_q <= {self.tile.max_seq_q})" + else: + return f" && (t.seqlen_q <= {self.tile.max_seq_q})" else: return "" @property def extra_cond(self) -> str: if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128 and self.tile.F_bhdq == 128: - return " && (t.seqlen_k <= 256)" + if self.mode == "group": + return " && (t.max_seqlen_k <= 256)" + else: + return " && (t.seqlen_k <= 256)" else: return "" @@ -1057,7 +1063,7 @@ def get_bwd_blobs( hdim = tile.F_bhdq if (mode == "group") and (spad1d == "f"): continue - if (mode == "group" or ("no" not in mask)) and tile.max_seq_q != 0: + if ("no" not in mask) and tile.max_seq_q != 0: continue if (bias == "no" or bias == "alibi") and dbias == "t": continue diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index c64a19104e..6e16aa8eef 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -328,6 +328,8 @@ def seqtune(self, max_bm0: int) -> str: if self.bm0 == max_bm0 or self.bm0 == 64: return "true/*fall back to largest tile*/" else: + if self.mode == "group": + return f"a.max_seqlen_q <= {self.bm0}" return f"a.seqlen_q <= {self.bm0}" @property @@ -1136,6 +1138,8 @@ def get_pipelines( ): pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip # group mode spad + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip # group mode spad+dpad # # qr_async_trload_v3 bf16/fp16 not ready # if (hdim, hdim_v) == (128, 128):