diff --git a/tests/pytorch/test_mhc.py b/tests/pytorch/test_mhc.py index 541ce9a8c2..1ec69255ee 100644 --- a/tests/pytorch/test_mhc.py +++ b/tests/pytorch/test_mhc.py @@ -14,13 +14,14 @@ mhc_fused_aggregate, mhc_fused_expand_combine, mhc_fused_projection, + mhc_generate_mix_and_aggregate, ) # Disable TF32 for matmul to ensure consistency between the fused and reference implementations torch.backends.cuda.matmul.allow_tf32 = False -def mhc_projection_ref(x, phi): +def mhc_projection_ref(x, phi, norm_weight): """ Reference operator for mHC's projection building operation. @@ -29,19 +30,20 @@ def mhc_projection_ref(x, phi): - phi_pre: (n, nC) - phi_post: (n, nC) - phi_res: (n^2, nC) + norm_weight: (nC,) or None, if not None, apply element-wise multiplication to phi before projection n: number of Hyper Connection streams C: hidden dimension per stream """ - x_dtype = x.dtype - x = x.to(torch.float32) - phi = phi.to(torch.float32) - - Hs = x @ phi.T # (M, 2n + n^2) - x_fp32 = x.to(torch.float32) # Use fp32 for better numerical stability in variance calculation + x_fp32 = x.to(torch.float32) ms = (x_fp32 * x_fp32).mean(dim=1) - return Hs.to(x_dtype), ms + phi_fp32 = phi.to(torch.float32) + if norm_weight is not None: + phi_fp32 = phi_fp32 * norm_weight.to(torch.float32)[None, :] + Hs = x_fp32 @ phi_fp32.T # (M, 2n + n^2) + + return Hs, ms def mhc_scale_ref(H, alpha, beta, ms, n): @@ -139,9 +141,9 @@ def mhc_aggregate_ref(x, H_pre, n): s, b, C, n = x.shape H_pre = H_pre.view(s, b, n, 1) - out = (x @ H_pre).view(s, b, C) + out = (x.to(H_pre.dtype) @ H_pre).view(s, b, C) - return out + return out.to(x.dtype) def mhc_expand_combine_ref(f, bias, H_post, x, H_res, n): @@ -267,25 +269,44 @@ def get_tols(dtype): @pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) -@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) -def test_mhc_projection(cfg: MHCConfig, dtype): +@pytest.mark.parametrize( + "dtypes", + [ + (torch.float32, torch.float32), + (torch.bfloat16, torch.bfloat16), + (torch.bfloat16, torch.float32), + ], + ids=["x_fp32_phi_fp32", "x_bf16_phi_bf16", "x_bf16_phi_fp32"], +) +@pytest.mark.parametrize("has_norm_weight", [False, True], ids=["no_norm_weight", "norm_weight"]) +def test_mhc_projection(cfg: MHCConfig, dtypes, has_norm_weight): reset_rng_states() s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n nC = n * C N = 2 * n + n * n - tols = get_tols(dtype) + x_dtype = dtypes[0] + phi_dtype = dtypes[1] + tols = get_tols(x_dtype) use_tf32 = False - x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype) - phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda") - + x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=x_dtype) + phi = torch.randn(N, nC, dtype=phi_dtype, requires_grad=True, device="cuda") x_ref = x.detach().clone().requires_grad_(True) phi_ref = phi.detach().clone().requires_grad_(True) - ref_out_Hs, ref_out_ms = mhc_projection_ref(x_ref, phi_ref) - fused_out_Hs_padded, fused_out_ms = mhc_fused_projection(x, phi, use_tf32) + if has_norm_weight: + norm_weight = torch.randn(nC, device="cuda", requires_grad=True, dtype=x_dtype) + norm_weight_ref = norm_weight.detach().clone().requires_grad_(True) + else: + norm_weight = None + norm_weight_ref = None + + ref_out_Hs, ref_out_ms = mhc_projection_ref(x_ref, phi_ref, norm_weight_ref) + fused_out_Hs_padded, fused_out_ms = mhc_fused_projection( + x, phi, norm_weight=norm_weight, use_tf32=use_tf32 + ) fused_out_Hs = fused_out_Hs_padded[:, :N] torch.testing.assert_close(fused_out_Hs, ref_out_Hs, **tols) @@ -295,10 +316,12 @@ def test_mhc_projection(cfg: MHCConfig, dtype): torch.testing.assert_close(x.grad, x_ref.grad, **tols) torch.testing.assert_close(phi.grad, phi_ref.grad, **tols) + if has_norm_weight: + torch.testing.assert_close(norm_weight.grad, norm_weight_ref.grad, **tols) @pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) -@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) def test_mhc_scale(cfg: MHCConfig, dtype): reset_rng_states() @@ -329,28 +352,39 @@ def test_mhc_scale(cfg: MHCConfig, dtype): torch.cat([fused_out[i] for i in range(3)], dim=-1).sum().backward() torch.testing.assert_close(H_padded.grad[:, :N], H_ref.grad, **tols) + torch.testing.assert_close(ms.grad, ms_ref.grad, **tols) torch.testing.assert_close(alpha.grad, alpha_ref.grad, **tols) torch.testing.assert_close(beta.grad, beta_ref.grad, **tols) - torch.testing.assert_close(ms.grad, ms_ref.grad, **tols) @pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) -@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) -def test_mhc_combined(cfg: MHCConfig, dtype): +@pytest.mark.parametrize( + "dtypes", + [ + (torch.float32, torch.float32), + (torch.bfloat16, torch.bfloat16), + (torch.bfloat16, torch.float32), + ], + ids=["x_fp32_phi_fp32", "x_bf16_phi_bf16", "x_bf16_phi_fp32"], +) +@pytest.mark.parametrize("has_norm_weight", [False, True], ids=["no_norm_weight", "norm_weight"]) +def test_mhc_rmsnorm(cfg: MHCConfig, dtypes, has_norm_weight): + # Verify if the fused kernel is equivalent to applying RMSNorm in the normal order reset_rng_states() s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n N = 2 * n + n * n nC = n * C - tols = get_tols(dtype) + x_dtype = dtypes[0] + phi_dtype = dtypes[1] + tols = get_tols(x_dtype) use_tf32 = False - x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype) - phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda") - - alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype) - beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype) + x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=x_dtype) + phi = torch.randn(N, nC, dtype=phi_dtype, requires_grad=True, device="cuda") + alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=phi_dtype) + beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=phi_dtype) x_ref = x.detach().clone().requires_grad_(True) phi_ref = phi.detach().clone().requires_grad_(True) @@ -358,8 +392,17 @@ def test_mhc_combined(cfg: MHCConfig, dtype): alpha_ref = alpha.detach().clone().requires_grad_(True) beta_ref = beta.detach().clone().requires_grad_(True) - ref_out_H, ref_out_r = mhc_projection_ref(x_ref, phi_ref) - fused_out_H_padded, fused_out_r = mhc_fused_projection(x, phi, use_tf32) + if has_norm_weight: + norm_weight = torch.randn(nC, device="cuda", requires_grad=True, dtype=x_dtype) + norm_weight_ref = norm_weight.detach().clone().requires_grad_(True) + else: + norm_weight = None + norm_weight_ref = None + + ref_out_H, ref_out_r = mhc_projection_ref(x_ref, phi_ref, norm_weight_ref) + fused_out_H_padded, fused_out_r = mhc_fused_projection( + x, phi, norm_weight=norm_weight, use_tf32=use_tf32 + ) ref_H_pre, ref_H_post, ref_H_res = mhc_scale_ref( ref_out_H[:, :N], alpha_ref, beta_ref, ref_out_r, n @@ -368,17 +411,19 @@ def test_mhc_combined(cfg: MHCConfig, dtype): fused_out_H_padded, alpha, beta, fused_out_r, n ) - def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref): - dtype = x_ref.dtype - x_ref = x_ref.to(torch.float32) - phi_ref = phi_ref.to(torch.float32) - alpha_ref = alpha_ref.to(torch.float32) - beta_ref = beta_ref.to(torch.float32) - + def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref, norm_weight_ref): # Check if after spliting RMSNorm to two steps in projection and scaling, - # theresult is close to applying RMSNorm in the correct order - x_rmsnorm = F.rms_norm(x_ref, normalized_shape=(nC,)) - H = x_rmsnorm @ phi_ref.T + # the result is close to applying RMSNorm in the correct order. + # Run RMSNorm in fp32 so the bf16 case has the same precision pattern as the + # kernel/ref (F.rms_norm on bf16 input would round x_rmsnorm back to bf16). + eps = torch.finfo(torch.float32).eps + norm_weight_fp32 = ( + norm_weight_ref.to(torch.float32) if norm_weight_ref is not None else None + ) + x_rmsnorm = F.rms_norm( + x_ref.to(torch.float32), normalized_shape=(nC,), weight=norm_weight_fp32, eps=eps + ) + H = x_rmsnorm @ phi_ref.T.to(torch.float32) H_pre = H[:, :n] H_post = H[:, n : 2 * n] H_res = H[:, 2 * n :] @@ -391,25 +436,86 @@ def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref): out_post = 2 * out_post.sigmoid() out_res = out_res - return out_pre.to(dtype), out_post.to(dtype), out_res.to(dtype) + return out_pre, out_post, out_res # Return in FP32 to match the kernel's behavior combined_H_pre, combined_H_post, combined_H_res = mhc_combined( - x_ref, phi_ref, alpha_ref, beta_ref + x_ref, phi_ref, alpha_ref, beta_ref, norm_weight_ref ) torch.testing.assert_close(combined_H_pre, ref_H_pre, **tols) torch.testing.assert_close(combined_H_post, ref_H_post, **tols) torch.testing.assert_close(combined_H_res, ref_H_res, **tols) + torch.testing.assert_close(ref_H_pre, fused_H_pre, **tols) + torch.testing.assert_close(ref_H_post, fused_H_post, **tols) + torch.testing.assert_close(ref_H_res, fused_H_res, **tols) + torch.testing.assert_close(combined_H_pre, fused_H_pre, **tols) torch.testing.assert_close(combined_H_post, fused_H_post, **tols) torch.testing.assert_close(combined_H_res, fused_H_res, **tols) +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"]) +def test_mhc_fuse_grad_acc(cfg: MHCConfig, dtype): + # Skip bf16 tests since in the unfused path the we accumulate 3 bf16 gradients, whereas in the fused path + # we accumulate 3 fp32 gradients and then cast to bf16 in the end, which causes two paths to have different precision patterns + + reset_rng_states() + + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + N = 2 * n + n * n + nC = n * C + + tols = get_tols(dtype) + use_tf32 = False + + x = torch.randn(s, b, C, n, device="cuda", requires_grad=True, dtype=dtype) + phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda") + + alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype) + beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype) + x_ref = x.detach().clone().requires_grad_(True) + phi_ref = phi.detach().clone().requires_grad_(True) + + alpha_ref = alpha.detach().clone().requires_grad_(True) + beta_ref = beta.detach().clone().requires_grad_(True) + + def end_to_end(x, phi, alpha, beta, fused_grad_x_acc): + fused_grad_x_acc_buffer = None + if fused_grad_x_acc: + fused_grad_x_acc_buffer = torch.empty_like(x, dtype=torch.float32) + aggregated, H_post, H_res = mhc_generate_mix_and_aggregate( + x, phi, alpha, beta, None, use_tf32, fused_grad_x_acc_buffer + ) + expanded_combined = mhc_fused_expand_combine( + aggregated, + None, + H_post, + x, + H_res, + n, + False, + fused_grad_x_acc_buffer, + ) + + return expanded_combined + + expanded_combined_fuse_grad = end_to_end( + x_ref, phi_ref, alpha_ref, beta_ref, fused_grad_x_acc=True + ) + expanded_combined_no_fuse_grad = end_to_end(x, phi, alpha, beta, fused_grad_x_acc=False) + + grad_output = torch.randn_like(expanded_combined_fuse_grad) + expanded_combined_fuse_grad.backward(grad_output) + expanded_combined_no_fuse_grad.backward(grad_output) + + torch.testing.assert_close(x.grad, x_ref.grad, **tols) + + @pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) -@pytest.mark.parametrize("recompute", [False, True], ids=["no_recompute", "recompute"]) -def test_mhc_sinkhorn(cfg: MHCConfig, dtype, recompute): +def test_mhc_sinkhorn(cfg: MHCConfig, dtype): reset_rng_states() s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n @@ -420,7 +526,7 @@ def test_mhc_sinkhorn(cfg: MHCConfig, dtype, recompute): x_ref = x.detach().clone().requires_grad_(True) ref_out = mhc_sinkhorn_ref(x_ref, n) - fused_out = mhc_fused_sinkhorn(x, n, recompute) + fused_out = mhc_fused_sinkhorn(x, n) torch.testing.assert_close(fused_out, ref_out, **tols) @@ -446,7 +552,7 @@ def test_mhc_aggregate(cfg: MHCConfig, dtype): H_pre_ref = H_pre.detach().clone().requires_grad_(True) ref_out = mhc_aggregate_ref(x_ref, H_pre_ref, n) - fused_out = mhc_fused_aggregate(x, H_pre, n, False) + fused_out = mhc_fused_aggregate(x, H_pre, n, use_tf32=False) torch.testing.assert_close(fused_out, ref_out, **tols) @@ -482,7 +588,7 @@ def test_mhc_expand_combine(cfg: MHCConfig, dtype, with_bias): H_res_ref = H_res.detach().clone().requires_grad_(True) ref_out = mhc_expand_combine_ref(f_ref, bias_ref, H_post_ref, x_ref, H_res_ref, n) - fused_out = mhc_fused_expand_combine(f, bias, H_post, x, H_res, n, False) + fused_out = mhc_fused_expand_combine(f, bias, H_post, x, H_res, n=n, use_tf32=False) torch.testing.assert_close(fused_out, ref_out, **tols) diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py index 965bb437ff..65023a4349 100644 --- a/transformer_engine/common/triton/mhc.py +++ b/transformer_engine/common/triton/mhc.py @@ -12,6 +12,8 @@ import triton import triton.language as tl +MAX_GRID_DIM_Y = 65535 # Maximum grid dimension in Y direction for current CUDA architectures + def projection_config_fwd(): block_m = [64, 128] @@ -34,23 +36,11 @@ def projection_config_fwd(): return configs -def projection_config_bwd(): - block_m = [32, 128] - block_k = [128] - warps = [2] - stages = [2, 3, 4] - - configs = [] - for m, bk, w, s in itertools.product(block_m, block_k, warps, stages): - configs.append( - triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_K": bk}, num_warps=w, num_stages=s) - ) - if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": - configs = configs[:1] - return configs - - -@triton.autotune(configs=projection_config_fwd(), key=["M", "K"], reset_to_zero=["h_ptr", "ms_ptr"]) +@triton.autotune( + configs=projection_config_fwd(), + key=["M", "K", "USE_TMA"], + reset_to_zero=["h_ptr", "ms_ptr"], +) @triton.jit def _mhc_projection_fwd_fused( x_ptr, # (M, K) @@ -67,12 +57,14 @@ def _mhc_projection_fwd_fused( stride_hm: tl.constexpr, stride_hn: tl.constexpr, stride_ms: tl.constexpr, + stride_norm_weight: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, STEP_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, precision: tl.constexpr, + USE_TMA: tl.constexpr, # If True, load x and phi via TMA tensor descriptors (Hopper+ only). Falls back to pointer-arith tl.load otherwise. ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -86,8 +78,9 @@ def _mhc_projection_fwd_fused( tl.assume(stride_hm == 32) tl.assume(stride_hn == 1) tl.assume(stride_ms == 1) + tl.assume(stride_norm_weight == 1) - tl.assume(BLOCK_SIZE_M % 32 == 0) + tl.assume(BLOCK_SIZE_M % 8 == 0) tl.assume(BLOCK_SIZE_K % 32 == 0) tl.assume(BLOCK_SIZE_N == 32) @@ -98,24 +91,53 @@ def _mhc_projection_fwd_fused( h_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) ms_acc = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + if USE_TMA: + x_desc = tl.make_tensor_descriptor( + x_ptr, + shape=[M, K], + strides=[stride_xm, 1], + block_shape=[BLOCK_SIZE_M, STEP_SIZE_K], + ) + phi_desc = tl.make_tensor_descriptor( + phi_ptr, + shape=[N, K], + strides=[stride_phin, 1], + block_shape=[BLOCK_SIZE_N, STEP_SIZE_K], + ) + k_base = pid_k * BLOCK_SIZE_K for k_start in range(0, tl.cdiv(BLOCK_SIZE_K, STEP_SIZE_K)): - k_offs = k_base + k_start * STEP_SIZE_K + tl.arange(0, STEP_SIZE_K) + k_off = k_base + k_start * STEP_SIZE_K + k_offs = k_off + tl.arange(0, STEP_SIZE_K) mask_k = k_offs < K - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - phi_ptrs = phi_ptr + offs_n_full[:, None] * stride_phin + k_offs[None, :] * stride_phik - phi = tl.load( - phi_ptrs, - mask=(offs_n_full[:, None] < N) & mask_k[None, :], - other=0.0, - cache_modifier=".ca", - ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) - ms_acc += tl.sum(x * x, axis=1) + + if USE_TMA: + x = tl.load_tensor_descriptor(x_desc, [pid_m * BLOCK_SIZE_M, k_off]) + phi = tl.load_tensor_descriptor(phi_desc, [0, k_off]) + else: + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + phi_ptrs = phi_ptr + offs_n_full[:, None] * stride_phin + k_offs[None, :] * stride_phik + phi = tl.load( + phi_ptrs, + mask=(offs_n_full[:, None] < N) & mask_k[None, :], + other=0.0, + cache_modifier=".ca", + ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) + + ms_acc += tl.sum(x.to(tl.float32) * x.to(tl.float32), axis=1) + + # Currently triton has a bug where for small block size, tl.dot(x, phi.T) will use SMEM to transpose the matrix + # instead of emit a ldmatrix instruction with `.trans` modifier, which leads bank conflicts and performance regression + # See https://github.com/triton-lang/triton/issues/6569#issuecomment-2841739082 h_acc = tl.dot( - x, tl.trans(phi, (1, 0)), h_acc, input_precision=precision, out_dtype=tl.float32 + x.to(phi.dtype), + tl.trans(phi, (1, 0)), + h_acc, + input_precision=precision, + out_dtype=tl.float32, ) h_ptrs = h_ptr + offs_m[:, None] * stride_hm + offs_n_full[None, :] * stride_hn @@ -129,15 +151,35 @@ def _mhc_projection_fwd_fused( tl.atomic_add(ms_ptrs, ms, mask=masks_ms, sem="relaxed") +def projection_config_bwd_dx(): + block_m = [32, 128] + block_k = [128] + warps = [2] + stages = [2, 3, 4] + + configs = [] + for m, bk, w, s in itertools.product(block_m, block_k, warps, stages): + configs.append( + triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_K": bk}, num_warps=w, num_stages=s) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + @triton.autotune( - configs=projection_config_bwd(), + configs=projection_config_bwd_dx(), key=["M", "K"], + # When FUSE_GRAD_X_ACC=True the kernel does a read-modify-write on grad_x_ptr; without + # restore_value the autotune timing trials accumulate onto the buffer and corrupt it. + restore_value=["grad_x_ptr"], ) @triton.jit -def _mhc_projection_bwd_fused( +def _mhc_projection_bwd_fused_dx( x_ptr, grad_x_ptr, # (M, K) phi_ptr, # (N, K) + norm_weight_ptr, # (K,) grad_h_ptr, # (M, N) grad_ms_ptr, # (M,) M, @@ -149,6 +191,7 @@ def _mhc_projection_bwd_fused( stride_grad_xk: tl.constexpr, stride_phin, stride_phik: tl.constexpr, + stride_norm_weight: tl.constexpr, stride_grad_phin, stride_grad_phik: tl.constexpr, stride_grad_hm: tl.constexpr, @@ -159,6 +202,8 @@ def _mhc_projection_bwd_fused( BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, precision: tl.constexpr, + FUSE_GRAD_X_ACC: tl.constexpr, + HAS_NORM_WEIGHT: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -174,6 +219,7 @@ def _mhc_projection_bwd_fused( tl.assume(stride_grad_phin == K) tl.assume(stride_grad_phik == 1) tl.assume(stride_grad_ms == 1) + tl.assume(stride_norm_weight == 1) tl.assume(BLOCK_SIZE_M % 32 == 0) tl.assume(BLOCK_SIZE_K % 32 == 0) @@ -204,19 +250,164 @@ def _mhc_projection_bwd_fused( phi = tl.load( phi_ptrs, mask=(offs_n_full[:, None] < N) & mask_k[None, :], other=0.0 ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) + + if HAS_NORM_WEIGHT: + norm_weight_ptrs = norm_weight_ptr + offs_k * stride_norm_weight + norm_weight = tl.load(norm_weight_ptrs, mask=mask_k, other=0.0, cache_modifier=".ca").to( + phi.dtype + ) # (BLOCK_SIZE_K,) + phi = phi.to(tl.float32) * norm_weight.to(tl.float32)[None, :] + grad_ms = tl.load( grad_ms_ptrs, mask=offs_ms < M, other=0.0, cache_modifier=".ca" ) # (BLOCK_SIZE_M,) grad_x = x * (grad_ms * 2 / tl.cast(K, tl.float32))[:, None] grad_x = tl.dot( - grad_h, phi, acc=grad_x, input_precision=precision, out_dtype=tl.float32 + grad_h.to(phi.dtype), phi, acc=grad_x, input_precision=precision, out_dtype=tl.float32 ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_k[None, :] * stride_grad_xk - grad_x = grad_x.to(x.dtype) + if FUSE_GRAD_X_ACC: # If fused gradient accumulation is enabled, the buffer is always fp32 + grad_x_acc = tl.load(grad_x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0) + grad_x = grad_x.to(tl.float32) + grad_x_acc + else: + grad_x = grad_x.to(x.dtype) tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_k[None, :]) +def projection_config_bwd_dphi(): + block_m = [512, 1024, 2048] + step_m = [32] + block_k = [128, 256] + warps = [2] + stages = [2, 3, 4] + + configs = [] + for bm, sm, bk, w, s in itertools.product(block_m, step_m, block_k, warps, stages): + configs.append( + triton.Config( + {"BLOCK_SIZE_M": bm, "STEP_SIZE_M": sm, "BLOCK_SIZE_K": bk}, + num_warps=w, + num_stages=s, + ) + ) + return configs + + +@triton.autotune( + configs=projection_config_bwd_dphi(), + key=["M", "K"], + reset_to_zero=["grad_phi_ptr", "grad_norm_weight_ptr"], +) +@triton.jit +def _mhc_projection_bwd_fused_dphi( + x_ptr, # (M, K) + grad_H_ptr, # (M, 32) + phi_ptr, # (N, K), N=24 in our case since n = 4 + norm_weight_ptr, # (K,) + grad_phi_ptr, # (N, K), N=24 in our case since n = 4 + grad_norm_weight_ptr, # (K,) + M, + N, + K, + stride_xm, + stride_xk: tl.constexpr, + stride_grad_Hm: tl.constexpr, + stride_grad_Hn: tl.constexpr, + stride_phin, + stride_phik: tl.constexpr, + stride_norm_weight: tl.constexpr, + stride_grad_phin, + stride_grad_phik: tl.constexpr, + stride_grad_norm_weight: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + STEP_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + precision: tl.constexpr, +): + pid_k = tl.program_id(axis=0) + pid_m = tl.program_id(axis=1) + + tl.assume(pid_k >= 0) + tl.assume(stride_xm > 0) + tl.assume(stride_xk == 1) + tl.assume(stride_grad_Hm == 32) + tl.assume(stride_grad_Hn == 1) + tl.assume(stride_phin == K) + tl.assume(stride_phik == 1) + tl.assume(stride_grad_phin == K) + tl.assume(stride_grad_phin == stride_phin) + tl.assume(stride_grad_phik == 1) + tl.assume(stride_grad_norm_weight == 1) + tl.assume(stride_norm_weight == 1) + + tl.assume(BLOCK_SIZE_M % 128 == 0) + tl.assume(BLOCK_SIZE_K % 64 == 0) + tl.assume(BLOCK_SIZE_N == 32) + tl.assume(STEP_SIZE_M % 32 == 0) + + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + mask_k = offs_k < K + offs_n_full = tl.arange(0, BLOCK_SIZE_N) + mask_n = offs_n_full < N + + grad_psi_acc = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) + + m_start = pid_m * BLOCK_SIZE_M + m_end = tl.minimum(m_start + BLOCK_SIZE_M, M) + for m_idx in range(0, tl.cdiv(m_end - m_start, STEP_SIZE_M)): + offs_m = m_start + m_idx * STEP_SIZE_M + tl.arange(0, STEP_SIZE_M) + mask_m = offs_m < M + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0 + ) # (STEP_SIZE_M, BLOCK_SIZE_K) + grad_H_ptrs = ( + grad_H_ptr + offs_m[:, None] * stride_grad_Hm + offs_n_full[None, :] * stride_grad_Hn + ) + grad_H = tl.load( + grad_H_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=0.0 + ) # (STEP_SIZE_M, BLOCK_SIZE_N) + + grad_psi_acc = tl.dot( + tl.trans(grad_H, (1, 0)), + x.to(grad_H.dtype), + acc=grad_psi_acc, + out_dtype=tl.float32, + input_precision=precision, + ) + + phi_ptrs = phi_ptr + offs_n_full[:, None] * stride_phin + offs_k[None, :] * stride_phik + phi = tl.load( + phi_ptrs, mask=(offs_n_full[:, None] < N) & mask_k[None, :], other=0.0 + ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) + norm_weight_ptrs = norm_weight_ptr + offs_k * stride_norm_weight + norm_weight = tl.load( + norm_weight_ptrs, mask=mask_k, other=0.0, cache_modifier=".cg" + ) # (BLOCK_SIZE_K,) + phi = phi.to(tl.float32) + norm_weight = norm_weight.to(tl.float32) + + # Keep grad_psi in SRAM and get grad_phi & grad_norm_weight + grad_phi = grad_psi_acc * norm_weight[None, :].to(grad_psi_acc.dtype) # (32, BLOCK_SIZE_K) + grad_norm_weight = tl.sum(grad_psi_acc * phi.to(grad_psi_acc.dtype), axis=0) # (BLOCK_SIZE_K,) + + grad_phi_ptrs = ( + grad_phi_ptr + offs_n_full[:, None] * stride_grad_phin + offs_k[None, :] * stride_grad_phik + ) + grad_norm_weight_ptrs = grad_norm_weight_ptr + offs_k * stride_grad_norm_weight + + tl.atomic_add( + grad_phi_ptrs, + grad_phi, + mask=(offs_n_full[:, None] < N) & mask_k[None, :], + sem="relaxed", + ) + tl.atomic_add(grad_norm_weight_ptrs, grad_norm_weight, mask=mask_k, sem="relaxed") + + def scale_config(): block_m = [128] warps = [4] @@ -749,6 +940,22 @@ def _mhc_sinkhorn_fwd_fused( tl.store(output_ptrs, P, mask=mask_batch[:, None]) +def aggregate_config_fwd(): + block_m = [1, 2, 4] + block_c = [128, 256] + warps = [1, 2, 4] + stages = [1, 2, 3, 4] + + configs = [] + for m, c, w, s in itertools.product(block_m, block_c, warps, stages): + configs.append( + triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + @triton.autotune( configs=sinkhorn_config(), key=["M"], @@ -854,25 +1061,21 @@ def _mhc_sinkhorn_bwd_fused( ) -def aggregate_config(): - block_m = [1, 2, 4] - block_c = [64, 128, 256] - warps = [1, 2, 4] - stages = [1, 2, 3, 4] +def aggregate_prune_fwd(configs, named_args, **kwargs): + M = named_args.get("M", kwargs.get("M", None)) - configs = [] - for m, c, w, s in itertools.product(block_m, block_c, warps, stages): - configs.append( - triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) + pruned_configs = list( + filter( + lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, configs ) - if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": - configs = configs[:1] - return configs + ) + return pruned_configs @triton.autotune( - configs=aggregate_config(), + configs=aggregate_config_fwd(), key=["M", "C"], + prune_configs_by={"early_config_prune": aggregate_prune_fwd}, ) @triton.jit def _mhc_aggregate_fwd( @@ -949,7 +1152,51 @@ def _mhc_aggregate_fwd( tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_c[None, :]) -@triton.autotune(configs=aggregate_config(), key=["M", "C"], reset_to_zero=["grad_H_pre_ptr"]) +def aggregate_config_bwd(): + block_m = [1, 2, 4] + block_c = [64, 128, 256] + step_c = [32, 64] + warps = [1, 2, 4] + stages = [1, 2, 3, 4] + + configs = [] + for bm, bc, sc, w, s in itertools.product(block_m, block_c, step_c, warps, stages): + configs.append( + triton.Config( + {"BLOCK_SIZE_M": bm, "BLOCK_SIZE_C": bc, "STEP_SIZE_C": sc}, + num_warps=w, + num_stages=s, + ) + ) + return configs + + +def aggregate_prune_bwd(configs, named_args, **kwargs): + M = named_args.get("M", kwargs.get("M", None)) + + pruned_configs = list( + filter( + lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, + configs, + ) + ) + + # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here + # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + pruned_configs = pruned_configs[:1] + return pruned_configs + + +@triton.autotune( + configs=aggregate_config_bwd(), + key=["M", "C"], + reset_to_zero=["grad_H_pre_ptr"], + # When FUSE_GRAD_X_ACC=True the kernel does a read-modify-write on grad_x_ptr; without + # restore_value the autotune timing trials accumulate onto the buffer and corrupt it. + restore_value=["grad_x_ptr"], + prune_configs_by={"early_config_prune": aggregate_prune_bwd}, +) @triton.jit def _mhc_aggregate_bwd( grad_output_ptr, # (M, C) @@ -969,7 +1216,9 @@ def _mhc_aggregate_bwd( # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_C: tl.constexpr, + STEP_SIZE_C: tl.constexpr, precision: tl.constexpr, + FUSE_GRAD_X_ACC: tl.constexpr, ): """ Forward: @@ -992,38 +1241,14 @@ def _mhc_aggregate_bwd( tl.assume(stride_grad_output_m > 0 and stride_grad_output_c == 1) tl.assume(BLOCK_SIZE_C % 32 == 0) + tl.assume(STEP_SIZE_C % 32 == 0) + tl.assume(BLOCK_SIZE_C % STEP_SIZE_C == 0) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) - offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) mask_m = offs_m < M - mask_c = offs_c < C - mask_cn = offs_cn < C * n - - grad_output_ptrs = ( - grad_output_ptr - + offs_m[:, None] * stride_grad_output_m - + offs_c[None, :] * stride_grad_output_c - ) - grad_output = tl.load( - grad_output_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C) - - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) - grad_H_pre = tl.dot( - tl.reshape(grad_output, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), - tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), - input_precision=precision, - out_dtype=tl.float32, - ) - grad_H_pre = tl.reshape(grad_H_pre, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) - offs_grad_H_pre = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) - grad_H_pre_ptrs = grad_H_pre_ptr + offs_grad_H_pre - tl.atomic_add(grad_H_pre_ptrs, grad_H_pre, mask=offs_grad_H_pre < M * n, sem="relaxed") + offs_c_start = pid_c * BLOCK_SIZE_C + offs_cn_start = pid_c * BLOCK_SIZE_C * n H_pre_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) H_pre = tl.load( @@ -1031,19 +1256,59 @@ def _mhc_aggregate_bwd( ) # (BLOCK_SIZE_M * n) H_pre = tl.reshape(H_pre, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) - # grad_x = grad_output @ H_pre.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - grad_x = grad_output[:, :, None] * H_pre[:, None, :] # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) + grad_H_pre_acc = tl.zeros((BLOCK_SIZE_M, 1, n), dtype=tl.float32) + for i in tl.range(0, BLOCK_SIZE_C, STEP_SIZE_C, loop_unroll_factor=2): + offs_c = offs_c_start + i + tl.arange(0, STEP_SIZE_C) + offs_cn = offs_cn_start + i * n + tl.arange(0, STEP_SIZE_C * n) + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + grad_output_ptrs = ( + grad_output_ptr + + offs_m[:, None] * stride_grad_output_m + + offs_c[None, :] * stride_grad_output_c + ) + grad_output = tl.load( + grad_output_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0 + ) # (BLOCK_SIZE_M, STEP_SIZE_C) - grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn - tl.store( - grad_x_ptrs, - grad_x, - mask=mask_m[:, None] & mask_cn[None, :], - ) + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, STEP_SIZE_C * n) + + grad_H_pre_acc = tl.dot( + tl.reshape(grad_output, (BLOCK_SIZE_M, 1, STEP_SIZE_C)), + tl.reshape(x, (BLOCK_SIZE_M, STEP_SIZE_C, n)), + acc=grad_H_pre_acc, + input_precision=precision, + out_dtype=tl.float32, + ) + + # grad_x = grad_output @ H_pre.T: (BLOCK_SIZE_M, STEP_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, STEP_SIZE_C, n) + grad_x = grad_output[:, :, None] * H_pre[:, None, :] # (BLOCK_SIZE_M, STEP_SIZE_C, n) + grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, STEP_SIZE_C * n)) + + grad_x_ptrs = ( + grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn + ) + if FUSE_GRAD_X_ACC: # If fused gradient accumulation is enabled, the buffer is always fp32 + grad_x_acc = tl.load(grad_x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0) + grad_x = grad_x.to(tl.float32) + grad_x_acc + tl.store( + grad_x_ptrs, + grad_x, + mask=mask_m[:, None] & mask_cn[None, :], + ) + + grad_H_pre = tl.reshape(grad_H_pre_acc, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) + offs_grad_H_pre = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + grad_H_pre_ptrs = grad_H_pre_ptr + offs_grad_H_pre + tl.atomic_add(grad_H_pre_ptrs, grad_H_pre, mask=offs_grad_H_pre < M * n, sem="relaxed") -def expand_combine_config(): + +def expand_combine_config_fwd(): block_m = [1, 2, 4] block_c = [128, 256] warps = [1, 2] @@ -1054,18 +1319,34 @@ def expand_combine_config(): configs.append( triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) ) - if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": - configs = configs[:1] return configs +def expand_combine_prune_fwd(configs, named_args, **kwargs): + M = named_args.get("M", kwargs.get("M", None)) + + pruned_configs = list( + filter( + lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, configs + ) + ) + + # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here + # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + pruned_configs = pruned_configs[:1] + return pruned_configs + + @triton.autotune( - configs=expand_combine_config(), + configs=expand_combine_config_fwd(), key=["M", "C"], + prune_configs_by={"early_config_prune": expand_combine_prune_fwd}, ) @triton.jit def _mhc_expand_combine_fwd( f_ptr, # (M, C) + bias_ptr, # (C,), or None if HAS_BIAS is False H_post_ptr, # (M, n) x_ptr, # (M, C, n) H_res_ptr, # (M, n, n) @@ -1075,6 +1356,7 @@ def _mhc_expand_combine_fwd( n: tl.constexpr, stride_fm, stride_fc, + stride_bias, # Not used if HAS_BIAS is False stride_xm, stride_xCn, stride_output_m, @@ -1082,9 +1364,10 @@ def _mhc_expand_combine_fwd( # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_C: tl.constexpr, + HAS_BIAS: tl.constexpr, ): """ - output = f @ H_post: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + output = (f + bias[None, :, None]) @ H_post: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + x @ H_res: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) """ pid_m = tl.program_id(1) @@ -1095,6 +1378,7 @@ def _mhc_expand_combine_fwd( tl.assume(C > 0) tl.assume(n == 4) tl.assume(stride_fm > 0 and stride_fc == 1) + tl.assume(stride_bias == 1) tl.assume(stride_xm > 0 and stride_xCn == 1) tl.assume(stride_output_m > 0 and stride_output_Cn == 1) @@ -1109,6 +1393,8 @@ def _mhc_expand_combine_fwd( f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + if HAS_BIAS: + bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,) offs_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) H_post = tl.load( @@ -1116,10 +1402,12 @@ def _mhc_expand_combine_fwd( ) H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) - # Residual connection path: res_out = f @ H_post: + # Residual connection path: res_out = f @ H_post + bias @ H_post: # (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) # Due to broadcasting, it's equivalent to a multiplicaiton out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) + if HAS_BIAS: + out_acc = tl.fma(bias[None, :, None], H_post[:, None, :], out_acc) out_acc = tl.fma(f[:, :, None], H_post[:, None, :], out_acc) H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) @@ -1167,332 +1455,59 @@ def _mhc_expand_combine_fwd( tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_cn[None, :]) -@triton.autotune( - configs=expand_combine_config(), - key=["M", "C"], - reset_to_zero=["grad_H_post_ptr", "grad_H_res_ptr"], -) -@triton.jit -def _mhc_expand_combine_bwd( - grad_output_ptr, # (M, C, n) - f_ptr, # (M, C) - H_post_ptr, # (M, n) - x_ptr, # (M, C, n) - H_res_ptr, # (M, n, n) - grad_H_post_ptr, # (M, n) - grad_f_ptr, # (M, C) - grad_H_res_ptr, # (M, n, n) - grad_x_ptr, # (M, C, n) - M, - C, - n: tl.constexpr, - stride_grad_output_m, - stride_grad_output_Cn, - stride_fm, - stride_fc, - stride_xm, - stride_xCn, - stride_grad_fm, - stride_grad_fc, - stride_grad_xm, - stride_grad_xCn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_C: tl.constexpr, - precision: tl.constexpr, -): - """ - Each block - It reads - - (BLOCK_SIZE_M, BLOCK_SIZE_C) of f, which is the output of the attention / FFN module - - (BLOCK_SIZE_M, n) of H_post, which is applied for the transformation of the attention / FFN output - - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of x, which is the skip connection's input - - (BLOCK_SIZE_M, n*n) of H_res, which is applied for the transformation of the skip connection - and writes - - (BLOCK_SIZE_M, n) of grad_H_post - - (BLOCK_SIZE_M, BLOCK_SIZE_C) of grad_f - - (BLOCK_SIZE_M, n, n) of grad_H_res - - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of grad_x - - Forward: - out = f @ H_post + x @ H_res - Backward: - GEMM: - grad_H_post = f.T @ grad_output: (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) - grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) - Not GEMM: - grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, 1) = (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) - grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - """ - - pid_m = tl.program_id(1) - pid_c = tl.program_id(0) - - tl.static_assert(n == 4) - tl.assume(M > 0) - tl.assume(C > 0) - tl.assume(n == 4) - tl.assume(stride_fm > 0 and stride_fc == 1) - tl.assume(stride_xm > 0 and stride_xCn == 1) - tl.assume(stride_grad_output_m > 0 and stride_grad_output_Cn == 1) - tl.assume(stride_grad_fm > 0 and stride_grad_fc == 1) - tl.assume(stride_grad_xm > 0 and stride_grad_xCn == 1) - - tl.assume(BLOCK_SIZE_C % 32 == 0) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) - offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) - mask_m = offs_m < M - mask_c = offs_c < C - mask_cn = offs_cn < C * n - - f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc - f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) - - H_post_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) - H_post = tl.load(H_post_ptr + H_post_offs, mask=H_post_offs < M * n, other=0.0) - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) - - H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) - H_res = tl.load( - H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0 - ) # (BLOCK_SIZE_M, n, n) - H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) - - grad_out_ptrs = ( - grad_output_ptr - + offs_m[:, None] * stride_grad_output_m - + offs_cn[None, :] * stride_grad_output_Cn - ) - grad_out = tl.load( - grad_out_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) - grad_out = tl.reshape( - grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - - # grad_H_post = f.T @ grad_output # (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.dot( - tl.reshape(f, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), - tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), - input_precision=precision, - out_dtype=tl.float32, - ) # (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.reshape(grad_H_post, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) - offs_grad_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) - grad_H_post_ptrs = grad_H_post_ptr + offs_grad_H_post - tl.atomic_add(grad_H_post_ptrs, grad_H_post, mask=offs_grad_H_post < M * n, sem="relaxed") - - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - x = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - - # grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) - grad_H_res = tl.dot( - tl.trans(x, (0, 2, 1)), grad_out, input_precision=precision, out_dtype=tl.float32 - ) # (BLOCK_SIZE_M, n, n) - grad_H_res = tl.reshape(grad_H_res, (BLOCK_SIZE_M * n * n,)) # (BLOCK_SIZE_M * n * n) - offs_grad_H_res = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) - grad_H_res_ptrs = grad_H_res_ptr + offs_grad_H_res - tl.atomic_add( - grad_H_res_ptrs, grad_H_res.to(tl.float32), mask=offs_grad_H_res < M * n * n, sem="relaxed" - ) - - grad_out_reshape = tl.reshape( - grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) - grad_out01, grad_out23 = tl.split( - grad_out_reshape - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) - grad_out0, grad_out1 = tl.split( - grad_out01 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - grad_out2, grad_out3 = tl.split( - grad_out23 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - - # grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, 1, n) @ (BLOCK_SIZE_M, n, BLOCK_SIZE_C) = (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) - # Triton doesn't support dot prod with inner dimension < 16, so we need to hack this: - # grad_f = grad_out[:, :, 0] @ H_post.T[:, 0, :] (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, 1) - # + grad_out[:, :, 1] @ H_post.T[:, 1, :] - # + grad_out[:, :, 2] @ H_post.T[:, 2, :] - # + grad_out[:, :, 3] @ H_post.T[:, 3, :] - # where H_post.T[:, i, :] = H_post[:, :, i] - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, 2, 2)) - H_post01, H_post23 = tl.split(H_post) # (BLOCK_SIZE_M, 2), (BLOCK_SIZE_M, 2) - H_post0, H_post1 = tl.split(H_post01) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) - H_post2, H_post3 = tl.split(H_post23) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) - - grad_f_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C), dtype=tl.float32) - # (BLOCK_SIZE_M, BLOCK_SIZE_C) * (BLOCK_SIZE_M, 1) -> (BLOCK_SIZE_M, BLOCK_SIZE_C) - grad_f_acc = tl.fma(grad_out0, H_post0[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out1, H_post1[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out2, H_post2[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out3, H_post3[:, None], grad_f_acc) - grad_f = grad_f_acc.to(f.dtype) - - grad_f_ptrs = grad_f_ptr + offs_m[:, None] * stride_grad_fm + offs_c[None, :] * stride_grad_fc - tl.store(grad_f_ptrs, grad_f, mask=mask_m[:, None] & mask_c[None, :]) - - # grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) - # The inner dim is n=4 which is too small for triton, so we will manually unroll the matmul - # grad_x = grad_out[:, :, 0] @ H_res.T[:, 0, :] - # + grad_out[:, :, 1] @ H_res.T[:, 1, :] - # + grad_out[:, :, 2] @ H_res.T[:, 2, :] - # + grad_out[:, :, 3] @ H_res.T[:, 3, :] - # where H_res.T[:, i, :] = H_res[:, :, i] - # Due to broadcasting, it's equivalent to multiplying each H_res[:, i, :].T with grad_out[:, i, :] - - H_res_reshape = tl.reshape(H_res, (BLOCK_SIZE_M, n, 2, 2)) # (BLOCK_SIZE_M, n, 2, 2) - H_res01, H_res23 = tl.split(H_res_reshape) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) - H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - - grad_x_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) - grad_x_acc = tl.fma(grad_out0[:, :, None], H_res0[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out1[:, :, None], H_res1[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) - - grad_x = grad_x_acc.to(x.dtype) - grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - - grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn - tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_cn[None, :]) - - -@triton.autotune( - configs=expand_combine_config(), - key=["M", "C"], -) -@triton.jit -def _mhc_expand_combine_with_bias_fwd( - f_ptr, # (M, C) - bias_ptr, # (C,) - H_post_ptr, # (M, n) - x_ptr, # (M, C, n) - H_res_ptr, # (M, n, n) - output_ptr, # # (M, C, n) - M, - C, - n: tl.constexpr, - stride_fm, - stride_fc, - stride_bias, - stride_xm, - stride_xCn, - stride_output_m, - stride_output_Cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_C: tl.constexpr, -): - """ - output = (f + bias[None, :, None]) @ H_post: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - + x @ H_res: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - """ - pid_m = tl.program_id(1) - pid_c = tl.program_id(0) - - tl.static_assert(n == 4) - tl.assume(M > 0) - tl.assume(C > 0) - tl.assume(n == 4) - tl.assume(stride_fm > 0 and stride_fc == 1) - tl.assume(stride_bias == 1) - tl.assume(stride_xm > 0 and stride_xCn == 1) - tl.assume(stride_output_m > 0 and stride_output_Cn == 1) - - tl.assume(BLOCK_SIZE_C % 32 == 0) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) - offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) - mask_m = offs_m < M - mask_c = offs_c < C - mask_cn = offs_cn < C * n +def expand_combine_config_bwd(): + block_m = [1, 2, 4] + block_c = [128, 256] + step_c = [32, 64] + warps = [1, 2] + stages = [1, 2, 3, 4] - f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc - f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) - bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,) + configs = [] + for m, c, sc, w, s in itertools.product(block_m, block_c, step_c, warps, stages): + configs.append( + triton.Config( + {"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c, "STEP_SIZE_C": sc}, + num_warps=w, + num_stages=s, + ) + ) + return configs - offs_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) - H_post = tl.load( - H_post_ptr + offs_H_post, mask=offs_H_post < M * n, other=0.0, cache_modifier=".ca" - ) - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) - # Residual connection path: res_out = f @ H_post + bias @ H_post: - # (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) - # Due to broadcasting, it's equivalent to a multiplicaiton - out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) - out_acc = tl.fma(bias[None, :, None], H_post[:, None, :], out_acc) - out_acc = tl.fma(f[:, :, None], H_post[:, None, :], out_acc) +def expand_combine_prune_bwd(configs, named_args, **kwargs): + M = named_args.get("M", kwargs.get("M", None)) - H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) - H_res = tl.load( - H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0, cache_modifier=".ca" + pruned_configs = list( + filter( + lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, + configs, + ) ) - H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) - - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - - # Manifold connection path: manifold_out = H_res @ x: - # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - # triton doesn't support dot prod with inner dimension < 16, so we need to manually unroll the computation for n=4: - # x @ H_res = x[:, :, 0] @ H_res[:, 0, :] - # + x[:, :, 1] @ H_res[:, 1, :] - # + x[:, :, 2] @ H_res[:, 2, :] - # + x[:, :, 3] @ H_res[:, 3, :] - - x_reshape = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2)) - x01, x23 = tl.split( - x_reshape - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) - x0, x1 = tl.split(x01) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - x2, x3 = tl.split(x23) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - - H_resT = tl.reshape(tl.trans(H_res, (0, 2, 1)), (BLOCK_SIZE_M, n, 2, 2)) - H_res01, H_res23 = tl.split(H_resT) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) - H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - out_acc = tl.fma(x0[:, :, None], H_res0[:, None, :], out_acc) - out_acc = tl.fma(x1[:, :, None], H_res1[:, None, :], out_acc) - out_acc = tl.fma(x2[:, :, None], H_res2[:, None, :], out_acc) - out_acc = tl.fma(x3[:, :, None], H_res3[:, None, :], out_acc) - - out = out_acc.to(x.dtype) - out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - - output_ptrs = ( - output_ptr + offs_m[:, None] * stride_output_m + offs_cn[None, :] * stride_output_Cn - ) - tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_cn[None, :]) + # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here + # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + pruned_configs = pruned_configs[:1] + return pruned_configs @triton.autotune( - configs=expand_combine_config(), + configs=expand_combine_config_bwd(), key=["M", "C"], reset_to_zero=["grad_H_post_ptr", "grad_H_res_ptr", "grad_bias_ptr"], + prune_configs_by={"early_config_prune": expand_combine_prune_bwd}, ) @triton.jit -def _mhc_expand_combine_with_bias_bwd( +def _mhc_expand_combine_bwd( grad_output_ptr, # (M, C, n) f_ptr, # (M, C) - bias_ptr, # (C,) + bias_ptr, # (C,), or None if HAS_BIAS is False H_post_ptr, # (M, n) x_ptr, # (M, C, n) H_res_ptr, # (M, n, n) grad_H_post_ptr, # (M, n) grad_f_ptr, # (M, C) - grad_bias_ptr, # (C,) + grad_bias_ptr, # (C,), or None if HAS_BIAS is False grad_H_res_ptr, # (M, n, n) grad_x_ptr, # (M, C, n) M, @@ -1502,18 +1517,21 @@ def _mhc_expand_combine_with_bias_bwd( stride_grad_output_Cn, stride_fm, stride_fc, - stride_bias, + stride_bias, # Not used if HAS_BIAS is False stride_xm, stride_xCn, stride_grad_fm, stride_grad_fc, - stride_grad_bias, + stride_grad_bias, # Not used if HAS_BIAS is False stride_grad_xm, stride_grad_xCn, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_C: tl.constexpr, + STEP_SIZE_C: tl.constexpr, precision: tl.constexpr, + HAS_BIAS: tl.constexpr, + FUSE_GRAD_X_ACC: tl.constexpr, ): """ Each block @@ -1557,137 +1575,169 @@ def _mhc_expand_combine_with_bias_bwd( tl.assume(BLOCK_SIZE_C % 32 == 0) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) - offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) mask_m = offs_m < M - mask_c = offs_c < C - mask_cn = offs_cn < C * n - f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc - f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + offs_c_start = pid_c * BLOCK_SIZE_C + offs_cn_start = pid_c * BLOCK_SIZE_C * n - bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,) + grad_H_post_acc = tl.zeros((BLOCK_SIZE_M, 1, n), dtype=tl.float32) + grad_H_res_acc = tl.zeros((BLOCK_SIZE_M, n, n), dtype=tl.float32) H_post_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) H_post = tl.load(H_post_ptr + H_post_offs, mask=H_post_offs < M * n, other=0.0) - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) + H_post_reshape = tl.reshape(H_post, (BLOCK_SIZE_M, 2, 2)) + H_post01, H_post23 = tl.split(H_post_reshape) # (BLOCK_SIZE_M, 2), (BLOCK_SIZE_M, 2) + H_post0, H_post1 = tl.split(H_post01) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) + H_post2, H_post3 = tl.split(H_post23) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) H_res = tl.load( H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0 ) # (BLOCK_SIZE_M, n, n) H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) + H_res_reshape = tl.reshape(H_res, (BLOCK_SIZE_M, n, 2, 2)) # (BLOCK_SIZE_M, n, 2, 2) + H_res01, H_res23 = tl.split(H_res_reshape) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) + H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - grad_out_ptrs = ( - grad_output_ptr - + offs_m[:, None] * stride_grad_output_m - + offs_cn[None, :] * stride_grad_output_Cn - ) - grad_out = tl.load( - grad_out_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) - grad_out = tl.reshape( - grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + for i in tl.range(0, BLOCK_SIZE_C, STEP_SIZE_C, loop_unroll_factor=2): + offs_c = offs_c_start + i + tl.arange(0, STEP_SIZE_C) + offs_cn = offs_cn_start + i * n + tl.arange(0, STEP_SIZE_C * n) + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc + f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + + if HAS_BIAS: + bias = tl.load( + bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0 + ) # (STEP_SIZE_C,) + + grad_out_ptrs = ( + grad_output_ptr + + offs_m[:, None] * stride_grad_output_m + + offs_cn[None, :] * stride_grad_output_Cn + ) + grad_out = tl.load( + grad_out_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, STEP_SIZE_C * n) + grad_out = tl.reshape( + grad_out, (BLOCK_SIZE_M, STEP_SIZE_C, n) + ) # (BLOCK_SIZE_M, STEP_SIZE_C, n) + + # grad_H_post = f.T @ grad_output # (BLOCK_SIZE_M, 1, STEP_SIZE_C) @ (BLOCK_SIZE_M, STEP_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) + grad_H_post_acc = tl.dot( + tl.reshape(f, (BLOCK_SIZE_M, 1, STEP_SIZE_C)), + tl.reshape(grad_out, (BLOCK_SIZE_M, STEP_SIZE_C, n)), + acc=grad_H_post_acc, + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, 1, n) + if HAS_BIAS: + grad_H_post_acc = tl.dot( + tl.broadcast_to(bias[None, None, :], (BLOCK_SIZE_M, 1, STEP_SIZE_C)), + tl.reshape(grad_out, (BLOCK_SIZE_M, STEP_SIZE_C, n)), + acc=grad_H_post_acc, + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, 1, n) + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, STEP_SIZE_C*n) + x = tl.reshape(x, (BLOCK_SIZE_M, STEP_SIZE_C, n)) # (BLOCK_SIZE_M, STEP_SIZE_C, n) + + # grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, STEP_SIZE_C) @ (BLOCK_SIZE_M, STEP_SIZE_C, n) = (BLOCK_SIZE_M, n, n) + grad_H_res_acc = tl.dot( + tl.trans(x, (0, 2, 1)), + grad_out, + acc=grad_H_res_acc, + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, n, n) + + grad_out_reshape = tl.reshape( + grad_out, (BLOCK_SIZE_M, STEP_SIZE_C, 2, 2) + ) # (BLOCK_SIZE_M, STEP_SIZE_C, 2, 2) + grad_out01, grad_out23 = tl.split( + grad_out_reshape + ) # (BLOCK_SIZE_M, STEP_SIZE_C, 2), (BLOCK_SIZE_M, STEP_SIZE_C, 2) + grad_out0, grad_out1 = tl.split( + grad_out01 + ) # (BLOCK_SIZE_M, STEP_SIZE_C), (BLOCK_SIZE_M, STEP_SIZE_C) + grad_out2, grad_out3 = tl.split( + grad_out23 + ) # (BLOCK_SIZE_M, STEP_SIZE_C), (BLOCK_SIZE_M, STEP_SIZE_C) + + # grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, 1, n) @ (BLOCK_SIZE_M, n, STEP_SIZE_C) = (BLOCK_SIZE_M, 1, STEP_SIZE_C) + # Triton doesn't support dot prod with inner dimension < 16, so we need to hack this: + # = grad_out[:, :, 0] @ H_post.T[:, 0, :] (BLOCK_SIZE_M, STEP_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, 1) + # + grad_out[:, :, 1] @ H_post.T[:, 1, :] + # + grad_out[:, :, 2] @ H_post.T[:, 2, :] + # + grad_out[:, :, 3] @ H_post.T[:, 3, :] + # where H_post.T[:, i, :] = H_post[:, :, i] + + grad_f_acc = tl.zeros((BLOCK_SIZE_M, STEP_SIZE_C), dtype=tl.float32) + # (BLOCK_SIZE_M, STEP_SIZE_C) * (BLOCK_SIZE_M, 1) -> (BLOCK_SIZE_M, STEP_SIZE_C) + grad_f_acc = tl.fma(grad_out0, H_post0[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out1, H_post1[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out2, H_post2[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out3, H_post3[:, None], grad_f_acc) + grad_f = grad_f_acc.to(f.dtype) + + grad_f_ptrs = ( + grad_f_ptr + offs_m[:, None] * stride_grad_fm + offs_c[None, :] * stride_grad_fc + ) + tl.store(grad_f_ptrs, grad_f, mask=mask_m[:, None] & mask_c[None, :]) + + if HAS_BIAS: + grad_bias = tl.sum(grad_f_acc, axis=0) # (STEP_SIZE_C,) + # This is reduction over M dimension, so it has nothing to do with whether we use split-C. It only depends on determinism or not. + grad_bias_ptrs = grad_bias_ptr + offs_c * stride_grad_bias + tl.atomic_add(grad_bias_ptrs, grad_bias, mask=mask_c, sem="relaxed") + + # grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, STEP_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, n, STEP_SIZE_C) + # The inner dim is n=4 which is too small for triton, so we will manually unroll the matmul + # grad_x = grad_out[:, :, 0] @ H_res.T[:, 0, :] + # + grad_out[:, :, 1] @ H_res.T[:, 1, :] + # + grad_out[:, :, 2] @ H_res.T[:, 2, :] + # + grad_out[:, :, 3] @ H_res.T[:, 3, :] + # where H_res.T[:, i, :] = H_res[:, :, i] + # Due to broadcasting, it's equivalent to multiplying each H_res[:, i, :].T with grad_out[:, i, :] + + grad_x_acc = tl.zeros((BLOCK_SIZE_M, STEP_SIZE_C, n), dtype=tl.float32) + grad_x_acc = tl.fma(grad_out0[:, :, None], H_res0[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out1[:, :, None], H_res1[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) + + if FUSE_GRAD_X_ACC: + grad_x = grad_x_acc # If fusing gradient accumulation, the buffer should be always fp32 so we don't cast here + else: + grad_x = grad_x_acc.to(x.dtype) + grad_x = tl.reshape( + grad_x, (BLOCK_SIZE_M, STEP_SIZE_C * n) + ) # (BLOCK_SIZE_M, STEP_SIZE_C*n) + + grad_x_ptrs = ( + grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn + ) + tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_cn[None, :]) - # grad_H_post = f.T @ grad_output # (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.dot( - tl.reshape(f, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), - tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), - input_precision=precision, - out_dtype=tl.float32, - ) # (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.dot( - tl.broadcast_to(bias[None, None, :], (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), - tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), - acc=grad_H_post, - input_precision=precision, - out_dtype=tl.float32, - ) # (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.reshape(grad_H_post, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) + grad_H_post = tl.reshape(grad_H_post_acc, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) offs_grad_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) grad_H_post_ptrs = grad_H_post_ptr + offs_grad_H_post - tl.atomic_add(grad_H_post_ptrs, grad_H_post, mask=offs_grad_H_post < M * n, sem="relaxed") - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - x = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - - # grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) - grad_H_res = tl.dot( - tl.trans(x, (0, 2, 1)), grad_out, input_precision=precision, out_dtype=tl.float32 - ) # (BLOCK_SIZE_M, n, n) - grad_H_res = tl.reshape(grad_H_res, (BLOCK_SIZE_M * n * n,)) # (BLOCK_SIZE_M * n * n) + grad_H_res = tl.reshape(grad_H_res_acc, (BLOCK_SIZE_M * n * n,)) # (BLOCK_SIZE_M * n * n) offs_grad_H_res = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) grad_H_res_ptrs = grad_H_res_ptr + offs_grad_H_res + + tl.atomic_add(grad_H_post_ptrs, grad_H_post, mask=offs_grad_H_post < M * n, sem="relaxed") tl.atomic_add( - grad_H_res_ptrs, grad_H_res.to(tl.float32), mask=offs_grad_H_res < M * n * n, sem="relaxed" + grad_H_res_ptrs, + grad_H_res.to(tl.float32), + mask=offs_grad_H_res < M * n * n, + sem="relaxed", ) - - grad_out_reshape = tl.reshape( - grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) - grad_out01, grad_out23 = tl.split( - grad_out_reshape - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) - grad_out0, grad_out1 = tl.split( - grad_out01 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - grad_out2, grad_out3 = tl.split( - grad_out23 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - - # grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, 1, n) @ (BLOCK_SIZE_M, n, BLOCK_SIZE_C) = (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) - # Triton doesn't support dot prod with inner dimension < 16, so we need to hack this: - # = grad_out[:, :, 0] @ H_post.T[:, 0, :] (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, 1) - # + grad_out[:, :, 1] @ H_post.T[:, 1, :] - # + grad_out[:, :, 2] @ H_post.T[:, 2, :] - # + grad_out[:, :, 3] @ H_post.T[:, 3, :] - # where H_post.T[:, i, :] = H_post[:, :, i] - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, 2, 2)) - H_post01, H_post23 = tl.split(H_post) # (BLOCK_SIZE_M, 2), (BLOCK_SIZE_M, 2) - H_post0, H_post1 = tl.split(H_post01) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) - H_post2, H_post3 = tl.split(H_post23) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) - - grad_f_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C), dtype=tl.float32) - # (BLOCK_SIZE_M, BLOCK_SIZE_C) * (BLOCK_SIZE_M, 1) -> (BLOCK_SIZE_M, BLOCK_SIZE_C) - grad_f_acc = tl.fma(grad_out0, H_post0[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out1, H_post1[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out2, H_post2[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out3, H_post3[:, None], grad_f_acc) - grad_f = grad_f_acc.to(f.dtype) - - grad_f_ptrs = grad_f_ptr + offs_m[:, None] * stride_grad_fm + offs_c[None, :] * stride_grad_fc - tl.store(grad_f_ptrs, grad_f, mask=mask_m[:, None] & mask_c[None, :]) - - grad_bias = tl.sum(grad_f_acc, axis=0) # (BLOCK_SIZE_C,) - grad_bias_ptrs = grad_bias_ptr + offs_c * stride_grad_bias - tl.atomic_add(grad_bias_ptrs, grad_bias, mask=mask_c, sem="relaxed") - - # grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) - # The inner dim is n=4 which is too small for triton, so we will manually unroll the matmul - # grad_x = grad_out[:, :, 0] @ H_res.T[:, 0, :] - # + grad_out[:, :, 1] @ H_res.T[:, 1, :] - # + grad_out[:, :, 2] @ H_res.T[:, 2, :] - # + grad_out[:, :, 3] @ H_res.T[:, 3, :] - # where H_res.T[:, i, :] = H_res[:, :, i] - # Due to broadcasting, it's equivalent to multiplying each H_res[:, i, :].T with grad_out[:, i, :] - - H_res_reshape = tl.reshape(H_res, (BLOCK_SIZE_M, n, 2, 2)) # (BLOCK_SIZE_M, n, 2, 2) - H_res01, H_res23 = tl.split(H_res_reshape) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) - H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - - grad_x_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) - grad_x_acc = tl.fma(grad_out0[:, :, None], H_res0[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out1[:, :, None], H_res1[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) - - grad_x = grad_x_acc.to(x.dtype) - grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - - grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn - tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_cn[None, :]) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index 987216e327..3565b17ff4 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -5,28 +5,54 @@ """PyTorch wrapper functions for mHC (manifold Hyper-Connection) Triton kernels.""" import os +from typing import Optional import torch import triton from transformer_engine.common.triton.mhc import ( + _mhc_projection_bwd_fused_dphi, + _mhc_projection_bwd_fused_dx, _mhc_scale_fwd_fused, _mhc_scale_bwd_fused, - _mhc_expand_combine_with_bias_fwd, - _mhc_expand_combine_with_bias_bwd, _mhc_expand_combine_fwd, _mhc_expand_combine_bwd, _mhc_aggregate_fwd, _mhc_aggregate_bwd, _mhc_projection_fwd_fused, - _mhc_projection_bwd_fused, - _mhc_sinkhorn_fwd_fused, _mhc_sinkhorn_fwd_fused_recompute, - _mhc_sinkhorn_bwd_fused, _mhc_sinkhorn_bwd_fused_recompute, + _mhc_sinkhorn_fwd_fused, + _mhc_sinkhorn_bwd_fused, ) from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm +def _support_tma(): + return torch.cuda.get_device_capability()[0] >= 9 + + +def _tma_aligned(t): + return (t.stride(0) * t.element_size()) % 16 == 0 and t.data_ptr() % 16 == 0 + + +_tma_allocator_initialized = False + + +def _init_tma_allocator(): + # TMA descriptors require a global memory allocation. Registered once on first use. + global _tma_allocator_initialized # pylint: disable=global-statement + if _tma_allocator_initialized: + return + + def alloc_fn( + size: int, alignment: int, stream: Optional[int] + ): # pylint: disable=unused-argument + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + _tma_allocator_initialized = True + + def check_deterministic(operator: str): """ Checks if the non-deterministic algorithm is allowed for the given operator. If not, raises an assertion error with instructions on how to allow it. @@ -39,6 +65,106 @@ def check_deterministic(operator: str): ) +def mhc_generate_mix_and_aggregate( + x: torch.Tensor, + phi: torch.Tensor, + alpha: torch.Tensor, + beta: torch.Tensor, + norm_weight: Optional[torch.Tensor] = None, + use_tf32: bool = True, + fused_grad_x_acc_buffer: Optional[torch.Tensor] = None, +): + """ + Generate the mix matrix H_pre, H_post, H_res and apply H_pre to x to aggregate n streams + This wraps projection, scale, sinkhorn, and aggregate operations into one function. + + To use mHC in your model: + ``` + layer_input, H_post, H_res = mhc_generate_mix_and_aggregate(x, phi, alpha, beta) + layer_output = layer(layer_input) # Attn / FFN layer + x = mhc_fused_expand_combine(layer_input, bias, H_post, x, H_res) + ``` + + This API accepts both BF16 and FP32 parameters, though the DeepSeek V4 recipe is: + - x: BF16 + - phi, alpha, beta: FP32 + + Parameters + ---------- + x : torch.Tensor, + input tensor of shape (s, b, C, n), where s is the sequence length, b is the batch size, C is the hidden dimension per hyper connection, and n is the number of hyper connections, + dtype is torch.bfloat16 or torch.float32 + Note that C is equal to the original hidden dimension divided by n. + phi : torch.Tensor + projection matrix of shape (N, nC), where N=2n+n*n (=24 for n=4), and nC is the hidden dimension after expansion (n times of C), + dtype is torch.bfloat16 or torch.float32 + norm_weight : torch.Tensor or None + optional, the weight for RMSNorm, of shape (K,), which is the learnable per-element affine parameters (gamma) applied to RMSNorm + dtype is torch.bfloat16 or torch.float32 + alpha : torch.Tensor + scaling factor for H, of shape (3,), where + alpha[0] is applied to H[:, 0:n] for H_pre + alpha[1] is applied to H[:, n:2n] for H_post + alpha[2] is applied to H[:, 2n:2n+n*n] for H_res + dtype: torch.bfloat16 or torch.float32 + beta : torch.Tensor + bias term for H, of shape (1, 2*n+n*n), where + beta[0, 0:n] is applied to H[:, 0:n] for H_pre + beta[0, n:2n] is applied to H[:, n:2n] for H_post + beta[0, 2n:2n+n*n] is applied to H[:, 2n:2n+n*n] for H_res + dtype is torch.bfloat16 or torch.float32 + use_tf32 : bool + whether to use TF32 for matrix multiplications + fused_grad_x_acc_buffer : Optional[torch.Tensor] + A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch, which should be reused + during the backward of mhc_fused_aggregate, mhc_fused_expand_combine and mhc_fused_projection operations + Note: the buffer must have dtype float32, and it will be cast to the activation's dtype and be returned in mhc_fused_projection + + Returns + ------- + out : torch.Tensor + out of shape (s, b, C), which is the aggregated result after applying H_pre to x, which will be fed into attention / FFN + with the same dtype as x + H_post : torch.Tensor + H_post of shape (s, b, n), which will be used in the post-processing after attention / FFN in `mhc_fused_expand_combine` + with dtype float32 + H_res : torch.Tensor + H_res of shape (s, b, n, n), which will be used to mix the residual connection in `mhc_fused_expand_combine` + with dtype float32 + """ + check_deterministic("mhc_generate_mix_and_aggregate") + s, b, C, n = x.shape + assert ( + n == 4 + ), "Only n=4 is supported in this implementation, where n is the Hyper Connection number" + if fused_grad_x_acc_buffer is not None: + assert ( + fused_grad_x_acc_buffer.dtype == torch.float32 + ), "fused_grad_x_acc_buffer must be fp32" + assert ( + fused_grad_x_acc_buffer.numel() == x.numel() + ), "fused_grad_x_acc_buffer.numel() must match x.numel()" + nC = n * C + H, ms = mhc_fused_projection( + x.view(s * b, nC), + phi, + norm_weight=norm_weight, + use_tf32=use_tf32, + fused_grad_x_acc_buffer=fused_grad_x_acc_buffer, + ) + H_pre, H_post, H_res = mhc_fused_scale(H, alpha, beta, ms, n) + H_res = mhc_fused_sinkhorn(H_res.view(s, b, n, n), n, recompute_hist=True, iters=20) + out = mhc_fused_aggregate( + x, + H_pre.view(s, b, n), + n, + use_tf32=use_tf32, + fused_grad_x_acc_buffer=fused_grad_x_acc_buffer, + ) + return out, H_post.view(s, b, n), H_res + + def mhc_fused_sinkhorn( H_res: torch.Tensor, n: int = 4, recompute_hist: bool = True, iters: int = 20 ): @@ -52,6 +178,7 @@ def mhc_fused_sinkhorn( ---------- H_res : torch.Tensor input H_res matrix of shape (s, b, n, n) that needs to be normalized into a doubly stochastic matrix. + dtype is torch.bfloat16 or torch.float32 n : int number of hyper connections, where only n=4 is supported in the current implementation recompute_hist : bool @@ -63,6 +190,7 @@ def mhc_fused_sinkhorn( ------- out : torch.Tensor out of shape (s, b, n, n), which is the final H_res after Sinkhorn normalization + with the same dtype as H_res """ assert n == 4, "Only n=4 is supported in this implementation" out = mHCSinkhornOp.apply(H_res, n, recompute_hist, iters) @@ -70,7 +198,11 @@ def mhc_fused_sinkhorn( def mhc_fused_scale( - H: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor, ms: torch.Tensor, n: int + H: torch.Tensor, + alpha: torch.Tensor, + beta: torch.Tensor, + ms: torch.Tensor, + n: int, ): """ Fused scale operation to compute the scaled H matrices (see eq. 16-18, section 4.3.1 of the DeepSeek mHC paper): @@ -96,6 +228,7 @@ def mhc_fused_scale( beta[0, 0:n] is applied to H[:, 0:n] for H_pre beta[0, n:2n] is applied to H[:, n:2n] for H_post beta[0, 2n:2n+n*n] is applied to H[:, 2n:2n+n*n] for H_res + Note: we assume alpha and beta have the same dtype, and according to the DeepSeek paper they should be fp32 ms : torch.Tensor mean square for each row of H from the projection kernel, of shape (M,), used for RMSNorm scaling n : int @@ -104,15 +237,17 @@ def mhc_fused_scale( Returns ------- h_pre : torch.Tensor - Scaled H_pre of shape (M, n), which aggregates (s, b, C, n) input of a Hyper Connection block into (s, b, n) as the input of attention / MLP + Scaled H_pre of shape (M, n), which aggregates (s, b, C, n) input of a Hyper Connection block into (s, b, n) as the input of attention / MLP, + with the same dtype as H h_post : torch.Tensor - Scaled H_post of shape (M, n), which expands the output of attention / MLP of shape (s, b, n) back to (s, b, C, n) for the residual connection + Scaled H_post of shape (M, n), which expands the output of attention / MLP of shape (s, b, n) back to (s, b, C, n) for the residual connection, + with the same dtype as H h_res : torch.Tensor - Scaled H_res of shape (M, n*n), which mixes the n streams of the (s, b, C, n) input of a Hyper Connection block + Scaled H_res of shape (M, n*n), which mixes the n streams of the (s, b, C, n) input of a Hyper Connection block, + with the same dtype as H """ assert n == 4, "Only n=4 is supported in this implementation" - check_deterministic("mhc_fused_scale") out = mHCScaleFusedOp.apply(H, alpha, beta, ms, n) h_pre = out[..., :n] h_post = out[..., n : 2 * n] @@ -120,7 +255,13 @@ def mhc_fused_scale( return h_pre, h_post, h_res -def mhc_fused_aggregate(x: torch.Tensor, H_pre: torch.Tensor, n: int, use_tf32: bool = True): +def mhc_fused_aggregate( + x: torch.Tensor, + H_pre: torch.Tensor, + n: int, + use_tf32: bool = True, + fused_grad_x_acc_buffer: Optional[torch.Tensor] = None, +): """ Aggregate operation to merge n activation streams into one (see section 4.3.1 of the DeepSeek mHC paper): out = x @ H_pre: (s, b, C, n) @ (s, b, n, 1) -> (s, b, C, 1) -> (s, b, C) after squeezing the last dimension @@ -130,22 +271,37 @@ def mhc_fused_aggregate(x: torch.Tensor, H_pre: torch.Tensor, n: int, use_tf32: x : torch.Tensor input activation tensor of shape (s, b, C, n), where s is the sequence length, b is the batch size, C is the hidden dimension per hyper connection, and n is the number of hyper connections. Note that C is equal to the original hidden dimension divided by n. + dtype is torch.bfloat16 or torch.float32 H_pre: torch.Tensor input H_pre matrix of shape (s, b, n) + dtype is torch.bfloat16 or torch.float32 n: int number of hyper connections, where only n=4 is supported in the current implementation use_tf32: bool whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail + fused_grad_x_acc_buffer : Optional[torch.Tensor] + A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch. + This optimization requires the operation order to be mhc_fused_projection -> mhc_fused_aggregate -> mhc_fused_expand_combine. + Note: the buffer must have dtype float32, and it will be cast to the activation's dtype and be returned in mhc_fused_projection Returns ------- out: torch.Tensor - output activation tensor of shape (s, b, C), which is the aggregated output after merging n hyper connections + output activation tensor of shape (s, b, C), which is the aggregated output after merging n hyper connections, + with the same dtype as x """ assert n == 4, "Only n=4 is supported in this implementation" + if fused_grad_x_acc_buffer is not None: + assert ( + fused_grad_x_acc_buffer.dtype == torch.float32 + ), "fused_grad_x_acc_buffer must be fp32" + assert ( + fused_grad_x_acc_buffer.numel() == x.numel() + ), "fused_grad_x_acc_buffer.numel() must match x.numel()" check_deterministic("mhc_fused_aggregate") - out = mHCAggregateOp.apply(x, H_pre, n, use_tf32) + out = mHCAggregateOp.apply(x, H_pre, n, use_tf32, fused_grad_x_acc_buffer) return out @@ -157,6 +313,7 @@ def mhc_fused_expand_combine( H_res: torch.Tensor, n: int, use_tf32: bool = True, + fused_grad_x_acc_buffer: Optional[torch.Tensor] = None, ): """ Expand and combine operation for merging n hyper connections (see section 4.3.1 of the DeepSeek mHC paper): @@ -167,26 +324,44 @@ def mhc_fused_expand_combine( ---------- f : torch.Tensor input activation tensor of shape (s, b, C), which is the output from the attention / FFN sub-layer in a transformer block + dtype is torch.bfloat16 or torch.float32 bias : torch.Tensor or None optional bias tensor of shape (C,) from the last linear layer, where f + bias is fused in this kernel for better performance + dtype is torch.bfloat16 or torch.float32 H_post : torch.Tensor input H_post matrix of shape (s, b, n) + dtype is torch.bfloat16 or torch.float32 x : torch.Tensor input activation tensor of shape (s, b, C, n), which is the hyper connection input before the aggregation operation + dtype is torch.bfloat16 or torch.float32 H_res : torch.Tensor input H_res matrix of shape (s, b, n, n) + dtype is torch.bfloat16 or torch.float32 n : int - number of hyper connections + number of hyper connections, where only n=4 is supported in the current implementation use_tf32 : bool whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail + fused_grad_x_acc_buffer : Optional[torch.Tensor] + A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch. + This optimization requires the operation order to be mhc_fused_projection -> mhc_fused_aggregate -> mhc_fused_expand_combine. + Note: the buffer must have dtype float32, and it will be cast to the activation's dtype and be returned in mhc_fused_projection Returns ------- out : torch.Tensor - out of shape (s, b, C, n), which is the expanded and combined output after merging n hyper connections + out of shape (s, b, C, n), which is the expanded and combined output after merging n hyper connections, + with the same dtype as x """ assert n == 4, "Only n=4 is supported in this implementation" + if fused_grad_x_acc_buffer is not None: + assert ( + fused_grad_x_acc_buffer.dtype == torch.float32 + ), "fused_grad_x_acc_buffer must be fp32" + assert ( + fused_grad_x_acc_buffer.numel() == x.numel() + ), "fused_grad_x_acc_buffer.numel() must match x.numel()" check_deterministic("mhc_fused_expand_combine") out = mHCExpandCombineOp.apply( f, @@ -196,41 +371,76 @@ def mhc_fused_expand_combine( H_res, n, use_tf32, + fused_grad_x_acc_buffer, ) return out -def mhc_fused_projection(x: torch.Tensor, phi: torch.Tensor, use_tf32: bool = True): +def mhc_fused_projection( + x: torch.Tensor, + phi: torch.Tensor, + use_tf32: bool = True, + norm_weight: Optional[torch.Tensor] = None, + fused_grad_x_acc_buffer: Optional[torch.Tensor] = None, +): """ Fused projection operation to compute H matrices and mean square for RMSNorm (see eq. 14-15, section 4.3.1 of the DeepSeek mHC paper): H = x @ phi^T: (M, K) @ (K, N) -> (M, N), which is padded to (M, 32) for better memory access pattern in the next kernels. ms = mean(x^2, dim=-1): (M,) + If norm_weight is provided, it will be absorbed into phi. In this case, the operation becomes: + Projection: + - H = x @ (phi.T * norm_weight) = x @ phi.T * norm_weight + - ms = mean(x^2, dim=-1) + - H = H / sqrt(ms) = x @ (phi.T * norm_weight) / sqrt(ms), where this step is fused into `mhc_fused_scale` + which is equivalent to performing the computation in the normal order: + - x_normalized = RMSNorm(x) = x * norm_weight / sqrt(ms) + - H = x_normalized @ phi.T = (x / sqrt(ms) @ phi.T) * norm_weight + Note: the current implementation only supports n=4 Parameters ---------- x : torch.Tensor input tensor of shape (M, K), where M=s*b is the batch size and K=nC is the hidden dimension after expansion. + dtype is torch.bfloat16 or torch.float32 phi : torch.Tensor projection matrix of shape (N, K), where N=2n+n*n (=24 for n=4) + dtype is torch.bfloat16 or torch.float32 use_tf32 : bool whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail. + norm_weight : torch.Tensor or None + optional, the weight for RMSNorm, of shape (K,), which is the learnable per-element affine parameters (gamma) applied to RMSNorm + dtype is torch.bfloat16 or torch.float32 + fused_grad_x_acc_buffer : Optional[torch.Tensor] + A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch. + This optimization requires the operation order to be mhc_fused_projection -> mhc_fused_aggregate -> mhc_fused_expand_combine. + Note: the buffer must have dtype float32, and it will be cast to the activation's dtype and be returned in mhc_fused_projection Returns ------- H : torch.Tensor - Projected matrix of shape (M, 32), where only the first N elements in the last dimension are valid. + Projected matrix of shape (M, 32), where only the first N elements in the last dimension are valid, + with dtype float32 ms : torch.Tensor - Mean square of shape (M,), which is used for RMSNorm in the next kernel. + Mean square of shape (M,), which is used for RMSNorm in the next kernel, + with dtype float32 """ assert ( phi.shape[0] == 24 ), "Currently only n=4 is supported, which means phi should have 24 in its first dimension" check_deterministic("mhc_fused_projection") - H, ms = mHCProjectionOp.apply(x, phi, use_tf32) + if fused_grad_x_acc_buffer is not None: + assert ( + fused_grad_x_acc_buffer.dtype == torch.float32 + ), "fused_grad_x_acc_buffer must be fp32" + assert ( + fused_grad_x_acc_buffer.numel() == x.numel() + ), "fused_grad_x_acc_buffer.numel() must match x.numel()" + H, ms = mHCProjectionOp.apply(x, phi, norm_weight, use_tf32, fused_grad_x_acc_buffer) return H, ms @@ -240,16 +450,20 @@ class mHCProjectionOp(torch.autograd.Function): """ @staticmethod - def forward(ctx, x, phi, use_tf32=True): + def forward(ctx, x, phi, norm_weight=None, use_tf32=True, fused_grad_x_acc_buffer=None): """ The forward pass of the fused projection operation. Computes H = x @ phi^T and the mean + If norm_weight is provided, it will be absorbd by phi square ms = mean(x^2, dim=-1) for RMSNorm in a single fused kernel. Parameters: ctx : The context object. x (tensor): The input tensor of shape (M, K), where M=s*b is the flattened batch dimension and K=nC is the hidden dimension after expansion. phi (tensor): The projection matrix of shape (N, K), where N=2n+n*n (=24 for n=4). + norm_weight (tensor or None): Optional, or tensor of shape (K,). RMSNorm's learnable per-element affine parameters use_tf32 (bool): Whether to use TF32 precision for matmul operations. If False, uses IEEE for better precision. + n (int): Number of hyper connections, where only n=4 is supported in the current implementation. + fused_grad_x_acc_buffer (torch.Tensor or None): A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. Returns: tuple: A tuple of (H, ms) where H is the projected matrix of shape (M, 32) padded for memory alignment (only the first N elements are valid), and ms is the mean square of shape (M,) in FP32. @@ -267,9 +481,7 @@ def forward(ctx, x, phi, use_tf32=True): # Pad H to (s, b, 32) for better memory access pattern in the kernel, but only the first N elements in the last dimension are valid H = torch.zeros((M, 32), device=device, dtype=torch.float32) - ms = torch.zeros( - (M,), device=device, dtype=torch.float32 - ) # Mean square for x, used to compute RMSNorm in the next kernel + ms = torch.zeros((M,), device=device, dtype=torch.float32) # pylint: disable=unnecessary-lambda-assignment grid = lambda META: ( @@ -277,6 +489,26 @@ def forward(ctx, x, phi, use_tf32=True): triton.cdiv(K, META["BLOCK_SIZE_K"]), ) + use_tma = _support_tma() and _tma_aligned(x) and _tma_aligned(phi) + if use_tma: + _init_tma_allocator() + + ctx.save_for_backward(x, phi, ms, norm_weight) + ctx.phi_dtype = phi.dtype + ctx.fused_grad_x_acc_buffer = fused_grad_x_acc_buffer + + if norm_weight is not None: + phi = phi.to(torch.float32) * norm_weight.to(torch.float32) + + precision = "tf32" if ctx.use_tf32 else "ieee" + # If upcasting from bf16 to fp32 takes place inside the triton kernel, triton will ignore "ieee" precision and use tf32 anyway + # See https://github.com/triton-lang/triton/issues/10176 for detail. + # Therefore, we need to use tf32x3 instead which at least has better accuracy than tf32 just to make the tests pass. In production + # precision should be tf32 so it's not affected. + if precision == "ieee" and x.dtype == torch.bfloat16 and phi.dtype == torch.float32: + precision = "tf32x3" + ctx.precision = precision + _mhc_projection_fwd_fused[grid]( x_ptr=x, # (M, K) phi_ptr=phi, # (N, K) @@ -292,22 +524,30 @@ def forward(ctx, x, phi, use_tf32=True): stride_hm=32, stride_hn=1, stride_ms=1, + stride_norm_weight=1, BLOCK_SIZE_N=32, - precision="tf32" if use_tf32 else "ieee", + precision=precision, + USE_TMA=use_tma, ) - ctx.save_for_backward(x, phi, ms) - ctx.phi_dtype = phi.dtype - - return H.to(ctx.dtype), ms # Keep ms in fp32 + return H, ms # Keep both in fp32, which will be passed to sigmoid in mHCScaleFusedOp @staticmethod def backward(ctx, grad_H, grad_ms): """ The backward pass of the fused projection operation. Computes gradients for x and phi. - grad_phi = grad_H^T @ x, truncated to the first N rows. - grad_x = grad_H @ phi + 2 * x * grad_ms / K, where the second term is the gradient contribution from + - grad_psi = grad_H^T @ x: (2n + n^2, M) @ (M, nC) = (2n + n^2, nC), where grad_H's last dim is padded to 32 + If norm_weight is None: + - grad_phi = grad_psi + Otherwise, + - grad_phi = grad_psi * norm_weight: (2n + n^2, nC) * (nC,) = (2n + n^2, nC) + - grad_norm_weight = sum(grad_psi * phi, dim=0): ((2n + n^2, nC) * (2n + n^2, nC)).sum(dim=0) -> (nC,) + Reorder a bit: + - grad_phi = grad_H^T @ x * norm_weight + - grad_norm_weight = sum((grad_H^T @ x) * phi, dim=0) + + - grad_x = grad_H @ phi + 2 * x * grad_ms / K, where the second term is the gradient contribution from the mean square computation fused in the forward pass. Parameters: @@ -316,9 +556,9 @@ def backward(ctx, grad_H, grad_ms): grad_ms (tensor): The gradient of the loss with respect to the mean square, of shape (M,). Returns: - tuple: A tuple with the gradients (grad_x, grad_phi, None). + tuple: A tuple with the gradients (grad_x, grad_phi, grad_norm_weight, None). """ - x, phi, ms = ctx.saved_tensors + x, phi, ms, norm_weight = ctx.saved_tensors M, K = x.shape device = x.device @@ -332,12 +572,56 @@ def backward(ctx, grad_H, grad_ms): M, ) - grad_x = torch.empty((M, K), device=device, dtype=x.dtype) + if ctx.fused_grad_x_acc_buffer is not None: + grad_x = ctx.fused_grad_x_acc_buffer.view_as(x) + else: + grad_x = torch.empty((M, K), device=device, dtype=x.dtype) + + if norm_weight is not None: + # With norm_weight, we need a fused kernel to perform GEMM and output both phi & norm_weight gradients + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: ( + triton.cdiv(K, META["BLOCK_SIZE_K"]), + triton.cdiv(M, META["BLOCK_SIZE_M"]), + ) - grad_x = torch.empty((M, K), device=device, dtype=x.dtype) - grad_phi = general_gemm(x, grad_H, out_dtype=torch.float32, layout="NT")[0][:N, :].to( - phi.dtype - ) # (2n + n^2, M) @ (M, nC) = (2n + n^2, nC); grad_H's last dim is padded to 32 + # For reduction over M, we should prefer parallelizing over M since it's likely to be better, unless determinism is enforced + grad_phi = torch.zeros_like(phi, dtype=torch.float32) + grad_norm_weight = torch.zeros_like(norm_weight, dtype=torch.float32) + + _mhc_projection_bwd_fused_dphi[grid]( + x_ptr=x, # (M, K) + grad_H_ptr=grad_H, # (M, 32) + phi_ptr=phi, # (N, K) + norm_weight_ptr=norm_weight, # (K,) + grad_phi_ptr=grad_phi, # (N, K) + grad_norm_weight_ptr=grad_norm_weight, # (K,) + M=M, + N=N, + K=K, + stride_xm=K, + stride_xk=1, + stride_grad_Hm=32, + stride_grad_Hn=1, + stride_phin=K, + stride_phik=1, + stride_norm_weight=1, + stride_grad_phin=K, + stride_grad_phik=1, + stride_grad_norm_weight=1, + BLOCK_SIZE_N=32, + precision="tf32" if ctx.use_tf32 else "ieee", + ) + + grad_phi = grad_phi.to(phi.dtype) + grad_norm_weight = grad_norm_weight.to(norm_weight.dtype) + else: + # Without norm_weight, this is only a GEMM with no fusion needed so we let cuBLAS handle it + grad_phi = general_gemm( + x.to(grad_H.dtype), grad_H, out_dtype=torch.float32, layout="NT" + )[0][:N, :] + grad_phi = grad_phi.to(phi.dtype) + grad_norm_weight = None # pylint: disable=unnecessary-lambda-assignment grid = lambda META: ( @@ -345,10 +629,11 @@ def backward(ctx, grad_H, grad_ms): triton.cdiv(K, META["BLOCK_SIZE_K"]), ) - _mhc_projection_bwd_fused[grid]( + _mhc_projection_bwd_fused_dx[grid]( x_ptr=x, grad_x_ptr=grad_x, # (M, K) phi_ptr=phi, # (N, K) + norm_weight_ptr=norm_weight, # (K,) grad_h_ptr=grad_H, # (M, 32) grad_ms_ptr=grad_ms, # (M,) M=M, @@ -360,16 +645,19 @@ def backward(ctx, grad_H, grad_ms): stride_grad_xk=1, stride_phin=K, stride_phik=1, + stride_norm_weight=1, stride_grad_phin=K, stride_grad_phik=1, stride_grad_hm=32, stride_grad_hn=1, stride_grad_ms=1, BLOCK_SIZE_N=32, - precision="tf32" if ctx.use_tf32 else "ieee", + precision=ctx.precision, + FUSE_GRAD_X_ACC=ctx.fused_grad_x_acc_buffer is not None, + HAS_NORM_WEIGHT=norm_weight is not None, ) - return grad_x.to(ctx.dtype), grad_phi.to(ctx.dtype), None + return grad_x.to(x.dtype), grad_phi, grad_norm_weight, None, None, None, None class mHCScaleFusedOp(torch.autograd.Function): @@ -507,10 +795,10 @@ def backward(ctx, grad_out): ) return ( - grad_h.to(ctx.dtype), - grad_alpha.to(ctx.dtype), - grad_beta.to(ctx.dtype), - grad_ms.to(ctx.dtype), + grad_h, + grad_alpha.to(alpha.dtype), + grad_beta.to(alpha.dtype), # We assume alpha and beta have the same dtype + grad_ms, None, ) @@ -676,7 +964,6 @@ def backward(ctx, grad_out): ) grad_res = grad_res.view(s, b, n, n) - return grad_res.to(ctx.dtype), None, None, None @@ -686,7 +973,7 @@ class mHCAggregateOp(torch.autograd.Function): """ @staticmethod - def forward(ctx, x, H_pre, n, use_tf32=True): + def forward(ctx, x, H_pre, n, use_tf32=True, fused_grad_x_acc_buffer=None): """ The forward pass of the aggregate operation. Merges n activation streams into one by computing a weighted sum using H_pre: @@ -699,6 +986,7 @@ def forward(ctx, x, H_pre, n, use_tf32=True): H_pre (tensor): The pre-connection matrix of shape (s, b, n), used as weights for aggregation. n (int): The number of hyper connections (only n=4 is supported). use_tf32 (bool): Whether to use TF32 precision for matmul operations. + fused_grad_x_acc_buffer (torch.Tensor or None): A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. Returns: tensor: The aggregated output of shape (s, b, C). @@ -735,6 +1023,7 @@ def forward(ctx, x, H_pre, n, use_tf32=True): ctx.save_for_backward(x, H_pre) ctx.n = n ctx.use_tf32 = use_tf32 + ctx.fused_grad_x_acc_buffer = fused_grad_x_acc_buffer return out @@ -763,7 +1052,11 @@ def backward(ctx, grad_output): assert n == 4, "Only n=4 is supported in this implementation" M = s * b - grad_x = torch.empty_like(x) + if ctx.fused_grad_x_acc_buffer is not None: + grad_x = ctx.fused_grad_x_acc_buffer.view_as(x) + else: + grad_x = torch.empty_like(x) + grad_H_pre = torch.zeros( (s, b, n), dtype=torch.float32, device=H_pre.device ) # We need to use atomic_add for this so we need higher precision @@ -790,11 +1083,15 @@ def backward(ctx, grad_output): stride_grad_xm=nC, stride_grad_xCn=1, precision="tf32" if ctx.use_tf32 else "ieee", + FUSE_GRAD_X_ACC=ctx.fused_grad_x_acc_buffer is not None, ) grad_H_pre = grad_H_pre.to(H_pre.dtype) # Cast back to the original dtype of H_pre - return grad_x, grad_H_pre, None, None + if ctx.fused_grad_x_acc_buffer is not None: + grad_x = None + + return grad_x, grad_H_pre, None, None, None, None class mHCExpandCombineOp(torch.autograd.Function): @@ -803,7 +1100,7 @@ class mHCExpandCombineOp(torch.autograd.Function): """ @staticmethod - def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True): + def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True, fused_grad_x_acc_buffer=None): """ The forward pass of the expand and combine operation. Expands the sub-layer output f back to n streams using H_post, and combines with the residual connections using H_res: @@ -819,6 +1116,7 @@ def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True): H_res (tensor): The residual connection matrix of shape (s, b, n, n). n (int): The number of hyper connections (only n=4 is supported). use_tf32 (bool): Whether to use TF32 precision for matmul operations. + fused_grad_x_acc_buffer (torch.Tensor or None): A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. Returns: tensor: The expanded and combined output of shape (s, b, C, n). @@ -843,45 +1141,29 @@ def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True): triton.cdiv(M, META["BLOCK_SIZE_M"]), ) - if bias is None: - _mhc_expand_combine_fwd[grid]( - f_ptr=f, - H_post_ptr=H_post, - x_ptr=x, - H_res_ptr=H_res, - output_ptr=out, - M=M, - C=C, - n=n, - stride_fm=C, - stride_fc=1, - stride_xm=Cn, - stride_xCn=1, - stride_output_m=Cn, - stride_output_Cn=1, - ) - else: - _mhc_expand_combine_with_bias_fwd[grid]( - f_ptr=f, - bias_ptr=bias, - H_post_ptr=H_post, - x_ptr=x, - H_res_ptr=H_res, - output_ptr=out, - M=M, - C=C, - n=n, - stride_fm=C, - stride_fc=1, - stride_bias=1, - stride_xm=Cn, - stride_xCn=1, - stride_output_m=Cn, - stride_output_Cn=1, - ) + _mhc_expand_combine_fwd[grid]( + f_ptr=f, + bias_ptr=bias, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + output_ptr=out, + M=M, + C=C, + n=n, + stride_fm=C, + stride_fc=1, + stride_bias=1, + stride_xm=Cn, + stride_xCn=1, + stride_output_m=Cn, + stride_output_Cn=1, + HAS_BIAS=bias is not None, + ) ctx.n = n ctx.have_bias = bias is not None + ctx.fused_grad_x_acc_buffer = fused_grad_x_acc_buffer if bias is not None: ctx.save_for_backward(f, bias, H_post, x, H_res) else: @@ -919,81 +1201,71 @@ def backward(ctx, grad_output): M = s * b grad_f = torch.empty_like(f) - grad_bias = torch.zeros_like(bias, dtype=torch.float32) if bias is not None else None + if ctx.fused_grad_x_acc_buffer is not None: + grad_x = ctx.fused_grad_x_acc_buffer.view_as(x) + else: + grad_x = torch.empty_like(x) + + # Since triton's autotune will reset grad_bias pointer when tuning, we need an empty placeholder here + grad_bias = torch.empty(1, device=grad_output.device, dtype=grad_output.dtype) grad_H_post = torch.zeros_like( H_post, dtype=torch.float32 ) # We need to use atomic_add for this so we need higher precision - grad_x = torch.empty_like(x) grad_H_res = torch.zeros_like( H_res, dtype=torch.float32 ) # We need to use atomic_add for this so we need higher precision + if bias is not None: + grad_bias = torch.zeros_like(bias, dtype=torch.float32) + # pylint: disable=unnecessary-lambda-assignment grid = lambda META: ( triton.cdiv(C, META["BLOCK_SIZE_C"]), triton.cdiv(M, META["BLOCK_SIZE_M"]), ) + _mhc_expand_combine_bwd[grid]( + grad_output_ptr=grad_output, + f_ptr=f, + bias_ptr=bias, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + grad_H_post_ptr=grad_H_post, + grad_f_ptr=grad_f, + grad_bias_ptr=grad_bias, + grad_H_res_ptr=grad_H_res, + grad_x_ptr=grad_x, + M=M, + C=C, + n=n, + stride_grad_output_m=n * C, + stride_grad_output_Cn=1, + stride_fm=C, + stride_fc=1, + stride_bias=1, + stride_xm=n * C, + stride_xCn=1, + stride_grad_fm=C, + stride_grad_fc=1, + stride_grad_bias=1, + stride_grad_xm=n * C, + stride_grad_xCn=1, + precision="tf32" if ctx.use_tf32 else "ieee", + HAS_BIAS=bias is not None, + FUSE_GRAD_X_ACC=ctx.fused_grad_x_acc_buffer is not None, + ) + + # If no bias, replace the grad_bias placeholder with None if bias is None: - _mhc_expand_combine_bwd[grid]( - grad_output_ptr=grad_output, - f_ptr=f, - H_post_ptr=H_post, - x_ptr=x, - H_res_ptr=H_res, - grad_H_post_ptr=grad_H_post, - grad_f_ptr=grad_f, - grad_H_res_ptr=grad_H_res, - grad_x_ptr=grad_x, - M=M, - C=C, - n=n, - stride_grad_output_m=n * C, - stride_grad_output_Cn=1, - stride_fm=C, - stride_fc=1, - stride_xm=n * C, - stride_xCn=1, - stride_grad_fm=C, - stride_grad_fc=1, - stride_grad_xm=n * C, - stride_grad_xCn=1, - precision="tf32" if ctx.use_tf32 else "ieee", - ) - else: - _mhc_expand_combine_with_bias_bwd[grid]( - grad_output_ptr=grad_output, - f_ptr=f, - bias_ptr=bias, - H_post_ptr=H_post, - x_ptr=x, - H_res_ptr=H_res, - grad_H_post_ptr=grad_H_post, - grad_f_ptr=grad_f, - grad_bias_ptr=grad_bias, - grad_H_res_ptr=grad_H_res, - grad_x_ptr=grad_x, - M=M, - C=C, - n=n, - stride_grad_output_m=n * C, - stride_grad_output_Cn=1, - stride_fm=C, - stride_fc=1, - stride_bias=1, - stride_xm=n * C, - stride_xCn=1, - stride_grad_fm=C, - stride_grad_fc=1, - stride_grad_bias=1, - stride_grad_xm=n * C, - stride_grad_xCn=1, - precision="tf32" if ctx.use_tf32 else "ieee", - ) + grad_bias = None grad_H_post = grad_H_post.to(H_post.dtype) # Cast back to the original dtype of H_post grad_H_res = grad_H_res.to(H_res.dtype) # Cast back to the original dtype of H_res if bias is not None: grad_bias = grad_bias.to(bias.dtype) - return grad_f, grad_bias, grad_H_post, grad_x, grad_H_res, None, None + if ctx.fused_grad_x_acc_buffer is not None: + grad_x = None + + return grad_f, grad_bias, grad_H_post, grad_x, grad_H_res, None, None, None, None