Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c8ab68a
Add a generic fused softmax (#1440)
yidong72 Aug 5, 2022
c1174a8
[transformer] unittest: less mem consumption for generic softmax (#1448)
crcrpar Aug 9, 2022
13c393e
Use `xmlrunner.XMLTestRunner` accordingly in `tests/L0/run_test.py` (…
crcrpar Aug 15, 2022
14ce259
Update mlp_cuda test (#1425)
crcrpar Aug 9, 2022
7c3f9d1
Use `xmlrunner.XMLTestRunner` accordingly in `tests/L0/run_test.py` (…
crcrpar Aug 15, 2022
519a038
Update `apex.mlp` to use fp16 in `autocast` (#1477)
crcrpar Sep 9, 2022
f4c4b86
Skip flaky test for ROCm
Dec 28, 2022
bb7b64d
Label smoothing in vocab parallel cross entropy (#1457)
MaximumEntropy Aug 24, 2022
3225191
introducing `APEX_RUN_WITH_SLOW_TESTS` env var (#1489)
crcrpar Sep 16, 2022
7c3cae3
[contrib][DistributedFusedAdam] Support overlapped grad sync with Meg…
timmoon10 Sep 20, 2022
23ef6ff
[transformer] Allow for skipping stream sync (#1505)
crcrpar Oct 12, 2022
127842e
Skip test_grad_scaler in test_dist_adam.py
Dec 29, 2022
589a5a6
Add run_rocm_extensions.sh for skipFlakyTest and skipIfRocm
Dec 29, 2022
4ca3c93
add tearDown (#1508)
Aidyn-A Oct 13, 2022
0b55a2f
Support overlapped grad sync with Megatron interleaved pipeline paral…
timmoon10 Oct 27, 2022
28de67e
update run_transformer tests to default to using pytorch native UCC i…
Fuzzkatt Oct 27, 2022
4a268de
update exception message (#1524)
crcrpar Oct 27, 2022
df833cb
Update megatron fused softmax follow megatron-lm (#1539)
yaoyu-33 Nov 22, 2022
5e5331a
Resolve filename collision issue in compilation on ROCm (Ref: #77)
Dec 29, 2022
f02cc05
Refactor run_transformer bert minimal and gpt minimal tests (#1540)
Fuzzkatt Dec 8, 2022
80a4954
check `is_ucc_available` (#1523)
crcrpar Oct 27, 2022
666e769
Support fused_weight_gradient_mlp_cuda for ROCm
Dec 30, 2022
52f036d
Fix bugs in run_rocm_extensions.py
Dec 30, 2022
483fb50
Move test_label_smoothing.py to xentropy folder
Dec 30, 2022
72f978c
Add some test folders for those ROCm does not support
Dec 30, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions apex/_autocast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import torch


__all__ = ["_cast_if_autocast_enabled"]


def _get_autocast_dtypes() -> Sequence[torch.dtype]:
if torch.cuda.is_bf16_supported():
return [torch.half, torch.bfloat16]
Expand Down
236 changes: 126 additions & 110 deletions apex/contrib/optimizers/distributed_fused_adam.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions apex/contrib/test/optimizers/test_dist_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.testing._internal import common_utils
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.testing.common_utils import skipIfRocm

class SimpleModel(torch.nn.Module):

Expand Down Expand Up @@ -250,6 +251,7 @@ def test_clip_grad_norm(self):
dist_model.parameters()):
torch.testing.assert_close(dist_param, ref_param)

@skipIfRocm
def test_grad_scaler(self):

torch.manual_seed(self.seed + self.rank)
Expand Down
78 changes: 64 additions & 14 deletions apex/contrib/test/run_rocm_extensions.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,76 @@
import argparse
import os
import unittest
import sys


test_dirs = ["groupbn", "fused_dense", "layer_norm", "multihead_attn", "transducer", "focal_loss", "index_mul_2d", "."] # "." for test_label_smoothing.py
ROCM_BLACKLIST = [
"layer_norm"
#test_dirs = ["fused_dense", "layer_norm", "multihead_attn", "transducer", "focal_loss", "index_mul_2d", "optimizers", ".", "groupbn"] # "." for test_label_smoothing.py
#ROCM_BLACKLIST = [
# "layer_norm"
#]

TEST_ROOT = os.path.dirname(os.path.abspath(__file__))
TEST_DIRS = [
"fused_dense",
"layer_norm", # not fully supported on ROCm
"conv_bias_relu",# not fully supported on ROCm
"fmha", # not fully supported on ROCm
#"cudnn_gbn", # not fully supported on ROCm
#"bottleneck", # not fully supported on ROCm
"multihead_attn",
"transducer",
"focal_loss",
"index_mul_2d",
"optimizers",
"xentropy",
"clip_grad",
"groupbn",
]

DEFAULT_TEST_DIRS = [
"fused_dense",
"multihead_attn",
"transducer",
"focal_loss",
"index_mul_2d",
"optimizers",
"xentropy",
"clip_grad",
"groupbn",
]

runner = unittest.TextTestRunner(verbosity=2)
def parse_args():
parser = argparse.ArgumentParser(
description="Extension test runner",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--include",
nargs="+",
choices=TEST_DIRS,
default=DEFAULT_TEST_DIRS,
help="select a set of tests to run (defaults to ALL tests).",
)
args, _ = parser.parse_known_args()
return args

errcode = 0
def main(args):
runner = unittest.TextTestRunner(verbosity=2)
errcode = 0
for test_dir in args.include:
test_dir = os.path.join(TEST_ROOT, test_dir)
print(test_dir)
suite = unittest.TestLoader().discover(test_dir)

for test_dir in test_dirs:
if test_dir in ROCM_BLACKLIST:
continue
suite = unittest.TestLoader().discover(test_dir)
print("\nExecuting tests from " + test_dir)

print("\nExecuting tests from " + test_dir)
result = runner.run(suite)

result = runner.run(suite)
if not result.wasSuccessful():
errcode = 1

if not result.wasSuccessful():
errcode = 1
sys.exit(errcode)

sys.exit(errcode)
if __name__ == '__main__':
args = parse_args()
main(args)
2 changes: 2 additions & 0 deletions apex/contrib/test/run_rocm_extensions.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/bash
APEX_TEST_WITH_ROCM=1 APEX_SKIP_FLAKY_TEST=1 python3 run_rocm_extensions.py
19 changes: 13 additions & 6 deletions apex/mlp/mlp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from copy import copy
import math

import torch
from torch import nn

from apex._autocast_utils import _cast_if_autocast_enabled
import mlp_cuda
from .. import amp


class MlpFunction(torch.autograd.Function):
@staticmethod
Expand All @@ -21,7 +24,11 @@ def backward(ctx, grad_o):
del ctx.outputs
return (None, None, *grads)

mlp_function = amp.half_function(MlpFunction.apply)

def mlp_function(bias, activation, *args):
autocast_args = _cast_if_autocast_enabled(bias, activation, *args)
return MlpFunction.apply(*autocast_args)


class MLP(torch.nn.Module):
"""Launch MLP in C++
Expand All @@ -32,16 +39,16 @@ class MLP(torch.nn.Module):
relu (bool): Default True
"""
def __init__(self, mlp_sizes, bias=True, activation='relu'):
super(MLP, self).__init__()
super().__init__()
self.num_layers = len(mlp_sizes) - 1
self.mlp_sizes = copy(mlp_sizes)
self.bias = 1 if bias else 0

if activation is 'none':
if activation == 'none':
self.activation = 0
elif activation is 'relu':
elif activation == 'relu':
self.activation = 1
elif activation is 'sigmoid':
elif activation == 'sigmoid':
self.activation = 2
else:
raise TypeError("activation must be relu or none.")
Expand Down
9 changes: 9 additions & 0 deletions apex/transformer/_ucc_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from torch import distributed as dist

HAS_UCC = hasattr(dist, "is_ucc_available") and dist.is_ucc_available()
if not HAS_UCC:
try:
import torch_ucc
HAS_UCC = True
except ImportError:
HAS_UCC = False
98 changes: 94 additions & 4 deletions apex/transformer/functional/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,73 @@ def backward(ctx, output_grads):


def scaled_masked_softmax(inputs, mask, scale):
# input is 4D tensor (b, np, sq, sk)
if mask is not None:
args = _cast_if_autocast_enabled(inputs, mask, scale)
with torch.cuda.amp.autocast(enabled=False):
return ScaledMaskedSoftmax.apply(*args)
else:
args = _cast_if_autocast_enabled(inputs, scale)
with torch.cuda.amp.autocast(enabled=False):
return ScaledSoftmax.apply(*args)


class GenericScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, mask, scale):
import generic_scaled_masked_softmax_cuda

scale_t = torch.tensor([scale])
softmax_results = generic_scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results

@staticmethod
def backward(ctx, output_grads):
import generic_scaled_masked_softmax_cuda_new

softmax_results, scale_t = ctx.saved_tensors

input_grads = generic_scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None


def generic_scaled_masked_softmax(inputs, mask, scale):
# input is 4D tensor (b, np, sq, sk)
args = _cast_if_autocast_enabled(inputs, mask, scale)
with torch.cuda.amp.autocast(enabled=False):
return ScaledMaskedSoftmax.apply(*args)
return GenericScaledMaskedSoftmax.apply(*args)


class ScaledSoftmax(torch.autograd.Function):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""

@staticmethod
def forward(ctx, inputs, scale):
import scaled_softmax_cuda

scale_t = torch.tensor([scale])

softmax_results = scaled_softmax_cuda.forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results

@staticmethod
def backward(ctx, output_grads):
import scaled_softmax_cuda

softmax_results, scale_t = ctx.saved_tensors

input_grads = scaled_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None


class FusedScaleMaskSoftmax(torch.nn.Module):
Expand Down Expand Up @@ -164,14 +227,14 @@ def is_kernel_available(self, mask, b, np, sq, sk):
and self.input_in_float16 # input must be fp16
and (
self.attn_mask_type == AttnMaskType.causal
or (self.attn_mask_type == AttnMaskType.padding and mask is not None)
or self.attn_mask_type == AttnMaskType.padding
)
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and 16 < sk <= 4096 # sk must be 16 ~ 4096
and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)

if self.attn_mask_type == AttnMaskType.causal:
Expand Down Expand Up @@ -209,3 +272,30 @@ def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda

return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)

class GenericFusedScaleMaskSoftmax(FusedScaleMaskSoftmax):
"""
Generic version of FusedSacleMaskSoftmax.
It removes the seq-len limitations and has slight performance degragation compared with FusedScaleMaskSoftmax

fused operation: scaling + mask + softmax

Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""

def __init__(
self, input_in_fp16, input_in_bf16, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale,
):
super().__init__(input_in_fp16, input_in_bf16, AttnMaskType.padding, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale)
self.scaled_masked_softmax_fusion = generic_scaled_masked_softmax

def is_kernel_available(self, mask, b, np, sq, sk):
if self.scaled_masked_softmax_fusion and 0 < sk: # user want to fuse # sk must be 1 ~
return True
return False
13 changes: 3 additions & 10 deletions apex/transformer/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch

from apex.transformer.log_util import get_transformer_logger
from apex.transformer._ucc_util import HAS_UCC


_logger = get_transformer_logger(__name__)
Expand Down Expand Up @@ -126,7 +127,8 @@ def initialize_model_parallel(
assert default_backend is None or default_backend in ("nccl", "ucc")
assert p2p_backend is None or p2p_backend in ("nccl", "ucc")
if "ucc" in (default_backend, p2p_backend):
check_torch_ucc_availability()
if not HAS_UCC:
raise ImportError("UCC backend requires pytorch source build with UCC installed and enabled")
warnings.warn("`ucc` backend support is experimental", ExperimentalWarning)
if default_backend == "ucc":
warnings.warn("The UCC's functionality as `default_backend` is not well verified", ExperimentalWarning)
Expand Down Expand Up @@ -671,12 +673,3 @@ def destroy_model_parallel():

# Used to warn when the UCC is specified.
class ExperimentalWarning(Warning): pass


def check_torch_ucc_availability() -> None:
try:
import torch_ucc # NOQA
except ImportError:
raise ImportError(
"UCC backend requires [torch_ucc](https://github.com/facebookresearch/torch_ucc) but not found"
)
Loading