Skip to content
Draft
200 changes: 153 additions & 47 deletions tests/pytorch/test_mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
mhc_fused_aggregate,
mhc_fused_expand_combine,
mhc_fused_projection,
mhc_generate_mix_and_aggregate,
)

# Disable TF32 for matmul to ensure consistency between the fused and reference implementations
torch.backends.cuda.matmul.allow_tf32 = False


def mhc_projection_ref(x, phi):
def mhc_projection_ref(x, phi, norm_weight):
"""
Reference operator for mHC's projection building operation.

Expand All @@ -29,19 +30,20 @@ def mhc_projection_ref(x, phi):
- phi_pre: (n, nC)
- phi_post: (n, nC)
- phi_res: (n^2, nC)
norm_weight: (nC,) or None, if not None, apply element-wise multiplication to phi before projection
n: number of Hyper Connection streams
C: hidden dimension per stream
"""
x_dtype = x.dtype
x = x.to(torch.float32)
phi = phi.to(torch.float32)

Hs = x @ phi.T # (M, 2n + n^2)

x_fp32 = x.to(torch.float32) # Use fp32 for better numerical stability in variance calculation
x_fp32 = x.to(torch.float32)
ms = (x_fp32 * x_fp32).mean(dim=1)

return Hs.to(x_dtype), ms
phi_fp32 = phi.to(torch.float32)
if norm_weight is not None:
phi_fp32 = phi_fp32 * norm_weight.to(torch.float32)[None, :]
Hs = x_fp32 @ phi_fp32.T # (M, 2n + n^2)

return Hs, ms


def mhc_scale_ref(H, alpha, beta, ms, n):
Expand Down Expand Up @@ -139,9 +141,9 @@ def mhc_aggregate_ref(x, H_pre, n):
s, b, C, n = x.shape
H_pre = H_pre.view(s, b, n, 1)

out = (x @ H_pre).view(s, b, C)
out = (x.to(H_pre.dtype) @ H_pre).view(s, b, C)

return out
return out.to(x.dtype)


def mhc_expand_combine_ref(f, bias, H_post, x, H_res, n):
Expand Down Expand Up @@ -267,25 +269,44 @@ def get_tols(dtype):


@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
def test_mhc_projection(cfg: MHCConfig, dtype):
@pytest.mark.parametrize(
"dtypes",
[
(torch.float32, torch.float32),
(torch.bfloat16, torch.bfloat16),
(torch.bfloat16, torch.float32),
],
ids=["x_fp32_phi_fp32", "x_bf16_phi_bf16", "x_bf16_phi_fp32"],
)
@pytest.mark.parametrize("has_norm_weight", [False, True], ids=["no_norm_weight", "norm_weight"])
def test_mhc_projection(cfg: MHCConfig, dtypes, has_norm_weight):
reset_rng_states()

s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n
nC = n * C
N = 2 * n + n * n

tols = get_tols(dtype)
x_dtype = dtypes[0]
phi_dtype = dtypes[1]
tols = get_tols(x_dtype)
use_tf32 = False

x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype)
phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda")

x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=x_dtype)
phi = torch.randn(N, nC, dtype=phi_dtype, requires_grad=True, device="cuda")
x_ref = x.detach().clone().requires_grad_(True)
phi_ref = phi.detach().clone().requires_grad_(True)

ref_out_Hs, ref_out_ms = mhc_projection_ref(x_ref, phi_ref)
fused_out_Hs_padded, fused_out_ms = mhc_fused_projection(x, phi, use_tf32)
if has_norm_weight:
norm_weight = torch.randn(nC, device="cuda", requires_grad=True, dtype=x_dtype)
norm_weight_ref = norm_weight.detach().clone().requires_grad_(True)
else:
norm_weight = None
norm_weight_ref = None

ref_out_Hs, ref_out_ms = mhc_projection_ref(x_ref, phi_ref, norm_weight_ref)
fused_out_Hs_padded, fused_out_ms = mhc_fused_projection(
x, phi, norm_weight=norm_weight, use_tf32=use_tf32
)
fused_out_Hs = fused_out_Hs_padded[:, :N]

torch.testing.assert_close(fused_out_Hs, ref_out_Hs, **tols)
Expand All @@ -295,10 +316,12 @@ def test_mhc_projection(cfg: MHCConfig, dtype):

torch.testing.assert_close(x.grad, x_ref.grad, **tols)
torch.testing.assert_close(phi.grad, phi_ref.grad, **tols)
if has_norm_weight:
torch.testing.assert_close(norm_weight.grad, norm_weight_ref.grad, **tols)


@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc)
@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
def test_mhc_scale(cfg: MHCConfig, dtype):
reset_rng_states()

