Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3e305f7
[PyTorch] Debug weight matrix usages for dgrad GEMM (#1637)
timmoon10 Apr 4, 2025
1bbeab1
Blockwise float8 quantizer and quantized tensor class (#1513)
kwyss-nvidia Apr 4, 2025
be1f647
[JAX-Q] Distributed MXFP8 flax layer tests (#1643)
jberchtold-nvidia Apr 4, 2025
fbcbcb0
Add GEMM logic for blockwise quantized tensors.
kwyss-nvidia Feb 28, 2025
522ffbe
Update NVTE_BLOCK_SCALING for GEMM.
kwyss-nvidia Mar 10, 2025
d7e1fce
Gate feature on CUDA 12.9
kwyss-nvidia Mar 6, 2025
f212c81
Gemm typo.
kwyss-nvidia Mar 11, 2025
48b2d57
Remove unecessary type converter change.
kwyss-nvidia Mar 11, 2025
5761589
Reflect epilogue availability and test supported epilogues.
kwyss-nvidia Mar 11, 2025
07b19b7
GEMM simplifications from recipe branch.
kwyss-nvidia Mar 12, 2025
c4a41b8
Format py code.
kwyss-nvidia Mar 15, 2025
51ed2fb
Update GEMM DGelu tests to match support depending on output dtype.
kwyss-nvidia Apr 1, 2025
e7af140
Force pow2Scales in GEMM
kwyss-nvidia Apr 2, 2025
596a009
Add GEMM test to pytorch test suite.
kwyss-nvidia Apr 2, 2025
4aa6067
Add copyright to GEMM test.
kwyss-nvidia Apr 2, 2025
758dc4a
Update import for GEMM test.
kwyss-nvidia Apr 4, 2025
7d5b5d9
Add license.
kwyss-nvidia Apr 4, 2025
ff884e2
[JAX] Flatten_axis for quantization and Sharding propagation fixes (#…
phu0ngng Apr 4, 2025
efdf8e0
Update test gemm supported predicate.
kwyss-nvidia Apr 4, 2025
a9f209a
Use sgemm like interfaces and naming.
kwyss-nvidia Apr 5, 2025
861c870
Rewrite GEMM comment.
kwyss-nvidia Apr 5, 2025
ada6438
MR Feedback.
kwyss-nvidia Apr 5, 2025
e484269
Refactor GEMM param canonicalization
timmoon10 Apr 6, 2025
9f0707e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 6, 2025
cf36b99
Merge branch 'main' into kwyss/cublas_gemm_github_mr
timmoon10 Apr 6, 2025
ba605f1
[PyTorch][Common] Refactor RoPE (#1626)
yaox12 Apr 7, 2025
a3ba4df
Fix cpp warnings (#1639)
yaox12 Apr 7, 2025
c84d170
Support FP8 primary weight in FSDP training (#1630)
shjwudp Apr 7, 2025
b362a6e
Removing NVTE_NO_SCALING (#1650)
phu0ngng Apr 7, 2025
f3123cf
Prune number of tests.
kwyss-nvidia Apr 7, 2025
4e4c59e
Merge branch 'main' into kwyss/cublas_gemm_github_mr
timmoon10 Apr 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ def __call__(self, x, mask, disable_dropout=False):
self_attn_mask_type="padding",
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
mlp_activations=("gelu", "linear"),
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

x = x.reshape(x.shape[0], -1)

if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
# Trigger all-gather to collect a complete tensor alone sequence on each device.
x = jax.lax.with_sharding_constraint(
x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
)
Expand Down Expand Up @@ -459,30 +460,30 @@ def setUpClass(cls):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self):
Expand All @@ -491,7 +492,7 @@ def test_te_delayed_scaling_fp8_with_sp(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
Expand All @@ -500,7 +501,7 @@ def test_te_mxfp8_with_sp(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_executable(test_operator
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
test_cast_mxfp8.cu
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
test_transpose.cu
test_cast_transpose.cu
Expand Down
Loading