-
Notifications
You must be signed in to change notification settings - Fork 721
[Pyt][Common] Enabling/Guarding sm120 support (non - attention) #2833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8a9c3b8
f8827d9
8fa6d64
0098c65
58f9f10
56952cb
ca0f5a7
a626817
2e33d70
1d0c411
5f20fc0
18eb4b7
940f574
725d26b
a03146f
e1b582d
aa579ee
fb60b0b
beb8932
10c744d
c76a6ea
6876c03
8ab7d6e
4f11e0a
0b4100e
d77b5e6
fb23df3
6327875
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,11 @@ | |
| warnings.filterwarnings("ignore", category=FutureWarning) | ||
| warnings.filterwarnings("ignore", category=UserWarning) | ||
|
|
||
| FP8_DEFAULT_RTOL_ATOL = (0.125, 0.0625) | ||
| FP8_CS_SM120_DETERMINISTIC_RTOL_ATOL = (0.4, 0.25) | ||
| BF16_DEFAULT_RTOL_ATOL = (0.025, 0.00125) | ||
| BF16_SM120_DETERMINISTIC_OVERLAP_RTOL_ATOL = (0.05, 0.01) | ||
|
|
||
|
|
||
| class multi_module_model(torch.nn.Module): | ||
| def __init__(self, module, num_layers, *args, **kwargs): | ||
|
|
@@ -551,9 +556,33 @@ def run_fwd_bwd(model, x): | |
|
|
||
| # Now validate accuracy | ||
| if not bool(numerics_failed.item()): | ||
| is_sm120 = torch.cuda.get_device_capability() == (12, 0) | ||
| is_deterministic_mode = os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1") == "0" | ||
| for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): | ||
| rtol = 0.125 if opts.fp8 else 0.025 | ||
| atol = 0.0625 if opts.fp8 else 0.00125 | ||
| if opts.fp8: | ||
| if ( | ||
| opts.quantization == "fp8_current_scaling" | ||
| and is_sm120 | ||
| and is_deterministic_mode | ||
| ): | ||
| # SM120 deterministic mode disables fused attn, so rt uses alternate attn backends. | ||
| # Combined with FP8 CS, this path needs the looser distributed fp8_cs tolerance policy. | ||
|
Comment on lines
+563
to
+569
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the discrepancy is due to changes in the attention backend, we should only relax the tols with |
||
| rtol, atol = FP8_CS_SM120_DETERMINISTIC_RTOL_ATOL | ||
| else: | ||
| rtol, atol = FP8_DEFAULT_RTOL_ATOL | ||
| else: | ||
| rtol, atol = BF16_DEFAULT_RTOL_ATOL | ||
| if ( | ||
| is_sm120 | ||
| and is_deterministic_mode | ||
| and opts.layer_type == te.TransformerLayer | ||
| and opts.num_layers > 1 | ||
| and opts.overlap_rs_dgrad | ||
| ): | ||
| # SM120 + deterministic training disables fused attn . | ||
| # Rt then selects an alternate attn backend, and | ||
| # the overlap path can show tiny BF16 accumulation-order drift vs reference. | ||
| rtol, atol = BF16_SM120_DETERMINISTIC_OVERLAP_RTOL_ATOL | ||
| grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) | ||
| dist_print(grad_info, src=WORLD_RANK, error=grad_failed) | ||
| numerics_failed[0] = int(grad_failed) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,11 +14,31 @@ | |
|
|
||
| recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) | ||
|
|
||
| SM120_SR_EQUIVALENCE_ATOL = 2e-7 | ||
|
|
||
| seed = 12345 | ||
| torch.manual_seed(seed) | ||
| torch.cuda.manual_seed(seed) | ||
|
|
||
|
|
||
| def _assert_sr_vs_rn_behavior( | ||
| me_sr: torch.Tensor, | ||
| me_rn: torch.Tensor, | ||
| me_t_sr: torch.Tensor, | ||
| me_t_rn: torch.Tensor, | ||
| ) -> None: | ||
| if torch.cuda.get_device_capability() == (12, 0): | ||
| # SM120 currently disables NVFP4 stochastic rounding in backend paths, | ||
| # so SR and RN should be numerically equivalent. | ||
|
Comment on lines
+31
to
+32
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I'd expect a function called |
||
| torch.testing.assert_close(me_sr, me_rn, atol=SM120_SR_EQUIVALENCE_ATOL, rtol=0.0) | ||
| torch.testing.assert_close(me_t_sr, me_t_rn, atol=SM120_SR_EQUIVALENCE_ATOL, rtol=0.0) | ||
| else: | ||
| assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." | ||
| assert ( | ||
| me_t_sr < me_t_rn | ||
| ), "Stochastic rounding failed - error larger than the round to nearest." | ||
|
|
||
|
|
||
| def unpack_fp4(x: torch.Tensor) -> torch.Tensor: | ||
| repeated = x.repeat_interleave(2, dim=1) | ||
| repeated[:, 0::2] &= 0x0F | ||
|
|
@@ -247,7 +267,7 @@ def check_quantization_nvfp4_versus_reference( | |
| me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean()) | ||
| sr_result = torch.zeros_like(x).float() | ||
| sr_t_result = torch.zeros_like(x).float().t().contiguous() | ||
| for i in range(n_iters): | ||
| for _ in range(n_iters): | ||
| q_sr, s_sr, q_t_sr, s_t_sr = quantize_fp4( | ||
| x, use_stochastic_rounding=True, use_2D=use_2D, use_RHT=use_RHT | ||
| ) | ||
|
|
@@ -278,8 +298,7 @@ def check_quantization_nvfp4_versus_reference( | |
|
|
||
| print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") | ||
| print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") | ||
| assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." | ||
| assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest." | ||
| _assert_sr_vs_rn_behavior(me_sr, me_rn, me_t_sr, me_t_rn) | ||
|
|
||
|
|
||
| def check_group_quantization_nvfp4_versus_reference( | ||
|
|
@@ -362,10 +381,7 @@ def check_group_quantization_nvfp4_versus_reference( | |
|
|
||
| print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") | ||
| print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") | ||
| assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." | ||
| assert ( | ||
| me_t_sr < me_t_rn | ||
| ), "Stochastic rounding failed - error larger than the round to nearest." | ||
| _assert_sr_vs_rn_behavior(me_sr, me_rn, me_t_sr, me_t_rn) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a proper bug. If we run on SM 12.0, we want the test to fail rather than giving us a false pass.