Expand Down Expand Up @@ -329,37 +352,57 @@ def test_mhc_scale(cfg: MHCConfig, dtype):
torch.cat([fused_out[i] for i in range(3)], dim=-1).sum().backward()

torch.testing.assert_close(H_padded.grad[:, :N], H_ref.grad, **tols)
torch.testing.assert_close(ms.grad, ms_ref.grad, **tols)
torch.testing.assert_close(alpha.grad, alpha_ref.grad, **tols)
torch.testing.assert_close(beta.grad, beta_ref.grad, **tols)
torch.testing.assert_close(ms.grad, ms_ref.grad, **tols)


@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
def test_mhc_combined(cfg: MHCConfig, dtype):
@pytest.mark.parametrize(
"dtypes",
[
(torch.float32, torch.float32),
(torch.bfloat16, torch.bfloat16),
(torch.bfloat16, torch.float32),
],
ids=["x_fp32_phi_fp32", "x_bf16_phi_bf16", "x_bf16_phi_fp32"],
)
@pytest.mark.parametrize("has_norm_weight", [False, True], ids=["no_norm_weight", "norm_weight"])
def test_mhc_rmsnorm(cfg: MHCConfig, dtypes, has_norm_weight):
# Verify if the fused kernel is equivalent to applying RMSNorm in the normal order
reset_rng_states()

s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n
N = 2 * n + n * n
nC = n * C

tols = get_tols(dtype)
x_dtype = dtypes[0]
phi_dtype = dtypes[1]
tols = get_tols(x_dtype)
use_tf32 = False

x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype)
phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda")

alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype)
beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype)
x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=x_dtype)
phi = torch.randn(N, nC, dtype=phi_dtype, requires_grad=True, device="cuda")
alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=phi_dtype)
beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=phi_dtype)

x_ref = x.detach().clone().requires_grad_(True)
phi_ref = phi.detach().clone().requires_grad_(True)

alpha_ref = alpha.detach().clone().requires_grad_(True)
beta_ref = beta.detach().clone().requires_grad_(True)

ref_out_H, ref_out_r = mhc_projection_ref(x_ref, phi_ref)
fused_out_H_padded, fused_out_r = mhc_fused_projection(x, phi, use_tf32)
if has_norm_weight:
norm_weight = torch.randn(nC, device="cuda", requires_grad=True, dtype=x_dtype)
norm_weight_ref = norm_weight.detach().clone().requires_grad_(True)
else:
norm_weight = None
norm_weight_ref = None

ref_out_H, ref_out_r = mhc_projection_ref(x_ref, phi_ref, norm_weight_ref)
fused_out_H_padded, fused_out_r = mhc_fused_projection(
x, phi, norm_weight=norm_weight, use_tf32=use_tf32
)

ref_H_pre, ref_H_post, ref_H_res = mhc_scale_ref(
ref_out_H[:, :N], alpha_ref, beta_ref, ref_out_r, n
Expand All @@ -368,17 +411,19 @@ def test_mhc_combined(cfg: MHCConfig, dtype):
fused_out_H_padded, alpha, beta, fused_out_r, n
)

def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref):
dtype = x_ref.dtype
x_ref = x_ref.to(torch.float32)
phi_ref = phi_ref.to(torch.float32)
alpha_ref = alpha_ref.to(torch.float32)
beta_ref = beta_ref.to(torch.float32)

def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref, norm_weight_ref):
# Check if after spliting RMSNorm to two steps in projection and scaling,
# theresult is close to applying RMSNorm in the correct order
x_rmsnorm = F.rms_norm(x_ref, normalized_shape=(nC,))
H = x_rmsnorm @ phi_ref.T
# the result is close to applying RMSNorm in the correct order.
# Run RMSNorm in fp32 so the bf16 case has the same precision pattern as the
# kernel/ref (F.rms_norm on bf16 input would round x_rmsnorm back to bf16).
eps = torch.finfo(torch.float32).eps
norm_weight_fp32 = (
norm_weight_ref.to(torch.float32) if norm_weight_ref is not None else None
)
x_rmsnorm = F.rms_norm(
x_ref.to(torch.float32), normalized_shape=(nC,), weight=norm_weight_fp32, eps=eps
)
H = x_rmsnorm @ phi_ref.T.to(torch.float32)
H_pre = H[:, :n]
H_post = H[:, n : 2 * n]
H_res = H[:, 2 * n :]
Expand All @@ -391,25 +436,86 @@ def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref):
out_post = 2 * out_post.sigmoid()
out_res = out_res

return out_pre.to(dtype), out_post.to(dtype), out_res.to(dtype)
return out_pre, out_post, out_res # Return in FP32 to match the kernel's behavior

combined_H_pre, combined_H_post, combined_H_res = mhc_combined(
x_ref, phi_ref, alpha_ref, beta_ref
x_ref, phi_ref, alpha_ref, beta_ref, norm_weight_ref
)

torch.testing.assert_close(combined_H_pre, ref_H_pre, **tols)
torch.testing.assert_close(combined_H_post, ref_H_post, **tols)
torch.testing.assert_close(combined_H_res, ref_H_res, **tols)

torch.testing.assert_close(ref_H_pre, fused_H_pre, **tols)
torch.testing.assert_close(ref_H_post, fused_H_post, **tols)
torch.testing.assert_close(ref_H_res, fused_H_res, **tols)

torch.testing.assert_close(combined_H_pre, fused_H_pre, **tols)
torch.testing.assert_close(combined_H_post, fused_H_post, **tols)
torch.testing.assert_close(combined_H_res, fused_H_res, **tols)


@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc)
@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"])
def test_mhc_fuse_grad_acc(cfg: MHCConfig, dtype):
# Skip bf16 tests since in the unfused path the we accumulate 3 bf16 gradients, whereas in the fused path
# we accumulate 3 fp32 gradients and then cast to bf16 in the end, which causes two paths to have different precision patterns

reset_rng_states()

s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n
N = 2 * n + n * n
nC = n * C

tols = get_tols(dtype)
use_tf32 = False

x = torch.randn(s, b, C, n, device="cuda", requires_grad=True, dtype=dtype)
phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda")

alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype)
beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype)
x_ref = x.detach().clone().requires_grad_(True)
phi_ref = phi.detach().clone().requires_grad_(True)

alpha_ref = alpha.detach().clone().requires_grad_(True)
beta_ref = beta.detach().clone().requires_grad_(True)

def end_to_end(x, phi, alpha, beta, fused_grad_x_acc):
fused_grad_x_acc_buffer = None
if fused_grad_x_acc:
fused_grad_x_acc_buffer = torch.empty_like(x, dtype=torch.float32)
aggregated, H_post, H_res = mhc_generate_mix_and_aggregate(
x, phi, alpha, beta, None, use_tf32, fused_grad_x_acc_buffer
)
expanded_combined = mhc_fused_expand_combine(
aggregated,
None,
H_post,
x,
H_res,
n,
False,
fused_grad_x_acc_buffer,
)

return expanded_combined

expanded_combined_fuse_grad = end_to_end(
x_ref, phi_ref, alpha_ref, beta_ref, fused_grad_x_acc=True
)
expanded_combined_no_fuse_grad = end_to_end(x, phi, alpha, beta, fused_grad_x_acc=False)

grad_output = torch.randn_like(expanded_combined_fuse_grad)
expanded_combined_fuse_grad.backward(grad_output)
expanded_combined_no_fuse_grad.backward(grad_output)

torch.testing.assert_close(x.grad, x_ref.grad, **tols)


@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
@pytest.mark.parametrize("recompute", [False, True], ids=["no_recompute", "recompute"])
def test_mhc_sinkhorn(cfg: MHCConfig, dtype, recompute):
def test_mhc_sinkhorn(cfg: MHCConfig, dtype):
reset_rng_states()

s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n
Expand All @@ -420,7 +526,7 @@ def test_mhc_sinkhorn(cfg: MHCConfig, dtype, recompute):
x_ref = x.detach().clone().requires_grad_(True)

ref_out = mhc_sinkhorn_ref(x_ref, n)
fused_out = mhc_fused_sinkhorn(x, n, recompute)
fused_out = mhc_fused_sinkhorn(x, n)

torch.testing.assert_close(fused_out, ref_out, **tols)

Expand All @@ -446,7 +552,7 @@ def test_mhc_aggregate(cfg: MHCConfig, dtype):
H_pre_ref = H_pre.detach().clone().requires_grad_(True)

ref_out = mhc_aggregate_ref(x_ref, H_pre_ref, n)
fused_out = mhc_fused_aggregate(x, H_pre, n, False)
fused_out = mhc_fused_aggregate(x, H_pre, n, use_tf32=False)

torch.testing.assert_close(fused_out, ref_out, **tols)

Expand Down Expand Up @@ -482,7 +588,7 @@ def test_mhc_expand_combine(cfg: MHCConfig, dtype, with_bias):
H_res_ref = H_res.detach().clone().requires_grad_(True)

ref_out = mhc_expand_combine_ref(f_ref, bias_ref, H_post_ref, x_ref, H_res_ref, n)
fused_out = mhc_fused_expand_combine(f, bias, H_post, x, H_res, n, False)
fused_out = mhc_fused_expand_combine(f, bias, H_post, x, H_res, n=n, use_tf32=False)

torch.testing.assert_close(fused_out, ref_out, **tols)

Expand Down
Loading
Loading