diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml new file mode 100644 index 000000000..b5aa06faf --- /dev/null +++ b/.github/workflows/rocm-ci.yml @@ -0,0 +1,202 @@ +name: Apex ROCm CI + +on: + pull_request: + types: [opened, synchronize, ready_for_review] + branches: + - master + - release/1.8.0 + - release/1.9.0 + - release/1.10.0 + workflow_dispatch: + inputs: + apex_gitref: + description: 'Apex branch or commit SHA to build' + required: false + default: 'master' + type: string + docker_image: + description: 'Docker image to use' + required: false + default: 'rocm/pytorch:latest' + type: string + run_extension: + description: 'Run Extension Import tests' + required: false + default: true + type: boolean + run_l0: + description: 'Run L0 tests' + required: false + default: true + type: boolean + run_contrib: + description: 'Run Contrib tests' + required: false + default: true + type: boolean + run_halo: + description: 'Run Peer Halo Exchange tests' + required: false + default: true + type: boolean + run_syncbn: + description: 'Run Distributed Synced BatchNorm tests' + required: false + default: true + type: boolean + +env: + DOCKER_IMAGE: ${{ inputs.docker_image || 'rocm/pytorch:latest' }} + +jobs: + build: + name: Build Apex Wheel + runs-on: build-only-apex + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + # Uses the specified branch on manual runs; defaults to the PR/Push context otherwise + ref: ${{ github.event_name == 'workflow_dispatch' && inputs.apex_gitref || '' }} + submodules: recursive + + - name: Pull Docker Image + run: | + docker pull ${{ env.DOCKER_IMAGE }} + + - name: Start Background Docker Container + run: | + docker run -d --name apex-build-container \ + -v ${{ github.workspace }}:/workspace -w /workspace \ + ${{ env.DOCKER_IMAGE }} sleep infinity + + - name: Build Apex Wheel + run: | + docker exec apex-build-container bash -c " + pip install --upgrade pip + pip install build ninja wheel packaging + + python3 -m build --wheel --no-isolation -C--build-option=--cpp_ext -C--build-option=--cuda_ext + + chown -R $(id -u):$(id -g) dist/ + " + + - name: Run Extension Import tests + if: ${{ github.event_name != 'workflow_dispatch' || inputs.run_extension }} + run: | + docker exec apex-build-container bash -c " + set -eo pipefail + + pip install expecttest onnxscript + pip install dist/apex-*.whl + + cd tests + python3 test_extension_import.py 2>&1 | tee ../extension_import_results.log + " + + - name: Cleanup Build Container + if: always() + run: docker rm -f apex-build-container + + - name: Upload Wheel Artifact + uses: actions/upload-artifact@v4 + with: + name: apex-wheel + path: dist/*.whl + retention-days: 7 + + test: + name: Run Unit Tests + timeout-minutes: 720 + runs-on: linux-apex-mi325-8 + needs: build + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'workflow_dispatch' && inputs.apex_gitref || '' }} + submodules: recursive + + - name: Download Wheel Artifact + uses: actions/download-artifact@v4 + with: + name: apex-wheel + path: dist/ + + - name: Pull Docker Image + run: | + docker pull ${{ env.DOCKER_IMAGE }} + + - name: Start Background Docker Container + run: | + docker run -d --name apex-test-container \ + --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host \ + -e OMP_NUM_THREADS=8 \ + -e TORCH_NCCL_ASYNC_ERROR_HANDLING=1 \ + -e NCCL_DEBUG=WARN \ + -v ${{ github.workspace }}:/workspace -w /workspace \ + ${{ env.DOCKER_IMAGE }} sleep infinity + + - name: Install Dependencies and Built Wheel + run: | + docker exec apex-test-container bash -c " + set -e + pip install expecttest onnxscript + pip install dist/apex-*.whl + " + + - name: Run L0 tests + if: ${{ (always()) && (github.event_name != 'workflow_dispatch' || inputs.run_l0) }} + run: | + docker exec apex-test-container bash -c " + set -eo pipefail + cd tests/L0 + sh run_rocm.sh 2>&1 | tee ../../L0_results.log + " + + - name: Run Contrib tests + if: ${{ (success() || failure()) && (github.event_name != 'workflow_dispatch' || inputs.run_contrib) }} + run: | + docker exec apex-test-container bash -c " + set -eo pipefail + cd apex/contrib/test + python3 run_rocm_extensions.py 2>&1 | tee ../../../contrib_results.log + " + + - name: Run Peer Halo Exchange tests + if: ${{ (success() || failure()) && (github.event_name != 'workflow_dispatch' || inputs.run_halo) }} + run: | + docker exec apex-test-container bash -c " + set -eo pipefail + export HSA_FORCE_FINE_GRAIN_PCIE=1 + export HSA_ENABLE_SDMA=0 + torchrun --nproc_per_node 8 apex/contrib/peer_memory/peer_halo_exchange_module_tests.py 2>&1 | tee halo_results.log + " + + - name: Run Distributed Synced BatchNorm tests + if: ${{ (success() || failure()) && (github.event_name != 'workflow_dispatch' || inputs.run_syncbn) }} + run: | + docker exec apex-test-container bash -c " + set -eo pipefail + cd tests/distributed/synced_batchnorm + sh unit_test.sh 2>&1 | tee ../../../syncbn_results.log + " + + - name: Fix Artifact Permissions + if: always() + run: | + docker exec apex-test-container bash -c "chown -R $(id -u):$(id -g) *.log" + + - name: Cleanup Background Container + if: always() + run: docker rm -f apex-test-container + + - name: Upload Test Logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: test-logs + path: | + *.log + retention-days: 14 diff --git a/.gitignore b/.gitignore index d30f85c34..da67982aa 100644 --- a/.gitignore +++ b/.gitignore @@ -145,3 +145,10 @@ dmypy.json # Cython debug symbols cython_debug/ +*.hip +*_hip.* +*hip* + + +#file temporarily created for build process +apex/git_version_info_installed.py \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 6479428db..7b4e73190 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,6 @@ -[submodule "apex/contrib/csrc/multihead_attn/cutlass"] - path = apex/contrib/csrc/multihead_attn/cutlass - url = https://github.com/NVIDIA/cutlass.git - branch = v1.2.0 [submodule "apex/contrib/csrc/cudnn-frontend"] path = apex/contrib/csrc/cudnn-frontend url = https://github.com/NVIDIA/cudnn-frontend.git +[submodule "third_party/aiter"] + path = third_party/aiter + url = https://github.com/ROCm/aiter diff --git a/.jenkins/docker/build.sh b/.jenkins/docker/build.sh new file mode 100644 index 000000000..1dc09902e --- /dev/null +++ b/.jenkins/docker/build.sh @@ -0,0 +1 @@ +sudo docker build . --rm -t apex diff --git a/.jenkins/docker/launch.sh b/.jenkins/docker/launch.sh new file mode 100644 index 000000000..1e8d08d52 --- /dev/null +++ b/.jenkins/docker/launch.sh @@ -0,0 +1 @@ +sudo docker run -it -v $HOME:/data --rm --privileged --device=/dev/dri --device=/dev/kfd --network host --group-add video apex diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..8bf9a1705 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,7 @@ +ARG FROM_IMAGE=lcskrishna/rocm-pytorch:rocm3.3_ubuntu16.04_py3.6_pytorch_bfloat16_mgpu + +FROM ${FROM_IMAGE} +RUN \ + git clone --recursive https://github.com/ROCmSoftwarePlatform/apex.git && \ + cd apex && \ + python3.6 setup.py install --cpp_ext --cuda_ext diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..a5dc0456c --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +recursive-include apex/contrib/csrc * +recursive-include apex/csrc * \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..99e44805f --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +PYTHON = python3 +PIP = $(PYTHON) -m pip + +clean: # This will remove ALL build folders. + @test -d build/ && echo "Deleting build folder" || true + @test -d build/ && rm -r build/ || true + @test -d dist/ && echo "Deleting dist folder" || true + @test -d dist/ && rm -r dist/ || true + @test -d apex.egg-info/ && echo "Deleting apex.egg-info folder" || true + @test -d apex.egg-info/ && rm -r apex.egg-info/ || true + + $(PYTHON) scripts/clean.py # remove the apex extensions installed at torch extensions folder + +aiter: + $(PIP) uninstall -y aiter + cd third_party/aiter && $(PIP) install . --no-build-isolation --no-deps + diff --git a/README.md b/README.md index bc2d34188..dfd26c557 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Introduction -This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. +This repository holds ROCm variant of Nvidia's Apex: https://github.com/NVIDIA/apex. +The aim of Apex repository is to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intent of Apex is to make up-to-date utilities available to users as quickly as possible. @@ -21,9 +22,9 @@ different flags to `amp.initialize`. [API Documentation](https://nvidia.github.io/apex/amp.html) -[Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) +[Comprehensive Imagenet example](https://github.com/rocm/apex/tree/master/examples/imagenet) -[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan) +[DCGAN example coming soon...](https://github.com/rocm/apex/tree/master/examples/dcgan) [Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs) @@ -35,11 +36,11 @@ optimized for NVIDIA's NCCL communication library. [API Documentation](https://nvidia.github.io/apex/parallel.html) -[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/parallel) +[Python Source](https://github.com/rocm/apex/tree/master/apex/parallel) -[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed) +[Example/Walkthrough](https://github.com/rocm/apex/tree/master/examples/simple/distributed) -The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) +The [Imagenet example](https://github.com/rocm/apex/tree/master/examples/imagenet) shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`. ### Synchronized Batch Normalization @@ -99,42 +100,283 @@ Note that we recommend restoring the model using the same `opt_level`. Also note # Installation ## Containers -NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch. -The containers come with all the custom extensions available at the moment. - -See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as: -- how to pull a container -- how to run a pulled container -- release notes +ROCm pytorch containers contain apex package and these are available from https://hub.docker.com/r/rocm/pytorch. ## From Source -To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch. +Torch must be installed before installing apex. We recommend using the nightly Pytorch obtainable from https://github.com/rocm/pytorch. The latest stable release obtainable from https://pytorch.org should also work. + +Apex on ROCm supports both python only build and extension build. +Note: Pytorch version recommended is >=1.5 for extension build. + +### The following command will install all the extensions, which will be built and linked at runtime using [PyTorch's JIT (just-in-time) loader](https://pytorch.org/docs/stable/cpp_extension.html): +This requires ninja to be installed +``` +pip install . --no-build-isolation +``` + +### Supported Versions +| ``APEX Version`` | ``APEX branch`` | ``Torch Version`` | +|------------------|-----------------|-------------------| +| ``1.9.0`` | release/1.9.0 | ``2.9`` | +| ``1.8.0`` | release/1.8.0 | ``2.8`` | +| ``1.7.0`` | release/1.7.0 | ``2.7`` | +| ``1.6.0`` | release/1.6.0 | ``2.6`` | +| ``1.5.0`` | release/1.5.0 | ``2.5`` | +| ``1.4.0`` | release/1.4.0 | ``2.4`` | +| ``1.3.0`` | release/1.3.0 | ``2.3`` | +| ``1.2.0`` | release/1.2.0 | ``2.2`` | +| ``1.1.0`` | release/1.1.0 | ``2.1`` | +| ``1.0.0`` | release/1.0.0 | ``2.0`` and older | + + +The relation between APEX and ROCm PyTorch is maintained in file `related_commits` in [ROCm PyTorch release branches](https://github.com/ROCm/pytorch/branches/all?query=release) in the following format. + +``` +ubuntu|pytorch|apex|release/1.0.0|06c33eee43f7a22f3ed7d9c3e5be0ddd757dc345|https://github.com/ROCmSoftwarePlatform/apex +centos|pytorch|apex|release/1.0.0|06c33eee43f7a22f3ed7d9c3e5be0ddd757dc345|https://github.com/ROCmSoftwarePlatform/apex +``` + +### To pre-build and install all the supported extensions while installing apex, use the following command in apex folder: +``` +APEX_BUILD_CPP_OPS=1 APEX_BUILD_CUDA_OPS=1 pip install . --no-build-isolation +``` + +It is also possible to pre-build and install specific extensions by using the following command in apex folder: +``` +APEX_BUILD_=1 pip install . --no-build-isolation +``` +The following extensions are supported: +| extension | environment to build specific extension | install option | +|-----------|-----------|-----------| +| amp_C | APEX_BUILD_AMP_C=1 | APEX_BUILD_CUDA_OPS=1 | +| apex_C | APEX_BUILD_APEX_C=1 | APEX_BUILD_CPP_OPS=1 | +| bnp | APEX_BUILD_BNP=1 | APEX_BUILD_CUDA_OPS=1 | +| distributed_adam_cuda | APEX_BUILD_DISTRIBUTED_ADAM=1 | APEX_BUILD_CUDA_OPS=1 | +| distributed_lamb_cuda | APEX_BUILD_DISTRIBUTED_LAMB=1 | APEX_BUILD_CUDA_OPS=1 | +| fast_multihead_attn | APEX_BUILD_FAST_MULTIHEAD_ATTN=1 | APEX_BUILD_CUDA_OPS=1 | +| focal_loss_cuda | APEX_BUILD_FOCAL_LOSS=1 | APEX_BUILD_CUDA_OPS=1 | +| fused_adam_cuda | APEX_BUILD_FUSED_ADAM=1 | APEX_BUILD_CUDA_OPS=1 | +| fused_bias_swiglu | APEX_BUILD_FUSED_BIAS_SWIGLU=1 | APEX_BUILD_CUDA_OPS=1 | +| fused_conv_bias_relu | APEX_BUILD_FUSED_CONV_BIAS_RELU=1 | APEX_BUILD_CUDA_OPS=1 | +| fused_dense_cuda | APEX_BUILD_FUSED_DENSE=1 | APEX_BUILD_CUDA_OPS=1 | +| fused_index_mul_2d | APEX_BUILD_FUSED_INDEX_MUL_2D=1 | APEX_BUILD_CUDA_OPS=1 | +| fused_lamb_cuda | APEX_BUILD_FUSED_LAMB=1 | APEX_BUILD_CUDA_OPS=1 | +| fused_layer_norm_cuda | APEX_BUILD_FUSED_LAYER_NORM=1 | APEX_BUILD_CUDA_OPS=1 | +| fused_rotary_positional_embedding | APEX_BUILD_FUSED_ROPE=1 | APEX_BUILD_CUDA_OPS=1 | +| fused_weight_gradient_mlp_cuda | APEX_BUILD_FUSED_WEIGHT_GRADIENT_MLP=1 | APEX_BUILD_CUDA_OPS=1 | +| generic_scaled_masked_softmax_cuda | APEX_BUILD_GENERIC_SCALED_MASKED_SOFTMAX_CUDA=1 | APEX_BUILD_CUDA_OPS=1 | +| mlp_cuda | APEX_BUILD_MLP=1 | APEX_BUILD_CUDA_OPS=1 | +| _apex_nccl_allocator | APEX_BUILD_NCCL_ALLOCATOR=1 | APEX_BUILD_CUDA_OPS=1 | +| nccl_p2p_cuda | APEX_BUILD_NCCL_P2P=1 | APEX_BUILD_CUDA_OPS=1 | +| peer_memory_cuda | APEX_BUILD_PEER_MEMORY=1 | APEX_BUILD_CUDA_OPS=1 | +| scaled_masked_softmax_cuda | APEX_BUILD_SCALED_MASKED_SOFTMAX_CUDA=1 | APEX_BUILD_CUDA_OPS=1 | +| scaled_softmax_cuda | APEX_BUILD_SCALED_SOFTMAX_CUDA=1 | APEX_BUILD_CUDA_OPS=1 | +| scaled_upper_triang_masked_softmax_cuda | APEX_BUILD_SCALED_UPPER_TRIANG_MASKED_SOFTMAX_CUDA=1 | APEX_BUILD_CUDA_OPS=1 | +| syncbn | APEX_BUILD_SYNCBN=1 | APEX_BUILD_CUDA_OPS=1 | +| transducer_joint_cuda | APEX_BUILD_TRANSDUCER_JOINT=1 | APEX_BUILD_CUDA_OPS=1 | +| transducer_loss_cuda | APEX_BUILD_TRANSDUCER_LOSS=1 | APEX_BUILD_CUDA_OPS=1 | +| xentropy_cuda | APEX_BUILD_XENTROPY=1 | APEX_BUILD_CUDA_OPS=1 | + +For example, to build FUSED_DENSE​ you can use the following command: +``` +APEX_BUILD_FUSED_DENSE​=1 pip install . --no-build-isolation +``` +This will pre-build and install FUSED_DENSE​ module and rest of the modules are installed to be JIT built and loaded at runtime. + +Aiter backend can be built and used for fused rope. To install aiter: +``` +make aiter +``` + +To use aiter in fused rope, you can use the flag ```USE_ROCM_AITER_ROPE_BACKEND=1```. + +### To add a new module into jit loader + +What is JIT (just-in-time) load? Just-in-time load helps to build the specific modules that are used without needing to build all modules during installation time. This helps to significantly reduce installation time. Without JIT load, it would take roughtly 30 minutes to install apex. With JIT load, it takes less than 1 minute to install apex. + +A python script is provided to ease the process of adding a new module to JIT load. +For this, the user must create C++/CUDA source code for a new apex module in either csrc or apex/contrib/csrc folder. +This script helps to create a builder and a loader for the apex module. +The builder creates the .so file for the apex module (during installation or jit load time) and the loader loads the .so file when the module is imported. + +To run the script: + +``` +python scripts/jit_module.py +``` + +The user should provide the name used to import the module i.e. import fused_bias_swiglu. +If the user does not provide the module name, the script will ask for the module name +``` +What is the name of the module? +``` + +The script is interactive and asks two questions +1. Is this a CUDA module? (Y/n) +2. Enter the sources (comma separated) Press Enter to skip + +If the user answers yes to cuda module, it builds with CUDAOpBuilder otherwise it builds as a cpu operation with CPUOpBuilder. The default is cuda operation. +The user must mention the list of .cpp, .h, .cu files used to compile the module as a comma separated list. +This argument is used to define the return value of sources() method in the builder module. +This will be used to also find the list of directories (include_paths() method) i.e. -I flag in g++ compiler. +The user can decide to skip the list of sources and add it manually to the builder file created by the script. + +e.g. +``` +python scripts/jit_module.py fused_bias_swiglu +1. Is this a CUDA module? (Y/n) y +2. Enter the sources (comma separated) Press Enter to skip csrc/megatron/fused_bias_swiglu.cpp,csrc/megatron/fused_bias_swiglu_cuda.cu +``` + +**Directory structure (fused_bias_swiglu example):** + +The above example creates a builder - [op_builder/fused_bias_swiglu.py](op_builder/fused_bias_swiglu.py) and a loader - [compatibility/fused_bias_swiglu.py](compatibility/fused_bias_swiglu.py). + +``` +apex/ # repo root +├── csrc/ # C++/CUDA source (or apex/contrib/csrc) +│ └── megatron/ +│ ├── fused_bias_swiglu.cpp # PyBind11 module defs, declarations +│ └── fused_bias_swiglu_cuda.cu # CUDA kernels / implementation +├── op_builder/ # Builder: compiles sources → .so +│ └── fused_bias_swiglu.py # FusedBiasSwiGLUBuilder (NAME = "fused_bias_swiglu", sources(), etc.) +├── compatibility/ # Loader: JIT-loads .so when module is imported +│ └── fused_bias_swiglu.py # _FusedBiasSwiGLUModule (loads via apex.op_builder.FusedBiasSwiGLUBuilder) +└── apex/ # Python package + └── fused_bias_swiglu/ # User-facing API (import apex.fused_bias_swiglu) + ├── __init__.py + └── fused_bias_swiglu.py # imports fused_bias_swiglu, wraps forward/backward, etc. +``` + -The latest stable release obtainable from https://pytorch.org should also work. +The user must not edit the loader code. + +The script creates an initial builder code and the users can edit the methods in the module. + +The builder module is created in op_builder folder and must override either CPUOpBuilder or CUDAOpBuilder class and define the following attributes and methods: + +| Attribute | Purpose | +|-----------|-----------| +| BUILD_VAR | The environment variable to indicate prebuilding the module when installing apex e.g. APEX_BUILD_FUSED_BIAS_SWIGLU for fused_bias_swiglu| +| INCLUDE_FLAG | Either APEX_BUILD_CUDA_OPS or APEX_BUILD_CPU_OPS to indicate whether the module will be built for gpu or cpu | +| NAME | name of module e.g. fused_bias_swiglu | + +| Method | Purpose | Necessary to override | +|-----------|-----------|-----------| +| absolute_name | return the namespace where the module will be installed | Yes | +| sources | list of C++/CUDA source files for the module | Yes | +| include_paths | list of folders where the included headers mentioned in the source files are placed | No | +| cxx_args | return a list of extra compiler flags for the C++ compiler when building C++ sources (e.g. optimization level, preprocessor macros) | No | +| nvcc_args | return a list of extra compiler flags for nvcc when building CUDA sources (e.g. -O3, architecture flags, preprocessor macros) | No | +| is_compatible | can this module be installed and loaded considering the environment e.g.minimum torch version supported | No | +| libraries_args | list of libraries to compile against e.g. MIOpen | No | + + + + + +### To create a wheel and then install apex using the wheel, use the following command in apex folder: +``` +python -m build --wheel --no-isolation (can use the same environment variables to build specific extensions, cpp extensions and cuda extensions) +pip install dist/apex-*.whl​ +``` + +### To uninstall apex and its extensions, use the following command in apex folder: +``` +pip uninstall apex +make clean +``` + +### Enable hipblasLT on ROCm +hipblasLT is supported only on mi300 (gfx942) only. +python setup.py automatically builds apex with hipblasLT support only if GPU device id is gfx942 +To verify if hipblasLT support is enabled, check the build logs +INFO: IS_HIPBLASLT_SUPPORTED value is True ==> indicates apex is built with hipblasLT support +INFO: IS_HIPBLASLT_SUPPORTED value is False ### Linux For performance and full functionality, we recommend installing Apex with CUDA and C++ extensions via ```bash -git clone https://github.com/NVIDIA/apex +git clone https://github.com/rocm/apex cd apex -pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ -``` - -Apex also supports a Python-only build via -```bash -pip install -v --disable-pip-version-check --no-cache-dir ./ +pip install . --no-build-isolation ``` -A Python-only build omits: -- Fused kernels required to use `apex.optimizers.FusedAdam`. -- Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`. -- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`. -- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`. -`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower. ### [Experimental] Windows -`pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source -on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work. +`pip install . --no-build-isolation` may work if you were able to build Pytorch from source +on your system. A Python-only build via `pip install --no-build-isolation -v --no-cache-dir .` is more likely to work. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment. + +# Release notes + +## release/1.11.0 + +Added extensions +- fused_conv_bias_relu + +Upgraded extensions +- Create custom python operators for MixedFusedLayerNorm and MixedFusedRMSNorm +- Added Pow implementation for Focal loss cuda kernel to improve computation time + +## release/1.10.0 + +- No new features were added in this release cycle. + +## release/1.9.0 + +- No new features were added in this release cycle. + +## release/1.8.0 + +Unit test related +- Fix transformer unit tests +- Fix fused dense gelu dense unit tests + +Build and installation related +- Support JIT (just-in-time) load cpp and CUDA extensions +- Script to add new module to JIT system + +## release/1.7.0 + +Build and installation related +- Support use of BUILD_VERSION environment to override version.txt when creating apex wheels +- Disable aiter installation by default. make aiter command is used to build apex + +Unit test related +- Include running transformer tests in L0/run_test.py +- Fix transformer unit tests +- Fix batch norm unit tests +- Fix fused dense gelu dense unit tests + +## release/1.6.0 + +Upgraded extensions +- Support unscale_grads in transformer Grad scaler +- Support amp function in fused dense, mlp +- Support blas backend flag in fused dense +- Support not destroying process group for distributed tests +- Upgrade fused adam to support parameters - capturable, master weights, grad scaler +- Upgrade distributed fused adam to support bias_correction, adam_w_mode, overlap_param_sync, store_params, store_param_remainders, with_scaled_states, nccl_ub +- Upgrade distributed fused lamb to support parameters fused_norm, full_ar, set_param_views_to_flat_buffer, skip_allgather, fuse_scale, param_order, nccl_allgather_channels + +Unit test related +- Fix fused dense, fused rope, mlp unit tests +- Add test fused adam unit test +- Include running fused dense tests in L0/run_test.py + + +## release/1.5.0 + +Added extensions +- fused bias swiglu +- fused gradient accumulator +- fused rope + +Upgraded extensions +- Support blaslt backend in fused weight gradient dense module + + + diff --git a/apex/RNN/RNNBackend.py b/apex/RNN/RNNBackend.py index b9d4937ef..a9382e601 100644 --- a/apex/RNN/RNNBackend.py +++ b/apex/RNN/RNNBackend.py @@ -254,17 +254,17 @@ def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_stat self.gate_size = gate_multiplier * self.hidden_size self.n_hidden_states = n_hidden_states - self.w_ih = nn.Parameter(torch.Tensor(self.gate_size, self.input_size)) - self.w_hh = nn.Parameter(torch.Tensor(self.gate_size, self.output_size)) + self.w_ih = nn.Parameter(torch.empty(self.gate_size, self.input_size)) + self.w_hh = nn.Parameter(torch.empty(self.gate_size, self.output_size)) #Check if there's recurrent projection if(self.output_size != self.hidden_size): - self.w_ho = nn.Parameter(torch.Tensor(self.output_size, self.hidden_size)) + self.w_ho = nn.Parameter(torch.empty(self.output_size, self.hidden_size)) self.b_ih = self.b_hh = None if self.bias: - self.b_ih = nn.Parameter(torch.Tensor(self.gate_size)) - self.b_hh = nn.Parameter(torch.Tensor(self.gate_size)) + self.b_ih = nn.Parameter(torch.empty(self.gate_size)) + self.b_hh = nn.Parameter(torch.empty(self.gate_size)) #hidden states for forward self.hidden = [ None for states in range(self.n_hidden_states)] diff --git a/apex/RNN/cells.py b/apex/RNN/cells.py index 32b61a1be..09b08581d 100644 --- a/apex/RNN/cells.py +++ b/apex/RNN/cells.py @@ -18,8 +18,8 @@ def __init__(self, input_size, hidden_size, bias = False, output_size = None): gate_multiplier = 4 super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size) - self.w_mih = nn.Parameter(torch.Tensor(self.output_size, self.input_size)) - self.w_mhh = nn.Parameter(torch.Tensor(self.output_size, self.output_size)) + self.w_mih = nn.Parameter(torch.empty(self.output_size, self.input_size)) + self.w_mhh = nn.Parameter(torch.empty(self.output_size, self.output_size)) self.reset_parameters() diff --git a/apex/__init__.py b/apex/__init__.py index b1125eb77..afe0c074c 100644 --- a/apex/__init__.py +++ b/apex/__init__.py @@ -39,7 +39,17 @@ def format(self, record): _library_root_logger.propagate = False +def check_if_rocm_pytorch(): + is_rocm_pytorch = False + if hasattr(torch.version, 'hip') and torch.version.hip is not None: + is_rocm_pytorch = True + return is_rocm_pytorch + +IS_ROCM_PYTORCH = check_if_rocm_pytorch() + def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool: + if IS_ROCM_PYTORCH: + return True cudnn_available = torch.backends.cudnn.is_available() cudnn_version = torch.backends.cudnn.version() if cudnn_available else None if not (cudnn_available and (cudnn_version >= required_cudnn_version)): diff --git a/apex/_autocast_utils.py b/apex/_autocast_utils.py index e86c6c6a5..3a92a83f3 100644 --- a/apex/_autocast_utils.py +++ b/apex/_autocast_utils.py @@ -3,6 +3,9 @@ import torch +__all__ = ["_cast_if_autocast_enabled"] + + def _get_autocast_dtypes() -> Sequence[torch.dtype]: if torch.cuda.is_bf16_supported(): return [torch.half, torch.bfloat16] diff --git a/apex/amp/__init__.py b/apex/amp/__init__.py index 34d080a69..b4f81cddf 100644 --- a/apex/amp/__init__.py +++ b/apex/amp/__init__.py @@ -1,5 +1,5 @@ -from .amp import init, half_function, float_function, promote_function,\ - register_half_function, register_float_function, register_promote_function +from .amp import init, half_function, bfloat16_function, float_function, promote_function,\ + register_half_function, register_bfloat16_function, register_float_function, register_promote_function from .handle import scale_loss, disable_casts from .frontend import initialize, state_dict, load_state_dict from ._amp_state import master_params, _amp_state diff --git a/apex/amp/_amp_state.py b/apex/amp/_amp_state.py index 1ac9d3116..7e8a329f5 100644 --- a/apex/amp/_amp_state.py +++ b/apex/amp/_amp_state.py @@ -2,18 +2,8 @@ # I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like. # But apparently it's ok: # http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm -import os import torch -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) - - -if TORCH_MAJOR == 1 and TORCH_MINOR < 8: - from torch._six import container_abcs -else: - import collections.abc as container_abcs - class AmpState(object): def __init__(self): diff --git a/apex/amp/_initialize.py b/apex/amp/_initialize.py index 28c5bbbdf..641451f6d 100644 --- a/apex/amp/_initialize.py +++ b/apex/amp/_initialize.py @@ -1,11 +1,13 @@ -import torch -from torch._six import string_classes +import collections.abc as container_abcs +from types import MethodType import functools -import numpy as np import sys -from types import MethodType import warnings -from ._amp_state import _amp_state, warn_or_err, container_abcs + +import numpy as np +import torch + +from ._amp_state import _amp_state, warn_or_err from .handle import disable_casts from .scaler import LossScaler from ._process_optimizer import _process_optimizer @@ -39,7 +41,7 @@ def to_type(dtype, t): def applier(value, fn): if isinstance(value, torch.Tensor): return fn(value) - elif isinstance(value, string_classes): + elif isinstance(value, str): return value elif isinstance(value, np.ndarray): return value @@ -80,10 +82,10 @@ def check_params_fp32(models): for model in models: for name, param in model.named_parameters(): if param.is_floating_point(): - if 'Half' in param.type(): + if 'Half' in param.type() or 'BFloat16' in param.type(): warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n" - "When using amp.initialize, you do not need to call .half() on your model\n" - "before passing it, no matter what optimization level you choose.".format( + "When using amp.initialize, you do not need to call .half() or .bfloat16()\n" + "on your model before passing it, no matter what optimization level you choose.".format( name, param.type())) elif not param.is_cuda: warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n" @@ -137,7 +139,7 @@ def __init__(self, fn): def __call__(self, module, state_dict, prefix, local_metadata): for key in state_dict: param = state_dict[key] - if 'Half' in param.type(): + if 'Half' in param.type() or 'BFloat16' in param.type(): param = param.to(torch.float32) state_dict[key] = param @@ -189,7 +191,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs for model in models: # Patch the forward method to cast incoming data to the correct type, and - # outgoing data to float32, so "the user never needs to call .half()." + # outgoing data to float32, so "the user never needs to call .half()/.bfloat16()." # I like writing things explicitly more than decorators. def patch_forward(old_fwd): def new_fwd(*args, **kwargs): @@ -232,7 +234,9 @@ def new_fwd(*args, **kwargs): if properties.patch_torch_functions: # handle is unused here. It's accessible later through a global value anyway. - handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2)) + handle = amp_init(loss_scale=properties.loss_scale, + patch_type=properties.patch_torch_functions_type, + verbose=(_amp_state.verbosity == 2)) for optimizer in optimizers: # Disable Amp casting for the optimizer step, because it should only be # applied to FP32 master params anyway. diff --git a/apex/amp/_process_optimizer.py b/apex/amp/_process_optimizer.py index 471289bba..66c4c3fdf 100644 --- a/apex/amp/_process_optimizer.py +++ b/apex/amp/_process_optimizer.py @@ -1,7 +1,7 @@ import types from ..fp16_utils import master_params_to_model_params from ..multi_tensor_apply import multi_tensor_applier -from ._amp_state import maybe_print +from ._amp_state import maybe_print, _amp_state import torch from ..optimizers import FusedSGD @@ -37,7 +37,7 @@ def lazy_init_with_master_weights(self): fp32_from_fp16_params_this_group = [] for i, param in enumerate(param_group['params']): if param.requires_grad: - if param.type() == 'torch.cuda.HalfTensor': + if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}: # maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}" # .format(param.size())) fp16_params_this_group.append(param) @@ -55,8 +55,8 @@ def lazy_init_with_master_weights(self): fp32_params_this_group.append(param) param_group['params'][i] = param else: - raise TypeError("Optimizer's parameters must be either " - "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " + raise TypeError("Optimizer's parameters must one of " + "torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. " "Received {}".format(param.type())) stash.fp16_groups.append(fp16_params_this_group) @@ -208,13 +208,13 @@ def lazy_init_no_master_weights(self): stash.all_fp32_params = [] for i, param_group in enumerate(self.param_groups): for i, param in enumerate(param_group['params']): - if param.type() == 'torch.cuda.HalfTensor': + if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}: stash.all_fp16_params.append(param) elif param.type() == 'torch.cuda.FloatTensor': stash.all_fp32_params.append(param) else: - raise TypeError("Optimizer's parameters must be either " - "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " + raise TypeError("Optimizer's parameters must be one of " + "torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.BFloat16Tensor. " "Received {}".format(param.type())) stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params] @@ -341,7 +341,7 @@ def _process_optimizer(optimizer, properties): import amp_C optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm - optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]); + optimizer._amp_stash.dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda') if properties.master_weights: optimizer._lazy_init_maybe_master_weights = types.MethodType( @@ -435,7 +435,7 @@ def new_add_param_group(self, new_group): fp32_from_fp16_params_this_group = [] for i, param in enumerate(new_group['params']): if param.requires_grad: - if param.type() == 'torch.cuda.HalfTensor': + if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}: fp16_params_this_group.append(param) master_param = param.detach().clone().float() master_param.requires_grad = True @@ -445,8 +445,8 @@ def new_add_param_group(self, new_group): fp32_params_this_group.append(param) new_group['params'][i] = param else: - raise TypeError("Optimizer's parameters must be either " - "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " + raise TypeError("Optimizer's parameters must be one of " + "torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. " "Received {}".format(param.type())) stash.fp16_groups.append(fp16_params_this_group) @@ -471,15 +471,15 @@ def new_add_param_group(self, new_group): # param.grad = None else: for param in new_group['params']: - if param.type() == 'torch.cuda.HalfTensor': + if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}: stash.all_fp16_params.append(param) stash.all_fp16_grad_stash.append(None) elif param.type() == 'torch.cuda.FloatTensor': stash.all_fp32_params.append(param) stash.all_fp32_grad_stash.append(None) else: - raise TypeError("Optimizer's parameters must be either " - "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " + raise TypeError("Optimizer's parameters must one of " + "torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. " "Received {}".format(param.type())) old_add_param_group(new_group) diff --git a/apex/amp/amp.py b/apex/amp/amp.py index 1eed72d07..b438b3fcc 100644 --- a/apex/amp/amp.py +++ b/apex/amp/amp.py @@ -9,7 +9,6 @@ import torch - _DECORATOR_HANDLE = None _USER_CAST_REGISTRY = set() _USER_PROMOTE_REGISTRY = set() @@ -31,6 +30,9 @@ def half_function(fn): wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True) return _decorator_helper(fn, utils.maybe_half, wrap_fn) +def bfloat16_function(fn): + wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True) + return _decorator_helper(fn, utils.maybe_bfloat16, wrap_fn) def float_function(fn): wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False) @@ -49,6 +51,11 @@ def register_half_function(module, name): name, module)) _USER_CAST_REGISTRY.add((module, name, utils.maybe_half)) +def register_bfloat16_function(module, name): + if not hasattr(module, name): + raise ValueError('No function named {} in module {}.'.format( + name, module)) + _USER_CAST_REGISTRY.add((module, name, utils.maybe_bfloat16)) def register_float_function(module, name): if not hasattr(module, name): @@ -65,7 +72,7 @@ def register_promote_function(module, name): # Top-level function to insert _all_ the hooks. -def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, allow_banned=False): +def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_caching=True, verbose=False, allow_banned=False): global _DECORATOR_HANDLE if not enabled: @@ -87,16 +94,30 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, wrap.promote(mod, fn, handle, verbose) _USER_PROMOTE_REGISTRY.clear() + # conditionally choose between fp16 and bfloat16 functions list to cache + if patch_type == torch.float16: + low_prec_funcs = 'FP16_FUNCS' + maybe_low_prec = utils.maybe_half + low_prec_tensor = torch.cuda.HalfTensor + elif patch_type == torch.bfloat16: + low_prec_funcs = 'BFLOAT16_FUNCS' + maybe_low_prec = utils.maybe_bfloat16 + low_prec_tensor = torch.cuda.BFloat16Tensor + else: + raise RuntimeError("Unsupported patch_torch_functions_type passed to initialize." + + "Supported types are: torch.float16 and torch.bfloat16.") + # 1) Force-{fp16, fp32} on white- / black-list functions override_modules = [functional_overrides, torch_overrides, tensor_overrides] - cast_table = [('FP16_FUNCS', utils.maybe_half), + cast_table = [(low_prec_funcs, maybe_low_prec), ('FP32_FUNCS', utils.maybe_float)] + for module, (list_name, cast_fn) in itertools.product(override_modules, cast_table): for fn in getattr(module, list_name): - try_caching = (cast_fn == utils.maybe_half) + try_caching = (cast_fn == maybe_low_prec) wrap.cached_cast(module.MODULE, fn, cast_fn, handle, try_caching, verbose) @@ -128,12 +149,12 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, for fn in getattr(tensor_overrides, list_name): promote_fn(cls, fn, handle, verbose) - # 3) For any in-place version of a blacklist function, error if any input is fp16. + # 3) For any in-place version of a blacklist function, error if any input is fp16/bfloat16. # NB: this is overly conservative. for fn in utils.as_inplace(torch_overrides.FP32_FUNCS): wrap.err_if_any_half(torch_overrides.MODULE, fn, handle) - # 3.5) For any in-place blacklist method, error if called on fp16 tensor + # 3.5) For any in-place blacklist method, error if called on fp16/bfloat16 tensor for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS): wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose) if compat.tensor_is_float_tensor(): @@ -141,7 +162,7 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, # 4) For other in-place methods, match the type of self tensor for fn in utils.as_inplace(itertools.chain( - tensor_overrides.FP16_FUNCS, + getattr(tensor_overrides, low_prec_funcs), tensor_overrides.CASTS)): wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose) if compat.tensor_is_float_tensor(): @@ -156,10 +177,10 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim() # Wrap all the rnns for x in rnn_compat.RNN_NAMES: - wrap.new_rnn_cast(x.upper(), handle, verbose) + wrap.new_rnn_cast(x.upper(), maybe_low_prec, handle, verbose) # Wrap all the RNN cells - rnn_compat.whitelist_rnn_cells(handle, verbose) + rnn_compat.whitelist_rnn_cells(maybe_low_prec, handle, verbose) # 6) Place error+print message on banned functions. # Or, if allow_banned, then cast to FP32. diff --git a/apex/amp/compat.py b/apex/amp/compat.py index 22276bd47..2725fa845 100644 --- a/apex/amp/compat.py +++ b/apex/amp/compat.py @@ -28,7 +28,8 @@ def is_floating_point(x): torch_type = x.type() return torch_type.endswith('FloatTensor') or \ torch_type.endswith('HalfTensor') or \ - torch_type.endswith('DoubleTensor') + torch_type.endswith('DoubleTensor') or \ + torch_type.endswith('BFloat16Tensor') except AttributeError: return False diff --git a/apex/amp/frontend.py b/apex/amp/frontend.py index da0f05dc9..5ee96b778 100644 --- a/apex/amp/frontend.py +++ b/apex/amp/frontend.py @@ -1,4 +1,5 @@ import torch +import os from ._initialize import _initialize from ._amp_state import _amp_state, warn_or_err, maybe_print from collections import OrderedDict @@ -16,6 +17,10 @@ def __init__(self): "opt_level" : None, "cast_model_type" : None, "patch_torch_functions" : False, + # TODO: patch_torch_functions_type could probably be unified with + # patch_torch_functions. Currently introducing a new attribute + # to be on the safer side and not break stuff. + "patch_torch_functions_type" : None, "keep_batchnorm_fp32" : None, "master_weights" : None, "loss_scale" : 1.0, @@ -53,7 +58,7 @@ def __setattr__(self, name, value): if name in self.options: # print("setting {} {}".format(name, value)) if name == "cast_model_type": - if self.opt_level == "O1" and value is not None: + if self.opt_level in {"O1", "O4"} and value is not None: if value is not False: if value is not torch.float32: warn_or_err("O1 inserts casts around Torch functions rather than " @@ -63,13 +68,25 @@ def __setattr__(self, name, value): "cast_model_type was {}".format(value)) self.options[name] = value elif name == "patch_torch_functions": - if self.opt_level != "O1" and value: + if self.opt_level not in {"O1", "O4"} and value: warn_or_err("Currently, patch_torch_functions=True should only be set by " - "selecting opt_level='O1'.") + "selecting opt_level='O1' or 'O4'.") self.options[name] = value + elif name == "patch_torch_functions_type": + if self.opt_level not in {"O1", "O4"} and value is not None: + warn_or_err("Currently, patch_torch_functions_type should only be set by " + "selecting opt_level='O1' or 'O4'.") + elif self.opt_level == "O1" and value != torch.float16: + warn_or_err("patch_torch_functions_type should only be set to torch.float16 " + "for opt_level='O1.") + elif self.opt_level == "O4" and value != torch.bfloat16: + warn_or_err("patch_torch_functions_type should only be set to torch.bfloat16 " + "for opt_level='O4.") + else: + self.options[name] = value elif name == "keep_batchnorm_fp32": - if self.opt_level == "O1" and value is not None: - warn_or_err("With opt_level O1, batchnorm functions are automatically patched " + if self.opt_level in {"O1", "O4"} and value is not None: + warn_or_err("With opt_level O1 or O4, batchnorm functions are automatically patched " "to run in FP32, so keep_batchnorm_fp32 should be None." + " keep_batchnorm_fp32 was {}".format(value)) if value == "False": @@ -82,9 +99,9 @@ def __setattr__(self, name, value): "or None, found keep_batchnorm_fp32={}".format(value) self.options[name] = value elif name == "master_weights": - if self.opt_level == "O1" and value is not None: - warn_or_err("It doesn't make sense to use master_weights with O1. " - "With O1, your model weights themselves should be FP32.") + if self.opt_level in {"O1", "O4"} and value is not None: + warn_or_err("It doesn't make sense to use master_weights with O1 and O4 . " + "With O1 and O4, your model weights themselves should be FP32.") self.options[name] = value elif name == "loss_scale": if value == "dynamic": @@ -113,6 +130,7 @@ def __call__(self, properties): properties.opt_level = "O3" properties.cast_model_type = torch.float16 properties.patch_torch_functions = False + properties.patch_torch_functions_type = None properties.keep_batchnorm_fp32 = False properties.master_weights = False properties.loss_scale = 1.0 @@ -136,6 +154,7 @@ def __call__(self, properties): properties.opt_level = "O2" properties.cast_model_type = torch.float16 properties.patch_torch_functions = False + properties.patch_torch_functions_type = None properties.keep_batchnorm_fp32 = True properties.master_weights = True properties.loss_scale = "dynamic" @@ -158,6 +177,7 @@ def __call__(self, properties): properties.opt_level = "O1" properties.cast_model_type = None properties.patch_torch_functions = True + properties.patch_torch_functions_type = torch.float16 properties.keep_batchnorm_fp32 = None properties.master_weights = None properties.loss_scale = "dynamic" @@ -177,6 +197,7 @@ def __call__(self, properties): properties.opt_level = "O0" properties.cast_model_type = torch.float32 properties.patch_torch_functions = False + properties.patch_torch_functions_type = None properties.keep_batchnorm_fp32 = None properties.master_weights = False properties.loss_scale = 1.0 @@ -184,11 +205,54 @@ def __call__(self, properties): # properties.enable_ddp_interop = False return properties # modified in place so this isn't really necessary +class O4: + brief = "O4: Insert automatic casts around Pytorch functions and Tensor methods.\n" + more = "The type of your model's weights is not altered. However, internally,\n"\ + "Pytorch functions are patched to cast any Tensor Core-friendly ops to BFLOAT16 for speed,\n"\ + "while operations that might benefit from the additional stability of FP32 are patched\n"\ + "to cast their inputs to fp32.\n"\ + "Loss scaling is not required in O4 mode since bflaot16 has the same dynamic range as fp32." + + def __call__(self, properties): + properties.enabled = True + properties.opt_level = "O4" + properties.cast_model_type = None + properties.patch_torch_functions = True + properties.patch_torch_functions_type = torch.bfloat16 + properties.keep_batchnorm_fp32 = None + properties.master_weights = None + properties.loss_scale = 1 + return properties # modified in place so this isn't really necessary + +class O5: + brief = "O5: BFLOAT16 training with FP32 batchnorm and FP32 master weights.\n" + more = "Calls .bfloat16() on your model, converting the entire model (except for batchnorms)\n"\ + "to BFLOAT16. Batchnorms are retained in FP32 for additional stability.\n"\ + "The forward pass is patched to cast incoming Tensors to BFLOAT16, so you don't need to change\n"\ + "your data pipeline.\n"\ + "O5 creates FP32 master weights outside the model and patches any optimizers to update\n"\ + "these master weights, then copy the master weights into the BFLOAT16 model weights.\n"\ + "Master weights can also improve convergence and stability." + + def __call__(self, properties): + properties.enabled = True + properties.opt_level = "O5" + properties.cast_model_type = torch.bfloat16 + properties.patch_torch_functions = False + properties.patch_torch_functions = None + properties.patch_torch_functions_type = None + properties.keep_batchnorm_fp32 = True + properties.master_weights = True + properties.loss_scale = 1 + return properties # modified in place so this isn't really necessary + opt_levels = {"O3": O3(), "O2": O2(), "O1": O1(), - "O0": O0()} + "O0": O0(), + "O4": O4(), + "O5": O5()} # allow user to directly pass Properties struct as well? @@ -199,6 +263,7 @@ def initialize( opt_level="O1", cast_model_type=None, patch_torch_functions=None, + patch_torch_functions_type=None, keep_batchnorm_fp32=None, master_weights=None, loss_scale=None, @@ -235,10 +300,11 @@ def initialize( enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script should run as if Amp were not present. opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are - "O0", "O1", "O2", and "O3", explained in detail above. + "O0", "O1", "O2", "O3", "O4" and "O5", explained in detail above. cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see above. patch_torch_functions (bool, optional, default=None): Optional property override. + patch_torch_functions_type (``torch.dtype``, optional, default=None): Optional property override keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If passed as a string, must be the string "True" or "False". master_weights (bool, optional, default=None): Optional property override. @@ -321,14 +387,14 @@ def initialize( if opt_level not in opt_levels: raise RuntimeError( "Unexpected optimization level {}. ".format(opt_level) + - "Options are 'O0', 'O1', 'O2', 'O3'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " + + "Options are 'O0', 'O1', 'O2', 'O3', 'O4', 'O5'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " + "not the number zero.") else: _amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties) maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True) maybe_print("Defaults for this optimization level are:", True) for k, v in _amp_state.opt_properties.options.items(): - maybe_print("{:22} : {}".format(k, v), True) + maybe_print("{:26} : {}".format(k, v), True) _amp_state.min_loss_scale = min_loss_scale _amp_state.max_loss_scale = max_loss_scale @@ -344,6 +410,8 @@ def initialize( _amp_state.opt_properties.cast_model_type = cast_model_type if patch_torch_functions is not None: _amp_state.opt_properties.patch_torch_functions = patch_torch_functions + if patch_torch_functions_type is not None: + _amp_state.opt_properties.patch_torch_functions_type = patch_torch_functions_type if keep_batchnorm_fp32 is not None: _amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32 if master_weights is not None: @@ -353,7 +421,12 @@ def initialize( maybe_print("After processing overrides, optimization options are:", True) for k, v in _amp_state.opt_properties.options.items(): - maybe_print("{:22} : {}".format(k, v), True) + maybe_print("{:26} : {}".format(k, v), True) + + + # Set flag to tell F8 that apex.amp is initialized + os.environ["APEX_AMP_ENABLED"] = "1" + return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs) diff --git a/apex/amp/lists/functional_overrides.py b/apex/amp/lists/functional_overrides.py index dd009cec6..9ecdf0972 100644 --- a/apex/amp/lists/functional_overrides.py +++ b/apex/amp/lists/functional_overrides.py @@ -26,6 +26,17 @@ 'linear', ] +BFLOAT16_FUNCS = [ + 'conv1d', + 'conv2d', + 'conv3d', + 'conv_transpose1d', + 'conv_transpose2d', + 'conv_transpose3d', + 'conv_tbc', # Undocumented / maybe new? + 'linear', +] + FP32_FUNCS = [ # Interpolation/Upsampling TODO: Remove for 1.2 diff --git a/apex/amp/lists/tensor_overrides.py b/apex/amp/lists/tensor_overrides.py index 18f3e5dcf..d2783cede 100644 --- a/apex/amp/lists/tensor_overrides.py +++ b/apex/amp/lists/tensor_overrides.py @@ -15,6 +15,10 @@ '__matmul__', ]) +BFLOAT16_FUNCS = [ + '__matmul__', +] + FP32_FUNCS = compat.filter_attrs(MODULE, [ '__ipow__', '__pow__', @@ -56,7 +60,7 @@ # between `torch` and `torch.Tensor` (and check with `hasattr`, # because a few random ones aren't defined on Tensor) _self_mod = importlib.import_module(__name__) -for attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']: +for attrname in ['FP16_FUNCS', 'BFLOAT16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']: lst = getattr(_self_mod, attrname) for fn in getattr(torch_overrides, attrname): if hasattr(MODULE, fn): diff --git a/apex/amp/lists/torch_overrides.py b/apex/amp/lists/torch_overrides.py index 7dedb05a8..099887038 100644 --- a/apex/amp/lists/torch_overrides.py +++ b/apex/amp/lists/torch_overrides.py @@ -26,6 +26,27 @@ 'mv', ] +BFLOAT16_FUNCS = [ + # Low level functions wrapped by torch.nn layers. + # The wrapper layers contain the weights which are then passed in as a parameter + # to these functions. + 'conv1d', + 'conv2d', + 'conv3d', + 'conv_transpose1d', + 'conv_transpose2d', + 'conv_transpose3d', + 'conv_tbc', + + # BLAS + 'addmm', + 'addmv', + 'addr', + 'matmul', + 'mm', + 'mv', +] + FP32_FUNCS = [ # Pointwise 'acos', diff --git a/apex/amp/rnn_compat.py b/apex/amp/rnn_compat.py index d062ae265..987dba775 100644 --- a/apex/amp/rnn_compat.py +++ b/apex/amp/rnn_compat.py @@ -28,7 +28,7 @@ def has_old_rnns(): except: return False -def whitelist_rnn_cells(handle, verbose): +def whitelist_rnn_cells(cast_fn, handle, verbose): # Different module + function names in old/new RNN cases if has_old_rnns(): fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell'] @@ -40,7 +40,7 @@ def whitelist_rnn_cells(handle, verbose): # Insert casts on cell functions for fn in fn_names: - wrap.cached_cast(mod, fn, utils.maybe_half, handle, + wrap.cached_cast(mod, fn, cast_fn, handle, try_caching=True, verbose=verbose) if has_old_rnns(): diff --git a/apex/amp/scaler.py b/apex/amp/scaler.py index 99888bc6f..c11f70398 100644 --- a/apex/amp/scaler.py +++ b/apex/amp/scaler.py @@ -6,12 +6,18 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False): # Exception handling for 18.04 compatibility if check_overflow: - cpu_sum = float(model_grad.float().sum()) + if model_grad.is_sparse: + cpu_sum = float(model_grad.float()._values().sum()) + else: + cpu_sum = float(model_grad.float().sum()) if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: return True if master_grad is not model_grad: # copy_ probably internally short-circuits this - master_grad.copy_(model_grad) + if model_grad.is_sparse: + master_grad.copy_(model_grad.to_dense()) + else: + master_grad.copy_(model_grad) if scale != 1.0: master_grad.mul_(scale) return False @@ -19,7 +25,10 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False): # Exception handling for 18.04 compatibility if check_overflow: - cpu_sum = float(model_grad.float().sum()) + if model_grad.is_sparse: + cpu_sum = float(model_grad.float()._values().sum()) + else: + cpu_sum = float(model_grad.float().sum()) if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: return True @@ -53,7 +62,7 @@ def __init__(self, self._scale_seq_len = scale_window self._unskipped = 0 self._has_overflow = False - self._overflow_buf = torch.cuda.IntTensor([0]) + self._overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda') if multi_tensor_applier.available: import amp_C LossScaler.has_fused_kernel = multi_tensor_applier.available diff --git a/apex/amp/utils.py b/apex/amp/utils.py index 0590cd70a..c27fce5e2 100644 --- a/apex/amp/utils.py +++ b/apex/amp/utils.py @@ -62,6 +62,17 @@ def maybe_half(x, name='', verbose=False): print('Float->Half ({})'.format(name)) return x.half() +def maybe_bfloat16(x, name='', verbose=False): + if is_nested(x): + return type(x)([maybe_bfloat16(y) for y in x]) + + if not x.is_cuda or type_string(x) == 'BFloat16Tensor': + return x + else: + if verbose: + print('Float->BFloat16 ({})'.format(name)) + return x.bfloat16() + def maybe_float(x, name='', verbose=False): if is_nested(x): return type(x)([maybe_float(y) for y in x]) @@ -92,9 +103,12 @@ def cached_cast(cast_fn, x, cache): return type(x)([cached_cast(y) for y in x]) if x in cache: cached_x = cache[x] + next_functions_available = False if x.requires_grad and cached_x.requires_grad: + if len(cached_x.grad_fn.next_functions) > 1: + next_functions_available = True # Make sure x is actually cached_x's autograd parent. - if cached_x.grad_fn.next_functions[1][0].variable is not x: + if next_functions_available and cached_x.grad_fn.next_functions[1][0].variable is not x: raise RuntimeError("x and cache[x] both require grad, but x is not " "cache[x]'s parent. This is likely an error.") # During eval, it's possible to end up caching casted weights with @@ -114,6 +128,8 @@ def cached_cast(cast_fn, x, cache): # connection between x and cached_x. if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad: del cache[x] + elif x.requires_grad and cached_x.requires_grad and not next_functions_available: + del cache[x] else: return cached_x @@ -189,22 +205,28 @@ def synthesize_flattened_rnn_weights(fp32_weights, fp16_weights.append(fp16_layer_weights) return fp16_weights +def _str_from_dtype(dtype=torch.float16): + type_to_str = {torch.float16 : 'Half', + torch.bfloat16 : 'BFloat16'} + return type_to_str[dtype] + # Roughly same as above, just the `fp32_weights` aren't nested. # Code kept separate for readability. def new_synthesize_flattened_rnn_weights(fp32_weights, fp16_flat_tensor, rnn_fn='', + dtype=torch.float16, verbose=False): fp16_weights = [] fp32_base_ptr = fp32_weights[0].data_ptr() for w_fp32 in fp32_weights: - w_fp16 = w_fp32.new().half() + w_fp16 = w_fp32.new().to(dtype=dtype) offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size() w_fp16.set_(fp16_flat_tensor.storage(), offset, w_fp32.shape) w_fp16.copy_(w_fp32) if verbose: - print('Float->Half ({})'.format(rnn_fn)) + print('Float->{} ({})'.format(_str_from_dtype(dtype), rnn_fn)) fp16_weights.append(w_fp16) return fp16_weights diff --git a/apex/amp/wrap.py b/apex/amp/wrap.py index 559d0558d..d0a23fdea 100644 --- a/apex/amp/wrap.py +++ b/apex/amp/wrap.py @@ -51,7 +51,8 @@ def wrapper(*args, **kwargs): if len(types) <= 1: return orig_fn(*args, **kwargs) - elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']): + elif len(types) == 2 and (types == set(['HalfTensor', 'FloatTensor']) + or types == set(['BFloat16Tensor', 'FloatTensor'])): new_args = utils.casted_args(cast_fn, args, kwargs) @@ -79,7 +80,8 @@ def wrapper(seq, *args, **kwargs): types = set([utils.type_string(x) for x in seq]) if len(types) <= 1: return orig_fn(seq, *args, **kwargs) - elif types == set(['HalfTensor', 'FloatTensor']): + elif (types == set(['HalfTensor', 'FloatTensor']) or + types == set(['BFloat16Tensor', 'FloatTensor'])): cast_seq = utils.casted_args(maybe_float, seq, {}) return orig_fn(cast_seq, *args, **kwargs) @@ -102,6 +104,8 @@ def wrapper(arg0, *args, **kwargs): if utils.type_string(arg0) == 'HalfTensor': cast_fn = utils.maybe_half + if utils.type_string(arg0) == 'BFloat16Tensor': + cast_fn = utils.maybe_bfloat16 elif utils.type_string(arg0) == 'FloatTensor': cast_fn = utils.maybe_float else: @@ -119,12 +123,12 @@ def err_if_any_half(mod, fn, handle, custom_err_msg=None): @functools.wraps(orig_fn) def wrapper(*args, **kwargs): types = utils.collect_fp_tensor_types(args, kwargs) - if 'HalfTensor' in types: + if 'HalfTensor' in types or 'BFloat16Tensor' in types: if custom_err_msg: raise NotImplementedError(custom_err_msg) else: raise NotImplementedError('Cannot call in-place function ' + - '{} with fp16 arguments.'.format(fn)) + '{} with fp16 or bfloat16 args.'.format(fn)) else: return orig_fn(*args, **kwargs) utils.set_func_save(handle, mod, fn, wrapper) @@ -137,9 +141,9 @@ def err_if_arg0_half(mod, fn, handle, verbose=False): @functools.wraps(orig_fn) def wrapper(arg0, *args, **kwargs): assert compat.is_tensor_like(arg0) - if utils.type_string(arg0) == 'HalfTensor': + if utils.type_string(arg0) in {'HalfTensor', 'BFloat16Tensor'}: raise NotImplementedError('Cannot call in-place method ' + - '{} on fp16 Tensors.'.format(fn)) + '{} with fp16 or bfloat16 args.'.format(fn)) else: cast_fn = utils.verbosify(utils.maybe_float, fn, verbose) new_args = utils.casted_args(cast_fn, args, kwargs) @@ -219,7 +223,7 @@ def fwd_wrapper(*fargs, **fkwargs): return fwd_wrapper utils.set_func_save(handle, backend, fn, rnn_wrapper) -def new_rnn_cast(fn, handle, verbose=False): +def new_rnn_cast(fn, cast_fn, handle, verbose=False): # Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744 # For rnn backend calls that route through _rnn_impls, we must patch the ref # that _rnn_impls stashed. For rnn backend calls that directly invoke @@ -232,7 +236,7 @@ def new_rnn_cast(fn, handle, verbose=False): assert isinstance(mod, rnn_compat.VariableFunctionsShim) fn = fn.lower() orig_fn = utils.get_func(mod, fn) - cast_fn = utils.verbosify(utils.maybe_half, fn, verbose) + cast_fn = utils.verbosify(cast_fn, fn, verbose) @functools.wraps(orig_fn) def wrapper(*args, **kwargs): # Exact call signature from modules/rnn.py @@ -247,14 +251,20 @@ def wrapper(*args, **kwargs): else: params_idx = 3 # PackedSequence case + if cast_fn == utils.maybe_half: + dtype = torch.half + elif cast_fn == utils.maybe_bfloat16: + dtype = torch.bfloat16 + else: + raise RuntimeError("Unsupported cast_fn passed. Supports only maybe_half and maybe_bfloat16") new_args = [] for i, arg in enumerate(args): if i == params_idx: num_params = sum([x.numel() for x in arg]) fp16_weight_buf = args[0].new_empty((num_params,), - dtype=torch.half) + dtype=dtype) casted_weights = utils.new_synthesize_flattened_rnn_weights( - arg, fp16_weight_buf, fn, verbose) + arg, fp16_weight_buf, fn, dtype, verbose) new_args.append(casted_weights) elif utils.is_fp_tensor(arg): new_args.append(cast_fn(arg)) diff --git a/apex/contrib/bottleneck/bottleneck.py b/apex/contrib/bottleneck/bottleneck.py index 5ea5694cc..8e98fc3c6 100644 --- a/apex/contrib/bottleneck/bottleneck.py +++ b/apex/contrib/bottleneck/bottleneck.py @@ -5,13 +5,13 @@ from torch import nn from apex import check_cudnn_version_and_warn -import fast_bottleneck +if check_cudnn_version_and_warn(__name__, 8400): + import fast_bottleneck +else: + fast_bottleneck = None import nccl_p2p_cuda as inc -assert check_cudnn_version_and_warn(__name__, 8400) - - def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): weight_tensor_nchw = tensor nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity) diff --git a/apex/contrib/bottleneck/halo_exchangers.py b/apex/contrib/bottleneck/halo_exchangers.py index 5697e3a69..b627fb2da 100644 --- a/apex/contrib/bottleneck/halo_exchangers.py +++ b/apex/contrib/bottleneck/halo_exchangers.py @@ -107,15 +107,10 @@ def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_inp right_tx = self.peer_pool.allocate_peer_tensors(list(right_output_halo.shape), right_output_halo.dtype, channels_last, True) pm.push_pull_halos_1d( self.diagnostics, self.explicit_nhwc, self.numSM, - left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo, - right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo, + self.left_zero, left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo, + self.right_zero, right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo, self.signals[self.wrap_around_left_rank_in_group], self.signals[self.wrap_around_right_rank_in_group], self.signals[self.rank_in_group] ) - # TODO: Add to push_pull_halos_1d kernel - if self.left_zero: - left_input_halo.zero_() - if self.right_zero: - right_input_halo.zero_() if not inplace: return left_input_halo, right_input_halo diff --git a/apex/contrib/clip_grad/clip_grad.py b/apex/contrib/clip_grad/clip_grad.py index 7d1eb8618..b6411352b 100644 --- a/apex/contrib/clip_grad/clip_grad.py +++ b/apex/contrib/clip_grad/clip_grad.py @@ -1,17 +1,18 @@ -import torch -from torch._six import inf from typing import Union, Iterable +import torch + _kernel_import_succeeded = False try: import amp_C from apex.multi_tensor_apply import multi_tensor_applier _kernel_import_succeeded = True -except: +except ImportError: _kernel_import_succeeded = False _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] + def clip_grad_norm_( parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False) -> torch.Tensor: diff --git a/apex/contrib/conv_bias_relu/conv_bias_relu.py b/apex/contrib/conv_bias_relu/conv_bias_relu.py index b3e66c5a9..a75583cbf 100644 --- a/apex/contrib/conv_bias_relu/conv_bias_relu.py +++ b/apex/contrib/conv_bias_relu/conv_bias_relu.py @@ -1,18 +1,19 @@ -import pdb - import torch from torch.autograd import gradcheck -from apex import check_cudnn_version_and_warn -import fused_conv_bias_relu - -check_cudnn_version_and_warn(__name__, 8400) +try: + import fused_conv_bias_relu +except ImportError: + fused_conv_bias_relu = None class ConvBiasReLU_(torch.autograd.Function): @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.half) + @torch.amp.custom_fwd(cast_inputs=torch.half, device_type="cuda") def forward(ctx, x, weight, bias, padding, stride): + ctx.bias_shape = bias.shape if bias is not None else None + if bias is not None and bias.dim() != 1: + bias = bias.view(-1) outputs = fused_conv_bias_relu.forward([x, weight, bias], padding, stride) ctx.save_for_backward(x, weight, outputs[0]) ctx.padding = padding @@ -21,20 +22,27 @@ def forward(ctx, x, weight, bias, padding, stride): return outputs[0] @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd(device_type="cuda") def backward(ctx, grad_output): bwd_args = [*ctx.saved_tensors, grad_output] padding = ctx.padding stride = ctx.stride grads = fused_conv_bias_relu.backward(bwd_args, padding, stride) - return grads[0], grads[1], grads[2], None, None + grad_bias = grads[2] + if grad_bias is not None and ctx.bias_shape is not None and grad_bias.shape != ctx.bias_shape: + grad_bias = grad_bias.view(ctx.bias_shape) + + return grads[0], grads[1], grad_bias, None, None class ConvBiasMaskReLU_(torch.autograd.Function): @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.half) + @torch.amp.custom_fwd(cast_inputs=torch.half, device_type="cuda") def forward(ctx, x, weight, bias, mask, padding, stride): + ctx.bias_shape = bias.shape if bias is not None else None + if bias is not None and bias.dim() != 1: + bias = bias.view(-1) outputs = fused_conv_bias_relu.forward_mask([x, weight, bias, mask], padding, stride) ctx.save_for_backward(x, weight, outputs[0]) ctx.padding = padding @@ -43,20 +51,27 @@ def forward(ctx, x, weight, bias, mask, padding, stride): return outputs[0] @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd(device_type="cuda") def backward(ctx, grad_output): bwd_args = [*ctx.saved_tensors, grad_output] padding = ctx.padding stride = ctx.stride grads = fused_conv_bias_relu.backward(bwd_args, padding, stride) - return grads[0], grads[1], grads[2], None, None, None + grad_bias = grads[2] + if grad_bias is not None and ctx.bias_shape is not None and grad_bias.shape != ctx.bias_shape: + grad_bias = grad_bias.view(ctx.bias_shape) + + return grads[0], grads[1], grad_bias, None, None, None class ConvBias_(torch.autograd.Function): @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.half) + @torch.amp.custom_fwd(cast_inputs=torch.half, device_type="cuda") def forward(ctx, x, weight, bias, padding, stride): + ctx.bias_shape = bias.shape if bias is not None else None + if bias is not None and bias.dim() != 1: + bias = bias.view(-1) outputs = fused_conv_bias_relu.forward_no_relu([x, weight, bias], padding, stride) ctx.save_for_backward(x, weight) ctx.padding = padding @@ -65,17 +80,20 @@ def forward(ctx, x, weight, bias, padding, stride): return outputs[0] @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd(device_type="cuda") def backward(ctx, grad_output): bwd_args = [*ctx.saved_tensors, grad_output] padding = ctx.padding stride = ctx.stride grads = fused_conv_bias_relu.backward_no_relu(bwd_args, padding, stride) - return grads[0], grads[1], grads[2], None, None + grad_bias = grads[2] + if grad_bias is not None and ctx.bias_shape is not None and grad_bias.shape != ctx.bias_shape: + grad_bias = grad_bias.view(ctx.bias_shape) + + return grads[0], grads[1], grad_bias, None, None ConvBiasReLU = ConvBiasReLU_.apply ConvBiasMaskReLU = ConvBiasMaskReLU_.apply ConvBias = ConvBias_.apply - diff --git a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu_rocm.cpp b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu_rocm.cpp new file mode 100644 index 000000000..7668053e2 --- /dev/null +++ b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu_rocm.cpp @@ -0,0 +1,395 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Try to include PyTorch's MIOpen handle helper +#if defined(USE_ROCM) +#include +#endif + +#define MIOPEN_CHECK(status) \ + do { \ + if ((status) != miopenStatusSuccess) { \ + std::fprintf(stderr, "MIOpen error: %d\n", static_cast(status)); \ + std::abort(); \ + } \ + } while (0) + +// Plan Cache for MIOpen Fusion +struct FusionPlanEntry { + miopenFusionPlanDescriptor_t fusion_plan; + miopenFusionOpDescriptor_t conv_op; + miopenFusionOpDescriptor_t bias_op; + miopenFusionOpDescriptor_t activ_op; +}; + +static std::unordered_map plan_cache; + +static std::vector conv_bias_relu_forward_fused(const at::Tensor& x, + const at::Tensor& weight, + const at::Tensor& bias, + int64_t padding, + int64_t stride, + bool use_relu); + +static std::vector conv_bias_forward(const at::Tensor& x, + const at::Tensor& weight, + const at::Tensor& bias, + int64_t padding, + int64_t stride, + bool use_relu) { + miopenHandle_t handle = at::native::getMiopenHandle(); + bool is_nhwc = x.is_contiguous(at::MemoryFormat::ChannelsLast); + miopenDataType_t dtype = (x.scalar_type() == at::kHalf) ? miopenHalf : miopenFloat; + + miopenTensorDescriptor_t x_desc = nullptr; + miopenTensorDescriptor_t w_desc = nullptr; + miopenTensorDescriptor_t y_desc = nullptr; + miopenTensorDescriptor_t b_desc = nullptr; + miopenConvolutionDescriptor_t conv_desc = nullptr; + + auto cleanup = [&]() { + if (b_desc) { + miopenDestroyTensorDescriptor(b_desc); + } + if (y_desc) { + miopenDestroyTensorDescriptor(y_desc); + } + if (w_desc) { + miopenDestroyTensorDescriptor(w_desc); + } + if (x_desc) { + miopenDestroyTensorDescriptor(x_desc); + } + if (conv_desc) { + miopenDestroyConvolutionDescriptor(conv_desc); + } + }; + + MIOPEN_CHECK(miopenCreateTensorDescriptor(&x_desc)); + MIOPEN_CHECK(miopenCreateTensorDescriptor(&w_desc)); + MIOPEN_CHECK(miopenCreateTensorDescriptor(&y_desc)); + MIOPEN_CHECK(miopenCreateConvolutionDescriptor(&conv_desc)); + + if (is_nhwc) { + std::vector x_dims = {(int)x.size(0), (int)x.size(1), (int)x.size(2), (int)x.size(3)}; + std::vector x_strides = {(int)x.stride(0), (int)x.stride(1), (int)x.stride(2), (int)x.stride(3)}; + MIOPEN_CHECK(miopenSetTensorDescriptor(x_desc, dtype, 4, x_dims.data(), x_strides.data())); + + std::vector w_dims = {(int)weight.size(0), (int)weight.size(1), (int)weight.size(2), (int)weight.size(3)}; + std::vector w_strides = {(int)weight.stride(0), (int)weight.stride(1), (int)weight.stride(2), (int)weight.stride(3)}; + MIOPEN_CHECK(miopenSetTensorDescriptor(w_desc, dtype, 4, w_dims.data(), w_strides.data())); + } else { + MIOPEN_CHECK(miopenSet4dTensorDescriptor(x_desc, dtype, x.size(0), x.size(1), x.size(2), x.size(3))); + MIOPEN_CHECK(miopenSet4dTensorDescriptor(w_desc, dtype, weight.size(0), weight.size(1), weight.size(2), weight.size(3))); + } + + int64_t n = x.size(0); + int64_t oc = weight.size(0); + int64_t h = (x.size(2) + 2 * padding - weight.size(2)) / stride + 1; + int64_t w = (x.size(3) + 2 * padding - weight.size(3)) / stride + 1; + std::vector out_shape = {n, oc, h, w}; + + auto out = at::empty(out_shape, x.options().memory_format(is_nhwc ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous)); + + if (is_nhwc) { + std::vector y_dims = {(int)out.size(0), (int)out.size(1), (int)out.size(2), (int)out.size(3)}; + std::vector y_strides = {(int)out.stride(0), (int)out.stride(1), (int)out.stride(2), (int)out.stride(3)}; + MIOPEN_CHECK(miopenSetTensorDescriptor(y_desc, dtype, 4, y_dims.data(), y_strides.data())); + } else { + MIOPEN_CHECK(miopenSet4dTensorDescriptor(y_desc, dtype, out.size(0), out.size(1), out.size(2), out.size(3))); + } + + MIOPEN_CHECK(miopenInitConvolutionDescriptor(conv_desc, miopenConvolution, + padding, padding, stride, stride, 1, 1)); + + size_t workspace_size = 0; + MIOPEN_CHECK(miopenConvolutionForwardGetWorkSpaceSize(handle, w_desc, x_desc, conv_desc, y_desc, &workspace_size)); + auto workspace = at::empty({static_cast(workspace_size)}, x.options().dtype(at::kByte)); + void* workspace_ptr = workspace_size ? workspace.data_ptr() : nullptr; + + miopenConvFwdAlgorithm_t algo = miopenConvolutionFwdAlgoGEMM; + miopenConvAlgoPerf_t perf_results; + int returned_algo_count = 0; + miopenStatus_t status = miopenFindConvolutionForwardAlgorithm(handle, + x_desc, x.data_ptr(), + w_desc, weight.data_ptr(), + conv_desc, + y_desc, out.data_ptr(), + 1, &returned_algo_count, + &perf_results, + workspace_ptr, workspace_size, + false); + if (status == miopenStatusSuccess && returned_algo_count > 0) { + algo = perf_results.fwd_algo; + } + + float alpha = 1.0f; + float beta = 0.0f; + MIOPEN_CHECK(miopenConvolutionForward(handle, + &alpha, + x_desc, x.data_ptr(), + w_desc, weight.data_ptr(), + conv_desc, + algo, + &beta, + y_desc, out.data_ptr(), + workspace_ptr, workspace_size)); + + if (bias.defined()) { + MIOPEN_CHECK(miopenCreateTensorDescriptor(&b_desc)); + MIOPEN_CHECK(miopenSet4dTensorDescriptor(b_desc, dtype, 1, (int)oc, 1, 1)); + MIOPEN_CHECK(miopenConvolutionForwardBias(handle, &alpha, b_desc, bias.data_ptr(), &beta, y_desc, out.data_ptr())); + } + + if (use_relu) { + out = at::relu(out); + } + + cleanup(); + return {out}; +} + +static std::vector conv_bias_forward_dispatch(const at::Tensor& x, + const at::Tensor& weight, + const at::Tensor& bias, + int64_t padding, + int64_t stride, + bool use_relu, + bool use_fusion) { + if (x.is_cuda()) { + if (use_fusion) { + return conv_bias_relu_forward_fused(x, weight, bias, padding, stride, use_relu); + } + return conv_bias_forward(x, weight, bias, padding, stride, use_relu); + } + auto out = at::convolution(x, weight, bias, {stride, stride}, {padding, padding}, {1, 1}, false, {0, 0}, 1); + if (use_relu) { + out = at::relu(out); + } + return {out}; +} + +std::string get_cache_key(const at::Tensor& x, const at::Tensor& w, int64_t padding, int64_t stride, bool relu) { + return std::to_string(x.size(0)) + "_" + std::to_string(x.size(1)) + "_" + + std::to_string(x.size(2)) + "_" + std::to_string(x.size(3)) + "_" + + std::to_string(w.size(0)) + "_" + std::to_string(w.size(1)) + "_" + + std::to_string(w.size(2)) + "_" + std::to_string(w.size(3)) + "_" + + std::to_string(padding) + "_" + std::to_string(stride) + "_" + + (x.is_contiguous(at::MemoryFormat::ChannelsLast) ? "NHWC" : "NCHW") + "_" + + (relu ? "RELU" : "NORELU"); +} + +static std::vector conv_bias_relu_forward_fused(const at::Tensor& x, + const at::Tensor& weight, + const at::Tensor& bias, + int64_t padding, + int64_t stride, + bool use_relu) { + + miopenHandle_t handle = at::native::getMiopenHandle(); + std::string key = get_cache_key(x, weight, padding, stride, use_relu); + + bool is_nhwc = x.is_contiguous(at::MemoryFormat::ChannelsLast); + miopenDataType_t dtype = (x.scalar_type() == at::kHalf) ? miopenHalf : miopenFloat; + + // Check cache + if (plan_cache.find(key) == plan_cache.end()) { + miopenFusionPlanDescriptor_t plan = nullptr; + miopenTensorDescriptor_t input_desc = nullptr; + miopenTensorDescriptor_t weight_desc = nullptr; + miopenConvolutionDescriptor_t conv_desc = nullptr; + + MIOPEN_CHECK(miopenCreateTensorDescriptor(&input_desc)); + + if (is_nhwc) { + std::vector dims = {(int)x.size(0), (int)x.size(1), (int)x.size(2), (int)x.size(3)}; + std::vector strides = {(int)x.stride(0), (int)x.stride(1), (int)x.stride(2), (int)x.stride(3)}; + MIOPEN_CHECK(miopenSetTensorDescriptor(input_desc, dtype, 4, dims.data(), strides.data())); + } else { + MIOPEN_CHECK(miopenSet4dTensorDescriptor(input_desc, dtype, x.size(0), x.size(1), x.size(2), x.size(3))); + } + + MIOPEN_CHECK(miopenCreateFusionPlan(&plan, miopenVerticalFusion, input_desc)); + + // 1. Conv Op + miopenFusionOpDescriptor_t conv_op; + MIOPEN_CHECK(miopenCreateConvolutionDescriptor(&conv_desc)); + MIOPEN_CHECK(miopenInitConvolutionDescriptor(conv_desc, miopenConvolution, + padding, padding, stride, stride, 1, 1)); + + MIOPEN_CHECK(miopenCreateTensorDescriptor(&weight_desc)); + if (is_nhwc) { + std::vector w_dims = {(int)weight.size(0), (int)weight.size(1), (int)weight.size(2), (int)weight.size(3)}; + std::vector w_strides = {(int)weight.stride(0), (int)weight.stride(1), (int)weight.stride(2), (int)weight.stride(3)}; + MIOPEN_CHECK(miopenSetTensorDescriptor(weight_desc, dtype, 4, w_dims.data(), w_strides.data())); + } else { + MIOPEN_CHECK(miopenSet4dTensorDescriptor(weight_desc, dtype, weight.size(0), weight.size(1), weight.size(2), weight.size(3))); + } + + MIOPEN_CHECK(miopenCreateOpConvForward(plan, &conv_op, conv_desc, weight_desc)); + + // 2. Bias Op + miopenFusionOpDescriptor_t bias_op = nullptr; + if (bias.defined()) { + miopenTensorDescriptor_t bias_desc = nullptr; + MIOPEN_CHECK(miopenCreateTensorDescriptor(&bias_desc)); + if(is_nhwc) + MIOPEN_CHECK(miopenSet4dTensorDescriptor(bias_desc, dtype, 1, (int)x.size(3), 1, 1)); + else + MIOPEN_CHECK(miopenSet4dTensorDescriptor(bias_desc, dtype, 1, (int)x.size(1), 1, 1)); + MIOPEN_CHECK(miopenCreateOpBiasForward(plan, &bias_op, bias_desc)); + miopenDestroyTensorDescriptor(bias_desc); + } + + // 3. Activation Op + miopenFusionOpDescriptor_t activ_op = nullptr; + if (use_relu) { + MIOPEN_CHECK(miopenCreateOpActivationForward(plan, &activ_op, miopenActivationRELU)); + }else + { + MIOPEN_CHECK(miopenCreateOpActivationForward(plan, &activ_op, miopenActivationCLAMP)); + } + + // Compile + MIOPEN_CHECK(miopenCompileFusionPlan(handle, plan)); + + plan_cache[key].fusion_plan = plan; + plan_cache[key].conv_op = conv_op; + plan_cache[key].bias_op = bias_op; + plan_cache[key].activ_op = activ_op; + + miopenDestroyTensorDescriptor(input_desc); + miopenDestroyTensorDescriptor(weight_desc); + miopenDestroyConvolutionDescriptor(conv_desc); + } + + auto& entry = plan_cache[key]; + + // Calculate output dimensions + int64_t n = x.size(0); + int64_t oc = weight.size(0); + int64_t h = (x.size(2) + 2 * padding - weight.size(2)) / stride + 1; + int64_t w = (x.size(3) + 2 * padding - weight.size(3)) / stride + 1; + std::vector out_shape = {n, oc, h, w}; + + auto out = at::empty(out_shape, x.options().memory_format(is_nhwc ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous)); + + miopenTensorDescriptor_t input_desc = nullptr; + miopenTensorDescriptor_t output_desc = nullptr; + miopenOperatorArgs_t args = nullptr; + + MIOPEN_CHECK(miopenCreateTensorDescriptor(&input_desc)); + MIOPEN_CHECK(miopenCreateTensorDescriptor(&output_desc)); + + if (is_nhwc) { + std::vector x_dims = {(int)x.size(0), (int)x.size(1), (int)x.size(2), (int)x.size(3)}; + std::vector x_strides = {(int)x.stride(0), (int)x.stride(1), (int)x.stride(2), (int)x.stride(3)}; + MIOPEN_CHECK(miopenSetTensorDescriptor(input_desc, dtype, 4, x_dims.data(), x_strides.data())); + + std::vector y_dims = {(int)out.size(0), (int)out.size(1), (int)out.size(2), (int)out.size(3)}; + std::vector y_strides = {(int)out.stride(0), (int)out.stride(1), (int)out.stride(2), (int)out.stride(3)}; + MIOPEN_CHECK(miopenSetTensorDescriptor(output_desc, dtype, 4, y_dims.data(), y_strides.data())); + } else { + MIOPEN_CHECK(miopenSet4dTensorDescriptor(input_desc, dtype, (int)x.size(0), (int)x.size(1), (int)x.size(2), (int)x.size(3))); + MIOPEN_CHECK(miopenSet4dTensorDescriptor(output_desc, dtype, (int)out.size(0), (int)out.size(1), (int)out.size(2), (int)out.size(3))); + } + + MIOPEN_CHECK(miopenCreateOperatorArgs(&args)); + + float alpha = 1.0f, beta = 0.0f; + MIOPEN_CHECK(miopenSetOpArgsConvForward(args, entry.conv_op, &alpha, &beta, weight.data_ptr())); + if (entry.bias_op && bias.defined()) { + MIOPEN_CHECK(miopenSetOpArgsBiasForward(args, entry.bias_op, &alpha, &beta, bias.data_ptr())); + } + if (entry.activ_op) { + if (use_relu) + MIOPEN_CHECK(miopenSetOpArgsActivForward(args, entry.activ_op, &alpha, &beta, 0.0, 0.0, 0.0)); + else{ + float alpha1 = -3.402823466e+38F, beta1 = 3.402823466e+38F; + MIOPEN_CHECK(miopenSetOpArgsActivForward(args, entry.activ_op, &alpha, &beta, alpha1, beta1, 0.0)); + } + } + + MIOPEN_CHECK(miopenExecuteFusionPlan(handle, entry.fusion_plan, + input_desc, x.data_ptr(), + output_desc, out.data_ptr(), + args)); + + miopenDestroyOperatorArgs(args); + miopenDestroyTensorDescriptor(input_desc); + miopenDestroyTensorDescriptor(output_desc); + + return {out}; +} + +std::vector conv_bias_relu_forward(std::vector inputs, int64_t padding, int64_t stride) { + auto x = inputs[0]; + auto weight = inputs[1]; + auto bias = inputs[2]; + return conv_bias_forward_dispatch(x, weight, bias, padding, stride, true, true); +} + +std::vector conv_bias_relu_backward(std::vector inputs, int64_t padding, int64_t stride) { + auto x = inputs[0]; + auto weight = inputs[1]; + auto out = inputs[2]; + auto grad_output = inputs[3]; + auto grad_relu = grad_output * (out > 0).to(grad_output.dtype()); + int64_t bias_size = weight.size(0); + std::vector bias_sizes = {bias_size}; + auto grads = at::convolution_backward(grad_relu, x, weight, + bias_sizes, + {stride, stride}, {padding, padding}, {1, 1}, + false, {0, 0}, 1, + {true, true, true}); + return {std::get<0>(grads), std::get<1>(grads), std::get<2>(grads)}; +} + +std::vector conv_bias_forward_api(std::vector inputs, int64_t padding, int64_t stride) { + auto x = inputs[0]; + auto weight = inputs[1]; + auto bias = inputs[2]; + return conv_bias_forward_dispatch(x, weight, bias, padding, stride, false, true); +} + +std::vector conv_bias_backward(std::vector inputs, int64_t padding, int64_t stride) { + auto x = inputs[0]; + auto weight = inputs[1]; + auto grad_output = inputs[2]; + int64_t bias_size = weight.size(0); + std::vector bias_sizes = {bias_size}; + + auto grads = at::convolution_backward(grad_output, x, weight, + bias_sizes, + {stride, stride}, {padding, padding}, {1, 1}, + false, {0, 0}, 1, + {true, true, true}); + return {std::get<0>(grads), std::get<1>(grads), std::get<2>(grads)}; +} + +std::vector conv_bias_mask_relu_forward(std::vector inputs, int64_t padding, int64_t stride) { + auto x = inputs[0]; + auto weight = inputs[1]; + auto bias = inputs[2]; + auto out_vec = conv_bias_forward_dispatch(x, weight, bias, padding, stride, false, false); + auto out = out_vec[0]; + auto mask = inputs[3]; + out = out * mask.to(out.dtype()); + return {at::relu(out)}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &conv_bias_relu_forward, "Fused Conv-Bias-ReLU forward (ROCm MIOpen Fusion)"); + m.def("backward", &conv_bias_relu_backward, "Conv-Bias-ReLU backward (ROCm)"); + m.def("forward_no_relu", &conv_bias_forward_api, "Conv-Bias forward (ROCm)"); + m.def("backward_no_relu", &conv_bias_backward, "Conv-Bias backward (ROCm)"); + m.def("forward_mask", &conv_bias_mask_relu_forward, "Conv-Bias-Mask-ReLU forward (ROCm)"); +} diff --git a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp b/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp index 15393fbe4..f32b0131b 100644 --- a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp +++ b/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp @@ -21,9 +21,9 @@ at::Tensor focal_loss_backward_cuda( // C++ interface -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) \ - AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) @@ -64,7 +64,9 @@ at::Tensor focal_loss_backward( PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &focal_loss_forward, - "Focal loss calculation forward (CUDA)"); + "Focal loss calculation forward (CUDA)", + py::call_guard()); m.def("backward", &focal_loss_backward, - "Focal loss calculation backward (CUDA)"); + "Focal loss calculation backward (CUDA)", + py::call_guard()); } diff --git a/apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu b/apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu index bda4f8890..a93160bcb 100644 --- a/apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu +++ b/apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu @@ -11,6 +11,12 @@ template bool is_aligned(const void *ptr) noexcept { return !(iptr % alignof(T)); } +__device__ __forceinline__ float fast_pow_ml(float x, float y) { + // Hardware instructions: v_log_f32 followed by v_exp_f32 + // x^y = exp2(y * log2(x)) + return __builtin_amdgcn_exp2f(y * __builtin_amdgcn_logf(x)); +} + template __global__ void focal_loss_forward_cuda_kernel( @@ -94,7 +100,19 @@ __global__ void focal_loss_forward_cuda_kernel( coeff_b2 = sigma; } - accscalar_t coeff_f = coeff_f1 * ::pow(coeff_f2, gamma); + // Specialized pow for common gamma values to reduce VALU pressure + accscalar_t coeff_f; + if (gamma == 2.0f) { + coeff_f = coeff_f1 * (coeff_f2 * coeff_f2); + } else if (gamma == 1.0f) { + coeff_f = coeff_f1 * coeff_f2; + } else if (gamma == 0.0f) { + coeff_f = coeff_f1; + } else { + constexpr bool is_float_v = std::is_same::value; + coeff_f = coeff_f1 * (is_float_v ? (accscalar_t)fast_pow_ml(float(coeff_f2), gamma) : ::pow(coeff_f2, gamma)); + } + accscalar_t coeff_b = coeff_b1 * coeff_b2; accscalar_t loss_t = coeff_f * (base + off_a); diff --git a/apex/contrib/csrc/groupbn/batch_norm.cu b/apex/contrib/csrc/groupbn/batch_norm.cu index 1ec98eeea..92eb11fbe 100644 --- a/apex/contrib/csrc/groupbn/batch_norm.cu +++ b/apex/contrib/csrc/groupbn/batch_norm.cu @@ -63,34 +63,38 @@ at::Tensor nhwc_bn_fwd_train( const int grid_dim_x, const bool coop) { + auto memory_format = x.suggest_memory_format(); + const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast); const int N = x.size(0); - const int H = x.size(1); - const int W = x.size(2); - const int C = x.size(3); + const int H = check_channels_last ? x.size(2) : x.size(1); + const int W = check_channels_last ? x.size(3) : x.size(2); + const int C = check_channels_last ? x.size(1) : x.size(3); // generating new magic number and use that for sync int* magic = magic_tensor.DATA_PTR(); *magic = (*magic + 1) & 0xff; // Allocate output tensor - at::Tensor y = at::empty({N, H, W, C}, x.options()); + at::Tensor y = check_channels_last ? at::empty({N, C, H, W}, x.options().memory_format(memory_format)) : at::empty({N, H, W, C}, x.options()); // Create wrapper NhwcBatchNorm *bn = new NhwcBatchNorm(); - bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); - bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); + bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group); + bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W); bn->setConstants(momentum, epsilon); // set pointers within the wrapper - bn->setInputOutputPointers(x.DATA_PTR(), + bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR(), nullptr, - y.DATA_PTR(), + y.contiguous(memory_format).DATA_PTR(), nullptr); - bn->setWeightPointers({scale.DATA_PTR(), bias.DATA_PTR()}, {nullptr, nullptr}); - bn->setParameterPointers({running_mean.DATA_PTR(), running_inv_var.DATA_PTR()}); + bn->setWeightPointers({scale.contiguous().DATA_PTR(), + bias.contiguous().DATA_PTR()}, {nullptr, nullptr}); + bn->setParameterPointers({running_mean.contiguous().DATA_PTR(), + running_inv_var.DATA_PTR()}); // deal with workspace(s) auto workspace_bytes = bn->numWorkspaceBytes(); @@ -111,12 +115,12 @@ at::Tensor nhwc_bn_fwd_train( Workspace ws(total_workspace_bytes); std::vector workspace; - workspace.push_back(minibatch_mean.DATA_PTR()); - workspace.push_back(minibatch_inv_var.DATA_PTR()); + workspace.push_back(minibatch_mean.contiguous().DATA_PTR()); + workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR()); auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[2]; - void* retired_ctas = ret_cta.DATA_PTR(); + void* retired_ctas = ret_cta.contiguous().DATA_PTR(); assert(ret_cta.size(0)>=retired_cta_bytes); workspace.push_back(retired_ctas); @@ -130,7 +134,7 @@ at::Tensor nhwc_bn_fwd_train( // Don't fuse in ReLU for now at least bn->fwd(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); - return y; + return y.contiguous(memory_format); } at::Tensor nhwc_bn_fwd_eval( @@ -145,30 +149,34 @@ at::Tensor nhwc_bn_fwd_eval( const float epsilon, const bool fuse_relu) { + const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast); + auto memory_format = x.suggest_memory_format(); const int N = x.size(0); - const int H = x.size(1); - const int W = x.size(2); - const int C = x.size(3); + const int H = check_channels_last ? x.size(2) : x.size(1); + const int W = check_channels_last ? x.size(3) : x.size(2); + const int C = check_channels_last ? x.size(1) : x.size(3); // Allocate output tensor - at::Tensor y = at::empty({N, H, W, C}, x.options()); + at::Tensor y = check_channels_last ? at::empty({N, C, H, W}, x.options().memory_format(memory_format)) : at::empty({N, H, W, C}, x.options()); // Create wrapper NhwcBatchNorm *bn = new NhwcBatchNorm(); - bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); - bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); + bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group); + bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W); bn->setConstants(momentum, epsilon); // set pointers within the wrapper - bn->setInputOutputPointers(x.DATA_PTR(), + bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR(), nullptr, - y.DATA_PTR(), + y.contiguous(memory_format).DATA_PTR(), nullptr); - bn->setWeightPointers({scale.DATA_PTR(), bias.DATA_PTR()}, {nullptr, nullptr}); - bn->setParameterPointers({running_mean.DATA_PTR(), running_inv_var.DATA_PTR()}); + bn->setWeightPointers({scale.contiguous().DATA_PTR(), + bias.contiguous().DATA_PTR()}, {nullptr, nullptr}); + bn->setParameterPointers({running_mean.contiguous().DATA_PTR(), + running_inv_var.contiguous().DATA_PTR()}); // deal with workspace(s) auto workspace_bytes = bn->numWorkspaceBytes(); @@ -194,7 +202,7 @@ at::Tensor nhwc_bn_fwd_eval( auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[2]; - void* retired_ctas = ret_cta.DATA_PTR(); + void* retired_ctas = ret_cta.contiguous().DATA_PTR(); assert(ret_cta.size(0)>=retired_cta_bytes); workspace.push_back(retired_ctas); @@ -208,7 +216,7 @@ at::Tensor nhwc_bn_fwd_eval( // Don't fuse in ReLU for now at least bn->fwdInference(stream, fuse_relu); - return y; + return y.contiguous(memory_format); } @@ -235,10 +243,12 @@ std::vector nhwc_bn_bwd( const int grid_dim_x, const bool coop) { // shape + const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast); + auto memory_format = x.suggest_memory_format(); const int N = x.size(0); - const int H = x.size(1); - const int W = x.size(2); - const int C = x.size(3); + const int H = check_channels_last ? x.size(2) : x.size(1); + const int W = check_channels_last ? x.size(3) : x.size(2); + const int C = check_channels_last ? x.size(1) : x.size(3); // generating new magic number and use that for sync int* magic = magic_tensor.DATA_PTR(); @@ -248,26 +258,30 @@ std::vector nhwc_bn_bwd( at::Tensor x_grad, scale_grad, bias_grad; // Allocate outputs - x_grad = at::empty_like(x); + x_grad = check_channels_last ? at::empty({N, C, H, W}, dy.options().memory_format(memory_format)) : at::empty_like(x); scale_grad = at::empty_like(scale); bias_grad = at::empty_like(bias); // Create wrapper NhwcBatchNorm *bn = new NhwcBatchNorm(); - bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); - bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); + bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group); + bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W); bn->setConstants(momentum, epsilon); // set pointers within the wrapper - bn->setInputOutputPointers(x.DATA_PTR(), - x_grad.DATA_PTR(), + bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR(), + x_grad.contiguous(memory_format).DATA_PTR(), nullptr, - dy.DATA_PTR()); + dy.contiguous(memory_format).DATA_PTR()); - bn->setWeightPointers({scale.DATA_PTR(), bias.DATA_PTR()}, {scale_grad.DATA_PTR(), bias_grad.DATA_PTR()}); - bn->setParameterPointers({running_mean.DATA_PTR(), running_inv_var.DATA_PTR()}); + bn->setWeightPointers({scale.contiguous().DATA_PTR(), + bias.contiguous().DATA_PTR()}, + {scale_grad.DATA_PTR(), + bias_grad.DATA_PTR()}); + bn->setParameterPointers({running_mean.contiguous().DATA_PTR(), + running_inv_var.contiguous().DATA_PTR()}); // deal with workspace(s) auto workspace_bytes = bn->numWorkspaceBytes(); @@ -288,12 +302,12 @@ std::vector nhwc_bn_bwd( Workspace ws(total_workspace_bytes); std::vector workspace; - workspace.push_back(minibatch_mean.DATA_PTR()); - workspace.push_back(minibatch_inv_var.DATA_PTR()); + workspace.push_back(minibatch_mean.contiguous().DATA_PTR()); + workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR()); auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[2]; - void* retired_ctas = ret_cta.DATA_PTR(); + void* retired_ctas = ret_cta.contiguous().DATA_PTR(); assert(ret_cta.size(0)>=retired_cta_bytes); workspace.push_back(retired_ctas); @@ -306,7 +320,7 @@ std::vector nhwc_bn_bwd( bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); - return std::vector{x_grad, scale_grad, bias_grad}; + return std::vector{x_grad.contiguous(memory_format), scale_grad, bias_grad}; } int nhwc_bn_fwd_occupancy() { diff --git a/apex/contrib/csrc/groupbn/batch_norm.h b/apex/contrib/csrc/groupbn/batch_norm.h index bb79d6758..e52751bce 100644 --- a/apex/contrib/csrc/groupbn/batch_norm.h +++ b/apex/contrib/csrc/groupbn/batch_norm.h @@ -26,7 +26,7 @@ #ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_ #define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_ -#include +#include "dnn.h" #include #include @@ -35,7 +35,8 @@ #include "nhwc_batch_norm_kernel.h" #include "cuda_utils.h" - +#include "c10/macros/Macros.h" +#include #define VERBOSE_DEFAULT false @@ -63,8 +64,8 @@ class NhwcBatchNorm { dim3 calc_fwd_grid(int *loop, const int grid_dim_x); dim3 calc_bwd_grid(int *loop, const int grid_dim_x); - void setInputDescriptor(const cudnnTensorFormat_t format, - const cudnnDataType_t data_type, + void setInputDescriptor(const dnnTensorFormat_t format, + const dnnDataType_t data_type, int n, int c, int h, int w, int bn_group) { m_ = n * h * w; int m_bn_adjusted = m_ * bn_group; @@ -78,8 +79,8 @@ class NhwcBatchNorm { setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w); } - void setOutputDescriptor(const cudnnTensorFormat_t format, - const cudnnDataType_t data_type, + void setOutputDescriptor(const dnnTensorFormat_t format, + const dnnDataType_t data_type, int n, int c, int h, int w) { setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w); } @@ -120,13 +121,20 @@ class NhwcBatchNorm { eps_ = eps; } - void processCudnnStatus(const cudnnStatus_t& status, + void processCudnnStatus(const dnnStatus_t& status, const std::string& string = std::string(), bool verbose = VERBOSE_DEFAULT) { - if (status != CUDNN_STATUS_SUCCESS) +#ifdef USE_ROCM + if (status != DNN_STATUS_SUCCESS) + LOG(FATAL) << string << " " << miopenGetErrorString(status); + else if (verbose) + LOG(INFO) << string << " " << miopenGetErrorString(status); +#else + if (status != DNN_STATUS_SUCCESS) LOG(FATAL) << string << " " << cudnnGetErrorString(status); else if (verbose) LOG(INFO) << string << " " << cudnnGetErrorString(status); +#endif } void checkCudaStatus(const std::string& string = std::string(), @@ -149,8 +157,8 @@ class NhwcBatchNorm { return retired_cta_bytes; } - cudnnTensorDescriptor_t X_tensor_desc_ = nullptr; - cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr; + dnnTensorDescriptor_t X_tensor_desc_ = nullptr; + dnnTensorDescriptor_t Y_tensor_desc_ = nullptr; void* X_ = nullptr; void* dX_ = nullptr; @@ -182,24 +190,36 @@ class NhwcBatchNorm { std::string name_; private: - void setTensorDescriptor(cudnnTensorDescriptor_t descriptor, - cudnnTensorFormat_t format, - cudnnDataType_t data_type, + void setTensorDescriptor(dnnTensorDescriptor_t descriptor, + dnnTensorFormat_t format, + dnnDataType_t data_type, int n, int c, int h, int w) { - cudnnStatus_t status = CUDNN_STATUS_SUCCESS; + dnnStatus_t status = DNN_STATUS_SUCCESS; +#ifdef USE_ROCM + status = miopenSet4dTensorDescriptor(descriptor, data_type, n, c, h, w); +#else status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w); +#endif processCudnnStatus(status, "set tensor descriptor"); } - void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) { - cudnnStatus_t status = CUDNN_STATUS_SUCCESS; + void createTensorDescriptor(dnnTensorDescriptor_t *descriptor) { + dnnStatus_t status = DNN_STATUS_SUCCESS; +#ifdef USE_ROCM + status = miopenCreateTensorDescriptor(descriptor); +#else status = cudnnCreateTensorDescriptor(descriptor); +#endif processCudnnStatus(status, "create tensor_descriptor"); } - void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) { - cudnnStatus_t status = CUDNN_STATUS_SUCCESS; + void destroyTensorDescriptor(dnnTensorDescriptor_t descriptor) { + dnnStatus_t status = DNN_STATUS_SUCCESS; +#ifdef USE_ROCM + status = miopenDestroyTensorDescriptor(descriptor); +#else status = cudnnDestroyTensorDescriptor(descriptor); +#endif processCudnnStatus(status, "destroy tensor_descriptor"); } @@ -216,18 +236,18 @@ class NhwcBatchNorm { // Kernel params static const int USE_ONLINE_APPROACH = 1; static const int THREADS_PER_CTA = 512; - static const int THREADS_PER_PIXEL = 16; - static const int C_ELEMENTS_PER_CTA = 64; + static const int THREADS_PER_PIXEL = 32; + static const int C_ELEMENTS_PER_CTA = 128; static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL; static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024; typedef uint16_t StorageType; //typedef float StorageType; // increasing this to 6 causes spills in fwd kernel! - static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5; - static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3; - static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10; - static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5; + static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 1; + static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 1; + static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 0; + static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 0; static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \ PIXELS_PER_THREAD_IN_SMEM_FWD; @@ -259,6 +279,57 @@ class NhwcBatchNorm { void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params, dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) { +#ifdef USE_ROCM +#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ + do { \ + CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ + auto fwd_func = nhwc_batch_norm_fwd< \ + StorageType, \ + THREADS_PER_CTA, \ + THREADS_PER_PIXEL, \ + PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ + PIXELS_PER_THREAD_IN_SMEM_FWD, \ + ELEMENTS_PER_LDG, \ + USE_ONLINE_APPROACH, \ + OUTER_LOOPS, \ + USE_RELU, \ + USE_ADD_RELU, \ + COMPILED_FOR_OCCUPANCY>; \ + if (COMPILED_FOR_OCCUPANCY > 1) { \ + hipFuncSetAttribute((void *) fwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \ + checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \ + } \ + void *params_ptr = static_cast(¶ms); \ + using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \ + StorageType, \ + THREADS_PER_CTA, \ + THREADS_PER_PIXEL, \ + PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ + PIXELS_PER_THREAD_IN_SMEM_FWD, \ + ELEMENTS_PER_LDG, \ + USE_ONLINE_APPROACH, \ + OUTER_LOOPS, \ + USE_RELU, \ + USE_ADD_RELU, \ + COMPILED_FOR_OCCUPANCY>); \ + if (COOP) { \ + hipLaunchCooperativeKernel(fwd_func, \ + grid_dim, \ + THREADS_PER_CTA, \ + ¶ms_ptr, \ + SMEM_SIZE_FWD, \ + stream); \ + } else { \ + hipLaunchKernel((void *) fwd_func, \ + grid_dim, \ + THREADS_PER_CTA, \ + ¶ms_ptr, \ + SMEM_SIZE_FWD, \ + stream); \ + } \ + checkCudaStatus(name_ + " fwd ser coop kernel"); \ + } while (0) +#else #define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ do { \ CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ @@ -308,6 +379,7 @@ class NhwcBatchNorm { } \ checkCudaStatus(name_ + " fwd ser coop kernel"); \ } while (0) +#endif // Don't try for an occupancy > 2 as this will squeeze register use and create spills. if (outer_loops == 1 && use_relu) { @@ -338,6 +410,99 @@ class NhwcBatchNorm { void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params, dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) { +#ifdef USE_ROCM +#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ + do { \ + CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ + auto bwd_func = nhwc_batch_norm_bwd< \ + StorageType, \ + THREADS_PER_CTA, \ + THREADS_PER_PIXEL, \ + PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ + PIXELS_PER_THREAD_IN_SMEM_BWD, \ + ELEMENTS_PER_LDG, \ + USE_ONLINE_APPROACH, \ + OUTER_LOOPS, \ + COMPILED_FOR_OCCUPANCY>; \ + if (COMPILED_FOR_OCCUPANCY > 1) { \ + hipFuncSetAttribute((void *) bwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \ + checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \ + } \ + void *params_ptr = static_cast(¶ms); \ + using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \ + StorageType, \ + THREADS_PER_CTA, \ + THREADS_PER_PIXEL, \ + PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ + PIXELS_PER_THREAD_IN_SMEM_BWD, \ + ELEMENTS_PER_LDG, \ + USE_ONLINE_APPROACH, \ + OUTER_LOOPS, \ + COMPILED_FOR_OCCUPANCY>); \ + if (COOP) { \ + hipLaunchCooperativeKernel(bwd_func, \ + grid_dim, \ + THREADS_PER_CTA, \ + ¶ms_ptr, \ + SMEM_SIZE_BWD, \ + stream); \ + } else { \ + hipLaunchKernel((void *) bwd_func, \ + grid_dim, \ + THREADS_PER_CTA, \ + ¶ms_ptr, \ + SMEM_SIZE_BWD, \ + stream); \ + } \ + checkCudaStatus(name_ + " bwd coop serial kernel"); \ + } while (0) + +#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ + do { \ + CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ + auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \ + StorageType, \ + THREADS_PER_CTA, \ + THREADS_PER_PIXEL, \ + PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ + PIXELS_PER_THREAD_IN_SMEM_BWD, \ + ELEMENTS_PER_LDG, \ + USE_ONLINE_APPROACH, \ + OUTER_LOOPS, \ + COMPILED_FOR_OCCUPANCY>; \ + if (COMPILED_FOR_OCCUPANCY > 1) { \ + hipFuncSetAttribute((void *) bwd_relu_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \ + checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \ + } \ + void *params_ptr = static_cast(¶ms); \ + using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \ + StorageType, \ + THREADS_PER_CTA, \ + THREADS_PER_PIXEL, \ + PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ + PIXELS_PER_THREAD_IN_SMEM_BWD, \ + ELEMENTS_PER_LDG, \ + USE_ONLINE_APPROACH, \ + OUTER_LOOPS, \ + COMPILED_FOR_OCCUPANCY>); \ + if (COOP) { \ + hipLaunchCooperativeKernel(bwd_relu_func, \ + grid_dim, \ + THREADS_PER_CTA, \ + ¶ms_ptr, \ + SMEM_SIZE_BWD, \ + stream); \ + } else { \ + hipLaunchKernel((void *) bwd_relu_func, \ + grid_dim, \ + THREADS_PER_CTA, \ + ¶ms_ptr, \ + SMEM_SIZE_BWD, \ + stream); \ + } \ + checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \ + } while (0) +#else #define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ do { \ CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ @@ -429,6 +594,7 @@ class NhwcBatchNorm { } \ checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \ } while (0) +#endif // Don't try for an occupancy > 2 as this will squeeze register use and create spills. if (outer_loops == 1 && use_relu) { @@ -460,7 +626,7 @@ class NhwcBatchNorm { // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; - int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float); + int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/at::cuda::warp_size())*ELEMENTS_PER_LDG*sizeof(float); int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); @@ -469,7 +635,7 @@ class NhwcBatchNorm { // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; - int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float); + int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/at::cuda::warp_size())*ELEMENTS_PER_LDG*sizeof(float); int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); diff --git a/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu b/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu index e383fb800..d3cc61523 100644 --- a/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu +++ b/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu @@ -65,36 +65,40 @@ at::Tensor nhwc_bn_addrelu_fwd_train( const int grid_dim_x, const bool coop) { + auto memory_format = x.suggest_memory_format(); + const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast); const int N = x.size(0); - const int H = x.size(1); - const int W = x.size(2); - const int C = x.size(3); + const int H = check_channels_last ? x.size(2) : x.size(1); + const int W = check_channels_last ? x.size(3) : x.size(2); + const int C = check_channels_last ? x.size(1) : x.size(3); // generating new magic number and use that for sync int* magic = magic_tensor.DATA_PTR(); *magic = (*magic + 1) & 0xff; // Allocate output tensor - at::Tensor y = at::empty({N, H, W, C}, x.options()); + at::Tensor y = check_channels_last? at::empty({N, C, H, W}, x.options().memory_format(memory_format)) : at::empty({N, H, W, C}, x.options()); // Create wrapper NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu(); - bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); - bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); + bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group); + bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W); bn->setConstants(momentum, epsilon); // set pointers within the wrapper - bn->setInputOutputPointers(x.DATA_PTR(), + bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR(), nullptr, - y.DATA_PTR(), + y.contiguous(memory_format).DATA_PTR(), nullptr, - z.DATA_PTR(), + z.contiguous(memory_format).DATA_PTR(), nullptr); - bn->setWeightPointers({scale.DATA_PTR(), bias.DATA_PTR()}, {nullptr, nullptr}); - bn->setParameterPointers({running_mean.DATA_PTR(), running_inv_var.DATA_PTR()}); + bn->setWeightPointers({scale.contiguous().DATA_PTR(), + bias.contiguous().DATA_PTR()}, {nullptr, nullptr}); + bn->setParameterPointers({running_mean.contiguous().DATA_PTR(), + running_inv_var.contiguous().DATA_PTR()}); // deal with workspace(s) auto workspace_bytes = bn->numWorkspaceBytes(); @@ -115,13 +119,13 @@ at::Tensor nhwc_bn_addrelu_fwd_train( Workspace ws(total_workspace_bytes); std::vector workspace; - workspace.push_back(minibatch_mean.DATA_PTR()); - workspace.push_back(minibatch_inv_var.DATA_PTR()); - workspace.push_back(bitmask.DATA_PTR()); + workspace.push_back(minibatch_mean.contiguous().DATA_PTR()); + workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR()); + workspace.push_back(bitmask.contiguous().DATA_PTR()); auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[3]; - void* retired_ctas = ret_cta.DATA_PTR(); + void* retired_ctas = ret_cta.contiguous().DATA_PTR(); assert(ret_cta.size(0)>=retired_cta_bytes); workspace.push_back(retired_ctas); @@ -136,7 +140,7 @@ at::Tensor nhwc_bn_addrelu_fwd_train( // Don't fuse in ReLU for now at least bn->fwd(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); - return y; + return y.contiguous(memory_format); } at::Tensor nhwc_bn_addrelu_fwd_eval( @@ -151,32 +155,36 @@ at::Tensor nhwc_bn_addrelu_fwd_eval( const float momentum, const float epsilon) { + auto memory_format = x.suggest_memory_format(); + const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast); const int N = x.size(0); - const int H = x.size(1); - const int W = x.size(2); - const int C = x.size(3); + const int H = check_channels_last ? x.size(2) : x.size(1); + const int W = check_channels_last ? x.size(3) : x.size(2); + const int C = check_channels_last ? x.size(1) : x.size(3); // Allocate output tensor - at::Tensor y = at::empty({N, H, W, C}, x.options()); + at::Tensor y = check_channels_last? at::empty({N, C, H, W}, x.options().memory_format(memory_format)): at::empty({N, H, W, C}, x.options()); // Create wrapper NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu(); - bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); - bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); + bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group); + bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W); bn->setConstants(momentum, epsilon); // set pointers within the wrapper - bn->setInputOutputPointers(x.DATA_PTR(), + bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR(), nullptr, - y.DATA_PTR(), + y.contiguous(memory_format).DATA_PTR(), nullptr, - z.DATA_PTR(), + z.contiguous(memory_format).DATA_PTR(), nullptr); - bn->setWeightPointers({scale.DATA_PTR(), bias.DATA_PTR()}, {nullptr, nullptr}); - bn->setParameterPointers({running_mean.DATA_PTR(), running_inv_var.DATA_PTR()}); + bn->setWeightPointers({scale.contiguous().DATA_PTR(), + bias.contiguous().DATA_PTR()}, {nullptr, nullptr}); + bn->setParameterPointers({running_mean.contiguous().DATA_PTR(), + running_inv_var.contiguous().DATA_PTR()}); // deal with workspace(s) auto workspace_bytes = bn->numWorkspaceBytes(); @@ -203,7 +211,7 @@ at::Tensor nhwc_bn_addrelu_fwd_eval( auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[3]; - void* retired_ctas = ret_cta.DATA_PTR(); + void* retired_ctas = ret_cta.contiguous().DATA_PTR(); assert(ret_cta.size(0)>=retired_cta_bytes); workspace.push_back(retired_ctas); @@ -217,7 +225,7 @@ at::Tensor nhwc_bn_addrelu_fwd_eval( // Don't fuse in ReLU for now at least bn->fwdInference(stream); - return y; + return y.contiguous(memory_format); } @@ -244,10 +252,12 @@ std::vector nhwc_bn_addrelu_bwd( const int grid_dim_x, const bool coop) { // shape + auto memory_format = x.suggest_memory_format(); + const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast); const int N = x.size(0); - const int H = x.size(1); - const int W = x.size(2); - const int C = x.size(3); + const int H = check_channels_last ? x.size(2) : x.size(1); + const int W = check_channels_last ? x.size(3) : x.size(2); + const int C = check_channels_last ? x.size(1) : x.size(3); // generating new magic number and use that for sync int* magic = magic_tensor.DATA_PTR(); @@ -257,29 +267,32 @@ std::vector nhwc_bn_addrelu_bwd( at::Tensor x_grad, z_grad, scale_grad, bias_grad; // Allocate outputs - x_grad = at::empty_like(x); - z_grad = at::empty_like(x); + x_grad = check_channels_last ? at::empty({N, C, H, W}, dy.options().memory_format(memory_format)) : at::empty_like(x); + z_grad = check_channels_last ? at::empty({N, C, H, W}, dy.options().memory_format(memory_format)) : at::empty_like(x); scale_grad = at::empty_like(scale); bias_grad = at::empty_like(bias); // Create wrapper NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu(); - bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group); - bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W); + bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group); + bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W); bn->setConstants(momentum, epsilon); // set pointers within the wrapper - bn->setInputOutputPointers(x.DATA_PTR(), - x_grad.DATA_PTR(), + bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR(), + x_grad.contiguous(memory_format).DATA_PTR(), nullptr, - dy.DATA_PTR(), + dy.contiguous(memory_format).DATA_PTR(), nullptr, - z_grad.DATA_PTR()); + z_grad.contiguous(memory_format).DATA_PTR()); - bn->setWeightPointers({scale.DATA_PTR(), bias.DATA_PTR()}, {scale_grad.DATA_PTR(), bias_grad.DATA_PTR()}); - bn->setParameterPointers({running_mean.DATA_PTR(), running_inv_var.DATA_PTR()}); + bn->setWeightPointers({scale.contiguous().DATA_PTR(), + bias.contiguous().DATA_PTR()}, + {scale_grad.DATA_PTR(), bias_grad.DATA_PTR()}); + bn->setParameterPointers({running_mean.contiguous().DATA_PTR(), + running_inv_var.contiguous().DATA_PTR()}); // deal with workspace(s) auto workspace_bytes = bn->numWorkspaceBytes(); @@ -300,13 +313,13 @@ std::vector nhwc_bn_addrelu_bwd( Workspace ws(total_workspace_bytes); std::vector workspace; - workspace.push_back(minibatch_mean.DATA_PTR()); - workspace.push_back(minibatch_inv_var.DATA_PTR()); - workspace.push_back(bitmask.DATA_PTR()); + workspace.push_back(minibatch_mean.contiguous().DATA_PTR()); + workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR()); + workspace.push_back(bitmask.contiguous().DATA_PTR()); auto stream = at::cuda::getCurrentCUDAStream().stream(); const int retired_cta_bytes = workspace_bytes[3]; - void* retired_ctas = ret_cta.DATA_PTR(); + void* retired_ctas = ret_cta.contiguous().DATA_PTR(); assert(ret_cta.size(0)>=retired_cta_bytes); workspace.push_back(retired_ctas); @@ -319,7 +332,7 @@ std::vector nhwc_bn_addrelu_bwd( bn->dgrad(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); - return std::vector{x_grad, z_grad, scale_grad, bias_grad}; + return std::vector{x_grad.contiguous(memory_format), z_grad.contiguous(memory_format), scale_grad, bias_grad}; } int nhwc_bn_addrelu_fwd_occupancy() { diff --git a/apex/contrib/csrc/groupbn/batch_norm_add_relu.h b/apex/contrib/csrc/groupbn/batch_norm_add_relu.h index 3dfe7b269..0481a9408 100644 --- a/apex/contrib/csrc/groupbn/batch_norm_add_relu.h +++ b/apex/contrib/csrc/groupbn/batch_norm_add_relu.h @@ -26,7 +26,7 @@ #ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_ #define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_ -#include +#include "dnn.h" #include #include @@ -35,7 +35,16 @@ #include "nhwc_batch_norm_kernel.h" #include "cuda_utils.h" +#include "c10/macros/Macros.h" +#include +#ifdef USE_ROCM +using bitmask_t = uint64_t; +using bitmask_pyt_t = int64_t; +#else +using bitmask_t = unsigned int; +using bitmask_pyt_t = int32_t; +#endif #define VERBOSE_DEFAULT false @@ -63,8 +72,8 @@ class NhwcBatchNormAddRelu { dim3 calc_fwd_grid(int *loop, const int grid_dim_x); dim3 calc_bwd_grid(int *loop, const int grid_dim_x); - void setInputDescriptor(const cudnnTensorFormat_t format, - const cudnnDataType_t data_type, + void setInputDescriptor(const dnnTensorFormat_t format, + const dnnDataType_t data_type, int n, int c, int h, int w, int bn_group) { m_ = n * h * w; int m_bn_adjusted = m_ * bn_group; @@ -78,8 +87,8 @@ class NhwcBatchNormAddRelu { setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w); } - void setOutputDescriptor(const cudnnTensorFormat_t format, - const cudnnDataType_t data_type, + void setOutputDescriptor(const dnnTensorFormat_t format, + const dnnDataType_t data_type, int n, int c, int h, int w) { setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w); } @@ -122,13 +131,20 @@ class NhwcBatchNormAddRelu { eps_ = eps; } - void processCudnnStatus(const cudnnStatus_t& status, + void processCudnnStatus(const dnnStatus_t& status, const std::string& string = std::string(), bool verbose = VERBOSE_DEFAULT) { - if (status != CUDNN_STATUS_SUCCESS) +#ifdef USE_ROCM + if (status != DNN_STATUS_SUCCESS) + LOG(FATAL) << string << " " << miopenGetErrorString(status); + else if (verbose) + LOG(INFO) << string << " " << miopenGetErrorString(status); +#else + if (status != DNN_STATUS_SUCCESS) LOG(FATAL) << string << " " << cudnnGetErrorString(status); else if (verbose) LOG(INFO) << string << " " << cudnnGetErrorString(status); +#endif } void checkCudaStatus(const std::string& string = std::string(), @@ -151,8 +167,8 @@ class NhwcBatchNormAddRelu { return retired_cta_bytes; } - cudnnTensorDescriptor_t X_tensor_desc_ = nullptr; - cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr; + dnnTensorDescriptor_t X_tensor_desc_ = nullptr; + dnnTensorDescriptor_t Y_tensor_desc_ = nullptr; void* X_ = nullptr; void* dX_ = nullptr; @@ -186,24 +202,36 @@ class NhwcBatchNormAddRelu { std::string name_; private: - void setTensorDescriptor(cudnnTensorDescriptor_t descriptor, - cudnnTensorFormat_t format, - cudnnDataType_t data_type, + void setTensorDescriptor(dnnTensorDescriptor_t descriptor, + dnnTensorFormat_t format, + dnnDataType_t data_type, int n, int c, int h, int w) { - cudnnStatus_t status = CUDNN_STATUS_SUCCESS; + dnnStatus_t status = DNN_STATUS_SUCCESS; +#ifdef USE_ROCM + status = miopenSet4dTensorDescriptor(descriptor, data_type, n, c, h, w); +#else status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w); +#endif processCudnnStatus(status, "set tensor descriptor"); } - void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) { - cudnnStatus_t status = CUDNN_STATUS_SUCCESS; + void createTensorDescriptor(dnnTensorDescriptor_t *descriptor) { + dnnStatus_t status = DNN_STATUS_SUCCESS; +#ifdef USE_ROCM + status = miopenCreateTensorDescriptor(descriptor); +#else status = cudnnCreateTensorDescriptor(descriptor); +#endif processCudnnStatus(status, "create tensor_descriptor"); } - void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) { - cudnnStatus_t status = CUDNN_STATUS_SUCCESS; + void destroyTensorDescriptor(dnnTensorDescriptor_t descriptor) { + dnnStatus_t status = DNN_STATUS_SUCCESS; +#ifdef USE_ROCM + status = miopenDestroyTensorDescriptor(descriptor); +#else status = cudnnDestroyTensorDescriptor(descriptor); +#endif processCudnnStatus(status, "destroy tensor_descriptor"); } @@ -211,7 +239,7 @@ class NhwcBatchNormAddRelu { float *partial_sums_ = nullptr; int *partial_counts_ = nullptr; int *retired_ctas_ = nullptr; - unsigned int *relu_bitmask_ = nullptr; + bitmask_t *relu_bitmask_ = nullptr; void _setFwdParams(NhwcBatchNormFwdParams *params) const; void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const; @@ -221,17 +249,17 @@ class NhwcBatchNormAddRelu { // Kernel params static const int USE_ONLINE_APPROACH = 1; static const int THREADS_PER_CTA = 512; - static const int THREADS_PER_PIXEL = 16; - static const int C_ELEMENTS_PER_CTA = 64; + static const int THREADS_PER_PIXEL = 32; + static const int C_ELEMENTS_PER_CTA = 128; static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL; static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024; typedef uint16_t StorageType; // increasing this to 6 causes spills in fwd kernel! - static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5; - static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3; - static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10; - static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5; + static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 1; + static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 1; + static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 0; + static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 0; static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \ PIXELS_PER_THREAD_IN_SMEM_FWD; @@ -262,6 +290,58 @@ class NhwcBatchNormAddRelu { // needless register spills. void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params, dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) { +#ifdef USE_ROCM +#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ + do { \ + CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \ + "Nhwc batchnormaddrelu kernel smem too big."; \ + auto fwd_func = nhwc_batch_norm_fwd< \ + StorageType, \ + THREADS_PER_CTA, \ + THREADS_PER_PIXEL, \ + PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ + PIXELS_PER_THREAD_IN_SMEM_FWD, \ + ELEMENTS_PER_LDG, \ + USE_ONLINE_APPROACH, \ + OUTER_LOOPS, \ + USE_RELU, \ + USE_ADD_RELU, \ + COMPILED_FOR_OCCUPANCY>; \ + if (COMPILED_FOR_OCCUPANCY > 1) { \ + hipFuncSetAttribute((void *) fwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \ + checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \ + } \ + void *params_ptr = static_cast(¶ms); \ + using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \ + StorageType, \ + THREADS_PER_CTA, \ + THREADS_PER_PIXEL, \ + PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ + PIXELS_PER_THREAD_IN_SMEM_FWD, \ + ELEMENTS_PER_LDG, \ + USE_ONLINE_APPROACH, \ + OUTER_LOOPS, \ + USE_RELU, \ + USE_ADD_RELU, \ + COMPILED_FOR_OCCUPANCY>); \ + if (COOP) { \ + hipLaunchCooperativeKernel(fwd_func, \ + grid_dim, \ + THREADS_PER_CTA, \ + ¶ms_ptr, \ + SMEM_SIZE_FWD, \ + stream); \ + } else { \ + hipLaunchKernel((void *) fwd_func, \ + grid_dim, \ + THREADS_PER_CTA, \ + ¶ms_ptr, \ + SMEM_SIZE_FWD, \ + stream); \ + } \ + checkCudaStatus(name_ + " fwd ser coop kernel"); \ + } while (0) +#else #define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ do { \ CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \ @@ -312,6 +392,7 @@ class NhwcBatchNormAddRelu { } \ checkCudaStatus(name_ + " fwd ser coop kernel"); \ } while (0) +#endif // Don't try for an occupancy > 2 as this will squeeze register use and create spills. if (outer_loops == 1) { @@ -332,7 +413,56 @@ class NhwcBatchNormAddRelu { void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params, dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) { +#ifdef USE_ROCM #define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ + do { \ + CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \ + "Nhwc batchnormaddrelu kernel smem too big."; \ + auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \ + StorageType, \ + THREADS_PER_CTA, \ + THREADS_PER_PIXEL, \ + PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ + PIXELS_PER_THREAD_IN_SMEM_BWD, \ + ELEMENTS_PER_LDG, \ + USE_ONLINE_APPROACH, \ + OUTER_LOOPS, \ + COMPILED_FOR_OCCUPANCY>; \ + if (COMPILED_FOR_OCCUPANCY > 1) { \ + hipFuncSetAttribute((void *) bwd_add_relu_func, \ + hipFuncAttributePreferredSharedMemoryCarveout, 100); \ + checkCudaStatus(name_ + \ + " bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \ + } \ + void *params_ptr = static_cast(¶ms); \ + using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \ + StorageType, \ + THREADS_PER_CTA, \ + THREADS_PER_PIXEL, \ + PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ + PIXELS_PER_THREAD_IN_SMEM_BWD, \ + ELEMENTS_PER_LDG, \ + USE_ONLINE_APPROACH, \ + OUTER_LOOPS, \ + COMPILED_FOR_OCCUPANCY>); \ + if (COOP) { \ + hipLaunchCooperativeKernel(bwd_add_relu_func, \ + grid_dim, \ + THREADS_PER_CTA, \ + ¶ms_ptr, \ + SMEM_SIZE_BWD, \ + stream); \ + } else { \ + hipLaunchKernel((void *) bwd_add_relu_func, \ + grid_dim, \ + THREADS_PER_CTA, \ + ¶ms_ptr, \ + SMEM_SIZE_BWD, \ + stream); \ + } \ + checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \ + } while (0) +#else do { \ CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \ "Nhwc batchnormaddrelu kernel smem too big."; \ @@ -380,6 +510,7 @@ class NhwcBatchNormAddRelu { } \ checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \ } while (0) +#endif // Don't try for an occupancy > 2 as this will squeeze register use and create spills. if (outer_loops == 1) { @@ -400,7 +531,7 @@ class NhwcBatchNormAddRelu { // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; - int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float); + int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/at::cuda::warp_size())*ELEMENTS_PER_LDG*sizeof(float); int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); @@ -409,7 +540,7 @@ class NhwcBatchNormAddRelu { // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; - int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float); + int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/at::cuda::warp_size())*ELEMENTS_PER_LDG*sizeof(float); int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); @@ -428,9 +559,13 @@ const std::vector NhwcBatchNormAddRelu::numWorkspaceBytes() const { const size_t num_mean_bytes = c_ * sizeof(float); const size_t num_variance_bytes = num_mean_bytes; +#ifdef USE_ROCM + int elems_per_group = ((m_ + 3) & ~3) * 2; +#else int elems_per_group = ((m_ + 31) & ~31) * 2; +#endif int group_count = div_up(c_, C_ELEMENTS_PER_CTA); - const size_t bitmask_bytes = elems_per_group * group_count * sizeof(unsigned int); + const size_t bitmask_bytes = elems_per_group * group_count * sizeof(bitmask_t); const size_t size_sums = grid_y*grid_x*THREADS_PER_PIXEL*\ ELEMENTS_PER_LDG*2*sizeof(float); @@ -448,7 +583,7 @@ void NhwcBatchNormAddRelu::setWorkspacePointers( minibatch_mean_ = static_cast(workspace[0]); minibatch_variance_ = static_cast(workspace[1]); - relu_bitmask_ = static_cast(workspace[2]); + relu_bitmask_ = static_cast(workspace[2]); retired_ctas_ = static_cast(workspace[3]); partial_sums_ = static_cast(workspace[4]); partial_counts_ = static_cast(workspace[5]); diff --git a/apex/contrib/csrc/groupbn/cuda_utils.h b/apex/contrib/csrc/groupbn/cuda_utils.h index 9f003840c..ec13d03d3 100644 --- a/apex/contrib/csrc/groupbn/cuda_utils.h +++ b/apex/contrib/csrc/groupbn/cuda_utils.h @@ -1,4 +1,8 @@ +#ifdef USE_ROCM +#include +#else #include +#endif #ifndef CUDA_UTILS_H #define CUDA_UTILS_H @@ -8,7 +12,11 @@ namespace cuda { namespace utils { static inline int MaxSharedMemoryPerMultiprocessor(int device_id) { +#ifdef USE_ROCM + return getDeviceProperties(device_id)->maxSharedMemoryPerMultiProcessor; +#else return getDeviceProperties(device_id)->sharedMemPerMultiprocessor; +#endif } diff --git a/apex/contrib/csrc/groupbn/dnn.h b/apex/contrib/csrc/groupbn/dnn.h new file mode 100644 index 000000000..f31757083 --- /dev/null +++ b/apex/contrib/csrc/groupbn/dnn.h @@ -0,0 +1,26 @@ +#ifndef DNN_H +#define DNN_H + +#ifdef USE_ROCM +#include +#define DNN_STATUS_SUCCESS miopenStatusSuccess +#define DNN_DATA_HALF miopenHalf +#define DNN_TENSOR_FORMAT 0 + +using dnnTensorFormat_t = int; +using dnnDataType_t = miopenDataType_t; +using dnnStatus_t = miopenStatus_t; +using dnnTensorDescriptor_t = miopenTensorDescriptor_t; +#else +#include +#define DNN_STATUS_SUCCESS CUDNN_STATUS_SUCCESS +#define DNN_DATA_HALF CUDNN_DATA_HALF +#define DNN_TENSOR_FORMAT CUDNN_TENSOR_NHWC + +using dnnTensorFormat_t = cudnnTensorFormat_t; +using dnnDataType_t = cudnnDataType_t; +using dnnStatus_t = cudnnStatus_t; +using dnnTensorDescriptor_t = cudnnTensorDescriptor_t; +#endif + +#endif // DNN_H diff --git a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h b/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h index 8430f3099..44ec92688 100644 --- a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h +++ b/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h @@ -26,8 +26,24 @@ #ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_ #define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_ +#ifdef USE_ROCM +#include +#include +#include +#endif #include #include +#include + +#ifdef USE_ROCM +using bitmask_t = uint64_t; +#define BITMASK_OFFSET 2 +#define ONE_BITMASK 1UL +#else +using bitmask_t = unsigned int; +#define BITMASK_OFFSET 2 +#define ONE_BITMASK 1U +#endif #define DEVICE_FUNCTION static inline __device__ @@ -37,6 +53,37 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// +DEVICE_FUNCTION void syncwarp() { +#ifdef USE_ROCM + __builtin_amdgcn_wave_barrier(); +#else + __syncwarp(); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +DEVICE_FUNCTION T shfl_sync(T var, int src_lane) { +#ifdef USE_ROCM + return __shfl(var, src_lane); +#else + return __shfl_sync(0xFFFFFFFFU, var, src_lane); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +DEVICE_FUNCTION bitmask_t ballot(int predicate) { +#ifdef USE_ROCM + return __ballot(predicate); +#else + return __ballot_sync(0xFFFFFFFFU, predicate); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template< typename T, int ELEMENTS_PER_LDG > struct PackedStorage { enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG }; @@ -55,12 +102,20 @@ struct PackedStorage { template< int N > DEVICE_FUNCTION void from_float(int (&dst)[N], const float (&src)[2*N]) { + // Convert from two f32s to two f16s (mantissa LSB rounds to nearest even) + // (From 64-bit to 32-bit) + half *dst_ = (half *) dst; #pragma unroll for (int i = 0; i < N; ++i) { +#ifdef USE_ROCM + dst_[2*i] = __float2half(src[2*i]); + dst_[2*i+1] = __float2half(src[2*i+1]); +#else uint16_t lo, hi; asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(lo) : "f"(src[2*i+0])); asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(hi) : "f"(src[2*i+1])); asm volatile("mov.b32 %0, {%1, %2};" : "=r"(dst[i]) : "h"(lo), "h"(hi)); +#endif } } @@ -78,12 +133,19 @@ DEVICE_FUNCTION void from_float(float (&dst)[N], const float (&src)[N]) { template< int N > DEVICE_FUNCTION void to_float(float (&dst)[2*N], int (&src)[N]) { + // Convert from two f16s to two f32s (From 32-bit to 64-bit) #pragma unroll for (int i = 0; i < N; ++i) { +#ifdef USE_ROCM + half *src_ = (half *) src; + dst[2*i] = __half2float(src_[2*i]); + dst[2*i+1] = __half2float(src_[2*i+1]); +#else uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;" : "=h"(lo), "=h"(hi) : "r"(src[i])); asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+0]) : "h"(lo)); asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+1]) : "h"(hi)); +#endif } } @@ -106,9 +168,13 @@ DEVICE_FUNCTION void ldg(int (&dst)[1], const uint16_t *gmem) { //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void ldg_stream(int (&dst)[1], const uint16_t *gmem) { +#ifdef USE_ROCM + dst[0] = __ldg((const int*) gmem); +#else unsigned int tmp; asm volatile ("ld.global.cs.nc.s32 %0, [%1];" : "=r"(tmp) : "l" ((const uint *)gmem)); dst[0] = tmp; +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -122,11 +188,17 @@ DEVICE_FUNCTION void ldg(int (&dst)[2], const uint16_t *gmem) { //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void ldg_stream(int (&dst)[2], const uint16_t *gmem) { +#ifdef USE_ROCM + int2 tmp = __ldg((const int2*) gmem); + dst[0] = tmp.x; + dst[1] = tmp.y; +#else int2 tmp; asm volatile ("ld.global.cs.nc.v2.s32 {%0,%1}, [%2];" : "=r"(tmp.x), "=r"(tmp.y) : "l"((const int2 *)gmem)); dst[0] = tmp.x; dst[1] = tmp.y; +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -156,22 +228,42 @@ DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[1]) { //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[1]) { +#ifdef USE_ROCM + reinterpret_cast(gmem)[0] = src[0]; +#else unsigned int tmp = src[0]; asm volatile ("st.global.cs.s32 [%0], %1;" :: "l"((uint *)gmem) , "r"(tmp)); +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[2]) { +#ifdef USE_ROCM + half *gmem_ = (half *) gmem; + half *src_ = (half *) src; + for (int i = 0; i < 4; i++) { + gmem_[i] = src_[i]; + } +#else reinterpret_cast(gmem)[0] = make_int2(src[0], src[1]); +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[2]) { +#ifdef USE_ROCM + half *gmem_ = (half *) gmem; + half *src_ = (half *) src; + for (int i = 0; i < 4; i++) { + gmem_[i] = src_[i]; + } +#else asm volatile ("st.global.cs.v2.s32 [%0], {%1,%2};" :: "l"((uint *)gmem) , "r"(src[0]), "r"( src[1])); +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -194,28 +286,65 @@ DEVICE_FUNCTION void stg_stream(uint16_t *gmem, float (&src)[N]) { //////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef USE_ROCM +DEVICE_FUNCTION void stg(uint16_t *gmem, float (&src)[4]) { + half *gmem_ = (half *) gmem; + gmem_[0] = __float2half(src[0]); + gmem_[1] = __float2half(src[1]); + gmem_[2] = __float2half(src[2]); + gmem_[3] = __float2half(src[3]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +DEVICE_FUNCTION void stg_stream(uint16_t *gmem, float (&src)[4]) { + half *gmem_ = (half *) gmem; + gmem_[0] = __float2half(src[0]); + gmem_[1] = __float2half(src[1]); + gmem_[2] = __float2half(src[2]); + gmem_[3] = __float2half(src[3]); +} +#endif + DEVICE_FUNCTION void read_from_gmem(float (&dst)[2], const float *gmem, int idx) { +#ifdef USE_ROCM + dst[0] = gmem[2*idx]; + dst[1] = gmem[2*idx+1]; +#else float2 tmp = __ldg(reinterpret_cast(&gmem[2*idx])); dst[0] = tmp.x; dst[1] = tmp.y; +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_gmem(float (&dst)[4], const float *gmem, int idx) { +#ifdef USE_ROCM + dst[0] = gmem[4*idx]; + dst[1] = gmem[4*idx+1]; + dst[2] = gmem[4*idx+2]; + dst[3] = gmem[4*idx+3]; +#else float4 tmp = __ldg(reinterpret_cast(&gmem[4*idx])); dst[0] = tmp.x; dst[1] = tmp.y; dst[2] = tmp.z; dst[3] = tmp.w; +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_smem(float (&x)[2], const float *smem, int idx) { +#ifdef USE_ROCM + x[0] = smem[2*idx]; + x[1] = smem[2*idx+1]; +#else float2 tmp = *(const float2*) &smem[2*idx]; x[0] = tmp.x; x[1] = tmp.y; +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -227,43 +356,79 @@ DEVICE_FUNCTION void read_from_smem(int (&x)[1], const int *smem, int idx) { //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_smem(float (&x)[4], const float *smem, int idx) { +#ifdef USE_ROCM + x[0] = smem[4*idx]; + x[1] = smem[4*idx+1]; + x[2] = smem[4*idx+2]; + x[3] = smem[4*idx+3]; +#else float4 tmp = *(const float4*) &smem[4*idx]; x[0] = tmp.x; x[1] = tmp.y; x[2] = tmp.z; x[3] = tmp.w; +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void read_from_smem(int (&x)[2], const int *smem, int idx) { +#ifdef USE_ROCM + x[0] = smem[2*idx]; + x[1] = smem[2*idx+1]; +#else int2 tmp = *(const int2*) &smem[2*idx]; x[0] = tmp.x; x[1] = tmp.y; +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[2]) { +#ifdef USE_ROCM + gmem[2*idx] = src[0]; + gmem[2*idx+1] = src[1]; +#else reinterpret_cast(&gmem[2*idx])[0] = make_float2(src[0], src[1]); +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[4]) { +#ifdef USE_ROCM + gmem[4*idx] = src[0]; + gmem[4*idx+1] = src[1]; + gmem[4*idx+2] = src[2]; + gmem[4*idx+3] = src[3]; +#else reinterpret_cast(&gmem[4*idx])[0] = make_float4(src[0], src[1], src[2], src[3]); +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void scaled_write_to_gmem(float *gmem, int idx, const float (&src)[4], const float coeff) { +#ifdef USE_ROCM + gmem[4*idx] = src[0]*coeff; + gmem[4*idx+1] = src[1]*coeff; + gmem[4*idx+2] = src[2]*coeff; + gmem[4*idx+3] = src[3]*coeff; +#else reinterpret_cast(&gmem[4*idx])[0] = make_float4(src[0]*coeff, src[1]*coeff, src[2]*coeff, src[3]*coeff); +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[2]) { +#ifdef USE_ROCM + smem[2*idx] = x[0]; + smem[2*idx+1] = x[1]; +#else reinterpret_cast(&smem[2*idx])[0] = make_float2(x[0], x[1]); +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -275,13 +440,25 @@ DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[1]) { //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[4]) { +#ifdef USE_ROCM + smem[4*idx] = x[0]; + smem[4*idx+1] = x[1]; + smem[4*idx+2] = x[2]; + smem[4*idx+3] = x[3]; +#else reinterpret_cast(&smem[4*idx])[0] = make_float4(x[0], x[1], x[2], x[3]); +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[2]) { +#ifdef USE_ROCM + smem[2*idx] = x[0]; + smem[2*idx+1] = x[1]; +#else reinterpret_cast(&smem[2*idx])[0] = make_int2(x[0], x[1]); +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -369,10 +546,8 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, void* params_my_data, void** params_pair_datas, int off, const int magic, const int sync_iters) { - // The size of a warp. - const int THREADS_PER_WARP = 32; // The number of warps in a CTA. - const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; + const int WARPS_PER_CTA = THREADS_PER_CTA / C10_WARP_SIZE; // The number of threads per pixel. const int THREADS_PER_PIXEL = 16; // The number of elements per ldg. @@ -383,15 +558,24 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, const int MAX_BLOCK_Y = 256; const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y; // The warp decomposition. - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int lane_id = threadIdx.x % THREADS_PER_WARP; + const int warp_id = threadIdx.x / C10_WARP_SIZE; + const int lane_id = threadIdx.x % C10_WARP_SIZE; // total size of data per sync iter const int data_total = MAX_OFFSET*THREADS_PER_PIXEL*ELEMENTS_PER_LDG*2; +#ifdef USE_ROCM + for (int offset = THREADS_PER_PIXEL; offset <= C10_WARP_SIZE >> 1; offset <<= 1) { + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + x[i] += shfl_sync(x[i], offset + lane_id); + } + } +#else #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id); + x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id); } +#endif + // The warp leaders, write to SMEM. if (lane_id < THREADS_PER_PIXEL) { @@ -408,25 +592,33 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, #pragma unroll for (int offset = 1; - offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { + offset < WARPS_PER_CTA/(C10_WARP_SIZE / THREADS_PER_PIXEL); ++offset) { float y[ELEMENTS_PER_LDG]; // Read the mean and variance from the other pixel. - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP); + read_from_smem(y, smem, threadIdx.x + offset*C10_WARP_SIZE); // Compute the updated sum. add(x, y); } +#ifdef USE_ROCM + for (int offset = C10_WARP_SIZE >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) { for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + x[i] += shfl_sync(x[i], offset + lane_id); + } + } +#else for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id); + x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id); } +#endif // Make sure the data was read from SMEM. - __syncwarp(); + syncwarp(); // Store the final values. if (threadIdx.x < THREADS_PER_PIXEL) { // probably could do it earlier, before sync +#ifndef USE_ROCM // bn_group > 1 is not enabled on HIP for (int sync_iter=0; sync_iter < sync_iters; ++sync_iter) { //float* params_pair_data = (reinterpret_cast(params_pair_datas))[sync_iter]; void* params_pair_data = params_pair_datas[sync_iter]; @@ -469,6 +661,7 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, add(x, other); } +#endif // finally, after syncing up and accounting for partial sums from // other GPUs as required, write the result @@ -482,22 +675,20 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, template< int THREADS_PER_CTA > DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) { - // The size of a warp. - const int THREADS_PER_WARP = 32; // The number of warps in a CTA. - const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; + const int WARPS_PER_CTA = THREADS_PER_CTA / C10_WARP_SIZE; // The number of threads per pixel. const int THREADS_PER_PIXEL = 8; // The number of elements per ldg. const int ELEMENTS_PER_LDG = 4; // The warp decomposition. - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int lane_id = threadIdx.x % THREADS_PER_WARP; + const int warp_id = threadIdx.x / C10_WARP_SIZE; + const int lane_id = threadIdx.x % C10_WARP_SIZE; #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id); - x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id); + x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id); + x[i] += shfl_sync(x[i], THREADS_PER_PIXEL*2+lane_id); } // The warp leaders, write to SMEM. @@ -515,21 +706,21 @@ DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) { #pragma unroll for (int offset = 1; - offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { + offset < WARPS_PER_CTA/(C10_WARP_SIZE / THREADS_PER_PIXEL); ++offset) { float y[ELEMENTS_PER_LDG]; // Read the mean and variance from the other pixel. - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP); + read_from_smem(y, smem, threadIdx.x + offset*C10_WARP_SIZE); // Compute the updated sum. add(x, y); } for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id); - x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id); + x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id); + x[i] += shfl_sync(x[i], THREADS_PER_PIXEL*2+lane_id); } // Make sure the data was read from SMEM. - __syncwarp(); + syncwarp(); // Store the final values. if (threadIdx.x < THREADS_PER_PIXEL) { @@ -542,80 +733,67 @@ DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) { template< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG > DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) { - // The size of a warp. - const int THREADS_PER_WARP = 32; - // The number of warps in a CTA. - const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; - // The number of pixels computed by a single warp. - const int PIXELS_PER_WARP = THREADS_PER_WARP / THREADS_PER_PIXEL; - - // The position in the warp. - const int nhw_in_warp = nhw % PIXELS_PER_WARP; - // The C in the warp. - const int c_in_warp = threadIdx.x % THREADS_PER_PIXEL; - - // Store the values to shared memory. - write_to_smem(smem, threadIdx.x, x); - - // Compute the parallel sums. - for (int offset = PIXELS_PER_WARP/2; offset > 0; offset /= 2) { - // NOP. - __syncwarp(); - - // Read the running sum from the other thread. - float y[ELEMENTS_PER_LDG]; - if (nhw_in_warp < offset) { - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL); - } - - // Compute the updated sum. - add(x, y); - - // NOP. - __syncwarp(); + const int WARPS_PER_CTA = THREADS_PER_CTA / C10_WARP_SIZE; + // The warp decomposition. + const int warp_id = threadIdx.x / C10_WARP_SIZE; + const int lane_id = threadIdx.x % C10_WARP_SIZE; + // total size of data per sync iter - // Update the sum in SMEM. - if (offset > 1 && nhw_in_warp < offset) { - write_to_smem(smem, threadIdx.x, x); +#ifdef USE_ROCM + for (int offset = THREADS_PER_PIXEL; offset <= C10_WARP_SIZE >> 1; offset <<= 1) { + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + x[i] += shfl_sync(x[i], offset + lane_id); } } +#else + #pragma unroll + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id); + } +#endif - // The warps are done. Do the final reduction at the CTA level. - __syncthreads(); // The warp leaders, write to SMEM. - const int idx = (threadIdx.x/THREADS_PER_WARP)*THREADS_PER_PIXEL + c_in_warp; - if (nhw_in_warp == 0) { - write_to_smem(smem, idx, x); + if (lane_id < THREADS_PER_PIXEL) { + write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x); } // The data is in SMEM. Do the final reduction. __syncthreads(); - // Read the 1st element to prepare the work. - if (nhw < WARPS_PER_CTA/2) { + // The 1st warp does all the work. + // We do the final reduction each half-warp sequentially reduces the final values. + if (warp_id == 0) { read_from_smem(x, smem, threadIdx.x); - } - // We have the running mean and running m2. Let's build the mean/var of the CTA. - for (int offset = WARPS_PER_CTA/2; offset > 0; offset /= 2) { - // NOP. - __syncwarp(); - - // Read the mean and variance from the other pixel. - float y[ELEMENTS_PER_LDG]; - if (nhw < offset) { - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL); + #pragma unroll + for (int offset = 1; + offset < WARPS_PER_CTA/(C10_WARP_SIZE / THREADS_PER_PIXEL); ++offset) { + float y[ELEMENTS_PER_LDG]; + // Read the mean and variance from the other pixel. + read_from_smem(y, smem, threadIdx.x + offset*C10_WARP_SIZE); + // Compute the updated sum. + add(x, y); } - // Compute the updated sum. - add(x, y); +#ifdef USE_ROCM + for (int offset = C10_WARP_SIZE >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) { + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + x[i] += shfl_sync(x[i], offset + lane_id); + } + } +#else + for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id); + } +#endif - // NOP. - __syncwarp(); + // Make sure the data was read from SMEM. + syncwarp(); - // Store the mean/var for the different pixels. - if (nhw < offset) { + // Store the final values. + if (threadIdx.x < THREADS_PER_PIXEL) { + // probably could do it earlier, before sync write_to_smem(smem, threadIdx.x, x); } } @@ -632,7 +810,7 @@ struct ParallelSums { }; //////////////////////////////////////////////////////////////////////////////////////////////////// - +/* template<> struct ParallelSums<16, 4> { template< int THREADS_PER_CTA > @@ -653,6 +831,7 @@ struct ParallelSums<8, 4> { parallel_sums_8x4(smem, x, nhw); } }; +*/ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -684,8 +863,12 @@ DEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expected_count int retired_ctas = -1; do { __threadfence(); +#ifdef USE_ROCM + retired_ctas = __ldg((const int*) gmem_retired_ctas); +#else asm volatile ("ld.global.cg.b32 %0, [%1];" : "=r"(retired_ctas) : "l"(gmem_retired_ctas)); +#endif } while (retired_ctas != 0); } __syncthreads(); @@ -806,7 +989,7 @@ struct NhwcBatchNormFwdParams { // saved mean/var (refer BN API from cudnn doc) float *gmem_saved_mean, *gmem_saved_var; // ReLU bitmask - unsigned int *gmem_relu_bitmask; + bitmask_t *gmem_relu_bitmask; // The dimensions. int nhw, c; // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. @@ -861,7 +1044,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG]; // Compute the NHW coordinate of the thread in the CTA. const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; @@ -878,6 +1061,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) // Shared memory buffer to store the extra pixels. extern __shared__ PackedStorageType smem_storage_packed[]; +#ifdef USE_ROCM + const half zero_h = __float2half(0.0F); +#endif + for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { // The position in the NHW dimension where the CTA starts. int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; @@ -960,11 +1147,15 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) zero_array(x_storage[i]); is_valid[i] = 0.f; if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { +#ifndef USE_ROCM if (loop_i == OUTER_LOOPS - 1) { ldg_stream(x_storage[i], &gmem_src[idx*params.c]); } else { +#endif ldg(x_storage[i], &gmem_src[idx*params.c]); +#ifndef USE_ROCM } +#endif is_valid[i] = 1.f; } } @@ -1089,7 +1280,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) } // Run the parallel sum accross the CTA to get the local sum. +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, m1, thread_in_cta_nhw); __syncthreads(); @@ -1106,7 +1301,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) } // Run the parallel sum accross the CTA to get the local adjusted variance. +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, m2, thread_in_cta_nhw); // The workspace in global memory is distributed across the different CTA. @@ -1152,14 +1351,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) add(m1, tmp); } +#ifndef USE_ROCM if (params.sync_iters>0) { ParallelSums::dispatchX( smem, m1, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+3, params.magic, params.sync_iters); } else { +#endif +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, m1, thread_in_cta_nhw); +#ifndef USE_ROCM } +#endif __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. @@ -1209,14 +1416,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) } } +#ifndef USE_ROCM if (params.sync_iters>0) { ParallelSums::dispatchX( smem, m2, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+2, params.magic, params.sync_iters); } else { +#endif +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, m2, thread_in_cta_nhw); +#ifndef USE_ROCM } +#endif __syncthreads(); read_from_smem(m2, smem, thread_in_cta_c); @@ -1263,8 +1478,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) // The base pointer to write to. uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; - unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask + + bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask + +#ifdef USE_ROCM + ((params.nhw + 3) & ~3) * 2 * c_blk_index; +#else ((params.nhw + 31) & ~31) * 2 * c_blk_index; +#endif // Store the elements in registers. #pragma unroll 1 @@ -1289,23 +1508,31 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) float x1_math[ELEMENTS_PER_LDG]; ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]); add(x_math, x1_math); - unsigned int relu_mask; + bitmask_t relu_mask; +#ifdef USE_ROCM + int lane_id = threadIdx.x & 63; +#else int lane_id = threadIdx.x & 31; +#endif #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - bool rectified = x_math[i] < 0.0F; - unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified); - if (lane_id == i) { + for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { +#ifdef USE_ROCM + bool rectified = __hle(__float2half(x_math[j]), zero_h); +#else + bool rectified = x_math[j] < 0; +#endif + bitmask_t local_relu_mask = ballot(rectified); + if (lane_id == j) { // Thread 0 remembers the relu_mask from the first time through this // loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last. relu_mask = local_relu_mask; } if (rectified) { - x_math[i] = 0.0F; + x_math[j] = 0.0F; } } if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) { - gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask; + gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id] = relu_mask; } } else if (USE_RELU) { relu_activation(x_math); @@ -1352,21 +1579,29 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) float x1_math[ELEMENTS_PER_LDG]; ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]); add(x_math, x1_math); - unsigned int relu_mask; + bitmask_t relu_mask; +#ifdef USE_ROCM + int lane_id = threadIdx.x & 63; +#else int lane_id = threadIdx.x & 31; +#endif #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - bool rectified = x_math[i] < 0.0F; - unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified); - if (lane_id == i) { + for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { +#ifdef USE_ROCM + bool rectified = __hle(__float2half(x_math[j]), zero_h); +#else + bool rectified = x_math[j] < 0; +#endif + bitmask_t local_relu_mask = ballot(rectified); + if (lane_id == j) { relu_mask = local_relu_mask; } if (rectified) { - x_math[i] = 0.0F; + x_math[j] = 0.0F; } } if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) { - gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask; + gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id] = relu_mask; } } else if (USE_RELU) { relu_activation(x_math); @@ -1395,7 +1630,7 @@ struct NhwcBatchNormBwdParams { // The mean/inv-var saved from fwd pass float *gmem_saved_mean, *gmem_saved_var; // ReLU bitmask - unsigned int *gmem_relu_bitmask; + bitmask_t *gmem_relu_bitmask; // The dimensions. int nhw, c; // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. @@ -1536,7 +1771,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_; @@ -1691,7 +1926,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) } // dscale parallel sum +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, dscale, thread_in_cta_nhw); __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. @@ -1699,7 +1938,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) __syncthreads(); // dbias parallel sum +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, dbias, thread_in_cta_nhw); __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. @@ -1740,13 +1983,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) } // dscale parallel sum +#ifndef USE_ROCM if (params.sync_iters>0) { ParallelSums::dispatchX( smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters); } else { +#endif +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, dscale, thread_in_cta_nhw); +#ifndef USE_ROCM } +#endif __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. @@ -1754,13 +2005,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) __syncthreads(); // dbias parallel sum +#ifndef USE_ROCM if (params.sync_iters>0) { ParallelSums::dispatchX( smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters); } else { +#endif +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, dbias, thread_in_cta_nhw); +#ifndef USE_ROCM } +#endif __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. @@ -1900,7 +2159,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_; @@ -2081,7 +2340,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) } // dscale parallel sum +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, dscale, thread_in_cta_nhw); __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. @@ -2089,7 +2352,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) __syncthreads(); // dbias parallel sum +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, dbias, thread_in_cta_nhw); __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. @@ -2130,13 +2397,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) } // dscale parallel sum +#ifndef USE_ROCM if (params.sync_iters>0) { ParallelSums::dispatchX( smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters); } else { +#endif +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, dscale, thread_in_cta_nhw); +#ifndef USE_ROCM } +#endif __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. @@ -2144,13 +2419,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) __syncthreads(); // dbias parallel sum +#ifndef USE_ROCM if (params.sync_iters>0) { ParallelSums::dispatchX( smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters); } else { +#endif +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, dbias, thread_in_cta_nhw); +#ifndef USE_ROCM } +#endif __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. @@ -2288,7 +2571,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_; @@ -2353,8 +2636,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) cta_nhw_smem -= offset; } - const unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask + + const bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask + +#ifdef USE_ROCM + ((params.nhw + 3) & ~3) * 2 * c_blk_index; +#else ((params.nhw + 31) & ~31) * 2 * c_blk_index; +#endif #pragma unroll 1 for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { @@ -2363,11 +2650,15 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs)); +#ifdef USE_ROCM + int lane_id = threadIdx.x & 63; +#else int lane_id = threadIdx.x & 31; +#endif // Read the elements from memory. float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; - unsigned int relu_mask[PIXELS_PER_THREAD_IN_REGISTERS]; + bitmask_t relu_mask[PIXELS_PER_THREAD_IN_REGISTERS]; #pragma unroll for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG; @@ -2389,7 +2680,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) } if (lane_id < ELEMENTS_PER_LDG) { - relu_mask[i] = gmem_relu_bitmask[idx * 2 + lane_id]; + relu_mask[i] = gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id]; } } } @@ -2403,8 +2694,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) bool rectified[ELEMENTS_PER_LDG]; #pragma unroll for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask[i], j) & - (1U << lane_id)) != 0); + rectified[j] = ((shfl_sync(relu_mask[i], j) & + (ONE_BITMASK << lane_id)) != 0); } to_float(x_math, x_storage[i]); to_float(dy_math, dy_storage[i]); @@ -2444,8 +2735,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const bool is_pixel_valid = is_pixel_valid_nhw && is_valid_c; PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG]; - unsigned int relu_mask; + bitmask_t relu_mask; +#ifdef USE_ROCM + int lane_id = threadIdx.x & 63; +#else int lane_id = threadIdx.x & 31; +#endif zero_array(x_storage_local); zero_array(dy_storage_local); if (is_pixel_valid_nhw) { @@ -2454,14 +2749,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]); } if (lane_id < ELEMENTS_PER_LDG) { - relu_mask = gmem_relu_bitmask[idx * 2 + lane_id]; + relu_mask = gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id]; } } bool rectified[ELEMENTS_PER_LDG]; #pragma unroll for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask, j) & - (1U << lane_id)) != 0); + rectified[j] = ((shfl_sync(relu_mask, j) & + (ONE_BITMASK << lane_id)) != 0); } // The offset to store in SMEM. @@ -2499,7 +2794,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) } // dscale parallel sum +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, dscale, thread_in_cta_nhw); __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. @@ -2507,7 +2806,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) __syncthreads(); // dbias parallel sum +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, dbias, thread_in_cta_nhw); __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. @@ -2548,13 +2851,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) } // dscale parallel sum +#ifndef USE_ROCM if (params.sync_iters>0) { ParallelSums::dispatchX( smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters); } else { +#endif +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, dscale, thread_in_cta_nhw); +#ifndef USE_ROCM } +#endif __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. @@ -2562,13 +2873,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) __syncthreads(); // dbias parallel sum +#ifndef USE_ROCM if (params.sync_iters>0) { ParallelSums::dispatchX( smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters); } else { +#endif +#ifdef USE_ROCM + ParallelSums::template dispatch( +#else ParallelSums::dispatch( +#endif smem, dbias, thread_in_cta_nhw); +#ifndef USE_ROCM } +#endif __syncthreads(); // The values in shared memory correspond to the CTA-wide sums. diff --git a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp new file mode 100644 index 000000000..b47c9daa5 --- /dev/null +++ b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp @@ -0,0 +1,145 @@ +#include + +#include +#include + +void index_mul_2d_float_foward_cuda(at::Tensor &out, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1); + +void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1, + at::Tensor &grad_in2, + const at::Tensor &grad_out, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1); + +void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out, + at::Tensor &grad_in1, + at::Tensor &grad_in2, + const at::Tensor &grad_out, + const at::Tensor &grad_grad_in1, + const at::Tensor &grad_grad_in2, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1); + +void index_mul_2d_half_foward_cuda(at::Tensor &out, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1); + +void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1, + at::Tensor &grad_in2, + const at::Tensor &grad_out, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1); + +void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out, + at::Tensor &grad_in1, + at::Tensor &grad_in2, + const at::Tensor &grad_out, + const at::Tensor &grad_grad_in1, + const at::Tensor &grad_grad_in2, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +void index_mul_2d_float_forward( + at::Tensor &out, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1) +{ + return index_mul_2d_float_foward_cuda(out, in1, in2, idx1); +} + +void index_mul_2d_float_backward( + at::Tensor &grad_in1, + at::Tensor &grad_in2, + const at::Tensor &grad_out, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1) +{ + return index_mul_2d_float_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1); +} + +void index_mul_2d_float_backwrad_backward( + at::Tensor &grad_grad_out, + at::Tensor &grad_in1, + at::Tensor &grad_in2, + const at::Tensor &grad_out, + const at::Tensor &grad_grad_in1, + const at::Tensor &grad_grad_in2, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1) +{ + return index_mul_2d_float_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, grad_grad_in2, in1, in2, idx1); +} + +void index_mul_2d_half_forward( + at::Tensor &out, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1) +{ + return index_mul_2d_half_foward_cuda(out, in1, in2, idx1); +} + +void index_mul_2d_half_backward( + at::Tensor &grad_in1, + at::Tensor &grad_in2, + const at::Tensor &grad_out, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1) +{ + return index_mul_2d_half_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1); +} + +void index_mul_2d_half_backwrad_backward( + at::Tensor &grad_grad_out, + at::Tensor &grad_in1, + at::Tensor &grad_in2, + const at::Tensor &grad_out, + const at::Tensor &grad_grad_in1, + const at::Tensor &grad_grad_in2, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1) +{ + return index_mul_2d_half_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, grad_grad_in2, in1, in2, idx1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("float_forward", &index_mul_2d_float_forward, + "index mul float calculation forward (CUDA)", + py::call_guard()); + m.def("float_backward", &index_mul_2d_float_backward, + "index mul float calculation backward (CUDA)", + py::call_guard()); + m.def("float_backward_backward", &index_mul_2d_float_backwrad_backward, + "index mul float calculation backward backward (CUDA)", + py::call_guard()); + m.def("half_forward", &index_mul_2d_half_forward, + "index mul half calculation forward (CUDA)", + py::call_guard()); + m.def("half_backward", &index_mul_2d_half_backward, + "index mul half calculation backward (CUDA)", + py::call_guard()); + m.def("half_backward_backward", &index_mul_2d_half_backwrad_backward, + "index mul half calculation backward backward (CUDA)", + py::call_guard()); +} + diff --git a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu new file mode 100644 index 000000000..4f18da3bf --- /dev/null +++ b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu @@ -0,0 +1,492 @@ +#include +#include +#include +#ifdef ATEN_ATOMIC_HEADER + #include +#else + #include +#endif + + +__global__ void index_mul_2d_float_dim64( + float *out, + const float *in1, + const float *in2, + const int64_t *idx1, + const int64_t size) +{ + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + constexpr int fea_dim = 64; + + if (start_idx < size) { + int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; + int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; + + float4 res, src1, src2; + src1 = reinterpret_cast(in1)[vec_idx1]; + src2 = reinterpret_cast(in2)[vec_idx2]; + res.x = src1.x * src2.x; + res.y = src1.y * src2.y; + res.z = src1.z * src2.z; + res.w = src1.w * src2.w; + reinterpret_cast(out)[vec_idx2] = res; + } +} + +__global__ void index_mul_2d_float( + float *out, + const float *in1, + const float *in2, + const int64_t *idx1, + const int64_t size, + const int64_t fea_dim) +{ + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + const int stride = blockDim.x; + + if (start_idx < size) { + int64_t vec_idx1 = (idx1[start_idx] * fea_dim); + int64_t vec_idx2 = (start_idx * fea_dim); + + for (int i = tidx; i < fea_dim; i += stride) { + out[vec_idx2 + i] = in1[vec_idx1 + i] * in2[vec_idx2 + i]; + } + } +} + +__global__ void index_mul_2d_half( + at::Half *out, + const at::Half *in1, + const at::Half *in2, + const int64_t *idx1, + const int64_t size, + const int64_t fea_dim) +{ + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + const int stride = blockDim.x; + + if (start_idx < size) { + int64_t vec_idx1 = (idx1[start_idx] * fea_dim); + int64_t vec_idx2 = (start_idx * fea_dim); + + for (int i = tidx; i < fea_dim; i += stride) { + out[vec_idx2 + i] = at::Half(static_cast(in1[vec_idx1 + i]) * static_cast(in2[vec_idx2 + i])); + } + } +} + +__global__ void index_mul_2d_grad_float_dim64( + float *grad_in1, + float *grad_in2, + const float *grad_out, + const float *in1, + const float *in2, + const int64_t *idx1, + const int64_t size) +{ + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + constexpr int fea_dim = 64; + + if (start_idx < size) { + int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; + int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; + + float4 src_in1, src_in2, src_grad_out, dst_grad_in2; + src_grad_out = reinterpret_cast(grad_out)[vec_idx2]; + src_in1 = reinterpret_cast(in1)[vec_idx1]; + src_in2 = reinterpret_cast(in2)[vec_idx2]; + int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4; + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_out.x * src_in2.x); + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_out.y * src_in2.y); + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_out.z * src_in2.z); + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_out.w * src_in2.w); + dst_grad_in2.x = src_grad_out.x * src_in1.x; + dst_grad_in2.y = src_grad_out.y * src_in1.y; + dst_grad_in2.z = src_grad_out.z * src_in1.z; + dst_grad_in2.w = src_grad_out.w * src_in1.w; + reinterpret_cast(grad_in2)[vec_idx2] = dst_grad_in2; + } +} + +__global__ void index_mul_2d_grad_float( + float *grad_in1, + float *grad_in2, + const float *grad_out, + const float *in1, + const float *in2, + const int64_t *idx1, + const int64_t size, + const int64_t fea_dim) +{ + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + const int stride = blockDim.x; + + if (start_idx < size) { + int64_t vec_idx1 = idx1[start_idx] * fea_dim; + int64_t vec_idx2 = start_idx * fea_dim; + + for (int i = tidx; i < fea_dim; i += stride) { + float src_in1 = in1[vec_idx1 + i]; + float src_in2 = in2[vec_idx2 + i]; + float src_grad_out = grad_out[vec_idx2 + i]; + grad_in2[vec_idx2 + i] = src_grad_out * src_in1; + gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_out * src_in2); + } + } +} + +__global__ void index_mul_2d_grad_half( + at::Half *grad_in1, + at::Half *grad_in2, + const at::Half *grad_out, + const at::Half *in1, + const at::Half *in2, + const int64_t *idx1, + const int64_t size, + const int64_t fea_dim) +{ + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + const int stride = blockDim.x; + + if (start_idx < size) { + int64_t vec_idx1 = idx1[start_idx] * fea_dim; + int64_t vec_idx2 = start_idx * fea_dim; + + for (int i = tidx; i < fea_dim; i += stride) { + float src_in1 = static_cast(in1[vec_idx1 + i]); + float src_in2 = static_cast(in2[vec_idx2 + i]); + float src_grad_out = static_cast(grad_out[vec_idx2 + i]); + grad_in2[vec_idx2 + i] = at::Half(src_grad_out * src_in1); + gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_out * src_in2)); + } + } +} + +__global__ void index_mul_2d_grad_grad_float_dim64( + float *grad_grad_out, + float *grad_in1, + float *grad_in2, + const float *grad_out, + const float *grad_grad_in1, + const float *grad_grad_in2, + const float *in1, + const float *in2, + const int64_t *idx1, + const int64_t size) +{ + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + constexpr int fea_dim = 64; + + if (start_idx < size) { + int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; + int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; + + float4 src_grad_grad_in1, src_in1, src_grad_grad_in2, src_in2, src_grad_out; + float4 dst_grad_grad_out, dst_grad_in2; + src_grad_grad_in1 = reinterpret_cast(grad_grad_in1)[vec_idx1]; + src_in1 = reinterpret_cast(in1)[vec_idx1]; + src_grad_grad_in2 = reinterpret_cast(grad_grad_in2)[vec_idx2]; + src_in2 = reinterpret_cast(in2)[vec_idx2]; + dst_grad_grad_out.x = src_grad_grad_in1.x * src_in2.x + src_grad_grad_in2.x * src_in1.x; + dst_grad_grad_out.y = src_grad_grad_in1.y * src_in2.y + src_grad_grad_in2.y * src_in1.y; + dst_grad_grad_out.z = src_grad_grad_in1.z * src_in2.z + src_grad_grad_in2.z * src_in1.z; + dst_grad_grad_out.w = src_grad_grad_in1.w * src_in2.w + src_grad_grad_in2.w * src_in1.w; + reinterpret_cast(grad_grad_out)[vec_idx2] = dst_grad_grad_out; + src_grad_out = reinterpret_cast(grad_out)[vec_idx2]; + int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4; + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_grad_in2.x * src_grad_out.x); + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_grad_in2.y * src_grad_out.y); + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_grad_in2.z * src_grad_out.z); + gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_grad_in2.w * src_grad_out.w); + dst_grad_in2.x = src_grad_grad_in1.x * src_grad_out.x; + dst_grad_in2.y = src_grad_grad_in1.y * src_grad_out.y; + dst_grad_in2.z = src_grad_grad_in1.z * src_grad_out.z; + dst_grad_in2.w = src_grad_grad_in1.w * src_grad_out.w; + reinterpret_cast(grad_in2)[vec_idx2] = dst_grad_in2; + } +} + +__global__ void index_mul_2d_grad_grad_float( + float *grad_grad_out, + float *grad_in1, + float *grad_in2, + const float *grad_out, + const float *grad_grad_in1, + const float *grad_grad_in2, + const float *in1, + const float *in2, + const int64_t *idx1, + const int64_t size, + const int64_t fea_dim) +{ + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + const int stride = blockDim.x; + + if (start_idx < size) { + int64_t vec_idx1 = idx1[start_idx] * fea_dim; + int64_t vec_idx2 = start_idx * fea_dim; + + for (int i = tidx; i < fea_dim; i += stride) { + float src_grad_grad_in1 = grad_grad_in1[vec_idx1 + i]; + float src_grad_grad_in2 = grad_grad_in2[vec_idx2 + i]; + float src_in1 = in1[vec_idx1 + i]; + float src_in2 = in2[vec_idx2 + i]; + float src_grad_out = grad_out[vec_idx2 + i]; + grad_grad_out[vec_idx2 + i] = src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1; + grad_in2[vec_idx2 + i] = src_grad_grad_in1 * src_grad_out; + gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_grad_in2 * src_grad_out); + } + } +} + +__global__ void index_mul_2d_grad_grad_half( + at::Half *grad_grad_out, + at::Half *grad_in1, + at::Half *grad_in2, + const at::Half *grad_out, + const at::Half *grad_grad_in1, + const at::Half *grad_grad_in2, + const at::Half *in1, + const at::Half *in2, + const int64_t *idx1, + const int64_t size, + const int64_t fea_dim) +{ + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int bidx = blockIdx.x; + const int start_idx = bidx * blockDim.y + tidy; + const int stride = blockDim.x; + + if (start_idx < size) { + int64_t vec_idx1 = idx1[start_idx] * fea_dim; + int64_t vec_idx2 = start_idx * fea_dim; + + for (int i = tidx; i < fea_dim; i += stride) { + float src_grad_grad_in1 = static_cast(grad_grad_in1[vec_idx1 + i]); + float src_grad_grad_in2 = static_cast(grad_grad_in2[vec_idx2 + i]); + float src_in1 = static_cast(in1[vec_idx1 + i]); + float src_in2 = static_cast(in2[vec_idx2 + i]); + float src_grad_out = static_cast(grad_out[vec_idx2 + i]); + grad_grad_out[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1); + grad_in2[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_grad_out); + gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_grad_in2 * src_grad_out)); + } + } +} + +void index_mul_2d_float_foward_cuda(at::Tensor &out, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1) { + const int64_t size = in2.size(0); + const int64_t fea_dim = in2.size(1); + if (size < 0){ + return; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (fea_dim == 64) { + const int BLOCK_THREADS_DIMX = 16; + const int BLOCK_THREADS_DIMY = 16; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); + + index_mul_2d_float_dim64<<>>( + out.data_ptr(), in1.data_ptr(), in2.data_ptr(), + idx1.data_ptr(), size); + } else { + const int BLOCK_THREADS_DIMX = 32; + const int BLOCK_THREADS_DIMY = 8; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); + + index_mul_2d_float<<>>( + out.data_ptr(), in1.data_ptr(), in2.data_ptr(), + idx1.data_ptr(), size, fea_dim); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1, + at::Tensor &grad_in2, + const at::Tensor &grad_out, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1) { + const int64_t size = in2.size(0); + const int64_t fea_dim = in2.size(1); + if (size < 0){ + return; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (fea_dim == 64) { + const int BLOCK_THREADS_DIMX = 16; + const int BLOCK_THREADS_DIMY = 16; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); + + index_mul_2d_grad_float_dim64<<>>( + grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), + in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); + + AT_CUDA_CHECK(cudaGetLastError()); + } else { + const int BLOCK_THREADS_DIMX = 32; + const int BLOCK_THREADS_DIMY = 8; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); + + index_mul_2d_grad_float<<>>( + grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), + in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); + } +} + +void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out, + at::Tensor &grad_in1, + at::Tensor &grad_in2, + const at::Tensor &grad_out, + const at::Tensor &grad_grad_in1, + const at::Tensor &grad_grad_in2, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1) { + const int64_t size = in2.size(0); + const int64_t fea_dim = in2.size(1); + if (size < 0){ + return; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (fea_dim == 64) { + const int BLOCK_THREADS_DIMX = 16; + const int BLOCK_THREADS_DIMY = 16; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); + + index_mul_2d_grad_grad_float_dim64<<>>( + grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), + grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), + in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); + } else { + const int BLOCK_THREADS_DIMX = 32; + const int BLOCK_THREADS_DIMY = 8; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); + + index_mul_2d_grad_grad_float<<>>( + grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), + grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), + in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void index_mul_2d_half_foward_cuda(at::Tensor &out, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1) { + const int64_t size = in2.size(0); + const int64_t fea_dim = in2.size(1); + if (size < 0){ + return; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int BLOCK_THREADS_DIMX = 32; + const int BLOCK_THREADS_DIMY = 8; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); + + index_mul_2d_half<<>>( + out.data_ptr(), in1.data_ptr(), in2.data_ptr(), + idx1.data_ptr(), size, fea_dim); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1, + at::Tensor &grad_in2, + const at::Tensor &grad_out, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1) { + const int64_t size = in2.size(0); + const int64_t fea_dim = in2.size(1); + if (size < 0){ + return; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int BLOCK_THREADS_DIMX = 32; + const int BLOCK_THREADS_DIMY = 8; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); + + index_mul_2d_grad_half<<>>( + grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), + in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); +} + +void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out, + at::Tensor &grad_in1, + at::Tensor &grad_in2, + const at::Tensor &grad_out, + const at::Tensor &grad_grad_in1, + const at::Tensor &grad_grad_in2, + const at::Tensor &in1, + const at::Tensor &in2, + const at::Tensor &idx1) { + const int64_t size = in2.size(0); + const int64_t fea_dim = in2.size(1); + if (size < 0){ + return; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int BLOCK_THREADS_DIMX = 32; + const int BLOCK_THREADS_DIMY = 8; + const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; + dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); + + index_mul_2d_grad_grad_half<<>>( + grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), + grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), + in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu index dc4e89cf5..62ff78ee3 100644 --- a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu @@ -67,8 +67,6 @@ void launch_(LaunchParams &launch_params, const bool configure_params } -// Create forward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); @@ -168,7 +166,6 @@ REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16); diff --git a/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu b/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu index 76d26a62e..8177f5380 100644 --- a/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu @@ -4,7 +4,7 @@ #include #include -#include +//#include #include #include diff --git a/apex/contrib/csrc/multihead_attn/cutlass b/apex/contrib/csrc/multihead_attn/cutlass deleted file mode 160000 index ed2ed4d66..000000000 --- a/apex/contrib/csrc/multihead_attn/cutlass +++ /dev/null @@ -1 +0,0 @@ -Subproject commit ed2ed4d667ce95e1371bd62db32b6a114e774336 diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu index d56b80768..510a291b9 100644 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu @@ -4,7 +4,7 @@ #include #include -#include +//#include #include #include @@ -17,7 +17,7 @@ namespace multihead_attn { namespace encdec { -namespace cublas_gemmex { +namespace rocblas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs_q, @@ -85,33 +85,71 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, char a_layout_n{'n'}; char b_layout_n{'n'}; - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - // Input Linear Q Fwd - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_q_dim, batches_q, embed_dim, - static_cast(&alpha), - static_cast(input_weights_q.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(inputs_q.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), q_lin_results_ptr, - CUDA_R_16F, output_lin_q_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + // Input Linear Q Fwd + TORCH_CUDABLAS_CHECK((hipblasGemmEx(handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + output_lin_q_dim, + batches_q, + embed_dim, + static_cast(&alpha), + static_cast(input_weights_q.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(inputs_q.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta), + q_lin_results_ptr, + HIP_R_16F, + output_lin_q_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + ))); + // Input Linear KV Fwd - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_kv_dim, batches_kv, - embed_dim, static_cast(&alpha), - static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(inputs_kv.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), k_lin_results_ptr, - CUDA_R_16F, output_lin_kv_dim, CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + output_lin_kv_dim, + batches_kv, + embed_dim, + static_cast(&alpha), + static_cast(input_weights_kv.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(inputs_kv.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta), + k_lin_results_ptr, + HIP_R_16F, + output_lin_kv_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, - static_cast(k_lin_results_ptr), lead_dim_kv, - batch_stride_kv, static_cast(q_lin_results_ptr), lead_dim_q, - batch_stride_q, beta, static_cast(softmax_results_ptr), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum( a_layout_t, + b_layout_n, + k_seq_len, + q_seq_len, + head_dim, + scale, + static_cast(k_lin_results_ptr), + lead_dim_kv, + batch_stride_kv, + static_cast(q_lin_results_ptr), + lead_dim_q, + batch_stride_q, + beta, + static_cast(softmax_results_ptr), + k_seq_len, + k_seq_len*q_seq_len, + attn_batches + ); // Padded Softmax bool softmax_success = false; @@ -145,29 +183,47 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, } // Matmul2 - gemm_switch_fp32accum( - a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, - static_cast(v_lin_results_ptr), lead_dim_kv, - batch_stride_kv, - (is_training) ? static_cast(dropout_results.data_ptr()) - : static_cast(softmax_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, - static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, - head_dim, attn_batches); + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + alpha, + static_cast(v_lin_results_ptr), + lead_dim_kv, + batch_stride_kv, + (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , + k_seq_len, + k_seq_len*q_seq_len, + beta, + static_cast(matmul2_results.data_ptr()), + head_dim*attn_batches, + head_dim, + attn_batches + ); // Output Linear - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(matmul2_results.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, - // CUBLAS_GEMM_ALGO1_TENSOR_OP)); - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + embed_dim, + batches_q, + embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(matmul2_results.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta), + static_cast(outputs.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_lin_q_results, input_lin_kv_results, @@ -240,54 +296,112 @@ std::vector bwd_cuda( char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; - char b_layout_t{'t'}; - - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + char b_layout_t{'t'}; + + rocblas_int flags = 0; + + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + #ifdef USE_ROCM + #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) + #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) + #if USE_GEMM_FLAGS_FP16_ALT_IMPL + #ifdef BACKWARD_PASS_GUARD + flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; + #endif + #endif + #endif // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches_q, + embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(output_grads.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta), + static_cast(output_lin_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches_q, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + embed_dim, + batches_q, + static_cast(&alpha), + static_cast(matmul2_results.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(output_grads.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta), + static_cast(output_weight_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + // MatMul2 Dgrad1 - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, - static_cast(v_lin_results_ptr), lead_dim_kv, - batch_stride_kv, static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, beta, - static_cast(matmul2_grads.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, attn_batches); - + gemm_switch_fp32accum( a_layout_t, + b_layout_n, + k_seq_len, + q_seq_len, + head_dim, + alpha, + static_cast(v_lin_results_ptr), + lead_dim_kv, + batch_stride_kv, + static_cast(output_lin_grads.data_ptr()), + head_dim*attn_batches, + head_dim, + beta, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + attn_batches + ); + // Matmul2 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, - lead_dim_kv, batch_stride_kv, attn_batches); - - // Apply Dropout Mask and Scale by Dropout Probability - apex_masked_scale_cuda( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0 / (1.0 - dropout_prob))); + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + alpha, + static_cast(output_lin_grads.data_ptr()), + head_dim*attn_batches, + head_dim, + static_cast(dropout_results.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + v_lin_grads_ptr, + lead_dim_kv, + batch_stride_kv, + attn_batches + ); + + // Apply Dropout Mask and Scale by Dropout Probability + apex_masked_scale_cuda( + static_cast(matmul2_grads.data_ptr()), + static_cast(matmul2_grads.data_ptr()), + static_cast(dropout_mask.data_ptr()), + dropout_elems, + (1.0 / (1.0 - dropout_prob))); // Softmax Grad bool softmax_success = false; @@ -299,70 +413,143 @@ std::vector bwd_cuda( assert(softmax_success); // Matmul1 Dgrad1 - gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, - k_seq_len, scale, k_lin_results_ptr, lead_dim_kv, - batch_stride_kv, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, - lead_dim_q, batch_stride_q, attn_batches); - + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + scale, + k_lin_results_ptr, + lead_dim_kv, + batch_stride_kv, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + q_lin_grads_ptr, + lead_dim_q, + batch_stride_q, + attn_batches + ); + // Matmul1 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, scale, q_lin_results_ptr, lead_dim_q, - batch_stride_q, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, - lead_dim_kv, batch_stride_kv, attn_batches); - - // Input Linear Q Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, output_lin_q_dim, - static_cast(&alpha), - static_cast(input_weights_q.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, - output_lin_q_dim, static_cast(&beta), - static_cast(input_q_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, - // CUBLAS_GEMM_ALGO10_TENSOR_OP)); - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // Input Linear Q Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_q_dim, batches_q, - static_cast(&alpha), - static_cast(inputs_q.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_q_dim, - static_cast(&beta), - static_cast(input_weight_q_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // Input Linear KV Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_kv, - output_lin_kv_dim, static_cast(&alpha), - static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(k_lin_grads_ptr), CUDA_R_16F, - output_lin_kv_dim, static_cast(&beta), - static_cast(input_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, - // CUBLAS_GEMM_ALGO10_TENSOR_OP)); - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // Input Linear KV Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_kv_dim, - batches_kv, static_cast(&alpha), - static_cast(inputs_kv.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(k_lin_grads_ptr), CUDA_R_16F, output_lin_kv_dim, - static_cast(&beta), - static_cast(input_weight_kv_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - return {input_q_grads, input_kv_grads, input_weight_q_grads, - input_weight_kv_grads, output_weight_grads}; + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + scale, + q_lin_results_ptr, + lead_dim_q, + batch_stride_q, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + k_lin_grads_ptr, + lead_dim_kv, + batch_stride_kv, + attn_batches + ); + + // Input Linear Q Dgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches_q, + output_lin_q_dim, + static_cast(&alpha), + static_cast(input_weights_q.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(q_lin_grads_ptr), + HIP_R_16F, + output_lin_q_dim, + static_cast(&beta), + static_cast(input_q_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + // Input Linear Q Wgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + output_lin_q_dim, + batches_q, + static_cast(&alpha), + static_cast(inputs_q.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(q_lin_grads_ptr), + HIP_R_16F, + output_lin_q_dim, + static_cast(&beta), + static_cast(input_weight_q_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + // Input Linear KV Dgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches_kv, + output_lin_kv_dim, + static_cast(&alpha), + static_cast(input_weights_kv.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(k_lin_grads_ptr), + HIP_R_16F, + output_lin_kv_dim, + static_cast(&beta), + static_cast(input_kv_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + // Input Linear KV Wgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + output_lin_kv_dim, + batches_kv, + static_cast(&alpha), + static_cast(inputs_kv.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(k_lin_grads_ptr), + HIP_R_16F, + output_lin_kv_dim, + static_cast(&beta), + static_cast(input_weight_kv_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + // TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + return { + input_q_grads, + input_kv_grads, + input_weight_q_grads, + input_weight_kv_grads, + output_weight_grads + }; } -} // end namespace cublas_gemmex -} // end namespace encdec +} // end namespace rocblas_gemmex +} // end namespace encdec } // end namespace multihead_attn + diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu index 10c9b8cef..56da36dcd 100644 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu @@ -4,7 +4,7 @@ #include #include -#include +//#include #include #include @@ -15,42 +15,48 @@ #include "layer_norm.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" +#include "type_shim.h" namespace multihead_attn { namespace encdec_norm_add { -namespace cublas_gemmex { - -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob) { - const int embed_dim = inputs_q.size(2); - const int sequences = inputs_q.size(1); - const int q_seq_len = inputs_q.size(0); - const int k_seq_len = inputs_kv.size(0); - const int batches_q = sequences * q_seq_len; - const int batches_kv = sequences * k_seq_len; - const int total_tokens_q = batches_q * embed_dim; - const int head_dim = embed_dim / heads; - const int output_lin_q_dim = embed_dim; - const int output_lin_kv_dim = 2 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim_q = attn_batches * head_dim; - const int lead_dim_kv = attn_batches * 2 * head_dim; - const int batch_stride_q = head_dim; - const int batch_stride_kv = 2 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta = 0.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // There is no reason to use more than one stream as every kernel is +namespace rocblas_gemmex { + +std::vector fwd_cuda( + bool use_time_mask, + bool is_training, + int heads, + torch::Tensor const& inputs_q, + torch::Tensor const& inputs_kv, + torch::Tensor const& lyr_nrm_gamma_weights, + torch::Tensor const& lyr_nrm_beta_weights, + torch::Tensor const& input_weights_q, + torch::Tensor const& input_weights_kv, + torch::Tensor const& output_weights, + const uint8_t* pad_mask, + float dropout_prob + ) +{ + const int embed_dim = inputs_q.size(2); + const int sequences = inputs_q.size(1); + const int q_seq_len = inputs_q.size(0); + const int k_seq_len = inputs_kv.size(0); + const int batches_q = sequences * q_seq_len; + const int batches_kv = sequences * k_seq_len; + const int total_tokens_q = batches_q * embed_dim; + const int head_dim = embed_dim / heads; + const int output_lin_q_dim = embed_dim; + const int output_lin_kv_dim = 2 * embed_dim; + const int attn_batches = heads * sequences; + const int lead_dim_q = attn_batches * head_dim; + const int lead_dim_kv = attn_batches * 2 *head_dim; + const int batch_stride_q = head_dim; + const int batch_stride_kv = 2 * head_dim; + const int dropout_elems = attn_batches * q_seq_len * k_seq_len; + const float alpha = 1.0; + const float beta = 0.0; + const float scale = 1.0 / sqrt(static_cast(head_dim)); + + // There is no reason to use more than one stream as every kernel is // sequentially dependent cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); @@ -96,7 +102,7 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, char a_layout_n{'n'}; char b_layout_n{'n'}; - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Layer Norm HostApplyLayerNorm( static_cast(lyr_nrm_results.data_ptr()), @@ -109,33 +115,68 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, static_cast(lyr_nrm_beta_weights.data_ptr())); // Input Linear Q Fwd - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_q_dim, batches_q, embed_dim, - static_cast(&alpha), - static_cast(input_weights_q.data_ptr()), CUDA_R_16F, - embed_dim, - // static_cast(inputs_q.data_ptr()), - static_cast(lyr_nrm_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), q_lin_results_ptr, - CUDA_R_16F, output_lin_q_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - + TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + output_lin_q_dim, + batches_q, + embed_dim, + static_cast(&alpha), + static_cast(input_weights_q.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + //static_cast(inputs_q.data_ptr()), + static_cast(lyr_nrm_results.data_ptr()), + HIP_R_16F /*b_type*/, + embed_dim, + static_cast(&beta), + q_lin_results_ptr, + HIP_R_16F /*c_type*/, + output_lin_q_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + // Input Linear KV Fwd - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_kv_dim, batches_kv, - embed_dim, static_cast(&alpha), - static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(inputs_kv.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), k_lin_results_ptr, - CUDA_R_16F, output_lin_kv_dim, CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + output_lin_kv_dim, + batches_kv, + embed_dim, + static_cast(&alpha), + static_cast(input_weights_kv.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + static_cast(inputs_kv.data_ptr()), + HIP_R_16F /*b_type*/, + embed_dim, + static_cast(&beta), + k_lin_results_ptr, + HIP_R_16F /*c_type*/, + output_lin_kv_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, - static_cast(k_lin_results_ptr), lead_dim_kv, - batch_stride_kv, static_cast(q_lin_results_ptr), lead_dim_q, - batch_stride_q, beta, static_cast(softmax_results_ptr), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum( a_layout_t, + b_layout_n, + k_seq_len, + q_seq_len, + head_dim, + scale, + static_cast(k_lin_results_ptr), + lead_dim_kv, + batch_stride_kv, + static_cast(q_lin_results_ptr), + lead_dim_q, + batch_stride_q, + beta, + static_cast(softmax_results_ptr), + k_seq_len, + k_seq_len*q_seq_len, + attn_batches + ); // Padded Softmax bool softmax_success = false; @@ -169,30 +210,49 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, } // Matmul2 - gemm_switch_fp32accum( - a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, - static_cast(v_lin_results_ptr), lead_dim_kv, - batch_stride_kv, - (is_training) ? static_cast(dropout_results.data_ptr()) - : static_cast(softmax_results.data_ptr()), - // static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, - static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, - head_dim, attn_batches); + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + alpha, + static_cast(v_lin_results_ptr), + lead_dim_kv, + batch_stride_kv, + (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()), + //static_cast(dropout_results.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + static_cast(matmul2_results.data_ptr()), + head_dim*attn_batches, + head_dim, + attn_batches + ); // Output Linear - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(matmul2_results.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(output_lin_results.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, - // CUBLAS_GEMM_ALGO1_TENSOR_OP)); - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // End-of-block Dropout-Add + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + embed_dim, + batches_q, + embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + static_cast(matmul2_results.data_ptr()), + HIP_R_16F /*b_type*/, + embed_dim, + static_cast(&beta), + static_cast(output_lin_results.data_ptr()), + HIP_R_16F /*c_type*/, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + // End-of-block Dropout-Add if (is_training) { apex_dropout_add_cuda( static_cast(output_lin_results.data_ptr()), @@ -207,7 +267,7 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, static_cast(outputs.data_ptr()), total_tokens_q); } - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {lyr_nrm_results, lyr_nrm_mean, @@ -293,61 +353,109 @@ std::vector bwd_cuda( char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; - char b_layout_t{'t'}; - - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - - // Dropout Add Backward - apex_masked_scale_cuda( - static_cast(output_grads.data_ptr()), - static_cast(dropout_add_grads.data_ptr()), - static_cast(dropout_add_mask.data_ptr()), total_tokens_q, - (1.0 / (1.0 - dropout_prob))); - + char b_layout_t{'t'}; + + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + + // Dropout Add Backward + apex_masked_scale_cuda( + static_cast(output_grads.data_ptr()), + static_cast(dropout_add_grads.data_ptr()), + static_cast(dropout_add_mask.data_ptr()), + total_tokens_q, + (1.0 / (1.0 - dropout_prob))); + // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(dropout_add_grads.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches_q, + embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + static_cast(dropout_add_grads.data_ptr()), + HIP_R_16F /*b_type*/, + embed_dim, + static_cast(&beta), + static_cast(output_lin_grads.data_ptr()), + HIP_R_16F /*c_type*/, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches_q, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(dropout_add_grads.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + embed_dim, + batches_q, + static_cast(&alpha), + static_cast(matmul2_results.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + static_cast(dropout_add_grads.data_ptr()), + HIP_R_16F /*b_type*/, + embed_dim, + static_cast(&beta), + static_cast(output_weight_grads.data_ptr()), + HIP_R_16F /*c_type*/, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + // MatMul2 Dgrad1 - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, - static_cast(v_lin_results_ptr), lead_dim_kv, - batch_stride_kv, static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, beta, - static_cast(matmul2_grads.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, attn_batches); - + gemm_switch_fp32accum( a_layout_t, + b_layout_n, + k_seq_len, + q_seq_len, + head_dim, + alpha, + static_cast(v_lin_results_ptr), + lead_dim_kv, + batch_stride_kv, + static_cast(output_lin_grads.data_ptr()), + head_dim*attn_batches, + head_dim, + beta, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + attn_batches + ); + // Matmul2 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, - lead_dim_kv, batch_stride_kv, attn_batches); - - // Apply Dropout Mask and Scale by Dropout Probability - apex_masked_scale_cuda( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0 / (1.0 - dropout_prob))); + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + alpha, + static_cast(output_lin_grads.data_ptr()), + head_dim*attn_batches, + head_dim, + static_cast(dropout_results.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + v_lin_grads_ptr, + lead_dim_kv, + batch_stride_kv, + attn_batches + ); + + // Apply Dropout Mask and Scale by Dropout Probability + apex_masked_scale_cuda( + static_cast(matmul2_grads.data_ptr()), + static_cast(matmul2_grads.data_ptr()), + static_cast(dropout_mask.data_ptr()), + dropout_elems, + (1.0 / (1.0 - dropout_prob))); // Softmax Grad bool softmax_success = false; @@ -359,87 +467,158 @@ std::vector bwd_cuda( assert(softmax_success); // Matmul1 Dgrad1 - gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, - k_seq_len, scale, k_lin_results_ptr, lead_dim_kv, - batch_stride_kv, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, - lead_dim_q, batch_stride_q, attn_batches); - + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + scale, + k_lin_results_ptr, + lead_dim_kv, + batch_stride_kv, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + q_lin_grads_ptr, + lead_dim_q, + batch_stride_q, + attn_batches + ); + // Matmul1 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, scale, q_lin_results_ptr, lead_dim_q, - batch_stride_q, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, - lead_dim_kv, batch_stride_kv, attn_batches); - - // Input Linear Q Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, output_lin_q_dim, - static_cast(&alpha), - static_cast(input_weights_q.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, - output_lin_q_dim, static_cast(&beta), - // static_cast(input_q_grads.data_ptr()), - static_cast(input_lin_q_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, - // CUBLAS_GEMM_ALGO10_TENSOR_OP)); - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // Input Linear Q Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_q_dim, batches_q, - static_cast(&alpha), - static_cast(inputs_q.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_q_dim, - static_cast(&beta), - static_cast(input_weight_q_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // Input Linear KV Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_kv, - output_lin_kv_dim, static_cast(&alpha), - static_cast(input_weights_kv.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(k_lin_grads_ptr), CUDA_R_16F, - output_lin_kv_dim, static_cast(&beta), - static_cast(input_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, - // CUBLAS_GEMM_ALGO10_TENSOR_OP)); - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // Input Linear KV Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_kv_dim, - batches_kv, static_cast(&alpha), - static_cast(inputs_kv.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(k_lin_grads_ptr), CUDA_R_16F, output_lin_kv_dim, - static_cast(&beta), - static_cast(input_weight_kv_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + scale, + q_lin_results_ptr, + lead_dim_q, + batch_stride_q, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + k_lin_grads_ptr, + lead_dim_kv, + batch_stride_kv, + attn_batches + ); + + // Input Linear Q Dgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches_q, + output_lin_q_dim, + static_cast(&alpha), + static_cast(input_weights_q.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + static_cast(q_lin_grads_ptr), + HIP_R_16F /*b_type*/, + output_lin_q_dim, + static_cast(&beta), + //static_cast(input_q_grads.data_ptr()), + static_cast(input_lin_q_grads.data_ptr()), + HIP_R_16F /*c_type*/, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + // Input Linear Q Wgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + output_lin_q_dim, + batches_q, + static_cast(&alpha), + static_cast(inputs_q.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + static_cast(q_lin_grads_ptr), + HIP_R_16F /*b_type*/, + output_lin_q_dim, + static_cast(&beta), + static_cast(input_weight_q_grads.data_ptr()), + HIP_R_16F /*c_type*/, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + // Input Linear KV Dgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches_kv, + output_lin_kv_dim, + static_cast(&alpha), + static_cast(input_weights_kv.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + static_cast(k_lin_grads_ptr), + HIP_R_16F /*b_type*/, + output_lin_kv_dim, + static_cast(&beta), + static_cast(input_kv_grads.data_ptr()), + HIP_R_16F /*c_type*/, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + // Input Linear KV Wgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + output_lin_kv_dim, + batches_kv, + static_cast(&alpha), + static_cast(inputs_kv.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + static_cast(k_lin_grads_ptr), + HIP_R_16F /*b_type*/, + output_lin_kv_dim, + static_cast(&beta), + static_cast(input_weight_kv_grads.data_ptr()), + HIP_R_16F /*c_type*/, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + // Fused Layer Norm Bwd with Residual Add - HostLayerNormGradient( - static_cast(input_lin_q_grads.data_ptr()), - static_cast(output_grads.data_ptr()), - static_cast(lyr_nrm_mean.data_ptr()), - static_cast(lyr_nrm_invvar.data_ptr()), inputs_q, - static_cast(batches_q), // n1 - static_cast(embed_dim), // n2 - static_cast(lyr_nrm_gamma_weights.data_ptr()), - static_cast(lyr_nrm_beta_weights.data_ptr()), 1.0e-5, - static_cast(input_q_grads.data_ptr()), - static_cast(lyr_nrm_gamma_grads.data_ptr()), - static_cast(lyr_nrm_beta_grads.data_ptr())); - - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + HostLayerNormGradient( + static_cast(input_lin_q_grads.data_ptr()), + static_cast(output_grads.data_ptr()), + static_cast(lyr_nrm_mean.data_ptr()), + static_cast(lyr_nrm_invvar.data_ptr()), + inputs_q, + static_cast(batches_q), // n1 + static_cast(embed_dim), // n2 + static_cast(lyr_nrm_gamma_weights.data_ptr()), + static_cast(lyr_nrm_beta_weights.data_ptr()), + 1.0e-5, + static_cast(input_q_grads.data_ptr()), + static_cast(lyr_nrm_gamma_grads.data_ptr()), + static_cast(lyr_nrm_beta_grads.data_ptr()) + ); + + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_q_grads, input_kv_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads, input_weight_q_grads, input_weight_kv_grads, output_weight_grads}; } -} // end namespace cublas_gemmex -} // end namespace encdec_norm_add +} // end namespace rocblas_gemmex +} // end namespace encdec_norm_add } // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/layer_norm.cuh b/apex/contrib/csrc/multihead_attn/layer_norm.cuh index 16c1eeef4..12ea20420 100644 --- a/apex/contrib/csrc/multihead_attn/layer_norm.cuh +++ b/apex/contrib/csrc/multihead_attn/layer_norm.cuh @@ -66,12 +66,12 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1, cuWelfordOnlineSum(curr, mu, sigma2, count); } // intra-warp reductions - for (int l = 0; l <= 4; ++l) { + for (int l = 0; l <= 4; ++l) { int srcLaneB = (threadIdx.x + (1 << l)) & 31; - U muB = WARP_SHFL(mu, srcLaneB); - U countB = WARP_SHFL(count, srcLaneB); - U sigma2B = WARP_SHFL(sigma2, srcLaneB); - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + U muB = WARP_SHFL(mu, srcLaneB, 32); + U countB = WARP_SHFL(count, srcLaneB, 32); + U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32); + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); } // threadIdx.x == 0 has correct values for each warp // inter-warp reductions @@ -107,8 +107,8 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1, sigma2 = ubuf[1] / U(n2); // don't care about final value of count, we know count == n2 } else { - mu = WARP_SHFL(mu, 0); - sigma2 = WARP_SHFL(sigma2 / U(n2), 0); + mu = WARP_SHFL(mu, 0, 32); + sigma2 = WARP_SHFL(sigma2 / U(n2), 0, 32); } } } @@ -157,12 +157,12 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals, cuWelfordOnlineSum(curr, mu, sigma2, count); } // intra-warp reductions - for (int l = 0; l <= 4; ++l) { + for (int l = 0; l <= 4; ++l) { int srcLaneB = (threadIdx.x + (1 << l)) & 31; - float muB = WARP_SHFL(mu, srcLaneB); - float countB = WARP_SHFL(count, srcLaneB); - float sigma2B = WARP_SHFL(sigma2, srcLaneB); - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + float muB = WARP_SHFL(mu, srcLaneB, 32); + float countB = WARP_SHFL(count, srcLaneB, 32); + float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32); + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); } // threadIdx.x == 0 has correct values for each warp // inter-warp reductions @@ -198,15 +198,28 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals, sigma2 = ubuf[1] / float(n2); // don't care about final value of count, we know count == n2 } else { - mu = WARP_SHFL(mu, 0); - sigma2 = WARP_SHFL(sigma2 / float(n2), 0); + mu = WARP_SHFL(mu, 0, 32); + sigma2 = WARP_SHFL(sigma2 / float(n2), 0, 32); } } } -template __device__ U rsqrt(U v) { return U(1) / sqrt(v); } -template <> __device__ float rsqrt(float v) { return rsqrtf(v); } -template <> __device__ double rsqrt(double v) { return rsqrt(v); } +template U rsqrt(U v) { + return U(1) / sqrt(v); +} +//template<> float rsqrt(float v) { +// return rsqrtf(v); +//} + +#if defined USE_ROCM +__device__ float rsqrt(float v) { return rsqrtf(v); } +#else +template<> float rsqrt(float v) { return rsqrtf(v); } +#endif +template<> double rsqrt(double v) { return rsqrt(v); } +// template __device__ U rsqrt(U v) { return U(1) / sqrt(v); } +// template <> __device__ float rsqrt(float v) { return rsqrtf(v); } +// template <> __device__ double rsqrt(double v) { return rsqrt(v); } // This is the un-specialized struct. Note that we prevent instantiation of // this struct by putting an undefined symbol in the function body so it won't @@ -248,7 +261,7 @@ cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean, // 1) blockDim.x == warpSize // 2) Tensors are contiguous // - for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { SharedMemory shared; U *buf = shared.getPointer(); U mu, sigma2; @@ -462,7 +475,7 @@ cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid, const T *__restrict__ input, const int n1, const int n2, const U *__restrict__ mean, const U *__restrict__ invvar, U epsilon, const T *gamma, T *grad_input) { - for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { U sum_loss1 = U(0); U sum_loss2 = U(0); const U c_mean = mean[i1]; @@ -508,9 +521,9 @@ cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid, } } // intra-warp reductions - for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { - sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); - sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32); + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32); } // inter-warp reductions if (blockDim.y > 1) { diff --git a/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu b/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu index f9a031d53..2adb6e93b 100644 --- a/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu @@ -4,7 +4,7 @@ #include #include -#include +//#include #include #include diff --git a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp b/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp index 5a116cd16..809620e0d 100644 --- a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp +++ b/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp @@ -119,7 +119,7 @@ torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads, } // end namespace fused_softmax namespace encdec { -namespace cublas_gemmex { +namespace rocblas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs_q, @@ -228,11 +228,11 @@ bwd(int heads, torch::Tensor const &output_grads, output_weights, dropout_mask, dropout_prob); } -} // end namespace cublas_gemmex +} // end namespace rocblas_gemmex } // end namespace encdec namespace encdec_norm_add { -namespace cublas_gemmex { +namespace rocblas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs_q, @@ -381,11 +381,12 @@ bwd(int heads, torch::Tensor const &output_grads, dropout_mask, dropout_add_mask, dropout_prob); } -} // end namespace cublas_gemmex + +} // end namespace rocblas_gemmex } // end namespace encdec_norm_add namespace self { -namespace cublas_gemmex { +namespace rocblas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, @@ -471,10 +472,10 @@ bwd(int heads, torch::Tensor const &output_grads, output_weights, dropout_mask, dropout_prob); } -} // end namespace cublas_gemmex +} // end namespace rocblas_gemmex } // end namespace self namespace self_bias { -namespace cublas_gemmex { +namespace rocblas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, @@ -492,7 +493,7 @@ std::vector bwd_cuda( // torch::Tensor const& input_biases, // torch::Tensor const& output_biases, torch::Tensor const &dropout_mask, float dropout_prob); - + std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, torch::Tensor const &input_weights, @@ -564,10 +565,10 @@ bwd(int heads, torch::Tensor const &output_grads, output_weights, dropout_mask, dropout_prob); } -} // end namespace cublas_gemmex +} // end namespace rocblas_gemmex } // namespace self_bias namespace self_bias_additive_mask { -namespace cublas_gemmex { +namespace rocblas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, @@ -587,7 +588,7 @@ std::vector bwd_cuda( // torch::Tensor const& input_biases, // torch::Tensor const& output_biases, torch::Tensor const &dropout_mask, float dropout_prob); - + std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, torch::Tensor const &input_weights, @@ -657,11 +658,11 @@ bwd(int heads, torch::Tensor const &output_grads, input_weights, output_weights, dropout_mask, dropout_prob); } -} // end namespace cublas_gemmex +} // end namespace rocblas_gemmex } // namespace self_bias_additive_mask namespace self_norm_add { -namespace cublas_gemmex { +namespace rocblas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, @@ -787,7 +788,7 @@ bwd(int heads, torch::Tensor const &output_grads, dropout_mask, dropout_add_mask, dropout_prob); } -} // end namespace cublas_gemmex +} // end namespace rocblas_gemmex } // end namespace self_norm_add } // end namespace multihead_attn @@ -802,31 +803,31 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Self Multihead Attention masked softmax dropout -- Forward."); m.def("mask_softmax_dropout_backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward."); - m.def("encdec_multihead_attn_forward", &multihead_attn::encdec::cublas_gemmex::fwd, + m.def("encdec_multihead_attn_forward", &multihead_attn::encdec::rocblas_gemmex::fwd, "Encdec Multihead Attention Forward."); - m.def("encdec_multihead_attn_backward", &multihead_attn::encdec::cublas_gemmex::bwd, + m.def("encdec_multihead_attn_backward", &multihead_attn::encdec::rocblas_gemmex::bwd, "Encdec Multihead Attention Backward."); - m.def("encdec_multihead_attn_norm_add_forward", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd, + m.def("encdec_multihead_attn_norm_add_forward", &multihead_attn::encdec_norm_add::rocblas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."); m.def( - "encdec_multihead_attn_norm_add_backward", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd, + "encdec_multihead_attn_norm_add_backward", &multihead_attn::encdec_norm_add::rocblas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."); - m.def("self_attn_forward", &multihead_attn::self::cublas_gemmex::fwd, + m.def("self_attn_forward", &multihead_attn::self::rocblas_gemmex::fwd, "Self Multihead Attention Forward."); - m.def("self_attn_backward", &multihead_attn::self::cublas_gemmex::bwd, + m.def("self_attn_backward", &multihead_attn::self::rocblas_gemmex::bwd, "Self Multihead Attention Backward."); - m.def("self_attn_bias_forward", &multihead_attn::self_bias::cublas_gemmex::fwd, + m.def("self_attn_bias_forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); - m.def("self_attn_bias_backward", &multihead_attn::self_bias::cublas_gemmex::bwd, + m.def("self_attn_bias_backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); - m.def("self_attn_bias_additive_mask_forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, + m.def("self_attn_bias_additive_mask_forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward."); m.def("self_attn_bias_additive_mask_backward", - &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, + &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward."); - m.def("self_attn_norm_add_forward", &multihead_attn::self_norm_add::cublas_gemmex::fwd, + m.def("self_attn_norm_add_forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward."); - m.def("self_attn_norm_add_backward", &multihead_attn::self_norm_add::cublas_gemmex::bwd, + m.def("self_attn_norm_add_backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward."); } diff --git a/apex/contrib/csrc/multihead_attn/philox.cuh b/apex/contrib/csrc/multihead_attn/philox.cuh index d2076ab5a..7660be679 100644 --- a/apex/contrib/csrc/multihead_attn/philox.cuh +++ b/apex/contrib/csrc/multihead_attn/philox.cuh @@ -7,19 +7,14 @@ class Philox { public: __device__ inline Philox(unsigned long long seed, unsigned long long subsequence, - unsigned long long offset) : STATE(0) { - //key.x = (unsigned int)seed; - //key.y = (unsigned int)(seed >> 32); - //counter = make_uint4(0, 0, 0, 0); - //counter.z = (unsigned int)(subsequence); - //counter.w = (unsigned int)(subsequence >> 32); - //STATE = 0; - //incr_n(offset / 4); - - key = reinterpret_cast(seed); - ull2 * tmp = reinterpret_cast(&counter); - tmp->x = offset / 4; - tmp->y = subsequence; + unsigned long long offset) { + key.x = (unsigned int)seed; + key.y = (unsigned int)(seed >> 32); + counter = make_uint4(0, 0, 0, 0); + counter.z = (unsigned int)(subsequence); + counter.w = (unsigned int)(subsequence >> 32); + STATE = 0; + incr_n(offset / 4); } __device__ inline uint4 operator()() { if (STATE == 0) { @@ -47,10 +42,6 @@ public: } private: - struct ull2 { - uint64_t x; - uint64_t y; - }; uint4 counter; uint4 output; uint2 key; @@ -68,47 +59,26 @@ private: return; ++counter.w; } - - __device__ uint4 incr128 (uint4 ctr) - { - uint4 res; - asm ("add.cc.u32 %0, %4, %8;\n\t" - "addc.cc.u32 %1, %5, %9;\n\t" - "addc.cc.u32 %2, %6, %10;\n\t" - "addc.u32 %3, %7, %11;\n\t" - : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) - : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w), - "n"(1), "n"(0), "n"(0), "n"(0)); - return res; - } - __device__ inline void incr() { - counter = incr128(counter); + if (++counter.x) + return; + if (++counter.y) + return; + if (++counter.z) + return; + ++counter.w; } __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, unsigned int *result_high) { *result_high = __umulhi(a, b); return a * b; } - __device__ uint2 mulhilo32_v2 (unsigned int a, unsigned int b) - { - uint2 *res; - unsigned long long tmp; - asm ("mul.wide.u32 %0, %1, %2;\n\t" - : "=l"(tmp) - : "r"(a), "r"(b)); - res = (uint2*)(&tmp); - return *res; - } __device__ inline uint4 single_round(uint4 ctr, uint2 key) { - //unsigned int hi0; - //unsigned int hi1; - //unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); - //unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); - //uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; - uint2 res0 = mulhilo32_v2(kPhiloxSA, ctr.x); - uint2 res1 = mulhilo32_v2(kPhiloxSB, ctr.z); - uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + unsigned int hi0; + unsigned int hi1; + unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); + unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); + uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; return ret; } static const unsigned long kPhilox10A = 0x9E3779B9; diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu index fad054ac9..f1128da54 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu @@ -4,7 +4,7 @@ #include #include -#include +//#include #include #include @@ -14,35 +14,36 @@ #include "dropout.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" +#include "type_shim.h" namespace multihead_attn { namespace self_bias_additive_mask { -namespace cublas_gemmex { +namespace rocblas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs, - torch::Tensor const &input_weights, - torch::Tensor const &output_weights, - torch::Tensor const &input_biases, - torch::Tensor const &output_biases, - const half *pad_mask, float dropout_prob) { - const int embed_dim = inputs.size(2); - const int sequences = inputs.size(1); - const int q_seq_len = inputs.size(0); - const int k_seq_len = q_seq_len; - const int batches = sequences * q_seq_len; - const int head_dim = embed_dim / heads; - const int output_lin_dim = 3 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim = attn_batches * 3 * head_dim; - const int batch_stride = 3 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta_zero = 0.0; - const float beta_one = 1.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // There is no reason to use more than one stream as every kernel is + int heads, torch::Tensor const& inputs, + torch::Tensor const& input_weights, + torch::Tensor const& output_weights, + torch::Tensor const& input_biases, + torch::Tensor const& output_biases, + const half* pad_mask, float dropout_prob) { + const int embed_dim = inputs.size(2); + const int sequences = inputs.size(1); + const int q_seq_len = inputs.size(0); + const int k_seq_len = q_seq_len; + const int batches = sequences * q_seq_len; + const int head_dim = embed_dim / heads; + const int output_lin_dim = 3 * embed_dim; + const int attn_batches = heads * sequences; + const int lead_dim = attn_batches * 3 * head_dim; + const int batch_stride = 3 * head_dim; + const int dropout_elems = attn_batches * q_seq_len * k_seq_len; + const float alpha = 1.0; + const float beta_zero = 0.0; + const float beta_one = 1.0; + const float scale = 1.0 / sqrt(static_cast(head_dim)); + + // There is no reason to use more than one stream as every kernel is // sequentially dependent cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); @@ -80,24 +81,49 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, char a_layout_n{'n'}; char b_layout_n{'n'}; - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Input Linear Fwd input_lin_results.copy_(input_biases); - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(inputs.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta_one), q_lin_results_ptr, - CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + output_lin_dim, + batches, + embed_dim, + static_cast(&alpha), + static_cast(input_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(inputs.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta_one), + q_lin_results_ptr, + HIP_R_16F, + output_lin_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, - static_cast(k_lin_results_ptr), lead_dim, batch_stride, - static_cast(q_lin_results_ptr), lead_dim, batch_stride, - beta_zero, static_cast(bmm1_results_ptr), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum( a_layout_t, + b_layout_n, + k_seq_len, + q_seq_len, + head_dim, + scale, + static_cast(k_lin_results_ptr), + lead_dim, + batch_stride, + static_cast(q_lin_results_ptr), + lead_dim, + batch_stride, + beta_zero, + static_cast(bmm1_results_ptr), + k_seq_len, + k_seq_len*q_seq_len, + attn_batches + ); + // Padded Softmax bool softmax_success = false; if (is_training) { @@ -122,29 +148,49 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, } // Matmul2 - gemm_switch_fp32accum( - a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - static_cast(dropout_results.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, beta_zero, - static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, - head_dim, attn_batches); + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + alpha, + static_cast(v_lin_results_ptr), + lead_dim, + batch_stride, + static_cast(dropout_results.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta_zero, + static_cast(matmul2_results.data_ptr()), + head_dim*attn_batches, + head_dim, + attn_batches + ); outputs.copy_(output_biases); // Output Linear - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(matmul2_results.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta_one), - static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, - // CUBLAS_GEMM_ALGO1_TENSOR_OP)); - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + embed_dim, + batches, + embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(matmul2_results.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta_one), + static_cast(outputs.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_lin_results, bmm1_results, dropout_results, dropout_mask, matmul2_results, outputs}; @@ -204,103 +250,197 @@ std::vector bwd_cuda( char b_layout_n{'n'}; char b_layout_t{'t'}; - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches, + embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(output_grads.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta), + static_cast(output_lin_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); - auto output_bias_grads = output_grads.view({-1, embed_dim}).sum(0, false); + // Output Linear Wgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + embed_dim, + batches, + static_cast(&alpha), + static_cast(matmul2_results.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(output_grads.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta), + static_cast(output_weight_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); // MatMul2 Dgrad1 - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, beta, - static_cast(matmul2_grads.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, attn_batches); - + gemm_switch_fp32accum( a_layout_t, + b_layout_n, + k_seq_len, + q_seq_len, + head_dim, + alpha, + static_cast(v_lin_results_ptr), + lead_dim, + batch_stride, + static_cast(output_lin_grads.data_ptr()), + head_dim*attn_batches, + head_dim, + beta, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + attn_batches + ); + // Matmul2 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); - - // Apply Dropout Mask and Scale by Dropout Probability + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + alpha, + static_cast(output_lin_grads.data_ptr()), + head_dim*attn_batches, + head_dim, + static_cast(dropout_results.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + v_lin_grads_ptr, + lead_dim, + batch_stride, + attn_batches + ); + + // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad dispatch_masked_scale_softmax_backward_recompute( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - reinterpret_cast(bmm1_results.data_ptr()), - reinterpret_cast(pad_mask.data_ptr()), - static_cast(dropout_mask.data_ptr()), - 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, - attn_batches * q_seq_len / sequences, attn_batches * q_seq_len, stream); - + static_cast(matmul2_grads.data_ptr()), + static_cast(matmul2_grads.data_ptr()), + reinterpret_cast(bmm1_results.data_ptr()), + reinterpret_cast(pad_mask.data_ptr()), + static_cast(dropout_mask.data_ptr()), + 1.0/(1.0-dropout_prob), + k_seq_len, + k_seq_len, + attn_batches*q_seq_len/sequences, + attn_batches*q_seq_len, + stream); + // Matmul1 Dgrad1 - gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, - k_seq_len, scale, k_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); - + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + scale, + k_lin_results_ptr, + lead_dim, + batch_stride, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + q_lin_grads_ptr, + lead_dim, + batch_stride, + attn_batches + ); + // Matmul1 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, scale, q_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); - // Input Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(input_lin_output_grads.data_ptr()), - // static_cast(q_lin_grads_ptr), - CUDA_R_16F, output_lin_dim, static_cast(&beta), - static_cast(input_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, - // CUBLAS_GEMM_ALGO10_TENSOR_OP)); - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // Input Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, - static_cast(&alpha), - static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, - static_cast(&beta), - static_cast(input_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - auto input_bias_grads = - input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + scale, + q_lin_results_ptr, + lead_dim, + batch_stride, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + k_lin_grads_ptr, + lead_dim, + batch_stride, + attn_batches + ); + + // Input Linear Dgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches, + output_lin_dim, + static_cast(&alpha), + static_cast(input_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(input_lin_output_grads.data_ptr()), + HIP_R_16F, + output_lin_dim, + static_cast(&beta), + static_cast(input_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + // Input Linear Wgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + output_lin_dim, + batches, + static_cast(&alpha), + static_cast(inputs.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(q_lin_grads_ptr), + HIP_R_16F, + output_lin_dim, + static_cast(&beta), + static_cast(input_weight_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_grads, input_weight_grads, output_weight_grads, input_bias_grads, output_bias_grads}; } -} // end namespace cublas_gemmex -} // namespace self_bias_additive_mask +} // end namespace rocblas_gemmex +} // end namespace self_bias_additive_mask } // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu index 227b4fed7..3b23ebb75 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu @@ -4,7 +4,7 @@ #include #include -#include +//#include #include #include @@ -14,10 +14,11 @@ #include "dropout.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" +#include "type_shim.h" namespace multihead_attn { namespace self_bias { -namespace cublas_gemmex { +namespace rocblas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, @@ -78,24 +79,51 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, char a_layout_n{'n'}; char b_layout_n{'n'}; - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + // Input Linear Fwd input_lin_results.copy_(input_biases); - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(inputs.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta_one), q_lin_results_ptr, - CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + output_lin_dim, + batches, + embed_dim, + static_cast(&alpha), + static_cast(input_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(inputs.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta_one), + q_lin_results_ptr, + HIP_R_16F, + output_lin_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, - static_cast(k_lin_results_ptr), lead_dim, batch_stride, - static_cast(q_lin_results_ptr), lead_dim, batch_stride, - beta_zero, static_cast(softmax_results_ptr), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum( a_layout_t, + b_layout_n, + k_seq_len, + q_seq_len, + head_dim, + scale, + static_cast(k_lin_results_ptr), + lead_dim, + batch_stride, + static_cast(q_lin_results_ptr), + lead_dim, + batch_stride, + beta_zero, + static_cast(softmax_results_ptr), + k_seq_len, + k_seq_len*q_seq_len, + attn_batches + ); + // Padded Softmax bool softmax_success = false; if (pad_mask == nullptr) { @@ -128,30 +156,49 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, } // Matmul2 - gemm_switch_fp32accum( - a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - (is_training) ? static_cast(dropout_results.data_ptr()) - : static_cast(softmax_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta_zero, - static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, - head_dim, attn_batches); + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + alpha, + static_cast(v_lin_results_ptr), + lead_dim, + batch_stride, + (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , + k_seq_len, + k_seq_len*q_seq_len, + beta_zero, + static_cast(matmul2_results.data_ptr()), + head_dim*attn_batches, + head_dim, + attn_batches + ); outputs.copy_(output_biases); // Output Linear - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(matmul2_results.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta_one), - static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, - // CUBLAS_GEMM_ALGO1_TENSOR_OP)); - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + embed_dim, + batches, + embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(matmul2_results.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta_one), + static_cast(outputs.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, outputs}; @@ -211,47 +258,105 @@ std::vector bwd_cuda( char b_layout_n{'n'}; char b_layout_t{'t'}; - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + rocblas_int flags = 0; + + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + #ifdef USE_ROCM + #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) + #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) + #if USE_GEMM_FLAGS_FP16_ALT_IMPL + #ifdef BACKWARD_PASS_GUARD + flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; + #endif + #endif + #endif // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches, + embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(output_grads.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta), + static_cast(output_lin_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - auto output_bias_grads = output_grads.view({-1, embed_dim}).sum(0, false); + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + embed_dim, + batches, + static_cast(&alpha), + static_cast(matmul2_results.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(output_grads.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta), + static_cast(output_weight_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); // MatMul2 Dgrad1 - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, beta, - static_cast(matmul2_grads.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum( a_layout_t, + b_layout_n, + k_seq_len, + q_seq_len, + head_dim, + alpha, + static_cast(v_lin_results_ptr), + lead_dim, + batch_stride, + static_cast(output_lin_grads.data_ptr()), + head_dim*attn_batches, + head_dim, + beta, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + attn_batches + ); // Matmul2 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); - - // Apply Dropout Mask and Scale by Dropout Probability + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + alpha, + static_cast(output_lin_grads.data_ptr()), + head_dim*attn_batches, + head_dim, + static_cast(dropout_results.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + v_lin_grads_ptr, + lead_dim, + batch_stride, + attn_batches + ); + + // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad dispatch_masked_scale_softmax_backward_stream( static_cast(matmul2_grads.data_ptr()), @@ -262,51 +367,95 @@ std::vector bwd_cuda( attn_batches * q_seq_len, stream); // Matmul1 Dgrad1 - gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, - k_seq_len, scale, k_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + scale, + k_lin_results_ptr, + lead_dim, + batch_stride, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + q_lin_grads_ptr, + lead_dim, + batch_stride, + attn_batches + ); // Matmul1 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, scale, q_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); - // Input Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(input_lin_output_grads.data_ptr()), - // static_cast(q_lin_grads_ptr), - CUDA_R_16F, output_lin_dim, static_cast(&beta), - static_cast(input_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, - // CUBLAS_GEMM_ALGO10_TENSOR_OP)); - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // Input Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, - static_cast(&alpha), - static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, - static_cast(&beta), - static_cast(input_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - auto input_bias_grads = - input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + scale, + q_lin_results_ptr, + lead_dim, + batch_stride, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + k_lin_grads_ptr, + lead_dim, + batch_stride, + attn_batches + ); + // Input Linear Dgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches, + output_lin_dim, + static_cast(&alpha), + static_cast(input_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(input_lin_output_grads.data_ptr()), + HIP_R_16F, + output_lin_dim, + static_cast(&beta), + static_cast(input_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + // Input Linear Wgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + output_lin_dim, + batches, + static_cast(&alpha), + static_cast(inputs.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(q_lin_grads_ptr), + HIP_R_16F, + output_lin_dim, + static_cast(&beta), + static_cast(input_weight_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_grads, input_weight_grads, output_weight_grads, input_bias_grads, output_bias_grads}; } -} // end namespace cublas_gemmex -} // namespace self_bias +} // end namespace rocblas_gemmex +} // end namespace self } // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu index 9701da7cc..35795cd85 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu @@ -4,7 +4,7 @@ #include #include -#include +//#include #include #include @@ -14,10 +14,11 @@ #include "dropout.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" +#include "type_shim.h" namespace multihead_attn { namespace self { -namespace cublas_gemmex { +namespace rocblas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, @@ -77,23 +78,47 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, char a_layout_n{'n'}; char b_layout_n{'n'}; - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Input Linear Fwd - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(inputs.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), q_lin_results_ptr, - CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + output_lin_dim, + batches, + embed_dim, + static_cast(&alpha), + static_cast(input_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(inputs.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta), + q_lin_results_ptr, + HIP_R_16F, + output_lin_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, - static_cast(k_lin_results_ptr), lead_dim, batch_stride, - static_cast(q_lin_results_ptr), lead_dim, batch_stride, - beta, static_cast(softmax_results_ptr), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum( a_layout_t, + b_layout_n, + k_seq_len, + q_seq_len, + head_dim, + scale, + static_cast(k_lin_results_ptr), + lead_dim, + batch_stride, + static_cast(q_lin_results_ptr), + lead_dim, + batch_stride, + beta, + static_cast(softmax_results_ptr), + k_seq_len, + k_seq_len*q_seq_len, + attn_batches + ); // Padded Softmax bool softmax_success = false; @@ -127,26 +152,46 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, } // Matmul2 - gemm_switch_fp32accum( - a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - (is_training) ? static_cast(dropout_results.data_ptr()) - : static_cast(softmax_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, - static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, - head_dim, attn_batches); + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + alpha, + static_cast(v_lin_results_ptr), + lead_dim, + batch_stride, + (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , + k_seq_len, + k_seq_len*q_seq_len, + beta, + static_cast(matmul2_results.data_ptr()), + head_dim*attn_batches, + head_dim, + attn_batches + ); // Output Linear - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(matmul2_results.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(outputs.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + embed_dim, + batches, + embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(matmul2_results.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta), + static_cast(outputs.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); return {input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, outputs}; @@ -204,54 +249,99 @@ std::vector bwd_cuda( char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; - char b_layout_t{'t'}; - - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + char b_layout_t{'t'}; // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches, + embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(output_grads.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta), + static_cast(output_lin_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(output_grads.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + embed_dim, + batches, + static_cast(&alpha), + static_cast(matmul2_results.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(output_grads.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(&beta), + static_cast(output_weight_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + // MatMul2 Dgrad1 - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, beta, - static_cast(matmul2_grads.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, attn_batches); - + gemm_switch_fp32accum( a_layout_t, + b_layout_n, + k_seq_len, + q_seq_len, + head_dim, + alpha, + static_cast(v_lin_results_ptr), + lead_dim, + batch_stride, + static_cast(output_lin_grads.data_ptr()), + head_dim*attn_batches, + head_dim, + beta, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + attn_batches + ); + // Matmul2 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); - - // Apply Dropout Mask and Scale by Dropout Probability - apex_masked_scale_cuda( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0 / (1.0 - dropout_prob))); + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + alpha, + static_cast(output_lin_grads.data_ptr()), + head_dim*attn_batches, + head_dim, + static_cast(dropout_results.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + v_lin_grads_ptr, + lead_dim, + batch_stride, + attn_batches + ); + + // Apply Dropout Mask and Scale by Dropout Probability + apex_masked_scale_cuda( + static_cast(matmul2_grads.data_ptr()), + static_cast(matmul2_grads.data_ptr()), + static_cast(dropout_mask.data_ptr()), + dropout_elems, + (1.0 / (1.0 - dropout_prob))); // Softmax Grad bool softmax_success = false; @@ -263,45 +353,97 @@ std::vector bwd_cuda( assert(softmax_success); // Matmul1 Dgrad1 - gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, - k_seq_len, scale, k_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); - + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + scale, + k_lin_results_ptr, + lead_dim, + batch_stride, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + q_lin_grads_ptr, + lead_dim, + batch_stride, + attn_batches + ); + // Matmul1 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, scale, q_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); - - // Input Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, - output_lin_dim, static_cast(&beta), - static_cast(input_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // Input Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, - static_cast(&alpha), - static_cast(inputs.data_ptr()), CUDA_R_16F, embed_dim, - static_cast(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, - static_cast(&beta), - static_cast(input_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - return {input_grads, input_weight_grads, output_weight_grads}; + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + scale, + q_lin_results_ptr, + lead_dim, + batch_stride, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + k_lin_grads_ptr, + lead_dim, + batch_stride, + attn_batches + ); + + // Input Linear Dgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches, + output_lin_dim, + static_cast(&alpha), + static_cast(input_weights.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(q_lin_grads_ptr), + HIP_R_16F, + output_lin_dim, + static_cast(&beta), + static_cast(input_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + // Input Linear Wgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + output_lin_dim, + batches, + static_cast(&alpha), + static_cast(inputs.data_ptr()), + HIP_R_16F, + embed_dim, + static_cast(q_lin_grads_ptr), + HIP_R_16F, + output_lin_dim, + static_cast(&beta), + static_cast(input_weight_grads.data_ptr()), + HIP_R_16F, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + return { + input_grads, + input_weight_grads, + output_weight_grads + }; } -} // end namespace cublas_gemmex +} // end namespace rocblas_gemmex } // end namespace self } // end namespace multihead_attn + diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu index 7a6ec15cd..17150aea9 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu @@ -4,7 +4,7 @@ #include #include -#include +//#include #include #include @@ -15,10 +15,11 @@ #include "layer_norm.cuh" #include "softmax.cuh" #include "strided_batched_gemm.cuh" +#include "type_shim.h" namespace multihead_attn { namespace self_norm_add { -namespace cublas_gemmex { +namespace rocblas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, @@ -88,7 +89,6 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, char a_layout_n{'n'}; char b_layout_n{'n'}; - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); // Layer Norm HostApplyLayerNorm( static_cast(lyr_nrm_results.data_ptr()), @@ -101,23 +101,47 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, static_cast(lyr_nrm_beta_weights.data_ptr())); // Input Linear Fwd - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, - // static_cast(inputs.data_ptr()), - static_cast(lyr_nrm_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(&beta), q_lin_results_ptr, - CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + output_lin_dim, + batches, + embed_dim, + static_cast(&alpha), + static_cast(input_weights.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + //static_cast(inputs.data_ptr()), + static_cast(lyr_nrm_results.data_ptr()), + HIP_R_16F /*b_type*/, + embed_dim, + static_cast(&beta), + q_lin_results_ptr, + HIP_R_16F /*c_type*/, + output_lin_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, - static_cast(k_lin_results_ptr), lead_dim, batch_stride, - static_cast(q_lin_results_ptr), lead_dim, batch_stride, - beta, static_cast(softmax_results_ptr), k_seq_len, - k_seq_len * q_seq_len, attn_batches); + gemm_switch_fp32accum( a_layout_t, + b_layout_n, + k_seq_len, + q_seq_len, + head_dim, + scale, + static_cast(k_lin_results_ptr), + lead_dim, + batch_stride, + static_cast(q_lin_results_ptr), + lead_dim, + batch_stride, + beta, + static_cast(softmax_results_ptr), + k_seq_len, + k_seq_len*q_seq_len, + attn_batches + ); // Padded Softmax bool softmax_success = false; @@ -151,27 +175,50 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, } // Matmul2 - gemm_switch_fp32accum( - a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - (is_training) ? static_cast(dropout_results.data_ptr()) - : static_cast(softmax_results.data_ptr()), - // static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, - static_cast(matmul2_results.data_ptr()), head_dim * attn_batches, - head_dim, attn_batches); + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + alpha, + static_cast(v_lin_results_ptr), + lead_dim, + batch_stride, + (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , + //static_cast(dropout_results.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + static_cast(matmul2_results.data_ptr()), + head_dim*attn_batches, + head_dim, + attn_batches + ); // Output Linear - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(matmul2_results.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(output_lin_results.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // End-of-block Dropout-Add + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + embed_dim, + batches, + embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + static_cast(matmul2_results.data_ptr()), + HIP_R_16F /*b_type*/, + embed_dim, + static_cast(&beta), + static_cast(output_lin_results.data_ptr()), + HIP_R_16F /*c_type*/, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + + // End-of-block Dropout-Add if (is_training) { apex_dropout_add_cuda( static_cast(output_lin_results.data_ptr()), @@ -186,8 +233,6 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, static_cast(outputs.data_ptr()), total_tokens); } - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, dropout_add_mask, outputs}; @@ -255,10 +300,8 @@ std::vector bwd_cuda( char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; - char b_layout_t{'t'}; - - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - + char b_layout_t{'t'}; + // Dropout Add Backward apex_masked_scale_cuda( static_cast(output_grads.data_ptr()), @@ -267,49 +310,96 @@ std::vector bwd_cuda( (1.0 / (1.0 - dropout_prob))); // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(dropout_add_grads.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches, + embed_dim, + static_cast(&alpha), + static_cast(output_weights.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + static_cast(dropout_add_grads.data_ptr()), + HIP_R_16F /*b_type*/, + embed_dim, + static_cast(&beta), + static_cast(output_lin_grads.data_ptr()), + HIP_R_16F /*c_type*/, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(dropout_add_grads.data_ptr()), - CUDA_R_16F, embed_dim, static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), CUDA_R_16F, - embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + embed_dim, + batches, + static_cast(&alpha), + static_cast(matmul2_results.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + static_cast(dropout_add_grads.data_ptr()), + HIP_R_16F /*b_type*/, + embed_dim, + static_cast(&beta), + static_cast(output_weight_grads.data_ptr()), + HIP_R_16F /*c_type*/, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); // MatMul2 Dgrad1 - gemm_switch_fp32accum( - a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha, - static_cast(v_lin_results_ptr), lead_dim, batch_stride, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, beta, - static_cast(matmul2_grads.data_ptr()), k_seq_len, - k_seq_len * q_seq_len, attn_batches); - + gemm_switch_fp32accum( a_layout_t, + b_layout_n, + k_seq_len, + q_seq_len, + head_dim, + alpha, + static_cast(v_lin_results_ptr), + lead_dim, + batch_stride, + static_cast(output_lin_grads.data_ptr()), + head_dim*attn_batches, + head_dim, + beta, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + attn_batches + ); + // Matmul2 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim * attn_batches, head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); - - // Apply Dropout Mask and Scale by Dropout Probability - apex_masked_scale_cuda( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0 / (1.0 - dropout_prob))); + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + alpha, + static_cast(output_lin_grads.data_ptr()), + head_dim*attn_batches, + head_dim, + static_cast(dropout_results.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + v_lin_grads_ptr, + lead_dim, + batch_stride, + attn_batches + ); + + // Apply Dropout Mask and Scale by Dropout Probability + apex_masked_scale_cuda( + static_cast(matmul2_grads.data_ptr()), + static_cast(matmul2_grads.data_ptr()), + static_cast(dropout_mask.data_ptr()), + dropout_elems, + (1.0 / (1.0 - dropout_prob))); // Softmax Grad bool softmax_success = false; @@ -321,44 +411,90 @@ std::vector bwd_cuda( assert(softmax_success); // Matmul1 Dgrad1 - gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, - k_seq_len, scale, k_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, q_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); - + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + scale, + k_lin_results_ptr, + lead_dim, + batch_stride, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + q_lin_grads_ptr, + lead_dim, + batch_stride, + attn_batches + ); + // Matmul1 Dgrad2 - gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, - q_seq_len, scale, q_lin_results_ptr, lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, k_seq_len * q_seq_len, beta, k_lin_grads_ptr, - lead_dim, batch_stride, attn_batches); - - // Input Linear Dgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, - output_lin_dim, static_cast(&beta), - // static_cast(input_grads.data_ptr()), - static_cast(input_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, - // CUBLAS_GEMM_ALGO10_TENSOR_OP)); - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // Input Linear Wgrad - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, - static_cast(&alpha), - // static_cast(inputs.data_ptr()), - static_cast(lyr_nrm_results.data_ptr()), CUDA_R_16F, - embed_dim, static_cast(q_lin_grads_ptr), CUDA_R_16F, - output_lin_dim, static_cast(&beta), - static_cast(input_weight_grads.data_ptr()), CUDA_R_16F, embed_dim, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + scale, + q_lin_results_ptr, + lead_dim, + batch_stride, + static_cast(matmul2_grads.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + k_lin_grads_ptr, + lead_dim, + batch_stride, + attn_batches + ); + + // Input Linear Dgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + embed_dim, + batches, + output_lin_dim, + static_cast(&alpha), + static_cast(input_weights.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + static_cast(q_lin_grads_ptr), + HIP_R_16F /*b_type*/, + output_lin_dim, + static_cast(&beta), + //static_cast(input_grads.data_ptr()), + static_cast(input_lin_grads.data_ptr()), + HIP_R_16F /*c_type*/, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); + + // Input Linear Wgrad + TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + embed_dim, + output_lin_dim, + batches, + static_cast(&alpha), + //static_cast(inputs.data_ptr()), + static_cast(lyr_nrm_results.data_ptr()), + HIP_R_16F /*a_type*/, + embed_dim, + static_cast(q_lin_grads_ptr), + HIP_R_16F /*b_type*/, + output_lin_dim, + static_cast(&beta), + static_cast(input_weight_grads.data_ptr()), + HIP_R_16F /*c_type*/, + embed_dim, + HIPBLAS_COMPUTE_32F, + HIPBLAS_GEMM_DEFAULT /*algo*/ + )); // Fused Layer Norm Bwd with Residual Add HostLayerNormGradient( @@ -374,12 +510,12 @@ std::vector bwd_cuda( static_cast(lyr_nrm_gamma_grads.data_ptr()), static_cast(lyr_nrm_beta_grads.data_ptr())); - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads, input_weight_grads, output_weight_grads}; } -} // end namespace cublas_gemmex -} // end namespace self_norm_add +} // end namespace rocblas_gemmex +} // end namespace self_norm_add } // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/softmax.cuh b/apex/contrib/csrc/multihead_attn/softmax.cuh index 34ed0dc93..6e7da0f71 100644 --- a/apex/contrib/csrc/multihead_attn/softmax.cuh +++ b/apex/contrib/csrc/multihead_attn/softmax.cuh @@ -15,7 +15,15 @@ #include #include #include - +#include +#include +#include + +#ifdef USE_ROCM +#define APEX_WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width) +#else +#define APEX_WARP_SHFL_XOR __shfl_xor_sync +#endif namespace { template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); @@ -161,7 +169,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); + val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { @@ -186,7 +194,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); + sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); } } @@ -228,7 +236,7 @@ bool warp_softmax_kernel(int log2_elements, int &warp_size, softmax_forward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -402,7 +410,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4( float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); + val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { @@ -426,7 +434,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4( for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); + sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); } } auto seeds = at::cuda::philox::unpack(philox_args); @@ -564,7 +572,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward( float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); + val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { @@ -588,7 +596,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward( for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); + sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); } } curandStatePhilox4_32_10_t state; @@ -647,7 +655,7 @@ bool warp_additive_masked_softmax_dropout_kernel( &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -874,7 +882,7 @@ __global__ void additive_masked_softmax_warp_forward( float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); + val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { @@ -899,7 +907,7 @@ __global__ void additive_masked_softmax_warp_forward( for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); + sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); } } @@ -941,7 +949,7 @@ bool warp_additive_masked_softmax_kernel( additive_masked_softmax_forward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -1164,7 +1172,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src, float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); + val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { @@ -1189,7 +1197,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src, for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); + sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); } } @@ -1233,7 +1241,7 @@ bool warp_masked_softmax_kernel( masked_softmax_forward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -1414,7 +1422,7 @@ __global__ void time_masked_softmax_warp_forward( float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); + val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { @@ -1439,7 +1447,7 @@ __global__ void time_masked_softmax_warp_forward( for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); + sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); } } @@ -1481,7 +1489,7 @@ bool warp_time_masked_softmax_kernel( time_masked_softmax_forward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -1734,7 +1742,7 @@ void dispatch_masked_scale_softmax_backward_masked_out( // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. @@ -1848,7 +1856,8 @@ void dispatch_masked_scale_softmax_backward_masked_out_stream( // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); + // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -2151,7 +2160,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute( float val[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE); + val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); } #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { @@ -2176,7 +2185,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute( for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); + sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); } } @@ -2247,7 +2256,7 @@ bool masked_scale_softmax_warp_backward_recompute_kernel( is_log_softmax> &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -2385,7 +2394,8 @@ void dispatch_masked_scale_softmax_backward_stream( // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); + // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -2586,7 +2596,7 @@ void dispatch_softmax_backward_fused_native( // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. @@ -2756,7 +2766,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); + sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); } } @@ -2798,7 +2808,7 @@ bool warp_softmax_backward_kernel( softmax_backward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -2990,7 +3000,7 @@ masked_softmax_warp_backward(__half *gradInput, const __half *grad, for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); + sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); } } @@ -3041,7 +3051,7 @@ bool warp_masked_softmax_backward_kernel( masked_softmax_backward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; diff --git a/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh b/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh index a9e114731..5d45efb3c 100644 --- a/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh +++ b/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh @@ -4,16 +4,20 @@ #include #include -#include +//#include #include +#include + //#include #include #include -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/wmma_gemm_traits.h" +//#include "cutlass/cutlass.h" +//#include "cutlass/gemm/gemm.h" +//#include "cutlass/gemm/wmma_gemm_traits.h" + +#include "type_shim.h" namespace { cublasOperation_t convertTransToCublasOperation(char trans) { @@ -29,560 +33,57 @@ cublasOperation_t convertTransToCublasOperation(char trans) { } } -void CublasStridedBatchedGemm( - char transa, char transb, long m, long n, long k, - float alpha, const half *a, long lda, long strideA, const half *b, long ldb, - long strideB, float beta, half *c, long ldc, long strideC, long batchCount, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) { - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - float fAlpha = alpha; - float fBeta = beta; - // THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx( - handle, opa, opb, (int)m, (int)n, (int)k, (void *)&fAlpha, a, CUDA_R_16F, - (int)lda, strideA, b, CUDA_R_16F, (int)ldb, strideB, (void *)&fBeta, c, - CUDA_R_16F, (int)ldc, strideC, (int)batchCount, CUDA_R_32F, algo)); - // THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); +void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, + float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, + float beta, half *c, long ldc, long strideC, long batchCount, hipblasGemmAlgo_t algo) { + cublasOperation_t opa = convertTransToCublasOperation(transa); + cublasOperation_t opb = convertTransToCublasOperation(transb); + + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + cublasSetStream(handle, stream); + float fAlpha = alpha; + float fBeta = beta; + + TORCH_CUDABLAS_CHECK(hipblasGemmStridedBatchedEx( + handle, + opa, + opb, + (int)m, + (int)n, + (int)k, + (void*)&fAlpha, + a, + HIP_R_16F /*a_type*/, + (int)lda, + strideA, + b, + HIP_R_16F /*b_type*/, + (int)ldb, + strideB, + (void*)&fBeta, + c, + HIP_R_16F /*c_type*/, + (int)ldc, + strideC, + (int)batchCount, + HIPBLAS_COMPUTE_32F, + algo)); } -} // namespace - -template -void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k, - float alpha, const half *a, long lda, long strideA, - const half *b, long ldb, long strideB, float beta, - half *c, long ldc, long strideC, long batchCount) { - // printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: - // %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\n", - // ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, - // SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta); - typedef cutlass::gemm::WmmaGemmTraits< - A_LAYOUT, B_LAYOUT, cutlass::Shape<32, 16, 16>, half, half, half, - cutlass::gemm::LinearScaling, float, - typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp< - typename cutlass::Shape<32, 16, 16>>::Shape, - typename cutlass::Shape<16, 16, 16>, - SRC_A, // kScalarsPerLdgA_ - SRC_B, // kScalarsPerLdgB_ - SRC_A, // KScalarsPerLdsA_ - SRC_B, // KScalarsPerLdsB_ - DST_C, // kScalarsPerLdgCAndStgD_ - DST_C / 2, // kScalarsPerStsD_ - DST_C / 2 // kScalarsPerLdsD_ - > - WmmaGemmTraits; - - typedef cutlass::gemm::Gemm Gemm; - typename Gemm::Params params; - - int result = params.initialize( - m, // M dimension for each batch - n, // N dimension for each batch - k, // K dimension for each batch - alpha, // scalar alpha - a, lda, - strideA, // distance in memory between the first element of neighboring - // batch - b, ldb, - strideB, // distance in memory between the first element of neighboring - // batch - beta, // scalar beta - c, // source matrix C - ldc, - strideC, // distance in memory between the first element of neighboring - // batch - c, // destination matrix C (may be different memory than source C matrix) - ldc, - strideC, // distance in memory between the first element of neighboring - // batch - batchCount); - - AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object."); - - // batchCount in cutlass batched GEMM kernels maps to gridDim.z, which is - // limited to 16 bits. To implement batched GEMM with larger batch size, we - // fragment it into smaller batched GEMMs of gridDim.z <= 64k - long batchesLeft = batchCount; - long iterBatchCount = std::min(batchesLeft, static_cast((1 << 16) - 1)); - do { - // printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: - // %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f - // TotalBatches: %ld iterBatchCount %ld\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), - // ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, - // ldc, strideA, strideB, strideC, alpha, beta, batchesLeft, iterBatchCount); - int result = - params.initialize(m, // M dimension for each batch - n, // N dimension for each batch - k, // K dimension for each batch - alpha, // scalar alpha - a, lda, - strideA, // distance in memory between the first - // element of neighboring batch - b, ldb, - strideB, // distance in memory between the first - // element of neighboring batch - beta, // scalar beta - c, // source matrix C - ldc, - strideC, // distance in memory between the first - // element of neighboring batch - c, // destination matrix C (may be different memory - // than source C matrix) - ldc, - strideC, // distance in memory between the first - // element of neighboring batch - iterBatchCount); - - AT_ASSERTM(result == 0, - "Failed to initialize CUTLASS Gemm::Params object."); - // Launch the CUTLASS GEMM kernel. - C10_CUDA_CHECK(Gemm::launch(params, stream)); - - // Update batched GEMM params based on completed work - batchesLeft = batchesLeft - iterBatchCount; - a += iterBatchCount * strideA; - b += iterBatchCount * strideB; - c += iterBatchCount * strideC; - ; - - iterBatchCount = std::min(batchesLeft, static_cast((1 << 16) - 1)); - - } while (batchesLeft > 0); -} - -namespace { -void gemm_switch_fp32accum(char transa, char transb, long m, - long n, long k, float alpha, const half *a, long lda, - long strideA, const half *b, long ldb, long strideB, - float beta, half *c, long ldc, long strideC, - long batchCount) { +void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, + float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, + float beta, half *c, long ldc, long strideC, long batchCount) { auto stream = c10::cuda::getCurrentCUDAStream(); - // printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa == - // 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta); - if ((transa == 't') && (transb == 'n')) { - if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { - CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, - strideA, b, ldb, strideB, beta, c, ldc, strideC, - batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); - } - else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else { - CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, - strideA, b, ldb, strideB, beta, c, ldc, strideC, - batchCount); - } - } else if ((transa == 'n') && (transb == 'n')) { - if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { - CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, - strideA, b, ldb, strideB, beta, c, ldc, strideC, - batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); - } - else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else { - CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, - strideA, b, ldb, strideB, beta, c, ldc, strideC, - batchCount); - } - } else if ((transa == 'n') && (transb == 't')) { - if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { - CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, - strideA, b, ldb, strideB, beta, c, ldc, strideC, - batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); - } - else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { - CutlassGemm_FP32Accum( - stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, - ldc, strideC, batchCount); - } else { - CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, - strideA, b, ldb, strideB, beta, c, ldc, strideC, - batchCount); - } + if ( (transa == 't') && (transb == 'n') ) { + if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, HIPBLAS_GEMM_DEFAULT); } + else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, HIPBLAS_GEMM_DEFAULT); } + } else if ( (transa == 'n') && (transb == 'n') ) { + if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, HIPBLAS_GEMM_DEFAULT); } + else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, HIPBLAS_GEMM_DEFAULT); } + } else if ( (transa == 'n') && (transb == 't') ) { + if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, HIPBLAS_GEMM_DEFAULT); } + else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, HIPBLAS_GEMM_DEFAULT); } } else { AT_ASSERTM(false, "TransA and TransB are invalid"); } @@ -619,7 +120,8 @@ void HgemmStridedBatched(char transa, char transb, long m, long n, long k, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float beta, half *c, long ldc, long strideC, - long batchCount) { + half *d, long ldd, long strideD, long batchCount) { + if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX)) @@ -632,7 +134,9 @@ void HgemmStridedBatched(char transa, char transb, long m, adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, + // gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, + // b, ldb, strideB, beta, c, ldc, strideC, batchCount); + gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } diff --git a/apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp b/apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp new file mode 100644 index 000000000..ae480f5f4 --- /dev/null +++ b/apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp @@ -0,0 +1,48 @@ + +#include +#include +#include +#include + +#include + +#define NCCL_CHECK(cmd) \ + do { \ + ncclResult_t result = cmd; \ + if (result != ncclSuccess) { \ + std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ + std::to_string(__LINE__) + ", " + \ + std::string(ncclGetErrorString(result)); \ + TORCH_CHECK(false, err); \ + } \ + } while (0) + +void *nccl_alloc_plug(size_t size, int device, void *stream) { + void *ptr; + NCCL_CHECK(ncclMemAlloc(&ptr, size)); + return ptr; +} + +void nccl_free_plug(void *ptr, std::size_t size, int device, void *stream) { + NCCL_CHECK(ncclMemFree(ptr)); +} + +std::shared_ptr nccl_allocator; + +void maybe_init() { + if (!nccl_allocator) { + nccl_allocator = std::make_shared< + torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator>( + nccl_alloc_plug, nccl_free_plug); + } +} + +std::shared_ptr +get_nccl_allocator() { + maybe_init(); + return nccl_allocator; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("get_nccl_allocator", []() { return get_nccl_allocator(); }); +}; \ No newline at end of file diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu index c386dcfb7..89b29c92d 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu +++ b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu @@ -5,7 +5,11 @@ #include #include #include +#ifdef USE_ROCM +#include "rccl/rccl.h" +#else #include "nccl.h" +#endif /* * This file implements a crude but effective mechanism for copying data between tenors owned by different ranks diff --git a/apex/contrib/csrc/nccl_p2p/nccl_version.cpp b/apex/contrib/csrc/nccl_p2p/nccl_version.cpp new file mode 100644 index 000000000..421d4ab03 --- /dev/null +++ b/apex/contrib/csrc/nccl_p2p/nccl_version.cpp @@ -0,0 +1,11 @@ +// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +// This file is used to check the version of NCCL detected. +#include + +#include + +std::tuple get_nccl_version(); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("get_nccl_version", &get_nccl_version); +} \ No newline at end of file diff --git a/apex/contrib/csrc/nccl_p2p/nccl_version_check.cu b/apex/contrib/csrc/nccl_p2p/nccl_version_check.cu new file mode 100644 index 000000000..2b44d2eb6 --- /dev/null +++ b/apex/contrib/csrc/nccl_p2p/nccl_version_check.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +// This file is used to check the version of NCCL detected. +#include +#include + + +std::tuple get_nccl_version() { + return { int(NCCL_MAJOR), int(NCCL_MINOR) }; +} \ No newline at end of file diff --git a/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp b/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp index c03c90fe3..e8ffa4aa1 100644 --- a/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp +++ b/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp @@ -76,11 +76,11 @@ void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_ou } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("strided_check_finite", &strided_check_finite, "Strided finite check."); - m.def("adam", &adam, "Adam optimized CUDA implementation."); - m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation."); - m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation."); - m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation."); - m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats."); - m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats."); + m.def("strided_check_finite", &strided_check_finite, "Strided finite check.", py::call_guard()); + m.def("adam", &adam, "Adam optimized CUDA implementation.", py::call_guard()); + m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.", py::call_guard()); + m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.", py::call_guard()); + m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.", py::call_guard()); + m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.", py::call_guard()); + m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.", py::call_guard()); } diff --git a/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu b/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu index 75bd35a10..18b60264a 100644 --- a/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu +++ b/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu @@ -235,12 +235,12 @@ void fused_adam_cuda( } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (g.scalar_type() == at::ScalarType::Half) { + if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) { //all other values should be fp32 for half gradients AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); //dispatch is done on the gradient type using namespace at; // prevents "toString is undefined" errors - DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel", using accscalar_t = at::acc_type; adam_cuda_kernel<<>>( p.DATA_PTR(), @@ -309,12 +309,12 @@ void fused_adam_cuda_mt( size_t tl_sz = tensor_lists.size(); AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5"); - if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) { + if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half || tensor_lists[3][0].scalar_type() == at::ScalarType::BFloat16) { //alher values should be fp32 for half gradients AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); //dich is done on the gradient type if (tl_sz == 5) { - DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", using accscalar_t = at::acc_type; multi_tensor_apply<5>( BLOCK_SIZE, @@ -331,7 +331,7 @@ void fused_adam_cuda_mt( decay); ); } else { - DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", using accscalar_t = at::acc_type; multi_tensor_apply<4>( BLOCK_SIZE, @@ -847,13 +847,13 @@ void fused_reversible_adam_cuda( } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (g.scalar_type() == at::ScalarType::Half) { + if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) { //all other values should be fp32 for half gradients AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); //dispatch is done on the gradient type using namespace at; // prevents "toString is undefined" errors if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) { - DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel", using accscalar_t = at::acc_type; reversible_adam_cuda_kernel<<>>( p.DATA_PTR(), @@ -872,7 +872,7 @@ void fused_reversible_adam_cuda( ); } else { AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type"); - DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_e5m2_kernel", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_e5m2_kernel", using accscalar_t = at::acc_type; reversible_adam_cuda_kernel<<>>( p.DATA_PTR(), @@ -992,12 +992,12 @@ void fused_maybe_adam_undo_cuda( } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (g.scalar_type() == at::ScalarType::Half) { + if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) { //all other values should be fp32 for half gradients AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); //dispatch is done on the gradient type using namespace at; // prevents "toString is undefined" errors - DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel", using accscalar_t = at::acc_type; maybe_adam_undo_cuda_kernel<<>>( overflow_flag.numel() ? overflow_flag.DATA_PTR() : NULL, diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp index 7ae13d514..f586b8d52 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp @@ -1,20 +1,36 @@ #include void multi_tensor_fused_adam_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_beta1, - at::Tensor per_tensor_beta2, - at::Tensor per_tensor_bias_correction, - at::Tensor per_tensor_eps, - at::Tensor per_tensor_weight_decay, - float lr, - float grad_scale, - int step, - int mode); + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor grad_scale, + float lr, float beta1, float beta2, float eps, int step, int mode, + int bias_correction, float weight_decay); + +void multi_tensor_fused_adam_capturable_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor grad_scale, + at::Tensor lr, float beta1, float beta2, float eps, at::Tensor step, + int mode, int bias_correction, float weight_decay); + +void multi_tensor_fused_adam_with_param_remainders_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor grad_scale, + float lr, float beta1, float beta2, float eps, int step, int mode, + int bias_correction, float weight_decay); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda, - "Multi tensor Adam optimized CUDA implementation."); -} + "CUDA kernels for multi-tensor Adam, " + "with param copy", + py::call_guard()); + m.def("multi_tensor_fused_adam_capturable", + &multi_tensor_fused_adam_capturable_cuda, + "CUDA kernels for multi-tensor Adam, " + "with param copy, capturable for CUDA graph", + py::call_guard()); + m.def("multi_tensor_fused_adam_with_param_remainders", + &multi_tensor_fused_adam_with_param_remainders_cuda, + "CUDA kernel for multi-tensor Adam, " + "with stored param remainders and param copy", + py::call_guard()); +} \ No newline at end of file diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu index a702adab6..817c3e4e6 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu @@ -6,159 +6,434 @@ // #include #include + #include -#include "type_shim.h" + #include "multi_tensor_apply.cuh" +#include "type_shim.h" #define BLOCK_SIZE 512 #define ILP 4 -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; +template +__device__ __forceinline__ bool is_aligned(const T* p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; +template +__device__ __forceinline__ void load_store(T* dst, const T* src, + int dst_offset = 0, + int src_offset = 0) { + typedef + typename std::aligned_storage::type LT; + ((LT*)dst)[dst_offset] = ((const LT*)src)[src_offset]; } -typedef enum{ - ADAM_MODE_0 =0, // eps under square root - ADAM_MODE_1 =1 // eps outside square root +// (1-t)*x + t*y +__device__ __forceinline__ float lerp(float t, float x, float y) { + // See https://developer.nvidia.com/blog/lerp-faster-cuda/ + return fma(t, y, fma(-t, x, x)); +} + +typedef enum { + ADAM_MODE_0 = 0, // L2 regularization mode + ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW) } adamMode_t; -template -struct DistAdamFunctor -{ +/* Multi-tensor Adam + * + * Updates params in-place and outputs a copy with a desired datatype. + */ +template +struct DistAdamFunctor { + // Vectorized local compute + __device__ __forceinline__ static void local_step( + T p[ILP], T m[ILP], T v[ILP], const GRAD_T g[ILP], const float grad_scale, + const float beta1, const float beta2, const float beta1_correction, + const float beta2_correction, const float eps, const float lr, + adamMode_t mode, const float weight_decay) { + if (mode == ADAM_MODE_0) { // L2 +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float scaled_grad = (g[ii] * grad_scale) + (weight_decay * p[ii]); + float next_m = lerp(beta1, scaled_grad, m[ii]); + float next_v = lerp(beta2, scaled_grad * scaled_grad, v[ii]); + float next_m_unbiased = next_m / beta1_correction; + float next_v_unbiased = next_v / beta2_correction; + float denom = sqrtf(next_v_unbiased) + eps; + float update = next_m_unbiased / denom; + m[ii] = next_m; + v[ii] = next_v; + p[ii] -= lr * update; + } + } else { // weight decay +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float scaled_grad = g[ii] * grad_scale; + float next_m = lerp(beta1, scaled_grad, m[ii]); + float next_v = lerp(beta2, scaled_grad * scaled_grad, v[ii]); + float next_m_unbiased = next_m / beta1_correction; + float next_v_unbiased = next_v / beta2_correction; + float denom = sqrtf(next_v_unbiased) + eps; + float update = (next_m_unbiased / denom) + (weight_decay * p[ii]); + m[ii] = next_m; + v[ii] = next_v; + p[ii] -= lr * update; + } + } + } + __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata& tl, - const float* per_tensor_beta1, - const float* per_tensor_beta2, - const int* per_tensor_bias_correction, - const float* per_tensor_eps, - const float* per_tensor_weight_decay, - const float lr, - const float grad_scale, - const int step, - adamMode_t mode) - { + int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl, + const float* grad_scale_ptr, const float beta1, const float beta2, + const float beta1_correction, const float beta2_correction, + const float eps, const float lr, adamMode_t mode, + const float weight_decay) const { int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int tensor_num = tl.start_tensor_this_launch + tensor_loc; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; - float b1 = per_tensor_beta1[tensor_num]; - float b2 = per_tensor_beta2[tensor_num]; - float eps = per_tensor_eps[tensor_num]; - float decay = per_tensor_weight_decay[tensor_num]; + const float grad_scale = *grad_scale_ptr; - float beta1_correction = 1.0f, beta2_correction = 1.0f; - if (per_tensor_bias_correction[tensor_num] == 1) { - beta1_correction = 1 - std::pow(b1, step); - beta2_correction = 1 - std::pow(b2, step); + T* p_in = (T*)tl.addresses[0][tensor_loc]; + p_in += chunk_idx * chunk_size; + T* m = (T*)tl.addresses[1][tensor_loc]; + m += chunk_idx * chunk_size; + T* v = (T*)tl.addresses[2][tensor_loc]; + v += chunk_idx * chunk_size; + const GRAD_T* g = (GRAD_T*)tl.addresses[3][tensor_loc]; + g += chunk_idx * chunk_size; + PARAM_OUT_T* p_out = (PARAM_OUT_T*)tl.addresses[4][tensor_loc]; + p_out += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + n = chunk_size < n ? chunk_size : n; + + const bool aligned = (n % ILP == 0 && is_aligned(p_in) && is_aligned(m) && + is_aligned(v) && is_aligned(g) && is_aligned(p_out)); + + for (int i_start = threadIdx.x * ILP; i_start < n; + i_start += blockDim.x * ILP) { + T local_p[ILP]; + T local_m[ILP]; + T local_v[ILP]; + GRAD_T local_g[ILP]; + PARAM_OUT_T local_p_out[ILP]; + + // Load + if (aligned) { + load_store(local_p, p_in + i_start); + load_store(local_m, m + i_start); + load_store(local_v, v + i_start); + load_store(local_g, g + i_start); + } else { +#pragma unroll + for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { + if (i < n) { + local_p[ii] = p_in[i]; + local_m[ii] = m[i]; + local_v[ii] = v[i]; + local_g[ii] = g[i]; + } else { + local_p[ii] = 0; + local_m[ii] = 0; + local_v[ii] = 0; + local_g[ii] = 0; + } + } + } + + // Local compute + local_step(local_p, local_m, local_v, local_g, grad_scale, beta1, beta2, + beta1_correction, beta2_correction, eps, lr, mode, + weight_decay); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + local_p_out[ii] = static_cast(local_p[ii]); + } + + // Store + if (aligned) { + load_store(p_in + i_start, local_p); + load_store(m + i_start, local_m); + load_store(v + i_start, local_v); + load_store(p_out + i_start, local_p_out); + } else { +#pragma unroll + for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { + if (i < n) { + p_in[i] = local_p[ii]; + m[i] = local_m[ii]; + v[i] = local_v[ii]; + p_out[i] = local_p_out[ii]; + } + } + } } + } +}; - T* p = (T *)tl.addresses[0][tensor_loc]; - p += chunk_idx*chunk_size; - T* m = (T *)tl.addresses[1][tensor_loc]; - m += chunk_idx*chunk_size; - T* v = (T *)tl.addresses[2][tensor_loc]; - v += chunk_idx*chunk_size; - GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc]; - g += chunk_idx*chunk_size; - GRAD_T* p_copy = NULL; - if (DEPTH == 5) { - p_copy = (GRAD_T *)tl.addresses[4][tensor_loc]; - p_copy += chunk_idx*chunk_size; +/* Multi-tensor Adam with CUDA Graph Support + * + * Updates params in-place and outputs a copy with a desired datatype. + */ +template +struct DistAdamCapturableFunctor { + // Vectorized local compute + __device__ __forceinline__ static void local_step( + T p[ILP], T m[ILP], T v[ILP], const GRAD_T g[ILP], const float grad_scale, + const float beta1, const float beta2, const float beta1_correction, + const float beta2_correction, const float eps, const float lr, + adamMode_t mode, const float weight_decay) { + if (mode == ADAM_MODE_0) { // L2 +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float scaled_grad = (g[ii] * grad_scale) + (weight_decay * p[ii]); + float next_m = lerp(beta1, scaled_grad, m[ii]); + float next_v = lerp(beta2, scaled_grad * scaled_grad, v[ii]); + float next_m_unbiased = next_m / beta1_correction; + float next_v_unbiased = next_v / beta2_correction; + float denom = sqrtf(next_v_unbiased) + eps; + float update = next_m_unbiased / denom; + m[ii] = next_m; + v[ii] = next_v; + p[ii] -= lr * update; + } + } else { // weight decay +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float scaled_grad = g[ii] * grad_scale; + float next_m = lerp(beta1, scaled_grad, m[ii]); + float next_v = lerp(beta2, scaled_grad * scaled_grad, v[ii]); + float next_m_unbiased = next_m / beta1_correction; + float next_v_unbiased = next_v / beta2_correction; + float denom = sqrtf(next_v_unbiased) + eps; + float update = (next_m_unbiased / denom) + (weight_decay * p[ii]); + m[ii] = next_m; + v[ii] = next_v; + p[ii] -= lr * update; + } } + } + + __device__ __forceinline__ void operator()( + int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl, + const float* grad_scale_ptr, const float beta1, const float beta2, + const int* step, const int bias_correction, const float eps, + const float* lr, adamMode_t mode, const float weight_decay) const { + assert(noop_gmem); + assert(grad_scale_ptr); + assert(step); + assert(lr); + + if (*noop_gmem == 1) return; + + float beta1_correction = 1.0f, beta2_correction = 1.0f; + if (bias_correction == 1) { + beta1_correction = 1 - pow(beta1, *step); + beta2_correction = 1 - pow(beta2, *step); + } + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + const float grad_scale = *grad_scale_ptr; + + T* p_in = (T*)tl.addresses[0][tensor_loc]; + p_in += chunk_idx * chunk_size; + T* m = (T*)tl.addresses[1][tensor_loc]; + m += chunk_idx * chunk_size; + T* v = (T*)tl.addresses[2][tensor_loc]; + v += chunk_idx * chunk_size; + const GRAD_T* g = (GRAD_T*)tl.addresses[3][tensor_loc]; + g += chunk_idx * chunk_size; + PARAM_OUT_T* p_out = (PARAM_OUT_T*)tl.addresses[4][tensor_loc]; + p_out += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + n = chunk_size < n ? chunk_size : n; - n -= chunk_idx*chunk_size; - - T incoming_p[ILP]; - T incoming_m[ILP]; - T incoming_v[ILP]; - T incoming_g[ILP]; - - // to make things simple, we put aligned case in a different code path - if (n % ILP == 0 && - chunk_size % ILP == 0 && - is_aligned(p) && - is_aligned(m) && - is_aligned(v) && - is_aligned(g) && - is_aligned(p_copy)) { - for (int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) { - // load - GRAD_T tmp_g[ILP]; - load_store(incoming_p, p, 0, i_start); - load_store(incoming_m, m, 0, i_start); - load_store(incoming_v, v, 0, i_start); - load_store(tmp_g, g, 0, i_start); + const bool aligned = (n % ILP == 0 && is_aligned(p_in) && is_aligned(m) && + is_aligned(v) && is_aligned(g) && is_aligned(p_out)); + + for (int i_start = threadIdx.x * ILP; i_start < n; + i_start += blockDim.x * ILP) { + T local_p[ILP]; + T local_m[ILP]; + T local_v[ILP]; + GRAD_T local_g[ILP]; + PARAM_OUT_T local_p_out[ILP]; + + // Load + if (aligned) { + load_store(local_p, p_in + i_start); + load_store(local_m, m + i_start); + load_store(local_v, v + i_start); + load_store(local_g, g + i_start); + } else { #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - incoming_g[ii] = static_cast(tmp_g[ii]); - T scaled_grad = incoming_g[ii]/grad_scale; - incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad; - incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; - T next_m_unbiased = incoming_m[ii] / beta1_correction; - T next_v_unbiased = incoming_v[ii] / beta2_correction; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(next_v_unbiased + eps); - else // Mode 1 - denom = sqrtf(next_v_unbiased) + eps; - float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]); - incoming_p[ii] = incoming_p[ii] - (lr * update); - if (DEPTH == 5) tmp_g[ii] = static_cast(incoming_p[ii]); + for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { + if (i < n) { + local_p[ii] = p_in[i]; + local_m[ii] = m[i]; + local_v[ii] = v[i]; + local_g[ii] = g[i]; + } else { + local_p[ii] = 0; + local_m[ii] = 0; + local_v[ii] = 0; + local_g[ii] = 0; + } } - load_store(p, incoming_p, i_start, 0); - load_store(m, incoming_m, i_start, 0); - load_store(v, incoming_v, i_start, 0); - if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0); } - } else { - for (int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) { + // Local compute + local_step(local_p, local_m, local_v, local_g, grad_scale, beta1, beta2, + beta1_correction, beta2_correction, eps, *lr, mode, + weight_decay); #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - incoming_p[ii] = 0; - incoming_m[ii] = 0; - incoming_v[ii] = 0; - incoming_g[ii] = 0; - - int i = i_start + threadIdx.x + ii*blockDim.x; - if (i < n && i < chunk_size) { - incoming_p[ii] = p[i]; - incoming_m[ii] = m[i]; - incoming_v[ii] = v[i]; - incoming_g[ii] = static_cast(g[i]); + for (int ii = 0; ii < ILP; ii++) { + local_p_out[ii] = static_cast(local_p[ii]); + } + + // Store + if (aligned) { + load_store(p_in + i_start, local_p); + load_store(m + i_start, local_m); + load_store(v + i_start, local_v); + load_store(p_out + i_start, local_p_out); + } else { +#pragma unroll + for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { + if (i < n) { + p_in[i] = local_p[ii]; + m[i] = local_m[ii]; + v[i] = local_v[ii]; + p_out[i] = local_p_out[ii]; } } + } + } + } +}; +/* Functor for multi-tensor Adam with implicit main params + * + * If params are BF16 and optimizer state is FP32, it is not necessary + * to store FP32 main params. Instead, store 16-bit param remainder + * and combine with BF16 param to reconstruct the FP32 main param. + */ +template +struct DistAdamWithParamRemaindersFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int* noop_gmem, TensorListMetadata<6>& tl, + const float* grad_scale_ptr, const float beta1, const float beta2, + const float beta1_correction, const float beta2_correction, + const float eps, const float lr, adamMode_t mode, + const float weight_decay) const { + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + const float grad_scale = *grad_scale_ptr; + + int16_t* p_in = (int16_t*)tl.addresses[0][tensor_loc]; + p_in += chunk_idx * chunk_size; + int16_t* p_rem = (int16_t*)tl.addresses[1][tensor_loc]; + p_rem += chunk_idx * chunk_size; + float* m = (float*)tl.addresses[2][tensor_loc]; + m += chunk_idx * chunk_size; + float* v = (float*)tl.addresses[3][tensor_loc]; + v += chunk_idx * chunk_size; + const GRAD_T* g = (GRAD_T*)tl.addresses[4][tensor_loc]; + g += chunk_idx * chunk_size; + int16_t* p_out = (int16_t*)tl.addresses[5][tensor_loc]; + p_out += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + n = chunk_size < n ? chunk_size : n; + + const bool aligned = + (n % ILP == 0 && is_aligned(p_in) && is_aligned(p_rem) && + is_aligned(m) && is_aligned(v) && is_aligned(g) && is_aligned(p_out)); + + for (int i_start = threadIdx.x * ILP; i_start < n; + i_start += blockDim.x * ILP) { + union fp32_or_int162 { + float fp32; + int16_t int16[2]; + }; + fp32_or_int162 local_p[ILP]; + int16_t local_p_bf16[ILP]; + int16_t local_p_rem[ILP]; + float local_m[ILP]; + float local_v[ILP]; + GRAD_T local_g[ILP]; + + // Load + if (aligned) { + load_store(local_p_bf16, p_in + i_start); + load_store(local_p_rem, p_rem + i_start); + load_store(local_m, m + i_start); + load_store(local_v, v + i_start); + load_store(local_g, g + i_start); + } else { #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - int j = i_start + threadIdx.x + ii*blockDim.x; - - if (j < n && j < chunk_size) { - T scaled_grad = incoming_g[ii]/grad_scale; - m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad; - v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; - T next_m_unbiased = m[j] / beta1_correction; - T next_v_unbiased = v[j] / beta2_correction; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(next_v_unbiased + eps); - else // Mode 1 - denom = sqrtf(next_v_unbiased) + eps; - float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]); - p[j] = incoming_p[ii] - (lr * update); - if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j]; + for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { + if (i < n) { + local_p_bf16[ii] = p_in[i]; + local_p_rem[ii] = p_rem[i]; + local_m[ii] = m[i]; + local_v[ii] = v[i]; + local_g[ii] = g[i]; + } else { + local_p_bf16[ii] = 0; + local_p_rem[ii] = 0; + local_m[ii] = 0; + local_v[ii] = 0; + local_g[ii] = 0; + } + } + } + + // Reconstruct FP32 params +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (local_p_rem[ii] < 0) local_p_bf16[ii]--; // Undo rounding + local_p[ii].int16[1] = local_p_bf16[ii]; + local_p[ii].int16[0] = local_p_rem[ii]; + } + + // Local compute + using LocalFunctor = DistAdamFunctor; + LocalFunctor::local_step(reinterpret_cast(local_p), local_m, + local_v, local_g, grad_scale, beta1, beta2, + beta1_correction, beta2_correction, eps, lr, + mode, weight_decay); + + // Split into BF16 params (rounded-to-nearest) and remainders +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + local_p_bf16[ii] = local_p[ii].int16[1]; + local_p_rem[ii] = local_p[ii].int16[0]; + if (local_p_rem[ii] < 0) local_p_bf16[ii]++; // Round up + } + + // Store + if (aligned) { + load_store(p_rem + i_start, local_p_rem); + load_store(m + i_start, local_m); + load_store(v + i_start, local_v); + load_store(p_out + i_start, local_p_bf16); + } else { +#pragma unroll + for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { + if (i < n) { + p_rem[i] = local_p_rem[ii]; + m[i] = local_m[ii]; + v[i] = local_v[ii]; + p_out[i] = local_p_bf16[ii]; } } } @@ -167,62 +442,95 @@ struct DistAdamFunctor }; void multi_tensor_fused_adam_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, // p, m, v, g, p_copy - at::Tensor per_tensor_beta1, - at::Tensor per_tensor_beta2, - at::Tensor per_tensor_bias_correction, - at::Tensor per_tensor_eps, - at::Tensor per_tensor_weight_decay, - float lr, - float grad_scale, - int step, - int mode) -{ + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, // p_in, m, v, g, p_out + at::Tensor grad_scale, float lr, float beta1, float beta2, float eps, + int step, int mode, int bias_correction, float weight_decay) { using namespace at; + // Expect p_in, m, v, g, p_out size_t tl_sz = tensor_lists.size(); - AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5"); - - if (tl_sz == 5) { - DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g - using accscalar_t = at::acc_type; - multi_tensor_apply<5>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - DistAdamFunctor<5, accscalar_t, scalar_t_0>(), - per_tensor_beta1.DATA_PTR(), - per_tensor_beta2.DATA_PTR(), - per_tensor_bias_correction.DATA_PTR(), - per_tensor_eps.DATA_PTR(), - per_tensor_weight_decay.DATA_PTR(), - lr, - grad_scale, - step, - (adamMode_t) mode); - ); - } else { - DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g - using accscalar_t = at::acc_type; - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - DistAdamFunctor<4, accscalar_t, scalar_t_0>(), - per_tensor_beta1.DATA_PTR(), - per_tensor_beta2.DATA_PTR(), - per_tensor_bias_correction.DATA_PTR(), - per_tensor_eps.DATA_PTR(), - per_tensor_weight_decay.DATA_PTR(), - lr, - grad_scale, - step, - (adamMode_t) mode); - ); + TORCH_CHECK(tl_sz == 5, "expected tensor lists of size 5"); + const auto p_in_type = tensor_lists[0][0].scalar_type(); + const auto g_type = tensor_lists[3][0].scalar_type(); + const auto p_out_type = tensor_lists[4][0].scalar_type(); + + float beta1_correction = 1.0f, beta2_correction = 1.0f; + if (bias_correction == 1) { + beta1_correction = 1 - std::pow(beta1, step); + beta2_correction = 1 - std::pow(beta2, step); } + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "dist_adam_cuda_kernel", + DISPATCH_FLOAT_HALF_AND_BFLOAT( + g_type, 1, "dist_adam_cuda_kernel", + DISPATCH_FLOAT_HALF_AND_BFLOAT( + p_out_type, 2, "dist_adam_cuda_kernel", + multi_tensor_apply<5>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + DistAdamFunctor(), + grad_scale.data_ptr(), beta1, beta2, beta1_correction, + beta2_correction, eps, lr, (adamMode_t)mode, + weight_decay);))); + C10_CUDA_CHECK(cudaGetLastError()); +} + +void multi_tensor_fused_adam_capturable_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, // p_in, m, v, g, p_out + at::Tensor grad_scale, at::Tensor lr, float beta1, float beta2, float eps, + at::Tensor step, int mode, int bias_correction, float weight_decay) { + using namespace at; + + // Expect p_in, m, v, g, p_out + size_t tl_sz = tensor_lists.size(); + TORCH_CHECK(tl_sz == 5, "expected tensor lists of size 5"); + const auto p_in_type = tensor_lists[0][0].scalar_type(); + const auto g_type = tensor_lists[3][0].scalar_type(); + const auto p_out_type = tensor_lists[4][0].scalar_type(); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "dist_adam_capturable_cuda_kernel", + DISPATCH_FLOAT_HALF_AND_BFLOAT( + g_type, 1, "dist_adam_capturable_cuda_kernel", + DISPATCH_FLOAT_HALF_AND_BFLOAT( + p_out_type, 2, "dist_adam_capturable_cuda_kernel", + multi_tensor_apply<5>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + DistAdamCapturableFunctor(), + grad_scale.data_ptr(), beta1, beta2, + step.data_ptr(), bias_correction, eps, + lr.data_ptr(), (adamMode_t)mode, weight_decay);))); C10_CUDA_CHECK(cudaGetLastError()); } + +void multi_tensor_fused_adam_with_param_remainders_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> + tensor_lists, // p_in, p_rem, m, v, g, p_out + at::Tensor grad_scale, float lr, float beta1, float beta2, float eps, + int step, int mode, int bias_correction, float weight_decay) { + using namespace at; + + // Expect p_in, p_rem, m, v, g, p_out + size_t tl_sz = tensor_lists.size(); + TORCH_CHECK(tl_sz == 6, "expected tensor lists of size 6"); + const auto g_type = tensor_lists[4][0].scalar_type(); + + float beta1_correction = 1.0f, beta2_correction = 1.0f; + if (bias_correction == 1) { + beta1_correction = 1 - std::pow(beta1, step); + beta2_correction = 1 - std::pow(beta2, step); + } + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + g_type, 0, "dist_adam_with_param_remainders_cuda_kernel", + multi_tensor_apply<6>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + DistAdamWithParamRemaindersFunctor(), + grad_scale.data_ptr(), beta1, beta2, + beta1_correction, beta2_correction, eps, lr, + (adamMode_t)mode, weight_decay);); + C10_CUDA_CHECK(cudaGetLastError()); +} \ No newline at end of file diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp index 584b2a0e7..b2431a13b 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp @@ -30,7 +30,7 @@ void multi_tensor_lamb_update_weights_cuda( PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term_cuda, - "Computes update term for LAMB optimizer"); + "Computes update term for LAMB optimizer", py::call_guard()); m.def("multi_tensor_lamb_update_weights", &multi_tensor_lamb_update_weights_cuda, - "Applies update term for LAMB optimizer"); + "Applies update term for LAMB optimizer", py::call_guard()); } diff --git a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu index 97cd84500..188900128 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu +++ b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu @@ -5,8 +5,15 @@ #include #include #include + +#ifdef USE_ROCM +#include +#include "rccl/rccl.h" +#else #include #include "nccl.h" +#endif + namespace cg = cooperative_groups; #define CUDACHECK(cmd) do { \ @@ -20,6 +27,13 @@ namespace cg = cooperative_groups; } \ } while(0) +// C++17 removes 'register' storage keyword +#if __cplusplus < 201703L +#define REGISTER register +#else +#define REGISTER +#endif + namespace { /* Basic deleter function for from_blob function. @@ -117,7 +131,20 @@ void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride } } -template +template +__device__ void __zero(T* dst) +{ + *dst = T(0); +} + +__device__ void __zero(int4* dst) +{ + int4 v; + v.x = v.y = v.z = v.w = 0; + *dst = v; +} + +template __device__ void strided_copy_kernel( T* dst, const int dst_stride_C, const int dst_stride_H, const int dst_stride_W, const T* src, const int src_stride_C, const int src_stride_H, const int src_stride_W, @@ -131,23 +158,28 @@ __device__ void strided_copy_kernel( { size_t c,h,w; if (is_HWC) { - c = i % NC; w = i / NC; + c = i - w * NC; h = w / NW; - w = w % NW; + w = w - h * NW; } else { - w = i % NW; h = i / NW; + w = i - h * NW; c = h / NH; - h = h % NH; + h = h - c * NH; } size_t dst_off = c*dst_stride_C + h*dst_stride_H + w*dst_stride_W; - size_t src_off = c*src_stride_C + h*src_stride_H + w*src_stride_W; - dst[dst_off] = src[src_off]; + if (zero) { + __zero(dst+dst_off); + } else { + size_t src_off = c*src_stride_C + h*src_stride_H + w*src_stride_W; + dst[dst_off] = src[src_off]; + } } } +template __device__ void checked_signal( volatile int* signal1_flag, volatile int* signal2_flag, const int v1, const int v2, const int v3, const int v4 @@ -159,30 +191,120 @@ __device__ void checked_signal( // flush all writes to global memory __threadfence_system(); // wait for top or bottom neighbor to clear signal - register int r1, r2, r3, r4; - bool top_zeroed=false, btm_zeroed=false, top_done=false, btm_done=false; - do { + REGISTER int r1, r2, r3, r4; + if (!(top_zero || btm_zero)) { + bool top_zeroed=false, top_done=false; + bool btm_zeroed=false, btm_done=false; do { - if (!top_zeroed) { - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory"); - if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true; + do { + if (!top_zeroed) { +#ifdef USE_ROCM + r1 = __builtin_nontemporal_load(signal1_flag); + r2 = __builtin_nontemporal_load(signal1_flag + 1); + r3 = __builtin_nontemporal_load(signal1_flag + 2); + r4 = __builtin_nontemporal_load(signal1_flag + 3); +#else + asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory"); +#endif + if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true; + } + if (!btm_zeroed) { +#ifdef USE_ROCM + r1 = __builtin_nontemporal_load(signal2_flag); + r2 = __builtin_nontemporal_load(signal2_flag + 1); + r3 = __builtin_nontemporal_load(signal2_flag + 2); + r4 = __builtin_nontemporal_load(signal2_flag + 3); +#else + asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory"); +#endif + if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true; + } + } while((top_zeroed == top_done) && (btm_zeroed == btm_done)); + if (!top_done && top_zeroed) { + // signal to top neighbor my output is ready +#ifdef USE_ROCM + __builtin_nontemporal_store(v1, signal1_flag); + __builtin_nontemporal_store(v2, signal1_flag + 1); + __builtin_nontemporal_store(v3, signal1_flag + 2); + __builtin_nontemporal_store(v4, signal1_flag + 3); +#else + asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); +#endif + top_done = true; } - if (!btm_zeroed) { - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory"); - if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true; + if (!btm_done && btm_zeroed) { + // signal to bottom neighbor my output is ready +#ifdef USE_ROCM + __builtin_nontemporal_store(v1, signal2_flag); + __builtin_nontemporal_store(v2, signal2_flag + 1); + __builtin_nontemporal_store(v3, signal2_flag + 2); + __builtin_nontemporal_store(v4, signal2_flag + 3); +#else + asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); +#endif + btm_done = true; } - } while((top_zeroed == top_done) && (btm_zeroed == btm_done)); - if (!top_done && top_zeroed) { - // signal to top neighbor my output is ready - asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); - top_done = true; - } - if (!btm_done && btm_zeroed) { - // signal to bottom neighbor my output is ready - asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); - btm_done = true; - } - } while (!top_done || !btm_done); + } while (!top_done || !btm_done); + } else if (top_zero) { + bool btm_zeroed=false, btm_done=false; + do { + do { + if (!btm_zeroed) { +#ifdef USE_ROCM + r1 = __builtin_nontemporal_load(signal2_flag); + r2 = __builtin_nontemporal_load(signal2_flag + 1); + r3 = __builtin_nontemporal_load(signal2_flag + 2); + r4 = __builtin_nontemporal_load(signal2_flag + 3); +#else + asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory"); +#endif + if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true; + } + } while(btm_zeroed == btm_done); + if (!btm_done && btm_zeroed) { + // signal to bottom neighbor my output is ready +#ifdef USE_ROCM + __builtin_nontemporal_store(v1, signal2_flag); + __builtin_nontemporal_store(v2, signal2_flag + 1); + __builtin_nontemporal_store(v3, signal2_flag + 2); + __builtin_nontemporal_store(v4, signal2_flag + 3); +#else + asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); +#endif + btm_done = true; + } + } while (!btm_done); + + } else if (btm_zero) { + bool top_zeroed=false, top_done=false; + do { + do { + if (!top_zeroed) { +#ifdef USE_ROCM + r1 = __builtin_nontemporal_load(signal1_flag); + r2 = __builtin_nontemporal_load(signal1_flag + 1); + r3 = __builtin_nontemporal_load(signal1_flag + 2); + r4 = __builtin_nontemporal_load(signal1_flag + 3); +#else + asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory"); +#endif + if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true; + } + } while(top_zeroed == top_done); + if (!top_done && top_zeroed) { + // signal to top neighbor my output is ready +#ifdef USE_ROCM + __builtin_nontemporal_store(v1, signal1_flag); + __builtin_nontemporal_store(v2, signal1_flag + 1); + __builtin_nontemporal_store(v3, signal1_flag + 2); + __builtin_nontemporal_store(v4, signal1_flag + 3); +#else + asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); +#endif + top_done = true; + } + } while (!top_done); + } } } @@ -193,10 +315,17 @@ __device__ void wait_for( { bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false; if (is_main_thread) { - register int r1, r2, r3, r4; + REGISTER int r1, r2, r3, r4; // wait for senders to signal their output is read do { +#ifdef USE_ROCM + r1 = __builtin_nontemporal_load(wait_flag); + r2 = __builtin_nontemporal_load(wait_flag + 1); + r3 = __builtin_nontemporal_load(wait_flag + 2); + r4 = __builtin_nontemporal_load(wait_flag + 3); +#else asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait_flag) : "memory"); +#endif } while (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4); } cg::this_grid().sync(); // all threads wait for main @@ -210,14 +339,21 @@ __device__ void clear_flag( cg::this_grid().sync(); // wait for all threads in kernel to finish bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false; if (is_main_thread) { - register int r1, r2, r3, r4; + REGISTER int r1, r2, r3, r4; r1 = 0; r2 = 0; r3 = 0; r4 = 0; +#ifdef USE_ROCM + __builtin_nontemporal_store(r1, wait_flag); + __builtin_nontemporal_store(r2, wait_flag + 1); + __builtin_nontemporal_store(r3, wait_flag + 2); + __builtin_nontemporal_store(r4, wait_flag + 3); +#else asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory"); +#endif } } -template -#if __CUDA_ARCH__ >= 700 +template +#if __CUDA_ARCH__ == 700 || __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 900 __launch_bounds__(128, 16) #endif __global__ void push_pull_halos_1d_kernel( @@ -241,20 +377,34 @@ __global__ void push_pull_halos_1d_kernel( ) { // push top output halo to transfer buffer - strided_copy_kernel(tox, tox_stride_C, tox_stride_H, tox_stride_W, toh, toh_stride_C, toh_stride_H, toh_stride_W, NC, NH, NW); + if (!top_zero) strided_copy_kernel(tox, tox_stride_C, tox_stride_H, tox_stride_W, toh, toh_stride_C, toh_stride_H, toh_stride_W, NC, NH, NW); // push btm output halo to transfer buffer - strided_copy_kernel(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW); + if (!btm_zero) strided_copy_kernel(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW); // signal to top and btm neigbhbors that output halos are ready to be read // the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values - checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358); + if (!(top_zero || btm_zero)) { + checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358); + } else if (top_zero) { + checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358); + } else if (btm_zero) { + checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358); + } // pull top halo from transfer buffer in peer memory to input - wait_for(wait1_flag, -987751720, 840868300, -225529332, 281513358); - strided_copy_kernel(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW); - clear_flag(wait1_flag); + if (top_zero) { + strided_copy_kernel(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW); + } else { + wait_for(wait1_flag, -987751720, 840868300, -225529332, 281513358); + strided_copy_kernel(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW); + clear_flag(wait1_flag); + } // pull btm halo from transfer buffer in peer memory to input - wait_for(wait2_flag, -987751720, 840868300, -225529332, 281513358); - strided_copy_kernel(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW); - clear_flag(wait2_flag); + if (btm_zero) { + strided_copy_kernel(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW); + } else { + wait_for(wait2_flag, -987751720, 840868300, -225529332, 281513358); + strided_copy_kernel(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW); + clear_flag(wait2_flag); + } } __global__ void delay_kernel(int delay_nanoseconds, int* counter) @@ -343,10 +493,12 @@ void push_pull_halos_1d( bool diagnostics, bool explicit_nhwc, int numSM, // number of SMs to use + bool top_zero, // true if top halo should be zeroed at::Tensor top_out_halo, // top output halo in sender device memory at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory at::Tensor top_inp_halo, // top input halo in receiver device memory + bool btm_zero, // true if btm halo should be zeroed at::Tensor btm_out_halo, // btm output halo in sender device memory at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory @@ -368,6 +520,7 @@ void push_pull_halos_1d( TORCH_CHECK(top_signal.is_cuda()); TORCH_CHECK(btm_signal.is_cuda()); TORCH_CHECK(waits.is_cuda()); + TORCH_CHECK(!(top_zero && btm_zero)); // shapes and strides int toh_N, toh_C, toh_H, toh_W; @@ -492,10 +645,34 @@ void push_pull_halos_1d( &NC, &NH, &NW, &top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p }; - int numBlocksPerSm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); - dim3 grid(numSM*numBlocksPerSm,1,1); - cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); + if (top_zero) { + int numBlocksPerSm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); + dim3 grid(numSM*numBlocksPerSm,1,1); +#ifdef USE_ROCM + hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#else + cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#endif + } else if (btm_zero) { + int numBlocksPerSm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); + dim3 grid(numSM*numBlocksPerSm,1,1); +#ifdef USE_ROCM + hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#else + cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#endif + } else { + int numBlocksPerSm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); + dim3 grid(numSM*numBlocksPerSm,1,1); +#ifdef USE_ROCM + hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#else + cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#endif + } } else { // cannot do int4 transfers if (diagnostics) printf("CAN NOT DO INT4\n"); @@ -513,13 +690,57 @@ void push_pull_halos_1d( }; int numBlocksPerSm; if (is_nhwc) { - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); - dim3 grid(numSM*numBlocksPerSm,1,1); - cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); + if (top_zero) { + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); + dim3 grid(numSM*numBlocksPerSm,1,1); +#ifdef USE_ROCM + hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#else + cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#endif + } else if (btm_zero) { + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); + dim3 grid(numSM*numBlocksPerSm,1,1); +#ifdef USE_ROCM + hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#else + cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#endif + } else { + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); + dim3 grid(numSM*numBlocksPerSm,1,1); +#ifdef USE_ROCM + hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#else + cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#endif + } } else { - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); - dim3 grid(numSM*numBlocksPerSm,1,1); - cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); + if (top_zero) { + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); + dim3 grid(numSM*numBlocksPerSm,1,1); +#ifdef USE_ROCM + hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#else + cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#endif + } else if (btm_zero) { + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); + dim3 grid(numSM*numBlocksPerSm,1,1); +#ifdef USE_ROCM + hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#else + cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#endif + } else { + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); + dim3 grid(numSM*numBlocksPerSm,1,1); +#ifdef USE_ROCM + hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#else + cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); +#endif + } } } } ); diff --git a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh index 5c79af90e..4f0169f3d 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh +++ b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh @@ -32,10 +32,12 @@ namespace apex { namespace contrib { namespace peer_memory { bool diagnostics, bool explicit_nhwc, int numSM, // number of SMs to use + bool top_zero, // true if top halo should be zeroed at::Tensor top_out_halo, // top output halo in sender device memory at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory at::Tensor top_inp_halo, // top input halo in receiver device memory + bool btm_zero, // true if btm halo should be zeroed at::Tensor btm_out_halo, // btm output halo in sender device memory at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory diff --git a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu index 1e6a465de..7c1c7c291 100755 --- a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu @@ -1,3 +1,4 @@ +#include #include #include #include @@ -17,12 +18,18 @@ #include "philox.cuh" +#ifdef USE_ROCM +#define SHFL_DOWN(val, laneMask, width) __shfl_down(val, laneMask, width) +#else +#define SHFL_DOWN(val, laneMask, width) __shfl_down_sync(0xffffffff, val, laneMask, width) +#endif + // Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width. // width should be a power of 2 and should be less than warpSize. template __device__ __forceinline__ scalar_t warpReduce(scalar_t x, int width=C10_WARP_SIZE){ for (unsigned offset = width/2; offset > 0; offset /= 2){ - x += __shfl_down_sync(0xffffffff, x, offset, width); + x += SHFL_DOWN(x, offset, width); } return x; } @@ -723,8 +730,8 @@ std::vector transducer_joint_cuda_forward( TORCH_CHECK(opt == 0 or opt == 1, "Got an invalid optimization level ", opt); // Simple heuristics - const int numThread = std::min(128, (static_cast(hiddenSize)+C10_WARP_SIZE-1) - / C10_WARP_SIZE * C10_WARP_SIZE); + const int numThread = std::min(128, (static_cast(hiddenSize)+at::cuda::warp_size()-1) + / at::cuda::warp_size() * at::cuda::warp_size()); if (opt == 0){ // vanilla kernel @@ -856,7 +863,7 @@ std::vector transducer_joint_cuda_backward( const int hiddenSize = grad.size(-1); const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); - const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE; + const int maxNumWarp = deviceProperties->maxThreadsPerBlock / at::cuda::warp_size(); torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt); torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt); @@ -864,7 +871,7 @@ std::vector transducer_joint_cuda_backward( int64_t *batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr(); // The number "y" I would like each thread to work on - const int workPerThread = 32; + const int workPerThread = 32; // Since the bwd for f and g have the same thread block size, we need to use the max of the two. int numWarp = largestPowerOfTwo((std::max(maxFLen, maxGLen) + workPerThread-1) / workPerThread); // Would like to have at least 2 warps @@ -874,8 +881,8 @@ std::vector transducer_joint_cuda_backward( // Need smem for transposing the partial sum. The partial sum is in a matrix of the shape // numWarp x warpSize - const int smemSize = numWarp * C10_WARP_SIZE; - const dim3 threads(C10_WARP_SIZE, numWarp, 1); + const int smemSize = numWarp * at::cuda::warp_size(); + const dim3 threads(at::cuda::warp_size(), numWarp, 1); AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_cuda_backward_kernel", ([&] { auto gradPtr = grad.data_ptr(); @@ -899,7 +906,7 @@ std::vector transducer_joint_cuda_backward( if (vectFactor > 1 and hiddenSize%vectFactor == 0 and memAlign){ // If vectorization helps and the alignment requirement is met, use the vectorized // kernel. For simplicity, hiddenSize needs to be a multiple vecFactor. - const dim3 blocks( (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor), + const dim3 blocks( (hiddenSize+at::cuda::warp_size()*vectFactor-1)/(at::cuda::warp_size()*vectFactor), maxFLen+maxGLen, batchSize); if (masked){ @@ -938,7 +945,7 @@ std::vector transducer_joint_cuda_backward( } } else{ - const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE, + const dim3 blocks((hiddenSize+at::cuda::warp_size()-1)/at::cuda::warp_size(), maxFLen + maxGLen, batchSize); if (masked){ transducer_joint_combined_backward diff --git a/apex/contrib/csrc/transducer/transducer_loss.cpp b/apex/contrib/csrc/transducer/transducer_loss.cpp index f63a67f1e..91c956239 100644 --- a/apex/contrib/csrc/transducer/transducer_loss.cpp +++ b/apex/contrib/csrc/transducer/transducer_loss.cpp @@ -104,6 +104,6 @@ torch::Tensor transducer_loss_backward( } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)"); - m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)"); + m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)", py::call_guard()); + m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)", py::call_guard()); } diff --git a/apex/contrib/csrc/xentropy/xentropy_kernel.cu b/apex/contrib/csrc/xentropy/xentropy_kernel.cu index 7f4190dd1..4c9f1c4ed 100644 --- a/apex/contrib/csrc/xentropy/xentropy_kernel.cu +++ b/apex/contrib/csrc/xentropy/xentropy_kernel.cu @@ -72,6 +72,7 @@ */ #include #include +#include #include #include @@ -81,6 +82,12 @@ #define ALIGN_BYTES 16 +#ifdef USE_ROCM +#define SYNCWARP(mask) +#else +#define SYNCWARP(mask) __syncwarp(mask) +#endif + using Tensor = at::Tensor; using TensorList = at::TensorList; using ScalarType = at::ScalarType; @@ -122,7 +129,7 @@ inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { uint64_t max_block_size = std::min(dim_size / ILP, static_cast(max_threads)); while (block_size < (max_block_size/2)) block_size *= 2; // Launch at least a single warp - the kernel assumes that. - block_size = std::max(block_size, static_cast(32)); + block_size = std::max(block_size, static_cast(at::cuda::warp_size())); return dim3(block_size); } @@ -191,15 +198,15 @@ blockReduce(AccumT* smem, AccumT val, AccumT warpVal = defaultVal; // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; - if (threadIdx.x < 32) { - int lane = threadIdx.x % 32; - if (lane < blockDim.x / 32) { + uint32_t mask = (((uint64_t)1) << (blockDim.x / C10_WARP_SIZE)) - 1; + if (threadIdx.x < C10_WARP_SIZE) { + int lane = threadIdx.x % C10_WARP_SIZE; + if (lane < blockDim.x / C10_WARP_SIZE) { #pragma unroll - for (int i = 0; i < 32; ++i) { - warpVal = r(warpVal, smem[lane * 32 + i]); + for (int i = 0; i < C10_WARP_SIZE; ++i) { + warpVal = r(warpVal, smem[lane * C10_WARP_SIZE + i]); } - __syncwarp(mask); + SYNCWARP(mask); smem[lane] = warpVal; } } @@ -210,7 +217,7 @@ blockReduce(AccumT* smem, AccumT val, AccumT blockVal = defaultVal; if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / 32; ++i) { + for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) { blockVal = r(blockVal, smem[i]); } smem[0] = blockVal; @@ -245,16 +252,16 @@ blockReduce(AccumT* smem, AccumT warpVal2 = defaultVal2; // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; - if (threadIdx.x < 32) { - int lane = threadIdx.x % 32; - if (lane < blockDim.x / 32) { + uint32_t mask = (((uint64_t)1) << (blockDim.x / C10_WARP_SIZE)) - 1; + if (threadIdx.x < C10_WARP_SIZE) { + int lane = threadIdx.x % C10_WARP_SIZE; + if (lane < blockDim.x / C10_WARP_SIZE) { #pragma unroll - for (int i = 0; i < 32; ++i) { - warpVal1 = r1(warpVal1, smem[lane * 32 + i]); - warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]); + for (int i = 0; i < C10_WARP_SIZE; ++i) { + warpVal1 = r1(warpVal1, smem[lane * C10_WARP_SIZE + i]); + warpVal2 = r2(warpVal2, smem[lane * C10_WARP_SIZE + i + blockDim.x]); } - __syncwarp(mask); + SYNCWARP(mask); smem[lane] = warpVal1; smem[lane + blockDim.x] = warpVal2; } @@ -267,7 +274,7 @@ blockReduce(AccumT* smem, AccumT blockVal2 = defaultVal2; if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / 32; ++i) { + for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) { blockVal1 = r1(blockVal1, smem[i]); blockVal2 = r2(blockVal2, smem[i + blockDim.x]); } @@ -574,7 +581,7 @@ std::vector host_softmax_xentropy( const Tensor & labels_, const float smoothing, const bool half_to_float){ - if (half_to_float) AT_ASSERTM(input_.type().scalarType() == ScalarType::Half,"conversion is supported for Half type only"); + if (half_to_float) AT_ASSERTM(input_.type().scalarType() == ScalarType::Half || input_.type().scalarType() == ScalarType::BFloat16,"conversion is supported for Half and BFloat16 type only"); AT_ASSERTM(labels_.type().scalarType() == ScalarType::Long,"Label type should be CUDA Long"); auto input = input_.contiguous(); @@ -605,7 +612,7 @@ std::vector host_softmax_xentropy( dim3 grid(outer_size); using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "host_softmax_xentropy", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(input.scalar_type(), 0, "host_softmax_xentropy", using accscalar_t = at::acc_type; const int ILP = sizeof(float4)/sizeof(scalar_t_0); dim3 block = SoftMax_getBlockSize(ILP, dim_size); @@ -673,7 +680,7 @@ Tensor host_softmax_xentropy_backward( dim3 grid(outer_size); - DISPATCH_FLOAT_AND_HALF(gI.scalar_type(), 0, "host_softmax_xentropy_backward", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(gI.scalar_type(), 0, "host_softmax_xentropy_backward", using accscalar_t = acc_type; const int ILP = sizeof(float4)/sizeof(scalar_t_0); dim3 block = SoftMax_getBlockSize(ILP, dim_size); @@ -712,7 +719,7 @@ at::Tensor softmax_xentropy_backward_cuda( const float smoothing) { bool half_to_float = grad_loss.type().scalarType() != logits.type().scalarType(); if (half_to_float) { - AT_ASSERTM((grad_loss.type().scalarType() == ScalarType::Float && logits.type().scalarType() == ScalarType::Half), "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); + AT_ASSERTM((grad_loss.type().scalarType() == ScalarType::Float && (logits.type().scalarType() == ScalarType::Half || logits.type().scalarType() == ScalarType::BFloat16)), "expected input and grad types to match, or input to be at::Half or at::Bfloat16 and grad to be at::Float"); } return host_softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, smoothing, half_to_float); } diff --git a/apex/contrib/groupbn/batch_norm.py b/apex/contrib/groupbn/batch_norm.py index 17ef196b9..af0b7e9b2 100644 --- a/apex/contrib/groupbn/batch_norm.py +++ b/apex/contrib/groupbn/batch_norm.py @@ -4,11 +4,30 @@ import bnp +def check_if_rocm_pytorch(): + is_rocm_pytorch = False + if torch.__version__ >= '1.5': + from torch.utils.cpp_extension import ROCM_HOME + is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False + + return is_rocm_pytorch + +IS_ROCM_PYTORCH = check_if_rocm_pytorch() + +def check_and_convert_channels_last(tensor, torch_channels_last): + if torch_channels_last: + channels_last = tensor.is_contiguous(memory_format = torch.channels_last) + if not channels_last: + tensor = tensor.to(memory_format = torch.channels_last) + return tensor + class bn_NHWC_impl(torch.autograd.Function): @staticmethod - def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream): + def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, is_train, torch_channels_last, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream): + x = check_and_convert_channels_last(x, torch_channels_last) if is_train: ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv) + ctx.torch_channels_last = torch_channels_last ctx.epsilon = epsilon ctx.momentum = mom ctx.ret_cta = ret_cta @@ -31,6 +50,8 @@ def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse @staticmethod def backward(ctx, grad_y): x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables + grad_y = check_and_convert_channels_last(grad_y, ctx.torch_channels_last) + x = check_and_convert_channels_last(x, ctx.torch_channels_last) epsilon = ctx.epsilon mom = ctx.momentum ret_cta = ctx.ret_cta @@ -47,15 +68,26 @@ def backward(ctx, grad_y): dx, dscale, dbias = bnp.bn_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream) - return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None class bn_addrelu_NHWC_impl(torch.autograd.Function): @staticmethod - def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom, epsilon, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream): + def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom, epsilon, is_train, torch_channels_last, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream): + x = check_and_convert_channels_last(x, torch_channels_last) + z = check_and_convert_channels_last(z, torch_channels_last) if is_train: - bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y) + if IS_ROCM_PYTORCH: + if torch_channels_last: + nhw = x.shape[0] * x.shape[2] * x.shape[3] + else: + nhw = x.shape[0] * x.shape[1] * x.shape[2] + shape = int(((nhw + 3) & ~3) * 2 * grid_dim_y) + bitmask = torch.cuda.LongTensor(shape) + else: + bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y) ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask) + ctx.torch_channels_last = torch_channels_last ctx.epsilon = epsilon ctx.momentum = mom ctx.ret_cta = ret_cta @@ -77,6 +109,8 @@ def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom @staticmethod def backward(ctx, grad_y): x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables + grad_y = check_and_convert_channels_last(grad_y, ctx.torch_channels_last) + x = check_and_convert_channels_last(x, ctx.torch_channels_last) epsilon = ctx.epsilon mom = ctx.momentum ret_cta = ctx.ret_cta @@ -92,7 +126,7 @@ def backward(ctx, grad_y): dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream) - return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None @@ -100,10 +134,11 @@ def backward(ctx, grad_y): class BatchNorm2d_NHWC(_BatchNorm): # if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True - def __init__(self, num_features, fuse_relu=False, bn_group=1, max_cta_per_sm=2, cta_launch_margin=12, multi_stream=False): + def __init__(self, num_features, fuse_relu=False, bn_group=1, torch_channels_last=False,max_cta_per_sm=2, cta_launch_margin=12, multi_stream=False): super(BatchNorm2d_NHWC, self).__init__(num_features) self.fuse_relu = fuse_relu + self.torch_channels_last = torch_channels_last self.multi_stream = multi_stream self.minibatch_mean = torch.cuda.FloatTensor(num_features) @@ -201,7 +236,7 @@ def forward(self, x, z=None): self.running_mean, self.running_var, self.minibatch_mean, self.minibatch_riv, self.grid_dim_y, self.ret_cta, self.momentum, - self.eps, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3, + self.eps, self.training, self.torch_channels_last, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3, self.addrelu_fwd_occupancy, self.addrelu_fwd_grid_dim_x, self.addrelu_bwd_occupancy, self.addrelu_bwd_grid_dim_x, self.multi_stream) @@ -211,7 +246,7 @@ def forward(self, x, z=None): self.running_mean, self.running_var, self.minibatch_mean, self.minibatch_riv, self.ret_cta, self.momentum, - self.eps, self.fuse_relu, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3, + self.eps, self.fuse_relu, self.training, self.torch_channels_last, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3, self.fwd_occupancy, self.fwd_grid_dim_x, self.bwd_occupancy, self.bwd_grid_dim_x, self.multi_stream) diff --git a/apex/contrib/index_mul_2d/__init__.py b/apex/contrib/index_mul_2d/__init__.py new file mode 100644 index 000000000..edb63d397 --- /dev/null +++ b/apex/contrib/index_mul_2d/__init__.py @@ -0,0 +1 @@ +from .index_mul_2d import index_mul_2d diff --git a/apex/contrib/index_mul_2d/index_mul_2d.py b/apex/contrib/index_mul_2d/index_mul_2d.py new file mode 100644 index 000000000..1d34fe20c --- /dev/null +++ b/apex/contrib/index_mul_2d/index_mul_2d.py @@ -0,0 +1,144 @@ +import torch + +import fused_index_mul_2d + +class IndexMul2d_(torch.autograd.Function): + ''' + Currently only support index in dimension 0 with a 2-dimension tensor. + The shape of indexed in1 must be same with in2. Now this kernel does not support broadcast. + The datatype must be float32 or float16. + ''' + @staticmethod + def forward(ctx, in1: torch.Tensor, in2: torch.Tensor, idx1: torch.Tensor) -> torch.Tensor: + assert in2.size(0) == idx1.size(0) + if ((in1.dtype != torch.float32 and in1.dtype != torch.half) or in2.dtype != in1.dtype): + raise RuntimeError("input1'dtype and input2's dtype must be fp32 or fp16. And input type must be same") + if (in1.dim() != 2 or in2.dim() != 2): + raise RuntimeError("in1 and in2 must be 2-dimension tensor.") + if (idx1.dim() != 1): + raise RuntimeError("idx1 must be 1-dimension tensor.") + + if not in1.is_contiguous(): + in1 = in1.contiguous() + if not in2.is_contiguous(): + in2 = in2.contiguous() + if not idx1.is_contiguous(): + idx1 = idx1.contiguous() + + assert in1.is_contiguous() + assert in2.is_contiguous() + assert idx1.is_contiguous() + + out = torch.empty_like(in2) + + if (in1.dtype == torch.float32): + fused_index_mul_2d.float_forward( + out, + in1, + in2, + idx1) + elif (in1.dtype == torch.half): + fused_index_mul_2d.half_forward( + out, + in1, + in2, + idx1) + + ctx.for_backwards = (in1, in2, idx1) + return out + + @staticmethod + def backward(ctx, grad_out): + + in1, in2, idx1 = ctx.for_backwards + + grad_in1, grad_in2 = index_mul_2d_backward(in1, in2, idx1, grad_out) + + return grad_in1, grad_in2, None + + +class IndexMul2dBackward_(torch.autograd.Function): + @staticmethod + def forward(ctx, in1: torch.Tensor, in2: torch.Tensor, idx1: torch.Tensor, + grad_out: torch.Tensor) -> torch.Tensor: + if not in1.is_contiguous(): + in1 = in1.contiguous() + if not in2.is_contiguous(): + in2 = in2.contiguous() + if not idx1.is_contiguous(): + idx1 = idx1.contiguous() + if not grad_out.is_contiguous(): + grad_out = grad_out.contiguous() + + assert in1.is_contiguous() + assert in2.is_contiguous() + assert idx1.is_contiguous() + assert grad_out.is_contiguous() + + grad_in1 = torch.zeros_like(in1) + grad_in2 = torch.empty_like(in2) + + if (in1.dtype == torch.float32): + fused_index_mul_2d.float_backward( + grad_in1, + grad_in2, + grad_out, + in1, + in2, + idx1) + elif (in1.dtype == torch.half): + fused_index_mul_2d.half_backward( + grad_in1, + grad_in2, + grad_out, + in1, + in2, + idx1) + + ctx.for_backwards = (in1, in2, idx1, grad_out) + return grad_in1, grad_in2 + + @staticmethod + def backward(ctx, grad_grad_in1, grad_grad_in2): + if not grad_grad_in1.is_contiguous(): + grad_grad_in1 = grad_grad_in1.contiguous() + if not grad_grad_in2.is_contiguous(): + grad_grad_in2 = grad_grad_in2.contiguous() + + assert grad_grad_in1.is_contiguous() + assert grad_grad_in2.is_contiguous() + + in1, in2, idx1, grad_out = ctx.for_backwards + + grad_in1 = torch.zeros_like(in1) + grad_in2 = torch.empty_like(in2) + grad_grad_out = torch.empty_like(grad_out) + + if (in1.dtype == torch.float32): + fused_index_mul_2d.float_backward_backward( + grad_grad_out, + grad_in1, + grad_in2, + grad_out, + grad_grad_in1, + grad_grad_in2, + in1, + in2, + idx1) + elif (in1.dtype == torch.half): + fused_index_mul_2d.half_backward_backward( + grad_grad_out, + grad_in1, + grad_in2, + grad_out, + grad_grad_in1, + grad_grad_in2, + in1, + in2, + idx1) + + return grad_in1, grad_in2, None, grad_grad_out + +index_mul_2d = IndexMul2d_.apply +index_mul_2d_backward = IndexMul2dBackward_.apply + diff --git a/apex/contrib/layer_norm/layer_norm.py b/apex/contrib/layer_norm/layer_norm.py index 8a8d26d43..b084b1ace 100644 --- a/apex/contrib/layer_norm/layer_norm.py +++ b/apex/contrib/layer_norm/layer_norm.py @@ -41,8 +41,8 @@ class FastLayerNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-5): super().__init__() self.epsilon = eps - self.weight = torch.nn.Parameter(torch.Tensor(hidden_size)) - self.bias = torch.nn.Parameter(torch.Tensor(hidden_size)) + self.weight = torch.nn.Parameter(torch.empty(hidden_size)) + self.bias = torch.nn.Parameter(torch.empty(hidden_size)) self.reset_parameters() def reset_parameters(self): diff --git a/apex/contrib/multihead_attn/encdec_multihead_attn.py b/apex/contrib/multihead_attn/encdec_multihead_attn.py index 1a0deb729..a8691026d 100644 --- a/apex/contrib/multihead_attn/encdec_multihead_attn.py +++ b/apex/contrib/multihead_attn/encdec_multihead_attn.py @@ -10,6 +10,7 @@ from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func from apex.normalization.fused_layer_norm import FusedLayerNorm + @torch.jit.script def jit_dropout_add(x, residual, prob, is_training): # type: (Tensor, Tensor, float, bool) -> Tensor @@ -36,14 +37,14 @@ def __init__(self, embed_dim, num_heads, dropout=0.0, bias=False, include_norm_a self.impl = impl self.scaling = self.head_dim ** -0.5 - self.in_proj_weight_q = Parameter(torch.Tensor(embed_dim, embed_dim)) - self.in_proj_weight_kv = Parameter(torch.Tensor(2 * embed_dim, embed_dim)) - self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + self.in_proj_weight_q = Parameter(torch.empty(embed_dim, embed_dim)) + self.in_proj_weight_kv = Parameter(torch.empty(2 * embed_dim, embed_dim)) + self.out_proj_weight = Parameter(torch.empty(embed_dim, embed_dim)) if self.bias: assert impl != "fast", "ERROR! The Fast implementation does not support biases!" - self.in_proj_bias_q = Parameter(torch.Tensor(embed_dim)) - self.in_proj_bias_kv = Parameter(torch.Tensor(2 * embed_dim)) - self.out_proj_bias = Parameter(torch.Tensor(embed_dim)) + self.in_proj_bias_q = Parameter(torch.empty(embed_dim)) + self.in_proj_bias_kv = Parameter(torch.empty(2 * embed_dim)) + self.out_proj_bias = Parameter(torch.empty(embed_dim)) else: self.register_parameter("in_proj_bias_q", None) self.register_parameter("in_proj_bias_kv", None) @@ -52,8 +53,8 @@ def __init__(self, embed_dim, num_heads, dropout=0.0, bias=False, include_norm_a self.out_proj_bias = None if self.include_norm_add: if impl == "fast": - self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim)) - self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim)) + self.lyr_nrm_gamma_weights = Parameter(torch.empty(embed_dim)) + self.lyr_nrm_beta_weights = Parameter(torch.empty(embed_dim)) self.lyr_nrm = None else: self.register_parameter("lyr_norm_gamma_weights", None) diff --git a/apex/contrib/multihead_attn/encdec_multihead_attn_func.py b/apex/contrib/multihead_attn/encdec_multihead_attn_func.py index cef255ba8..5710e87dd 100644 --- a/apex/contrib/multihead_attn/encdec_multihead_attn_func.py +++ b/apex/contrib/multihead_attn/encdec_multihead_attn_func.py @@ -263,7 +263,8 @@ def backward(ctx, output_grads): dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) # Softmax Grad (not a publically documented op) - softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results.dtype) + ### softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) # og + softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, torch.float32, grad_input=softmax_results) # Matmul1 - DGRAD1 # Input1: (data grads) [seqs*heads, seql_q, seql_k] diff --git a/apex/contrib/multihead_attn/self_multihead_attn.py b/apex/contrib/multihead_attn/self_multihead_attn.py index ceee38c51..2806c4dde 100644 --- a/apex/contrib/multihead_attn/self_multihead_attn.py +++ b/apex/contrib/multihead_attn/self_multihead_attn.py @@ -10,6 +10,7 @@ from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func from apex.normalization.fused_layer_norm import FusedLayerNorm + @torch.jit.script def jit_dropout_add(x, residual, prob, is_training): # type: (Tensor, Tensor, float, bool) -> Tensor @@ -53,20 +54,20 @@ def __init__( impl == "fast" and bias ), "additive mask not supported for fast mode without bias" if separate_qkv_params: - self.q_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) - self.k_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) - self.v_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + self.q_weight = Parameter(torch.empty(embed_dim, embed_dim)) + self.k_weight = Parameter(torch.empty(embed_dim, embed_dim)) + self.v_weight = Parameter(torch.empty(embed_dim, embed_dim)) else: - self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) - self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) + self.out_proj_weight = Parameter(torch.empty(embed_dim, embed_dim)) if self.bias: if separate_qkv_params: - self.q_bias = Parameter(torch.Tensor(embed_dim)) - self.k_bias = Parameter(torch.Tensor(embed_dim)) - self.v_bias = Parameter(torch.Tensor(embed_dim)) + self.q_bias = Parameter(torch.empty(embed_dim)) + self.k_bias = Parameter(torch.empty(embed_dim)) + self.v_bias = Parameter(torch.empty(embed_dim)) else: - self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) - self.out_proj_bias = Parameter(torch.Tensor(embed_dim)) + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) + self.out_proj_bias = Parameter(torch.empty(embed_dim)) else: if separate_qkv_params: self.register_parameter("q_bias", None) @@ -82,8 +83,8 @@ def __init__( self.out_proj_bias = None if self.include_norm_add: if impl == "fast": - self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim)) - self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim)) + self.lyr_nrm_gamma_weights = Parameter(torch.empty(embed_dim)) + self.lyr_nrm_beta_weights = Parameter(torch.empty(embed_dim)) self.lyr_nrm = None else: self.register_parameter("lyr_norm_gamma_weights", None) diff --git a/apex/contrib/multihead_attn/self_multihead_attn_func.py b/apex/contrib/multihead_attn/self_multihead_attn_func.py index c27a7203c..f26e70439 100644 --- a/apex/contrib/multihead_attn/self_multihead_attn_func.py +++ b/apex/contrib/multihead_attn/self_multihead_attn_func.py @@ -236,7 +236,8 @@ def backward(ctx, output_grads): dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) # Softmax Grad (not a publically documented op) - softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results.dtype) + ### softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) # og + softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, torch.float32, grad_input=softmax_results) # Matmul1 - DGRAD1 # Input1: (data grads) [seqs*heads, seql_q, seql_k] @@ -301,7 +302,6 @@ def backward(ctx, output_grads): output_bias_grads, None, None, - None, ) diff --git a/apex/contrib/nccl_allocator/__init__.py b/apex/contrib/nccl_allocator/__init__.py new file mode 100644 index 000000000..7a460dc69 --- /dev/null +++ b/apex/contrib/nccl_allocator/__init__.py @@ -0,0 +1 @@ +from .nccl_allocator import * \ No newline at end of file diff --git a/apex/contrib/nccl_allocator/nccl_allocator.py b/apex/contrib/nccl_allocator/nccl_allocator.py new file mode 100644 index 000000000..62fcee756 --- /dev/null +++ b/apex/contrib/nccl_allocator/nccl_allocator.py @@ -0,0 +1,63 @@ +import os +import torch +import _apex_nccl_allocator + +from contextlib import nullcontext + + +__all__ = ["init", "nccl_mem", "create_nccl_mem_pool"] + + +def create_nccl_mem_pool(): + _allocator = _apex_nccl_allocator.get_nccl_allocator() + _pool = torch.cuda.MemPool(_allocator) + return _pool + + +def init() -> None: + os.environ["NCCL_NVLS_ENABLE"] = "1" + os.environ["TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"] = "0" + + +class nccl_mem: + def __init__(self, pool, enabled = True, device = None, group = None): + self.device = None + self.group = None + self.mem_context = None + self.pool = pool + + if enabled: + if device is None: + self.device = torch.device("cuda", torch.cuda.current_device()) + elif isinstance(device, int): + self.device = torch.device("cuda", device) + elif isinstance(device, str): + assert "cuda" in device, "only cuda devices are supported" + self.device = torch.device(device) + + if group is None: + self.group = torch.distributed.distributed_c10d._get_default_group() + else: + self.group = group + + self.mem_context = torch.cuda.use_mem_pool(self.pool) + else: + self.mem_context = nullcontext() + + def __enter__(self): + self.mem_context.__enter__() + if self.group is not None: + backend = self.group._get_backend(self.device) + try: + backend.deregister_mem_pool(self.pool) + except RuntimeError: + pass + + def __exit__(self, *args): + if self.group is not None: + backend = self.group._get_backend(self.device) + try: + backend.register_mem_pool(self.pool) + except RuntimeError: + pass + self.mem_context.__exit__(*args) \ No newline at end of file diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 550068022..65da11218 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -1,26 +1,280 @@ import collections import contextlib +from dataclasses import dataclass import enum -import importlib import inspect import io -import math +import itertools import threading +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) +import warnings import torch -import amp_C +from torch.distributed.distributed_c10d import _get_default_group + +try: + import apex.contrib.nccl_allocator as nccl_allocator +except ImportError: + nccl_allocator = None + from apex.multi_tensor_apply import multi_tensor_applier -from torch.distributed.distributed_c10d import _get_default_group, _get_global_rank +import amp_C +import distributed_adam_cuda + +# Fallback to private functions if using PyTorch <1.13.0 +try: + from torch.distributed.distributed_c10d import get_global_rank +except ImportError: + from torch.distributed.distributed_c10d import _get_global_rank + + get_global_rank = _get_global_rank +try: + from torch.distributed.distributed_c10d import reduce_scatter_tensor +except ImportError: + from torch.distributed.distributed_c10d import _reduce_scatter_base + + reduce_scatter_tensor = _reduce_scatter_base +try: + from torch.distributed.distributed_c10d import all_gather_into_tensor +except ImportError: + from torch.distributed.distributed_c10d import _all_gather_base + + all_gather_into_tensor = _all_gather_base + +# Import context manager to coalesce NCCL calls +# Note: Replace these backward compatibility shims once PyTorch +# exposes a stable public API for coalescing communication. +from torch.distributed.distributed_c10d import _coalescing_manager + +if "device" not in inspect.signature(_coalescing_manager).parameters: + # PyTorch <=1.13.1 does not have device arg + _coalescing_manager_no_device_arg = _coalescing_manager + + @contextlib.contextmanager + def _coalescing_manager(group, device, reqs): + with _coalescing_manager_no_device_arg(group, reqs): + yield + + +if "reqs" in inspect.signature(_coalescing_manager).parameters: + # PyTorch <=2.0.1 handles synchronization externally to coalescing + # manager + _coalescing_manager_with_reqs_arg = _coalescing_manager + + class _CoalescingManager: + def __init__(self): + self.works: List[torch.distributed.Work] = [] + + def append(self, work: torch.distributed.Work) -> None: + if work: + self.works.append(work) + + def wait(self) -> None: + for work in self.works: + work.wait() + + @contextlib.contextmanager + def _coalescing_manager( + group: Optional[torch.distributed.ProcessGroup] = None, + device: Optional[torch.device] = None, + async_ops: bool = False, + ) -> contextlib.AbstractContextManager: + assert device is not None + cm = _CoalescingManager() + with _coalescing_manager_with_reqs_arg( + group, + device, + cm.works, + ): + yield cm + if not async_ops: + cm.wait() + + def _coalescing_manager_append_work( + cm: _CoalescingManager, + work: torch.distributed.Work, + ) -> None: + """Add asynchronous request to coalescing manager""" + cm.append(work) + +else: + # PyTorch >2.0.1 handles synchronization within coalescing + # manager + def _coalescing_manager_append_work( + cm: torch.distributed._CoalescingManager, + work: torch.distributed.Work, + ) -> None: + """Dummy function for backward compatibility + + Coalescing manager already keeps track of asynchronous + communication. + + """ + pass + + +# Import optional CUDA kernels +_FOUND_DEPRECATED_FUSED_ADAM: bool = False +try: + import fused_adam_cuda + + _FOUND_DEPRECATED_FUSED_ADAM = True +except ImportError: + warnings.warn( + "Could not find recommended CUDA kernels when importing " + "`DistributedFusedAdam`. " + "For best performance, Apex should be installed with " + "`--deprecated_fused_adam`." + ) -def _round_to_multiple(number, multiple, round_up=True): + +def _round_to_multiple( + number: int, + multiple: int, + round_up: bool = True, +) -> int: """Assumes arguments are positive integers""" - return (number+multiple-1 if round_up else number) // multiple * multiple + return (number + multiple - 1 if round_up else number) // multiple * multiple + + +def _devices_match(device1: torch.device, device2: torch.device) -> bool: + """Whether two PyTorch devices are equivalent""" + device1 = torch.device(device1) + device2 = torch.device(device2) + if device1.type != device2.type: + return False + if device1.type == "cuda": + index1 = device1.index + index2 = device2.index + if index1 is None: + index1 = torch.cuda.current_device() + if index2 is None: + index2 = torch.cuda.current_device() + if index1 != index2: + return False + return True + + +def _multi_tensor_copy( + buffers_in: List[torch.Tensor], + buffers_out: List[torch.Tensor], + dummy_overflow_buf: Optional[torch.Tensor] = None, +) -> None: + """Copy between corresponding buffers + + Uses fused copy kernel if possible. + """ + + # Group buffers by device and dtype + buffer_groups = collections.defaultdict(list) + for buf_in, buf_out in zip(buffers_in, buffers_out): + if buf_in.data_ptr() == buf_out.data_ptr() or buf_in.numel() == 0: + # Nothing to be done if input and output buffers are same + # or have no entries + continue + if buf_in.dtype == buf_out.dtype: + # Just copy bytes if dtypes are same + buf_in = buf_in.view(torch.uint8) + buf_out = buf_out.view(torch.uint8) + is_cuda = ( + _devices_match(buf_in.device, "cuda") + and _devices_match(buf_out.device, "cuda") + ) + is_contiguous = buf_in.is_contiguous() and buf_out.is_contiguous() + key = ( + buf_in.dtype, + buf_out.dtype, + is_cuda, + is_contiguous, + ) + buffer_groups[key].append((buf_in, buf_out)) + + # Copy each group of buffers + for key, buffers in buffer_groups.items(): + # Check if buffers support fused kernel + dtype_in, dtype_out, is_cuda, is_contiguous = key + supported_dtypes = (torch.float32, torch.float16) + use_fused_kernel = ( + dtype_in in supported_dtypes and dtype_out in supported_dtypes + ) or (dtype_in == torch.uint8 and dtype_out == torch.uint8) + use_fused_kernel = use_fused_kernel and is_cuda and is_contiguous + + # Copy buffers + if use_fused_kernel and _FOUND_DEPRECATED_FUSED_ADAM: + if dummy_overflow_buf is None: + dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device="cuda") + multi_tensor_applier( + fused_adam_cuda.maybe_cast_mt, + dummy_overflow_buf, + list(zip(*buffers)), + ) + else: + # Warning: dummy_overflow_buf was not set in such case + for buf_in, buf_out in buffers: + buf_out.copy_(buf_in) + + +@contextlib.contextmanager +def _disable_pre_forward_hook( + param: torch.nn.Parameter, +) -> contextlib.AbstractContextManager: + """Prevent parameter from calling pre-forward hook""" + hook_is_enabled = getattr( + param, + "_pre_forward_hook_is_enabled", + False, + ) + if hook_is_enabled: + param._pre_forward_hook_is_enabled = False + try: + yield + finally: + if hook_is_enabled: + param._pre_forward_hook_is_enabled = True + + +@torch.no_grad() +def _bf16_rem_to_fp32( + bf16: torch.Tensor, + rem: torch.Tensor, + fp32: torch.Tensor, +) -> None: + """Pack BF16 tensor and 16-bit remainders into FP32 tensor""" + + # Check inputs + assert bf16.size() == rem.size() == fp32.size(), ( + "Tensor dimensions do not match: " + f"bf16={list(bf16.size())}, " + f"rem={list(rem.size())}, " + f"fp32={list(fp32.size())}, " + ) + assert bf16.dtype is torch.bfloat16, f"bf16 buffer has invalid dtype ({bf16.dtype})" + assert rem.dtype is torch.int16, f"rem buffer has invalid dtype ({rem.dtype})" + assert fp32.dtype is torch.float32, f"fp32 buffer has invalid dtype ({fp32.dtype})" + + # Undo bf16 rounding + bf16 = bf16.view(torch.int16) - torch.where(rem < 0, 1, 0) + + # Pack bf16 and remainder into little-endian fp32 + fp32 = fp32.unsqueeze(-1).view(torch.int16) + fp32 = torch.stack((rem, bf16), dim=-1, out=fp32) + class DistributedFusedAdam(torch.optim.Optimizer): - """AdamW optimizer with ZeRO algorithm. + """Adam optimizer with ZeRO algorithm. Currently GPU-only. Requires Apex to be installed via - ``python setup.py install --cuda_ext --cpp_ext``. + ``python setup.py install --cuda_ext --cpp_ext --distributed_adam --deprecated_fused_adam``. This implements the ZeRO-2 algorithm, which distributes the optimizer state and gradients between parallel processes. In @@ -38,11 +292,16 @@ class DistributedFusedAdam(torch.optim.Optimizer): params (iterable): iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. (default: 1e-3) + bias_correction (bool, optional): apply correction factor to + moment estimates. (default: True) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability. (default: 1e-8) + adam_w_mode (boolean, optional): Decouple weight decay + regularization (also known as AdamW algorithm) (default: + True) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) amsgrad (boolean, optional): whether to use the AMSGrad @@ -72,19 +331,50 @@ class DistributedFusedAdam(torch.optim.Optimizer): average_grad_sync (bool, optional): whether to use average reduction for gradient synchronization rather than sum (default: True) - overlap_grad_sync(boolean, optional): whether to overlap + overlap_grad_sync (boolean, optional): whether to overlap gradient synchronization with backward pass compute (default: True) + overlap_param_sync (boolean, optional): whether to overlap + parameter synchronization with forward pass compute + (default: False). This is an experimental feature. bucket_cap_mb (float, optional): bucket size in megabytes (default: 100) - pipeline_size (int, optional): number of buckets to - synchronize simultaneously (default: 2) + pipeline_size (int, optional): number of buckets to process + simultaneously in optimizer step (default: 2) + contiguous_param_buffer (bool, optional): convert parameters + into views into large persistent buffers (default: False). + This enables some performance optimizations (e.g. avoiding + some memory copies), but may add memory overhead (e.g. if + the memory allocator can't reuse the original parameter + buffers). contiguous_grad_buffer (bool, optional): allocate gradient - buckets out of a large persistent buffer (default: False). - This allows individual parameter gradients to be accessed - externally (see grad_buffer_view function). It also - maximizes memory usage and may prevent overlapping - communication and compute. + buckets out of a large persistent buffers (default: + False). This allows individual parameter gradients to be + accessed externally (see grad_buffer_view function). It + enables some performance optimizations (e.g. avoiding some + memory copies), but prevents some memory optimizations + (e.g. the memory allocator can't reuse buffers for + gradient buckets). + store_params (bool, optional): store a distributed copy of the + parameters as optimizer state (default: True). This may be + desirable if the optimizer dtype has higher precision than + the parameter dtype. + store_param_remainders (bool, optional): if model is BF16 and + optimizer is FP32, store bits required to reconstruct FP32 + params (default: False). This is an experimental feature. + with_scaled_states (bool, optional): apply per-tensor scaling + factors to the optimizer state (default: False). As + discussed in `FP8-LM: Training FP8 Large Language + Models`_, this helps maintain a reasonable dynamic range + even when the state is in a low-precision datatype like + FP16. + nccl_ub (bool, optional): enable NCCL user buffers for zero-copy + (default: False). It allows the collectives to use only 1 SM + when IB SHARP is enabled in a one-rank-per-node communication + group. This will help speedup the gemms overlapped with data- + parallel communications. + capturable (bool, optional): whether to use the version of the + optimizer that can be used with CUDA Graphs. (default: False). .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -93,9 +383,12 @@ class DistributedFusedAdam(torch.optim.Optimizer): .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _ZeRO\: Memory Optimizations Toward Training Trillion Parameter Models: https://arxiv.org/abs/1910.02054 + .. _FP8-LM\: Training FP8 Large Language Models: + https://arxiv.org/pdf/2310.18313v2.pdf """ + @dataclass class ParameterFragment: """Buffer ranges for a parameter fragment @@ -103,51 +396,99 @@ class ParameterFragment: parameter bucket. """ - def __init__( - self, - param_group_id, - param_id, - bucket_id, - param_range, - bucket_range, - in_local_shard, - shard_range, - shard_bucket_range, - shard_param_range, - ): - # Parameter group index - self.param_group_id = param_group_id - # Parameter index within parameter group - self.param_id = param_id - # Bucket index - self.bucket_id = bucket_id - # Range within flattened parameter buffer - self.param_range = param_range - # Range within bucket - self.bucket_range = bucket_range - # Whether fragment is in local shard of bucket - self.in_local_shard = in_local_shard - # Range within local shard - self.shard_range = shard_range - # Range of local fragment shard within bucket - self.shard_bucket_range = shard_bucket_range - # Range of local fragment shard within parameter - self.shard_param_range = shard_param_range + + # Parameter group index + param_group_id: int + # Parameter index within parameter group + param_id: int + # Bucket index + bucket_id: int + # Range within flattened parameter buffer + param_range: Tuple[int, int] + # Range within bucket + bucket_range: Tuple[int, int] + # Whether fragment is in local shard of bucket + in_local_shard: bool + # Range within local shard + shard_range: Optional[Tuple[int, int]] + # Range of local fragment shard within bucket + shard_bucket_range: Optional[Tuple[int, int]] + # Range of local fragment shard within parameter + shard_param_range: Optional[Tuple[int, int]] class StateBucket: - def __init__(self, shard_size, dtype, device): - """Optimizer state for a bucket""" + """Optimizer state for a bucket""" + + def __init__( + self, + bucket_size: int, + shard_size: int, + dtype: torch.dtype, + device: torch.device, + grad_sync_dtype: torch.dtype, + param_sync_dtype: torch.dtype, + contiguous_buffer_offset: int = 0, + store_params: bool = False, + store_param_remainders: bool = False, + ): + # Size of parameter bucket + self.bucket_size: int = bucket_size + # Size of local shard of parameter bucket + self.shard_size: int = shard_size + # Data type for state + self.dtype = dtype + # Data type for gradient synchronization + self.grad_sync_dtype = grad_sync_dtype + # Data type for parameter synchronization + self.param_sync_dtype = param_sync_dtype + # Size of the filled region in the bucket + self.filled_size: int = 0 + # Is it able to continue filling + self.able_to_fill: bool = True + # Offset to bucket in contiguous buffers + self.contiguous_buffer_offset: int = contiguous_buffer_offset # Buffer ranges corresponding to parameter fragments - self.fragments = [] + self.fragments: List[ParameterFragment] = [] # Local shard of parameters - self.params_shard = torch.zeros([shard_size], dtype=dtype, device=device) + self.params_shard: Optional[torch.Tensor] = None + if store_params: + self.params_shard = torch.zeros( + [shard_size], + dtype=self.dtype, + device=device, + ) + # Local shard of parameter remainders + self.param_remainders_shard: Optional[torch.Tensor] = None + if store_param_remainders: + self.param_remainders_shard = torch.zeros( + [shard_size], + dtype=torch.int16, + device=device, + ) # Local shard of first moment estimate - self.exp_avg_shard = torch.zeros([shard_size], dtype=dtype, device=device) + self.exp_avg_shard: torch.Tensor = torch.zeros( + [shard_size], + dtype=self.dtype, + device=device, + ) # Local shard of second moment estimate - self.exp_avg_sq_shard = torch.zeros([shard_size], dtype=dtype, device=device) + self.exp_avg_sq_shard: torch.Tensor = torch.zeros( + [shard_size], + dtype=self.dtype, + device=device, + ) + + def dtypes(self) -> Tuple[torch.dtype, torch.dtype, torch.dtype]: + """Datatypes for the bucket's compute and communication""" + return ( + self.dtype, + self.grad_sync_dtype, + self.param_sync_dtype, + ) class GradientStatus(enum.Enum): """Status of gradients within a bucket""" + # Gradients are ready to use READY = enum.auto() # Bucket is partially filled with unreduced gradients @@ -159,228 +500,761 @@ class GradientStatus(enum.Enum): class GradientBucket: """Gradient buffers and state for a bucket""" + def __init__(self): # Local shard of gradients - self.grads_shard = None + self.grads_shard: Optional[torch.Tensor] = None # Local contribution to gradients - self.grads_bucket = None + self.grads_bucket: Optional[torch.Tensor] = None # Buffer for gradient reduce-scatter - self.sync_grads_shard = None + self.sync_grads_shard: Optional[torch.Tensor] = None # Status of gradients - self.status = DistributedFusedAdam.GradientStatus.READY - # Request object for asynchronous communication - self.sync_request = None - - def sync_wait(self): - """Wait for asynchronous communication to finish""" - if self.sync_request is not None: - self.sync_request.wait() - self.sync_request = None - - _step_supports_amp_scaling = True - - def __init__(self, - params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0., - amsgrad=False, - dtype=torch.float32, - grad_sync_dtype=None, - param_sync_dtype=None, - device='cuda', - process_group=None, - distributed_process_group=None, - redundant_process_group=None, - average_grad_sync=True, - overlap_grad_sync=True, - bucket_cap_mb=100, - pipeline_size=2, - contiguous_grad_buffer=False, + self.status: GradientStatus = DistributedFusedAdam.GradientStatus.READY + # Params that have generated grads + self.grads_generated: Set[torch.nn.Parameter] = set() + + class ParameterStatus(enum.Enum): + """Status of parameters within a bucket""" + + # Parameters are sharded between processes + SHARDED = enum.auto() + # Asynchronous communication is in progress + SYNCING = enum.auto() + # Parameters are ready to use + READY = enum.auto() + + class ParameterBucket: + """Parameter buffers and state for a bucket""" + + def __init__(self): + # Local shard of parameters + self.params_shard: Optional[torch.Tensor] = None + # Gathered parameter values + self.params_bucket: Optional[torch.Tensor] = None + # Status of parameters + self.status: ParameterStatus = DistributedFusedAdam.ParameterStatus.SHARDED + # Params that have been updated + self.params_updated: Set[torch.nn.Parameter] = set() + + # Enable custom logic for AMP grad scaling + _step_supports_amp_scaling: bool = True + _custom_amp_unscale_grads: bool = True + + def __init__( + self, + params: Union[Iterable[torch.nn.Parameter], Iterable[dict]], + lr: float = 1e-3, + bias_correction: bool = True, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + adam_w_mode: bool = True, + weight_decay: float = 0.0, + amsgrad: bool = False, + dtype: torch.dtype = torch.float32, + grad_sync_dtype: Optional[torch.dtype] = None, + param_sync_dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = "cuda", + process_group: Optional[torch.distributed.ProcessGroup] = None, + distributed_process_group: Optional[torch.distributed.ProcessGroup] = None, + redundant_process_group: Optional[torch.distributed.ProcessGroup] = None, + average_grad_sync: bool = True, + overlap_grad_sync: bool = True, + overlap_param_sync: bool = False, + bucket_cap_mb: float = 100.0, + pipeline_size: int = 2, + contiguous_param_buffer: bool = False, + contiguous_grad_buffer: bool = False, + store_params: bool = True, + store_param_remainders: bool = False, + with_scaled_states: bool = False, + nccl_ub: bool = False, + capturable: bool = False, ): - defaults = dict(lr=lr, bias_correction=bias_correction, - betas=betas, eps=eps, weight_decay=weight_decay) - super(DistributedFusedAdam, self).__init__(params, defaults) + if (with_scaled_states or store_param_remainders) and capturable: + raise Exception(f"{self.__class__.__name__} with scaled states " + "or storing param remainders doesn't support CUDA graph yet.") + + if capturable and not _FOUND_DEPRECATED_FUSED_ADAM: + raise Exception(f"Capturable {self.__class__.__name__} relies on " + "multi_tensor_copy to set dummy_overflow_buf to indicate " + "whether there's gradient Inf/NaN, build APEX with " + "`--deprecated_fused_adam` is essential.") + + if capturable: + raise Exception("Distributed fused adam does not support cudagraph on ROCm") + + # If capturable for CUDA graph + self.capturable: bool = capturable + # If the optimizer is capturable then LR should be a tensor (on GPU) + if capturable: + lr = torch.tensor(lr, dtype=torch.float32, device=device) + + defaults = dict( + lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + super().__init__(params, defaults) # Adam options + self.adam_w_mode: bool = adam_w_mode + self.amsgrad: bool = amsgrad if amsgrad: - raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.') + raise RuntimeError( + "DistributedFusedAdam does not support the AMSGrad variant." + ) # Datatype options if grad_sync_dtype is None: grad_sync_dtype = dtype if param_sync_dtype is None: param_sync_dtype = dtype - supported_dtypes = [ - (torch.float32, torch.float16), - (torch.float32, torch.float32), - ] - if (dtype, grad_sync_dtype) not in supported_dtypes: + supported_dtypes = (torch.float32, torch.float16, torch.bfloat16) + if ( + dtype not in supported_dtypes + or grad_sync_dtype not in supported_dtypes + ): + raise ValueError( + "Unsupported dtypes for DistributedFusedAdam " + f"(dtype={dtype}, " + f"grad_sync_dtype={grad_sync_dtype}, " + f"param_sync_dtype={param_sync_dtype}))" + ) + self.dtype: torch.dtype = dtype + self.grad_sync_dtype: torch.dtype = grad_sync_dtype + self.param_sync_dtype: torch.dtype = param_sync_dtype + + # Device options + if not _devices_match(device, "cuda"): raise RuntimeError( - 'Invalid dtypes for DistributedFusedAdam ' - f'(dtype={dtype}, ' - f'grad_sync_dtype={grad_sync_dtype}, ' - f'param_sync_dtype={param_sync_dtype}))') - if device != 'cuda': - raise RuntimeError('DistributedFusedAdam only supports GPU') - self.dtype = dtype - self.grad_sync_dtype = grad_sync_dtype - self.param_sync_dtype = param_sync_dtype - self.device = device + "Invalid device for DistributedFusedAdam " f"(device={device})" + ) + self.device: torch.device = torch.device("cuda", torch.cuda.current_device()) # Process groups - self.process_group = ( - _get_default_group() - if process_group is None - else process_group + self.process_group: torch.distributed.ProcessGroup = ( + _get_default_group() if process_group is None else process_group ) - self.distributed_process_group = ( + self.distributed_process_group: torch.distributed.ProcessGroup = ( self.process_group if distributed_process_group is None else distributed_process_group ) - self.redundant_process_group = redundant_process_group - self.process_group_size = torch.distributed.get_world_size(self.process_group) - self.distributed_rank = torch.distributed.get_rank(self.distributed_process_group) - self.distributed_size = torch.distributed.get_world_size(self.distributed_process_group) - self.redundant_size = ( + self.redundant_process_group: Optional[ + torch.distributed.ProcessGroup + ] = redundant_process_group + self.process_group_size: int = torch.distributed.get_world_size( + self.process_group + ) + self.distributed_rank: int = torch.distributed.get_rank( + self.distributed_process_group + ) + self.distributed_size: int = torch.distributed.get_world_size( + self.distributed_process_group + ) + self.redundant_size: int = ( 1 if self.redundant_process_group is None else torch.distributed.get_world_size(self.redundant_process_group) ) if self.process_group_size != self.distributed_size * self.redundant_size: raise RuntimeError( - 'Invalid process group configuration ' - f'(process group size = {self.process_group_size}, ' - f'distributed process group size = {self.distributed_size}, ' - f'redundant process group size = {self.redundant_size})' + "Invalid process group configuration " + f"(process group size = {self.process_group_size}, " + f"distributed process group size = {self.distributed_size}, " + f"redundant process group size = {self.redundant_size})" ) - try: - self._process_group_ranks = [ - _get_global_rank(self.process_group, local_rank) - for local_rank in range(self.distributed_size) - ] - except: - self._process_group_ranks = list(range(self.distributed_size)) + self.process_group_root: int = get_global_rank(self.process_group, 0) # Use average reduction for grad sync - self.average_grad_sync = average_grad_sync + self.average_grad_sync: bool = average_grad_sync # Copy param grads to bucket as soon as available - self.greedy_grad_copy = True - # Synchronize grad buckets as soon as all grads are available - self.overlap_grad_sync = overlap_grad_sync + self.greedy_grad_copy: bool = True + # Synchronize grad buckets as soon as their grads are available + self.overlap_grad_sync: bool = overlap_grad_sync + # Try synchronizing param buckets just before param is needed + self.overlap_param_sync: bool = overlap_param_sync # Number of buckets to synchronize at a time - self.pipeline_size = pipeline_size - # Allocate contiguous buffer for gradients - self.contiguous_grad_buffer = contiguous_grad_buffer + self.pipeline_size: int = pipeline_size + + # Store params or param remainders + if store_param_remainders: + if store_params: + raise RuntimeError( + "Attempted to construct DistributedFusedAdam " + "with store_params=True and store_param_remainders=True" + ) + if self.dtype != torch.float32 or self.param_sync_dtype != torch.bfloat16: + raise RuntimeError( + "DistributedFusedAdam requires " + "BF16 params and FP32 optimizer state " + "when storing parameter remainders " + f"(dtype={self.dtype}, " + f"param_sync_dtype={self.param_sync_dtype}))" + ) + self.store_params: bool = store_params + self.store_param_remainders: bool = store_param_remainders + + # Whether to scale optimizer state + self.with_scaled_states: bool = with_scaled_states + if self.with_scaled_states: + if not self.store_params: + raise RuntimeError( + "Attempted to construct DistributedFusedAdam " + "with with_scaled_state=True and store_params=False" + ) + if self.store_param_remainders: + raise RuntimeError( + "Attempted to construct DistributedFusedAdam " + "with with_scaled_state=True and store_params_remainders=True" + ) + if self.dtype not in (torch.float16, torch.bfloat16): + raise RuntimeError( + "Attempted to construct DistributedFusedAdam " + f"with with_scaled_state=True and dtype={self.dtype} " + "(only fp16 and bf16 are supported)" + ) + if self.param_sync_dtype == torch.float32: + # _local_step_with_scaled_states applies Adam kernel + # to fp32 workspace buffer and relies on + # _check_params_shard_dtypes to copy to param sync + # workspace buffer. However, + # _check_params_shard_dtypes does nothing if + # param_sync_dtype is fp32. + raise RuntimeError( + "Attempted to construct DistributedFusedAdam " + f"with with_scaled_state=True and param_sync_dtype={self.param_sync_dtype}" + ) + # Scaling factors to apply to recover unscaled optimizer state + self._state_scales: dict = {} # Determine bucket sizes dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8 - self.alignment = 128 // dtype_size - bucket_size = 1024*1024*bucket_cap_mb / dtype_size + self.alignment: int = 128 // dtype_size + self.bucket_cap_mb: float = bucket_cap_mb + bucket_size = 1024 * 1024 * bucket_cap_mb / dtype_size shard_size = int(bucket_size / self.distributed_size) shard_size = _round_to_multiple(shard_size, self.alignment, round_up=False) shard_size = max(shard_size, self.alignment) - bucket_size = shard_size * self.distributed_size - self.bucket_size = bucket_size - self.shard_size = shard_size - - # Load CUDA kernels - global fused_adam_cuda, distributed_adam_cuda - fused_adam_cuda = importlib.import_module("fused_adam_cuda") - distributed_adam_cuda = importlib.import_module("distributed_adam_cuda") + self.default_shard_size: int = shard_size # Optimizer state - self.state['buckets'] = [] - self.state['step'] = 0 + self.state["buckets"]: List[StateBucket] = [] + self.state["step"]: torch.Tensor | int = torch.tensor([0], dtype=torch.int, + device=self.device) if self.capturable else 0 - # Objects for gradient synchronization - self._grads_buckets = collections.defaultdict(self.GradientBucket) - self._grads_generated = set() - self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)] + # Gradient state + self._grads_buckets: Dict[int, GradientBucket] = collections.defaultdict( + self.GradientBucket + ) + # Param state + self._params_buckets: Dict[int, ParameterBucket] = collections.OrderedDict() + + # Whether to allocate contiguous buffers for parameters + self.contiguous_param_buffer: bool = contiguous_param_buffer + # Whether to allocate contiguous buffers for gradients + self.contiguous_grad_buffer: bool = contiguous_grad_buffer + # Whether to use NCCL User Buffer + self.nccl_ub: bool = nccl_ub + # Contiguous buffers for parameters + self._param_buffers: Dict[ + Tuple[torch.dtype, torch.dtype, torch.dtype], torch.Tensor + ] = {} + # Contiguous buffers for gradients + self._grad_buffers: Dict[ + Tuple[torch.dtype, torch.dtype, torch.dtype], torch.Tensor + ] = {} + # Output buffer for gradient shards, only required for NCCL user buffer + if self.nccl_ub: + if not nccl_allocator: + raise RuntimeError("NCCL allocator importing failed but nccl ub is still requested") + elif not self.contiguous_grad_buffer: + raise RuntimeError("NCCL user buffers require contiguous grad buffers") + else: + self._shard_grad_buffers: Dict[ + Tuple[torch.dtype, torch.dtype, torch.dtype], torch.Tensor + ] = {} + + # Side streams for state dict communication + self._pipeline_streams: List[torch.cuda.Stream] = [ + torch.cuda.Stream() for _ in range(self.pipeline_size) + ] + # Side streams for gradients and parameters communication + self._comm_streams: List[torch.cuda.Stream] = [ + torch.cuda.Stream() for _ in range(self.pipeline_size) + ] + self._last_comm_stream_id: int = -1 - # Divide gradients by factor before optimizer step. Used for - # grad clipping and gradient scaler. - self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device) + # Scale by factor before optimizer step. Used for grad + # clipping and gradient scaler. + self._grad_scale: torch.Tensor = torch.full( + [], 1.0, dtype=torch.float32, device=self.device + ) # Norm of parameter gradients. Used for gradient clipping and # gradient scaler. - self._grad_norm = None + self._grad_norm: Optional[torch.Tensor] = None + + # Dummy flag for multi-tensor kernels + # Note: Apex multi-tensor kernels have a noop_flag argument + # that is intended to detect non-finite values. It shouldn't + # have any effect with the kernels used in the optimizer, but + # we still set it to zero out of an abundance of caution. + self._dummy_overflow_buf: torch.Tensor = torch.zeros( + [1], dtype=torch.int32, device=self.device + ) # Check if collectives have no_copy option - self._reduce_scatter_no_copy = ( - 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args - ) - self._all_gather_no_copy = ( - 'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args - ) - self._gather_no_copy = ( - 'no_copy' in inspect.getfullargspec(torch.distributed.gather).args + self._gather_no_copy: bool = ( + "no_copy" in inspect.getfullargspec(torch.distributed.gather).args ) + # Make sure parameter values are same across processes + self._broadcast_params() + + # Lock for callbacks + self._lock: threading.Lock = threading.Lock() # Attach hooks for gradient synchronization self._register_post_backward_hooks() + # Attach hooks for param synchronization + if self.overlap_param_sync: + self._register_pre_forward_hooks() + + # Move LR to device + if capturable: + for idx, group in enumerate(self.param_groups): + if len(group['params']) == 0: + continue + for item in ['lr']: + if torch.is_tensor(group[item]): + self.param_groups[idx][item] = group[item].to(device=self.device) + else: + self.param_groups[idx][item] = torch.tensor(group[item], + device=self.device) + + # For better representation string + arg_names = inspect.getfullargspec(DistributedFusedAdam.__init__).args + arg_names.remove('self') + arg_names.remove('params') + for i, group in enumerate(self.param_groups): + for key in sorted(group.keys()): + if key in arg_names: + arg_names.remove(key) + self.args_dict = {name: getattr(self, name) for name in arg_names} + + def __repr__(self) -> str: + # Based on: https://github.com/pytorch/pytorch/blob/v2.3.0-rc12/torch/optim/optimizer.py#L315 + format_string = self.__class__.__name__ + ' (' + for i, group in enumerate(self.param_groups): + format_string += '\n' + format_string += f'Parameter Group {i}\n' + for key in sorted(group.keys()): + if key != 'params': + format_string += f' {key}: {group[key]}\n' + + for key, val in self.args_dict.items(): + if 'process_group' in key and val: + format_string += f'{key}: {hex(id(val))}, world size {val.size()}\n' + else: + format_string += f'{key}: {val}\n' + + format_string += ')' + return format_string + + @torch.no_grad() + def _broadcast_params(self) -> None: + """Broadcast parameter values from root rank""" + process_group = self.process_group + with _coalescing_manager(process_group, self.device, async_ops=True) as cm: + for param_group in self.param_groups: + for param in param_group["params"]: + _coalescing_manager_append_work( + cm, + torch.distributed.broadcast( + param, + src=self.process_group_root, + group=process_group, + async_op=True, + ), + ) + cm.wait() - def _register_post_backward_hooks(self): - """Attach hooks for gradient synchronization + def _make_post_backward_hook( + self, + param: torch.nn.Parameter, + param_group_id: int, + param_id: int, + ) -> Callable: + """Create callback function to call after param generates grad - Optimizer state for parameters are initialized lazily as they - are encountered in the backward pass. + Lazily initialize parameter and try launching grad sync. """ - self._num_grads = 0 - grad_buffer_size = 0 - self._lock = threading.Lock() + + def post_backward_hook(*unused) -> None: + if getattr(param, "_pre_forward_hook_is_enabled", False): + raise RuntimeError( + "A parameter called its post-backward hook " + "before its pre-forward hook. " + "Please manually interact with the parameter " + "before the forward pass (e.g. by calling data_ptr) " + "or run DistributedFusedAdam with overlap_param_sync=False." + ) + with self._lock: + need_to_initialize = "fragments" not in self.state[param] + if need_to_initialize: + self._init_param_state(param, param_group_id, param_id) + if self.greedy_grad_copy: + self._grad_copy(param) + if self.overlap_grad_sync: + self._try_start_bucket_grad_sync( + params=[param], + ignore_last_bucket=need_to_initialize, + ) + + return post_backward_hook + + def _register_post_backward_hooks(self) -> None: + """Attach hooks for gradient synchronization""" self._grad_accs = [] for param_group_id, group in enumerate(self.param_groups): - for param_id, param in enumerate(group['params']): - torch.distributed.broadcast( + for param_id, param in enumerate(group["params"]): + if param.requires_grad: + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + hook = self._make_post_backward_hook( + param, + param_group_id, + param_id, + ) + grad_acc.register_hook(hook) + self._grad_accs.append(grad_acc) + + def _make_pre_forward_hook( + self, + param: torch.nn.Parameter, + param_group_id: int, + param_id: int, + ) -> Callable: + """Create callback function to call before param forward pass + + Make sure param has been synchronized and try launching next + param sync. + + """ + + def pre_forward_hook(*unused) -> None: + with self._lock: + if "fragments" not in self.state[param]: + return + self._param_copy(param) + if self.overlap_param_sync: + self._try_start_bucket_param_sync() + + return pre_forward_hook + + def _register_pre_forward_hooks(self) -> None: + """Attach hooks for parameter synchronization + + If _pre_forward_hook_is_enabled is set in a parameter, then + the callback will be called the first time any of its + attributes are accessed. This is hackily done by + monkey-patching the parameter class, so proceed with caution. + + """ + for param_group_id, group in enumerate(self.param_groups): + for param_id, param in enumerate(group["params"]): + # Monkey-patch parameter class + cls = param.__class__ + if not getattr(cls, "_has_pre_forward_hook", False): + # Monkey-patch magic methods to call __getattribute__ + special_funcs = [ + "__abs__", + "__add__", + "__and__", + "__bool__", + "__complex__", + "__contains__", + "__deepcopy__", + "__delitem__", + "__div__", + "__eq__", + "__float__", + "__floordiv__", + "__ge__", + "__getitem__", + "__gt__", + "__iadd__", + "__iand__", + "__idiv__", + "__ifloordiv__", + "__ilshift__", + "__imod__", + "__imul__", + "__index__", + "__int__", + "__invert__", + "__ior__", + "__ipow__", + "__irshift__", + "__isub__", + "__iter__", + "__itruediv__", + "__ixor__", + "__le__", + "__len__", + "__long__", + "__lshift__", + "__lt__", + "__matmul__", + "__mod__", + "__mul__", + "__neg__", + "__nonzero__", + "__or__", + "__pos__", + "__pow__", + "__radd__", + "__rand__", + "__rdiv__", + "__reduce__", + "__reduce_ex__", + "__reversed__", + "__rfloordiv__", + "__rlshift__", + "__rmatmul__", + "__rmod__", + "__rmul__", + "__ror__", + "__rpow__", + "__rrshift__", + "__rshift__", + "__rsub__", + "__rtruediv__", + "__rxor__", + "__setitem__", + "__sizeof__", + "__sub__", + "__truediv__", + "__xor__", + ] + for func_name in special_funcs: + + def make_augmented_func() -> Callable: + base_func_name = f"_base_{func_name}" + + def augmented_func(self, *args, **kwargs): + return getattr(self, base_func_name)(*args, **kwargs) + + return augmented_func + + setattr(cls, f"_base_{func_name}", getattr(cls, func_name)) + setattr(cls, func_name, make_augmented_func()) + + # Monkey-patch __getattribute__ to call pre-forward hook + def make_getattribute() -> Callable[[str], Any]: + special_attrs = { + "_pre_forward_hook_is_enabled", + "_pre_forward_hook", + "__del__", + "__delattr__", + "__dir__", + "__getattr__", + "__getattribute__", + "__hash__", + "__init__", + "__new__", + "__setattr__", + } + + def getattribute_with_pre_forward_hook(self, name: str): + """Variant of __getattribute__ that can call pre-forward hook""" + if name not in special_attrs: + if getattr(self, "_pre_forward_hook_is_enabled", False): + self._pre_forward_hook_is_enabled = False + self._pre_forward_hook() + return object.__getattribute__(self, name) + + return getattribute_with_pre_forward_hook + + cls.__getattribute__ = make_getattribute() + cls._has_pre_forward_hook = True + + # Register pre-forward callback + param._pre_forward_hook_is_enabled = False + param._pre_forward_hook = self._make_pre_forward_hook( param, - src=self._process_group_ranks[0], - group=self.process_group, + param_group_id, + param_id, ) - if param.requires_grad: - self._num_grads += 1 - - # Callback after gradient is generated - def wrapper(p, p_group_id, p_id): - p_tmp = p.expand_as(p) - grad_acc = p_tmp.grad_fn.next_functions[0][0] - def reduction_hook(*unused): - with self._lock: - if 'fragments' not in self.state[p]: - self._init_param_state(p, p_group_id, p_id) - if self.greedy_grad_copy: - self._grad_copy(p) - if self.overlap_grad_sync: - self._try_start_bucket_grad_sync( - params=[p], - ignore_last_bucket=True, - ) - grad_acc.register_hook(reduction_hook) - self._grad_accs.append(grad_acc) - wrapper(param, param_group_id, param_id) - - # Gradient size, with padding for alignment - grad_size = _round_to_multiple(param.numel(), self.alignment) - grad_buffer_size += grad_size - - # Allocate contiguous gradient buffer if needed - if self.contiguous_grad_buffer: - grad_buffer_size = _round_to_multiple( - grad_buffer_size, - self.bucket_size, + + @torch.no_grad() + def init_param_buffer(self) -> None: + """Allocate contiguous buffers for param buckets + + This converts the parameters into views into contiguous + buffers. This enables some performance optimizations (e.g. + avoiding some memory copies), but may add memory overhead + (e.g. if the memory allocator can't reuse the original + parameter buffers). To minimize memory overhead, this buffer + should be initialized before the first training step. + + """ + + # Make sure all params are initialized + self.contiguous_param_buffer = True + self.init_params() + + # Construct param buffers + buffer_sizes = collections.defaultdict(lambda: 0) + for bucket in self.state["buckets"]: + dtypes = bucket.dtypes() + buffer_sizes[dtypes] = max( + bucket.contiguous_buffer_offset + bucket.bucket_size, + buffer_sizes[dtypes], ) - self._grad_buffer = torch.zeros( - [grad_buffer_size], - dtype=self.dtype, + for dtypes, buffer_size in buffer_sizes.items(): + _, _, param_sync_dtype = dtypes + self._param_buffers[dtypes] = torch.zeros( + [buffer_size], + dtype=param_sync_dtype, device=self.device, ) - def init_params(self, params=None): + # Figure out corresponding positions in params and param buffer + params = list(self.parameters()) + param_flat_views = [] + param_buffer_views = [] + for i, param in enumerate(params): + fragment = self.state[param]["fragments"][0] + bucket_id = fragment.bucket_id + bucket = self.state["buckets"][bucket_id] + param_size = param.numel() + bucket_start, _ = fragment.bucket_range + buffer_offset = bucket.contiguous_buffer_offset + buffer_start = buffer_offset + bucket_start + buffer_end = buffer_start + param_size + param_buffer = self._param_buffers[bucket.dtypes()] + param_buffer_view = param_buffer[buffer_start:buffer_end].detach() + if not _devices_match(param_buffer_view.device, param.device): + raise RuntimeError( + "Attempted to change a parameter with device={param.device} " + f"into a buffer view with device={param_buffer_view.device}" + ) + if param_buffer_view.dtype != param.dtype: + if ( + not torch.is_floating_point(param_buffer_view) + and param_buffer_view.element_size() == param.element_size() + ): + param_buffer_view = param_buffer_view.view(dtype=param.dtype) + else: + raise RuntimeError( + f"Attempted to change a parameter with dtype={param.dtype} " + f"into a buffer view with dtype={param_buffer_view.dtype}" + ) + if param.is_contiguous(memory_format=torch.channels_last): + param = param.permute(0, 2, 3, 1) + param_flat_views.append(param.detach().view(-1)) + param_buffer_views.append(param_buffer_view) + + # Copy values into param buffer + _multi_tensor_copy( + param_flat_views, + param_buffer_views, + dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Make all params a view into the param buffer + for param, buffer_view in zip(params, param_buffer_views): + # Preserve memory format for param here, i.e. NHWC tensors + # `param.data.set_()` failed to change storage. + # `param.set_()` invalidates bprop hook. + param.data = buffer_view.as_strided(param.size(), param.stride()) + + def _init_grad_buffer(self) -> None: + """Allocate contiguous buffer for grad buckets""" + + # Make sure all params are initialized + self.contiguous_grad_buffer = True + self.init_params() + + # Construct grad buffers + buffer_sizes = collections.defaultdict(lambda: 0) + for bucket in self.state["buckets"]: + dtypes = bucket.dtypes() + buffer_sizes[dtypes] = max( + bucket.contiguous_buffer_offset + bucket.bucket_size, + buffer_sizes[dtypes], + ) + for dtypes, buffer_size in buffer_sizes.items(): + _, grad_sync_dtype, _ = dtypes + if not self.nccl_ub: + self._grad_buffers[dtypes] = torch.zeros( + [buffer_size], dtype=grad_sync_dtype, device=self.device, + ) + else: + pool = nccl_allocator.create_nccl_mem_pool() + with nccl_allocator.nccl_mem(pool): + self._grad_buffers[dtypes] = torch.zeros( + [buffer_size], dtype=grad_sync_dtype, device=self.device, + ) + shard_buffer_size = buffer_size // self.distributed_size + with nccl_allocator.nccl_mem(pool): + self._shard_grad_buffers[dtypes] = torch.zeros( + [shard_buffer_size], dtype=grad_sync_dtype, device=self.device, + ) + + def parameters(self) -> Iterable[torch.nn.Parameter]: + """Returns an iterator over optimizer parameters""" + return itertools.chain.from_iterable( + group["params"] for group in self.param_groups + ) + + def parameter( + self, + *args: Union[int, ParameterFragment], + ) -> torch.nn.Parameter: + """Get optimizer parameter + + Can either accept two ints or one + DistributedFusedAdam.ParameterFragment. + + Arguments: + param_group_id (int): Parameter group index + param_id (int): Parameter index within parameter group + + """ + if ( + len(args) == 2 + and isinstance(args[0], int) + and isinstance(args[1], int) + ): + param_group_id = args[0] + param_id = args[1] + elif len(args) == 1 and isinstance(args[0], self.ParameterFragment): + fragment = args[0] + param_group_id = fragment.param_group_id + param_id = fragment.param_id + else: + raise TypeError( + "Expected input types are " + "[int, int] or [DistributedFusedAdam.ParameterFragment], " + f"but found {[type(arg).__name__ for arg in args]}" + ) + return self.param_groups[param_group_id]["params"][param_id] + + def init_params( + self, + params: Optional[Iterable[torch.nn.Parameter]] = None, + dtype: Optional[torch.dtype] = None, + grad_sync_dtype: Optional[torch.dtype] = None, + param_sync_dtype: Optional[torch.dtype] = None, + ) -> None: """Initialize optimizer state for parameters + Ignores parameters that have already been initialized. + Arguments: params (iterable, optional): parameters to initialize (default: all parameters) @@ -388,111 +1262,320 @@ def init_params(self, params=None): """ # Default cases - if isinstance(params, torch.Tensor): + if params is None: + params = self.parameters() + elif isinstance(params, torch.Tensor): params = [params] - elif params is None: - params = [] - for group in self.param_groups: - params.extend(group['params']) + + # Ignore parameters that have already been initialized + params = [param for param in params if "fragments" not in self.state[param]] + if not params: + return # Get indices corresponding to parameters id_map = dict() for param_group_id, group in enumerate(self.param_groups): - for param_id, param in enumerate(group['params']): + for param_id, param in enumerate(group["params"]): id_map[param] = (param_group_id, param_id) # Initialize parameters for param in params: - if param in id_map and 'fragments' not in self.state[param]: + if param in id_map: param_group_id, param_id = id_map[param] - self._init_param_state(param, param_group_id, param_id) + self._init_param_state( + param, + param_group_id, + param_id, + dtype=dtype, + grad_sync_dtype=grad_sync_dtype, + param_sync_dtype=param_sync_dtype, + ) + + def init_params_bucket( + self, + params: Iterable[torch.nn.Parameter], + dtype: Optional[torch.dtype] = None, + grad_sync_dtype: Optional[torch.dtype] = None, + param_sync_dtype: Optional[torch.dtype] = None, + ) -> None: + """Initialize optimizer state for parameters in one effective bucket + + The buckets corresponding to the provided parameters are + configured so they all perform communication together. Ignores + parameters that have already been initialized. + + Arguments: + params (iterable): parameters to initialize + + """ + + # Ignore parameters that have already been initialized + if isinstance(params, torch.Tensor): + params = [params] + params = [param for param in params if "fragments" not in self.state[param]] + if not params: + return + + # Get indices corresponding to parameters + id_map = dict() + for param_group_id, group in enumerate(self.param_groups): + for param_id, param in enumerate(group["params"]): + id_map[param] = [param_group_id, param_id] + param_ids = [tuple([param] + id_map[param]) for param in params] + + # Mark existings bucket as fully filled + for bucket in self.state["buckets"]: + bucket.able_to_fill = False + + # Initialize optimizer state for parameters + start_bucket_id = len(self.state["buckets"]) + self.init_params( + params, + dtype=dtype, + grad_sync_dtype=grad_sync_dtype, + param_sync_dtype=param_sync_dtype, + ) + end_bucket_id = len(self.state["buckets"]) + + # Make sure all added buckets depend on provided params + for bucket_id in range(start_bucket_id, end_bucket_id): + bucket = self.state["buckets"][bucket_id] + bucket_size = bucket.bucket_size + bucket.able_to_fill = False + ids_in_bucket = set( + (fragment.param_group_id, fragment.param_id) + for fragment in bucket.fragments + ) + for param, param_group_id, param_id in param_ids: + if (param_group_id, param_id) not in ids_in_bucket: + param_size = param.numel() + fragment = self.ParameterFragment( + param_group_id=param_group_id, + param_id=param_id, + bucket_id=bucket_id, + param_range=(param_size, param_size), + bucket_range=(bucket_size, bucket_size), + in_local_shard=False, + shard_range=None, + shard_bucket_range=None, + shard_param_range=None, + ) + self.state[param]["fragments"].append(fragment) + bucket.fragments.append(fragment) + @torch.no_grad() def _init_param_state( - self, - param, - param_group_id, - param_id, - ): + self, + param: torch.nn.Parameter, + param_group_id: int, + param_id: int, + dtype: Optional[torch.dtype] = None, + grad_sync_dtype: Optional[torch.dtype] = None, + param_sync_dtype: Optional[torch.dtype] = None, + ) -> None: """Initialize optimizer state for a parameter""" - # Make sure there is at least one bucket - if not self.state['buckets']: - self.state['buckets'].append( - self.StateBucket(self.shard_size, self.dtype, self.device) + # Return immediately if already initialized + if "fragments" in self.state[param]: + return + self.state[param]["fragments"] = [] + + # Data type configuration + if dtype is None: + dtype = self.dtype + if grad_sync_dtype is None: + grad_sync_dtype = self.grad_sync_dtype + if param_sync_dtype is None: + param_sync_dtype = self.param_sync_dtype + if dtype != self.dtype: + raise ValueError( + "Optimizer states with non-default dtypes are not supported" + ) + supported_dtypes = (torch.float32, torch.float16, torch.bfloat16) + if ( + dtype not in supported_dtypes + or grad_sync_dtype not in supported_dtypes + ): + raise ValueError( + "Unsupported dtypes for DistributedFusedAdam " + f"(dtype={dtype}, " + f"grad_sync_dtype={grad_sync_dtype}, " + f"param_sync_dtype={param_sync_dtype}))" + ) + + # Store params or param remainders + store_params = ( + self.store_params + or dtype != self.dtype + or param_sync_dtype != self.param_sync_dtype + ) + store_param_remainders = ( + self.store_param_remainders + and dtype == self.dtype + and param_sync_dtype == self.param_sync_dtype + ) + + def last_bucket_id() -> int: + """Index of last optimizer state bucket with desired dtypes + + -1 if there are no such buckets. + + """ + dtypes = (dtype, grad_sync_dtype, param_sync_dtype) + bucket_id = len(self.state["buckets"]) - 1 + while bucket_id > 0: + bucket = self.state["buckets"][bucket_id] + if bucket.dtypes() == dtypes: + break + bucket_id -= 1 + return bucket_id + + def make_bucket( + bucket_size: int, + shard_size: int, + buffer_offset: int, + ) -> None: + """Construct new optimizer state bucket""" + self.state["buckets"].append( + self.StateBucket( + bucket_size, + shard_size, + dtype, + self.device, + grad_sync_dtype, + param_sync_dtype, + contiguous_buffer_offset=buffer_offset, + store_params=store_params, + store_param_remainders=store_param_remainders, + ) ) + # Make sure there is at least one bucket with expected dtypes + if last_bucket_id() < 0: + shard_size = self.default_shard_size + bucket_size = shard_size * self.distributed_size + buffer_offset = 0 + make_bucket(bucket_size, shard_size, buffer_offset) + # Split parameter values into fragments # Note: Each fragment resides within a bucket param_start = 0 param_size = param.numel() - self.state[param]['fragments'] = [] while param_start < param_size: - # Get current bucket - bucket_id = len(self.state['buckets']) - 1 - bucket = self.state['buckets'][bucket_id] + bucket_id = last_bucket_id() + bucket = self.state["buckets"][bucket_id] fragment_id = len(bucket.fragments) + bucket_size = bucket.bucket_size + shard_size = bucket.shard_size # Determine fragment position within bucket - if fragment_id == 0: - bucket_start = 0 - else: - _, bucket_start = bucket.fragments[-1].bucket_range - bucket_start = _round_to_multiple(bucket_start, self.alignment) - fragment_size = min(param_size-param_start, self.bucket_size-bucket_start) + bucket_start = _round_to_multiple( + bucket.filled_size, + self.alignment, + round_up=True, + ) + fragment_size = min(param_size - param_start, bucket_size - bucket_start) param_end = param_start + fragment_size bucket_end = bucket_start + fragment_size # Create new bucket if current one is full - if fragment_size <= 0: - self.state['buckets'].append( - self.StateBucket(self.shard_size, self.dtype, self.device) - ) + if fragment_size <= 0 or not bucket.able_to_fill: + shard_size = self.default_shard_size + bucket_size = shard_size * self.distributed_size + buffer_offset = bucket.contiguous_buffer_offset + bucket.bucket_size + make_bucket(bucket_size, shard_size, buffer_offset) continue # Fragment position within local shard shard_id = self.distributed_rank - shard_start = bucket_start - self.shard_size*shard_id - shard_end = bucket_end - self.shard_size*shard_id - shard_start = min(max(shard_start, 0), self.shard_size) - shard_end = min(max(shard_end, 0), self.shard_size) + shard_start = bucket_start - shard_size * shard_id + shard_end = bucket_end - shard_size * shard_id + shard_start = min(max(shard_start, 0), shard_size) + shard_end = min(max(shard_end, 0), shard_size) in_local_shard = shard_start < shard_end + shard_range = None + shard_bucket_range = None + shard_param_range = None if in_local_shard: - shard_bucket_start = shard_start + self.shard_size*shard_id + shard_range = (shard_start, shard_end) + shard_bucket_start = shard_start + shard_size * shard_id shard_bucket_end = shard_bucket_start + shard_end - shard_start + shard_bucket_range = (shard_bucket_start, shard_bucket_end) shard_param_start = shard_bucket_start - bucket_start + param_start shard_param_end = shard_param_start + shard_end - shard_start - else: - shard_bucket_start, shard_bucket_end = None, None - shard_param_start, shard_param_end = None, None + shard_param_range = (shard_param_start, shard_param_end) # Record fragment info fragment = self.ParameterFragment( param_group_id=param_group_id, param_id=param_id, bucket_id=bucket_id, - param_range=(param_start,param_end), - bucket_range=(bucket_start,bucket_end), + param_range=(param_start, param_end), + bucket_range=(bucket_start, bucket_end), in_local_shard=in_local_shard, - shard_range=(shard_start,shard_end), - shard_bucket_range=(shard_bucket_start,shard_bucket_end), - shard_param_range=(shard_param_start,shard_param_end), + shard_range=shard_range, + shard_bucket_range=shard_bucket_range, + shard_param_range=shard_param_range, ) - self.state[param]['fragments'].append(fragment) + self.state[param]["fragments"].append(fragment) bucket.fragments.append(fragment) + bucket.filled_size = bucket_end param_start = param_end - # Initialize master param buffer - for fragment in self.state[param]['fragments']: - if fragment.in_local_shard: - bucket = self.state['buckets'][fragment.bucket_id] - param_start, param_end = fragment.shard_param_range - shard_start, shard_end = fragment.shard_range - model_param_fragment = param.view(-1)[param_start:param_end] - master_param_fragment = bucket.params_shard[shard_start:shard_end] - master_param_fragment.copy_(model_param_fragment) + # Initialize optimizer state scaling factors if needed + if self.with_scaled_states: + for fragment in self.state[param]["fragments"]: + if not fragment.in_local_shard: + continue + bucket_id = fragment.bucket_id + self._state_scales[(param_group_id, param_id, bucket_id)] = dict( + param=torch.zeros([1], dtype=torch.float32, device=self.device), + exp_avg=torch.zeros([1], dtype=torch.float32, device=self.device), + exp_avg_sq=torch.zeros([1], dtype=torch.float32, device=self.device), + ) + + # Initialize main param buffer + if store_params: + for fragment in self.state[param]["fragments"]: + if not fragment.in_local_shard: + continue + bucket_id = fragment.bucket_id + bucket = self.state["buckets"][bucket_id] + # If param is channels last, i.e. tensor with shape (N, C, H, W) + # and stride (HWC, 1, WC, C), then we will turn it into a tensor + # with shape (N, H, W, C) and stride (HWC, WC, C, 1). The purppose + # is to avoid failures when flattening the tensor (`.view(-1)`) + # and stepping the optimizer. + if param.is_contiguous(memory_format=torch.channels_last): + param = param.permute(0, 2, 3, 1) + param_range = slice(*fragment.shard_param_range) + shard_range = slice(*fragment.shard_range) + model_param_fragment = param.detach().view(-1)[param_range] + if self.with_scaled_states: + model_param_fragment = torch.empty_like( + model_param_fragment, + dtype=torch.float32, + ).copy_(model_param_fragment) + self._apply_state_scale( + model_param_fragment, + self._state_scales[(param_group_id, param_id, bucket_id)]["param"], + ) + main_param_fragment = bucket.params_shard[shard_range] + main_param_fragment.copy_(model_param_fragment) + + # Check if buckets are underutilized + if all("fragments" in self.state[param] for param in self.parameters()): + bucket_size = sum(bucket.bucket_size for bucket in self.state["buckets"]) + filled_size = sum(bucket.filled_size for bucket in self.state["buckets"]) + buckets_utilization = filled_size / bucket_size + if buckets_utilization < 0.7: + warnings.warn( + f"Only {buckets_utilization:.1%} of buckets are used. " + "Consider decreasing the bucket_cap_mb argument." + ) - def zero_grad(self, set_to_none=True): + def zero_grad(self, set_to_none: bool = False) -> None: """Clear parameter gradients""" # Reset bucket buffers @@ -500,35 +1583,77 @@ def zero_grad(self, set_to_none=True): # Construct views into contiguous grad buffer, if needed if self.contiguous_grad_buffer: - self._grad_buffer.zero_() - for bucket_id in range(len(self.state['buckets'])): - bucket_start = bucket_id * self.bucket_size - bucket_end = bucket_start + self.bucket_size - bucket = self._grads_buckets[bucket_id] - bucket.grads_bucket = self._grad_buffer[bucket_start:bucket_end] + if not self._grad_buffers: + self._init_grad_buffer() + for grad_buffer in self._grad_buffers.values(): + grad_buffer.zero_() + for bucket_id, bucket in enumerate(self.state["buckets"]): + bucket_size = bucket.bucket_size + buffer_start = bucket.contiguous_buffer_offset + buffer_end = buffer_start + bucket_size + grad_buffer = self._grad_buffers[bucket.dtypes()] + self._grads_buckets[bucket_id].grads_bucket = grad_buffer[ + buffer_start:buffer_end + ] + if self.nccl_ub: + shard_size = bucket.shard_size + shard_buffer_start = ( + bucket.contiguous_buffer_offset // self.distributed_size + ) + shard_buffer_end = shard_buffer_start + shard_size + shard_grad_buffer = self._shard_grad_buffers[bucket.dtypes()] + self._grads_buckets[bucket_id].sync_grads_shard = shard_grad_buffer[ + shard_buffer_start:shard_buffer_end + ] # Reset param grads - for group in self.param_groups: - for param in group['params']: - if param.grad is None or set_to_none: + for param in self.parameters(): + with _disable_pre_forward_hook(param): + need_to_zero = True + if set_to_none: param.grad = None - else: + elif self.contiguous_grad_buffer: + bucket_id = self.state[param]["fragments"][0].bucket_id + bucket = self.state["buckets"][bucket_id] + if param.dtype == bucket.grad_sync_dtype and _devices_match( + param.device, self.device + ): + param.grad = self.grad_buffer_view(param) + need_to_zero = False + if need_to_zero and param.grad is not None: param.grad.zero_() # Reset other state - self._grads_generated = set() - self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device) + self._grad_scale.fill_(1.0) self._grad_norm = None + self._dummy_overflow_buf.zero_() - def _grad_copy(self, param): - """Copy parameter gradients to buckets""" + def _grad_copy(self, param: torch.nn.Parameter) -> None: + """Copy parameter gradients to gradient buckets - # Copy param grad to buckets - for fragment in self.state[param]['fragments']: + Initializes gradient buckets if needed. The original parameter + gradient is set to None. + + """ + # Initialize parameter if needed + if "fragments" not in self.state[param]: + for param_group_id, group in enumerate(self.param_groups): + for param_id, param_ in enumerate(group["params"]): + if param is param_: + self._init_param_state(param, param_group_id, param_id) + if "fragments" not in self.state[param]: + raise RuntimeError( + "Could not initialize DistributedFusedAdam with parameter" + ) + + # Copy param grad to buckets + for fragment in self.state[param]["fragments"]: # Get fragment position bucket_id = fragment.bucket_id bucket = self._grads_buckets[bucket_id] + bucket_size = self.state["buckets"][bucket_id].bucket_size + grad_sync_dtype = self.state["buckets"][bucket_id].grad_sync_dtype grad_start, grad_end = fragment.param_range bucket_start, bucket_end = fragment.bucket_range @@ -538,22 +1663,35 @@ def _grad_copy(self, param): bucket.status = self.GradientStatus.PARTIALLY_FILLED # Allocate gradient buffer if needed + if bucket.grads_bucket is None and self.contiguous_grad_buffer: + if not self._grad_buffers: + self._init_grad_buffer() + state_bucket = self.state["buckets"][bucket_id] + buffer_start = state_bucket.contiguous_buffer_offset + buffer_end = buffer_start + bucket_size + grad_buffer = self._grad_buffers[state_bucket.dtypes()] + grad_buffer = grad_buffer[buffer_start:buffer_end] + if ( + bucket.grads_shard is None + or bucket.grads_shard.storage().data_ptr() + != grad_buffer.storage().data_ptr() + ): + bucket.grads_bucket = grad_buffer + bucket.grads_bucket.zero_() if bucket.grads_bucket is None: - if self.contiguous_grad_buffer: - grad_buffer_start = bucket_id * self.bucket_size - grad_buffer_end = grad_buffer_start + self.bucket_size - bucket.grads_bucket = self._grad_buffer[grad_buffer_start:grad_buffer_end] - else: - bucket.grads_bucket = torch.empty( - [self.bucket_size], - dtype=self.grad_sync_dtype, - device=self.device, - ) - bucket.grads_bucket.zero_() + bucket.grads_bucket = torch.zeros( + [bucket_size], + dtype=grad_sync_dtype, + device=self.device, + ) # Copy param grad to bucket if param.grad is not None: - grad_in = param.grad.detach().view(-1)[grad_start:grad_end] + if param.grad.is_contiguous(memory_format=torch.channels_last): + grad_in = param.grad.permute(0, 2, 3, 1) + else: + grad_in = param.grad + grad_in = grad_in.detach().view(-1)[grad_start:grad_end] grad_out = bucket.grads_bucket[bucket_start:bucket_end] if grad_in.data_ptr() != grad_out.data_ptr(): grad_out.add_(grad_in) @@ -561,64 +1699,181 @@ def _grad_copy(self, param): # Free param grad buffer param.grad = None - def grad_buffer_view(self, param): + def _param_copy( + self, + params: Union[torch.nn.Parameter, Iterable[torch.nn.Parameter]], + ) -> None: + """Update parameters with values from parameter buckets + + Synchronizes and deletes parameter buckets as needed. + + """ + + # Get parameter fragments to be synchronized + if isinstance(params, torch.Tensor): + params = [params] + fragments = [] + for param in params: + if "fragments" in self.state[param]: + fragments.extend( + fragment + for fragment in self.state[param]["fragments"] + if fragment.bucket_id in self._params_buckets + ) + + # Return immediately if no fragments need to be synchronized + if not fragments: + return + + # Make sure all needed buckets have been synchronized + buckets = collections.OrderedDict() + for fragment in fragments: + bucket_id = fragment.bucket_id + bucket = self._params_buckets[bucket_id] + buckets[bucket] = bucket.status + if any( + status != self.ParameterStatus.READY for bucket, status in buckets.items() + ): + self._start_bucket_param_sync(buckets.keys()) + self._finish_bucket_param_sync() + + # Copy values from bucket buffers to params + self._param_copy_fragments(fragments) + + # Delete buckets if possible + for fragment in fragments: + bucket_id = fragment.bucket_id + bucket = self._params_buckets[bucket_id] + bucket.params_updated.add(self.parameter(fragment)) + bucket_fragments = self.state["buckets"][bucket_id].fragments + if len(bucket.params_updated) == len(bucket_fragments): + del self._params_buckets[bucket_id] + + def _param_copy_fragments( + self, + fragments: Iterable[ParameterFragment], + ) -> None: + """Update parameter fragments with values from parameter buckets""" + + # Figure out corresponding positions in param buckets and params + buffers_in = [] + buffers_out = [] + for fragment in fragments: + + # Check if fragment needs to be updated + bucket_id = fragment.bucket_id + bucket_start, bucket_end = fragment.bucket_range + param_start, param_end = fragment.param_range + if param_end <= param_start or bucket_id not in self._params_buckets: + continue + + # Corresponding positions in param bucket and param + bucket = self._params_buckets[bucket_id] + param = self.parameter(fragment) + + # Conv with NHWC layout, i.e. shape (N, C, H, W) and stride + # (HWC, 1, WC, C), can't `.view(-1)`. Here to turn it to + # tensor with shape (N, H, W, C) and stride (HWC, WC, C, 1). + if param.is_contiguous(memory_format=torch.channels_last): + param = param.permute(0, 2, 3, 1) + + buffer_in = bucket.params_bucket[bucket_start:bucket_end] + buffer_out = param.detach().view(-1)[param_start:param_end] + + if ( + torch.is_floating_point(buffer_in) + and torch.is_floating_point(buffer_out) + ): + # Cast between floating-point dtypes + buffers_in.append(buffer_in) + buffers_out.append(buffer_out) + else: + # Copy most significant bytes for non-floating-point + # dtypes + # Note: Assume dtypes are little-endian + in_bytes = buffer_in.unsqueeze(-1).view(torch.uint8) + out_bytes = buffer_out.unsqueeze(-1).view(torch.uint8) + copy_size = min(in_bytes.size(-1), out_bytes.size(-1)) + buffers_in.append(in_bytes[..., -copy_size:]) + buffers_out.append(out_bytes[..., -copy_size:]) + if copy_size < out_bytes.size(-1): + out_bytes[..., :-copy_size].zero_() + + # Copy data from parameter buckets to parameters + _multi_tensor_copy( + buffers_in, + buffers_out, + dummy_overflow_buf=self._dummy_overflow_buf, + ) + + def grad_buffer_view(self, param: torch.nn.Parameter) -> torch.Tensor: """Construct view into grad buffer corresponding to param Assumes optimizer is using a contiguous grad buffer. """ + + # Initialize contiguous grad buffers if needed assert self.contiguous_grad_buffer + if not self._grad_buffers: + self._init_grad_buffer() # Figure out corresponding position in grad buffer - param_fragments = self.state[param]['fragments'] - start_bucket_id = param_fragments[0].bucket_id - start_bucket_offset, _ = param_fragments[0].bucket_range - end_bucket_id = param_fragments[-1].bucket_id - _, end_bucket_offset = param_fragments[-1].bucket_range - buffer_start = start_bucket_id * self.bucket_size + start_bucket_offset - buffer_end = end_bucket_id * self.bucket_size + end_bucket_offset + fragment = self.state[param]["fragments"][0] + bucket_id = fragment.bucket_id + bucket = self.state["buckets"][bucket_id] + bucket_start, _ = fragment.bucket_range + buffer_offset = bucket.contiguous_buffer_offset + buffer_start = buffer_offset + bucket_start + buffer_end = buffer_start + param.numel() # Construct view into grad buffer - flat_buffer = self._grad_buffer[buffer_start:buffer_end] - return flat_buffer.detach().view(param.size()) + # Preserve memory format for gradient here + flat_buffer = self._grad_buffers[bucket.dtypes()] + flat_buffer = flat_buffer[buffer_start:buffer_end] + return flat_buffer.detach().as_strided(param.size(), param.stride()) - def _force_bucket_grad_sync(self): + def _force_bucket_grad_sync(self) -> None: """Ensure that all gradient buckets are synchronized""" # Synchronize all unsynchronized buckets - self._finish_bucket_grad_sync() - buckets = [ - bucket - for bucket_id, bucket in sorted(self._grads_buckets.items()) - if bucket.status != self.GradientStatus.READY - ] + Status = self.GradientStatus + buckets = [] + for bucket_id, grads_bucket in sorted(self._grads_buckets.items()): + if grads_bucket.status not in (Status.READY, Status.SYNCING): + buckets.append(grads_bucket) + if grads_bucket.grads_bucket is None: + state_bucket = self.state["buckets"][bucket_id] + grads_bucket.grads_bucket = torch.zeros( + [state_bucket.bucket_size], + dtype=state_bucket.grad_sync_dtype, + device=self.device, + ) if buckets: self._start_bucket_grad_sync(buckets) - self._finish_bucket_grad_sync() + self._finish_bucket_grad_sync() # Fill any unsynchronized gradients with zeros - for bucket_id in range(len(self.state['buckets'])): - bucket = self._grads_buckets[bucket_id] - if bucket.grads_shard is None: - bucket.grads_shard = torch.zeros( - [self.shard_size], - dtype=self.grad_sync_dtype, + for bucket_id in range(len(self.state["buckets"])): + grads_bucket = self._grads_buckets[bucket_id] + if grads_bucket.grads_shard is None: + state_bucket = self.state["buckets"][bucket_id] + grads_bucket.grads_shard = torch.zeros( + [state_bucket.shard_size], + dtype=state_bucket.grad_sync_dtype, device=self.device, ) - # Reset set of generated gradients - self._grads_generated = set() - def _try_start_bucket_grad_sync( - self, - params=[], - ignore_last_bucket=True, - ): - """Launches gradient synchronization if enough buckets are ready + self, + params: Optional[Iterable[torch.nn.Parameter]] = None, + ignore_last_bucket: bool = False, + ) -> None: + """Attempt to launch gradient synchronization - Gradient synchronization is asynchronous. Launches gradient - synchronization if all gradients have been generated or if - there are enough buckets ready to fill pipeline. + Launches gradient synchronization if any bucket has receieved + all its expected gradients. Gradient synchronization is + asynchronous. Arguments: params (iterable): parameters that have had their @@ -631,131 +1886,136 @@ def _try_start_bucket_grad_sync( """ # Register params that have generated grads + if params is None: + params = [] for param in params: - self._grads_generated.add(param) - for fragment in self.state[param]['fragments']: + for fragment in self.state[param]["fragments"]: bucket_id = fragment.bucket_id - bucket_fragments = self.state['buckets'][bucket_id].fragments - is_filled = True - for other_fragment in reversed(bucket_fragments): - param_group_id = other_fragment.param_group_id - param_id = other_fragment.param_id - other_param = self.param_groups[param_group_id]['params'][param_id] - if other_param not in self._grads_generated: - is_filled = False - break - if is_filled: - bucket = self._grads_buckets[bucket_id] - bucket.status = self.GradientStatus.FULLY_FILLED + grads_bucket = self._grads_buckets[bucket_id] + state_bucket = self.state["buckets"][bucket_id] + bucket_fragments = state_bucket.fragments + grads_bucket.grads_generated.add(param) + if len(grads_bucket.grads_generated) == len(bucket_fragments): + grads_bucket.status = self.GradientStatus.FULLY_FILLED + if grads_bucket.grads_bucket is None: + grads_bucket.grads_bucket = torch.zeros( + [state_bucket.bucket_size], + dtype=state_bucket.grad_sync_dtype, + device=self.device, + ) # Launch reductions if enough buckets are ready - if len(self._grads_generated) == self._num_grads: - self._force_bucket_grad_sync() - else: - filled_buckets = [] - for bucket_id, bucket in sorted(self._grads_buckets.items()): - if ignore_last_bucket and bucket_id == len(self.state['buckets'])-1: - continue - if bucket.status == self.GradientStatus.FULLY_FILLED: - filled_buckets.append(bucket) - pipeline_size = _round_to_multiple( - len(filled_buckets), - self.pipeline_size, - ) - if pipeline_size > 0: - self._start_bucket_grad_sync(filled_buckets[:pipeline_size]) + filled_buckets = [] + for bucket_id, bucket in sorted(self._grads_buckets.items()): + if ignore_last_bucket and bucket_id == len(self.state["buckets"]) - 1: + continue + if bucket.status == self.GradientStatus.FULLY_FILLED: + filled_buckets.append(bucket) + if filled_buckets: + self._start_bucket_grad_sync(filled_buckets) - def _start_bucket_grad_sync(self, buckets): + def _start_bucket_grad_sync(self, buckets: List[GradientBucket]) -> None: """Synchronize gradient buckets Gradient synchronization is asynchronous. Involves reduce-scatter over distributed process group and allreduce - over redundant process group. + over redundant process group. Assumes grad bucket buffers are + already initialized. """ - # Call recursively if more buckets than streams - while len(buckets) > self.pipeline_size: - self._start_bucket_grad_sync(buckets[:self.pipeline_size]) - buckets = buckets[self.pipeline_size:] - self._finish_bucket_grad_sync() + # Complete any outstanding grad syncs + # Note: Not needed with contiguous grad buffer since there is + # no memory benefit from eagerly freeing grad buffers. + if not self.contiguous_grad_buffer: + self._finish_bucket_grad_sync() # Reduction operation - if self.average_grad_sync: + if self.average_grad_sync and not self.nccl_ub: reduce_op = torch.distributed.ReduceOp.AVG else: reduce_op = torch.distributed.ReduceOp.SUM - # Reduce gradients - main_stream = torch.cuda.current_stream() - for stream in self._pipeline_streams: - stream.wait_stream(main_stream) - for i, bucket in enumerate(buckets): + # Initialize grad state and buffers + for bucket in buckets: + if bucket.status == self.GradientStatus.SYNCING: + self._finish_bucket_grad_sync() bucket.status = self.GradientStatus.SYNCING - stream = self._pipeline_streams[i % self.pipeline_size] - with torch.cuda.stream(stream): + bucket.grads_generated.clear() + if self.distributed_size == 1: + bucket.sync_grads_shard = bucket.grads_bucket + elif bucket.sync_grads_shard is None: + bucket_size = bucket.grads_bucket.numel() + shard_size = bucket_size // self.distributed_size + bucket.sync_grads_shard = torch.empty( + [shard_size], + dtype=bucket.grads_bucket.dtype, + device=bucket.grads_bucket.device, + ) - # Reduce-scatter over distributed process group - bucket.sync_wait() - if self.distributed_size == 1: - bucket.sync_grads_shard = bucket.grads_bucket - else: - with torch.cuda.stream(main_stream): - bucket.sync_grads_shard = torch.zeros( - [self.shard_size], - dtype=self.grad_sync_dtype, - device=self.device, - ) - grads_bucket_shards = [ - bucket.grads_bucket[i*self.shard_size:(i+1)*self.shard_size] - for i in range(self.distributed_size) - ] - if self._reduce_scatter_no_copy: - no_copy_kwarg = { 'no_copy': True } - else: - no_copy_kwarg = {} - bucket.sync_request = ( - torch.distributed.reduce_scatter( - bucket.sync_grads_shard, - grads_bucket_shards, - op=reduce_op, - group=self.distributed_process_group, - async_op=True, - **no_copy_kwarg, - ) - ) + # Handle case with multiple grad accumulation steps + if bucket.grads_shard is not None: + if bucket.sync_grads_shard.data_ptr() == bucket.grads_shard.data_ptr(): + bucket.grads_shard = bucket.grads_shard.clone() - # All-reduce over redundant process group - # Note: Assuming reduce-scatters are finished in the - # order they are submitted, all-reduces should be - # submitted in a consistent order. There could be race - # conditions if wait doesn't finish in order. - if self.redundant_size > 1: - bucket.sync_wait() - bucket.sync_request = ( - torch.distributed.all_reduce( - bucket.sync_grads_shard, - op=reduce_op, - group=self.redundant_process_group, - async_op=True, + # Side stream for communication + # If new bucket is ready before last bucket communication finishes, use multiple + # communication streams could help pipeline reduce-scatter and all-reduce. + main_stream = torch.cuda.current_stream() + self._last_comm_stream_id = (self._last_comm_stream_id + 1) % len(self._comm_streams) + comm_stream = self._comm_streams[self._last_comm_stream_id] + comm_stream.wait_stream(main_stream) + + # Reduce-scatter over distributed process group + if buckets and self.distributed_size > 1: + with torch.cuda.stream(comm_stream): + group = self.distributed_process_group + with _coalescing_manager(group, self.device, async_ops=True) as cm: + for bucket in buckets: + if self.average_grad_sync and self.nccl_ub: + bucket.grads_bucket /= self.distributed_size + _coalescing_manager_append_work( + cm, + reduce_scatter_tensor( + bucket.sync_grads_shard, + bucket.grads_bucket, + op=reduce_op, + group=group, + async_op=True, + ), ) - ) + cm.wait() + + # All-reduce over redundant process group + if buckets and self.redundant_size > 1: + with torch.cuda.stream(comm_stream): + group = self.redundant_process_group + with _coalescing_manager(group, self.device, async_ops=True) as cm: + for bucket in buckets: + _coalescing_manager_append_work( + cm, + torch.distributed.all_reduce( + bucket.sync_grads_shard, + op=reduce_op, + group=group, + async_op=True, + ), + ) + cm.wait() - def _finish_bucket_grad_sync(self): + def _finish_bucket_grad_sync(self) -> None: """Wait for any gradient synchronizations that are in progress""" + main_stream = torch.cuda.current_stream() + for comm_stream in self._comm_streams: + main_stream.wait_stream(comm_stream) for bucket_id, bucket in sorted(self._grads_buckets.items()): if bucket.status == self.GradientStatus.SYNCING: - - # Finish asynchronous communication - bucket.sync_wait() - # Accumulate gradient in local shard if bucket.grads_shard is None: bucket.grads_shard = bucket.sync_grads_shard else: bucket.grads_shard.add_(bucket.sync_grads_shard) bucket.grads_bucket = None - bucket.sync_grads_shard = None # Reset status bucket.status = self.GradientStatus.READY @@ -763,8 +2023,133 @@ def _finish_bucket_grad_sync(self): # Cached gradient norm has been invalidated self._grad_norm = None + def _try_start_bucket_param_sync( + self, + params: Iterable[torch.nn.Parameter] = None, + ) -> None: + """Attempt to launch parameter synchronization + + Launches parameter synchronization for buckets corresponding + to provided parameters, if needed. If parameters are not + provided and no other synchronizations are in progress, + attempts to find a parameter that still requires + synchronization. Parameter synchronization is asynchronous. + + Arguments: + params (iterable, optional): parameters to synchronize + + """ + + # Default behavior: only launch param sync if no other syncs + # are in progress + if params is None: + params = [] + if any( + bucket.status == self.ParameterStatus.SYNCING + for bucket in self._params_buckets.values() + ): + return + for bucket_id, bucket in self._params_buckets.items(): + if bucket.status == self.ParameterStatus.SHARDED: + params.append( + self.parameter( + self.state["buckets"][bucket_id].fragments[-1] + ) + ) + break + + # Find buckets corresponding to params + bucket_ids = set() + for param in params: + bucket_ids.update( + fragment.bucket_id for fragment in self.state[param]["fragments"] + ) + buckets = [ + self._params_buckets[bucket_id] + for bucket_id in sorted(bucket_ids) + if bucket_id in self._params_buckets + ] + buckets = [ + bucket + for bucket in buckets + if bucket.status == self.ParameterStatus.SHARDED + ] + + # Launch param sync if needed + if buckets: + self._start_bucket_param_sync(buckets) + + def _start_bucket_param_sync(self, buckets: List[ParameterBucket]) -> None: + """Synchronize parameter buckets + + Parameter synchronization is asynchronous. Involves all-gather + over distributed process group. Assumes param shard buffers + are already initialized. + + """ + + # Complete any outstanding param syncs + self._finish_bucket_param_sync() + + # Initialize param state and buffers + buckets = [ + bucket + for bucket in buckets + if bucket.status == self.ParameterStatus.SHARDED + ] + for bucket in buckets: + bucket.status = self.ParameterStatus.SYNCING + if bucket.params_bucket is not None: + pass + elif self.distributed_size == 1: + bucket.params_bucket = bucket.params_shard + else: + shard_size = bucket.params_shard.numel() + bucket_size = shard_size * self.distributed_size + bucket.params_bucket = torch.empty( + [bucket_size], + dtype=bucket.params_shard.dtype, + device=bucket.params_shard.device, + ) + + # Side stream for communication + main_stream = torch.cuda.current_stream() + self._last_comm_stream_id = (self._last_comm_stream_id + 1) % len(self._comm_streams) + comm_stream = self._comm_streams[self._last_comm_stream_id] + comm_stream.wait_stream(main_stream) + + # All-gather over distributed process group + if buckets and self.distributed_size > 1: + with torch.cuda.stream(comm_stream): + group = self.distributed_process_group + with _coalescing_manager(group, self.device, async_ops=True) as cm: + for bucket in buckets: + _coalescing_manager_append_work( + cm, + all_gather_into_tensor( + bucket.params_bucket, + bucket.params_shard, + group=group, + async_op=True, + ), + ) + cm.wait() + + def _finish_bucket_param_sync(self) -> None: + """Wait for any param synchronizations that are in progress""" + main_stream = torch.cuda.current_stream() + for comm_stream in self._comm_streams: + main_stream.wait_stream(comm_stream) + for bucket_id, bucket in self._params_buckets.items(): + if bucket.status == self.ParameterStatus.SYNCING: + bucket.params_shard = None + bucket.status = self.ParameterStatus.READY + @contextlib.contextmanager - def no_sync(self, greedy_grad_copy=False): + def no_sync( + self, + greedy_grad_copy: None = False, + ) -> contextlib.AbstractContextManager: """Disable overlapped gradient synchronization Context manager that is similar to @@ -790,29 +2175,44 @@ def no_sync(self, greedy_grad_copy=False): self.greedy_grad_copy = old_greedy_grad_copy self.overlap_grad_sync = old_overlap_grad_sync - def grad_sync(self): + def grad_sync(self) -> None: """Ensure that all gradients are synchronized""" - for bucket in self.state['buckets']: + for bucket in self.state["buckets"]: for fragment in bucket.fragments: - param_group_id = fragment.param_group_id - param_id = fragment.param_id - param = self.param_groups[param_group_id]['params'][param_id] + param = self.parameter(fragment) if param.grad is not None: self._grad_copy(param) - self._try_start_bucket_grad_sync( - params=[param], - ignore_last_bucket=False, - ) + if not self.contiguous_grad_buffer: + self._try_start_bucket_grad_sync( + params=[param], + ignore_last_bucket=False, + ) self._force_bucket_grad_sync() - def _local_grad_norm(self, parameters=[], norm_type=2.0): + def param_sync(self) -> None: + """Ensure that all parameters are synchronized""" + if self.contiguous_param_buffer: + self._param_copy(self.parameters()) + else: + while self._params_buckets: + bucket_id, bucket = next(iter((self._params_buckets.items()))) + for fragment in reversed(self.state["buckets"][bucket_id].fragments): + self._param_copy(self.parameter(fragment)) + self._params_buckets.clear() + + @torch.no_grad() + def _local_grad_norm( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None, + norm_type: float = 2.0, + ) -> torch.Tensor: """Local contribution to parameter gradient norm Returns square of 2-norm. Other norms are not yet supported. If no parameters are provided, the norm is computed for all parameters in optimizer. Provided parameters are assumed to be - in optimizer. + in optimizer and to require gradients. """ norm_type = float(norm_type) @@ -821,38 +2221,78 @@ def _local_grad_norm(self, parameters=[], norm_type=2.0): # Make sure that gradients have been reduced self.grad_sync() - if not parameters or len(parameters) == self._num_grads: + # Check if provided parameters are subset of all parameters + if parameters is not None: + parameters = list(parameters) + params_set = set(parameters) + all_params_set = set() + for bucket in self.state["buckets"]: + for fragment in bucket.fragments: + all_params_set.add(self.parameter(fragment)) + if not params_set.issubset(all_params_set): + raise RuntimeError( + "Attempted to compute gradient norm for a parameter " + "that is not managed by DistributedFusedAdam" + ) + if params_set == all_params_set: + parameters = None + + # Group grads by dtype + grad_groups = collections.defaultdict(list) + if parameters is None: # Compute norm of all local gradients - dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') - grad_norm_sq = multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [[bucket.grads_shard for bucket in self._grads_buckets.values()]], - False, - )[0] ** 2 + for bucket_id, grads_bucket in self._grads_buckets.items(): + state_bucket = self.state["buckets"][bucket_id] + dtype = state_bucket.grad_sync_dtype + grad_groups[dtype].append(grads_bucket.grads_shard) else: # Compute norm of selected local gradients - grads = [] for param in parameters: - for fragment in self.state[param]['fragments']: - if fragment.in_local_shard: - bucket = self._grads_buckets[fragment.bucket_id] - shard_start, shard_end = fragment.shard_range - grads.append(bucket.grads_shard[shard_start:shard_end]) - if grads: - dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') - grad_norm_sq = multi_tensor_applier( + if "fragments" not in self.state[param]: + continue + for fragment in self.state[param]["fragments"]: + if not fragment.in_local_shard: + continue + shard_start, shard_end = fragment.shard_range + if shard_end <= shard_start: + continue + bucket_id = fragment.bucket_id + grads_bucket = self._grads_buckets[bucket_id] + state_bucket = self.state["buckets"][bucket_id] + grad_groups[state_bucket.grad_sync_dtype].append( + grads_bucket.grads_shard[shard_start:shard_end] + ) + + # Compute norm of each group of grads + grad_norm_sq = None + for grad_group in grad_groups.values(): + grad_group_norm_sq = ( + multi_tensor_applier( amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [grads], + self._dummy_overflow_buf, + [grad_group], False, - )[0] ** 2 + )[0] + ** 2 + ) + if grad_norm_sq is None: + grad_norm_sq = grad_group_norm_sq else: - grad_norm_sq = torch.zeros([1], dtype=torch.float32, device=self.device) - - return grad_norm_sq.detach().view([]) - - def grad_norm(self, parameters=[], norm_type=2.0, force=False): + grad_norm_sq += grad_group_norm_sq + if grad_norm_sq is None: + grad_norm_sq = torch.zeros([], dtype=torch.float32, device=self.device) + + # Interpret norm as scalar + grad_norm_sq = grad_norm_sq.to(dtype=torch.float32, device=self.device) + grad_norm_sq = grad_norm_sq.view([]) + return grad_norm_sq + + def grad_norm( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None, + norm_type: float = 2.0, + force: bool = False, + ) -> torch.Tensor: """Gradient norm of parameters in optimizer The norm is computed over all gradients together, as if they @@ -864,9 +2304,8 @@ def grad_norm(self, parameters=[], norm_type=2.0, force=False): Arguments: parameters (iterable, optional): an iterable of parameters in optimizer (default: all parameters in optimizer). - norm_type (float or int, optional): type of the used - p-norm (default: 2). Only 2-norm is currently - supported. + norm_type (float, optional): type of the used p-norm + (default: 2). Only 2-norm is currently supported. force (bool, optional): ignore cached value and force norm computation (default: False). @@ -884,9 +2323,15 @@ def grad_norm(self, parameters=[], norm_type=2.0, force=False): group=self.distributed_process_group, ) self._grad_norm = grad_norm_sq.sqrt() - return self._grad_norm.detach() - - def clip_grad_norm(self, max_norm, parameters=[], norm_type=2.0): + grad_norm = self._grad_norm * self._grad_scale + return grad_norm.detach() + + def clip_grad_norm( + self, + max_norm: float, + parameters: Optional[Iterable[torch.nn.Parameter]] = None, + norm_type: float = 2.0, + ) -> torch.Tensor: """Clips gradient norm of parameters in optimizer The norm is computed over all gradients together, as if they @@ -898,20 +2343,94 @@ def clip_grad_norm(self, max_norm, parameters=[], norm_type=2.0): communication. Arguments: - max_norm (float or int): max norm of the gradients + max_norm (float): max norm of the gradients parameters (iterable, optional): an iterable of parameters in optimizer (default: all parameters in optimizer). - norm_type (float or int, optional): type of the used + norm_type (float, optional): type of the used p-norm (default: 2) """ assert max_norm > 0 total_norm = self.grad_norm(parameters=parameters, norm_type=norm_type) - inv_clip_coef = (total_norm + 1e-6) / max_norm - self._inv_grad_scale = torch.clamp(inv_clip_coef, min=1.0).view(1) + clip_coef = max_norm / (total_norm + 1e-6) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + self._grad_scale *= clip_coef_clamped return total_norm - def step(self, closure=None, *, grad_scaler=None): + @torch.no_grad + def unscale_grads( + self, + *args: Union[Optional[torch.Tensor], Any], + inv_scale: Optional[torch.Tensor] = None, + grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, + ) -> None: + """Custom unscale function for use by AMP gradient scaler + + Either inv_scale or grad_scaler must be provided, but not + both. If grad_scaler is provided, this is equivalent to + calling its unscale_ function. + + Arguments: + inv_scale (torch.Tensor, optional): factor to multiply + gradients. May be provided either as a kwarg or as the + first positional arg. + grad_scaler (torch.cuda.amp.GradScaler): gradient scaler + (default: None) + + """ + + # inv_scale is either kwarg or first positional arg + if inv_scale is None and len(args) >= 1: + inv_scale = args[0] + + # Check for non-finite values + # Note: We compute gradient norm to check for non-finite + # values. This is more conservative and compute intensive than + # directly checking, but it avoids extra communication if we + # have already computed gradient norm e.g. for gradient + # clipping. + found_inf = torch.logical_not(torch.isfinite(self.grad_norm())) + found_inf_per_device = { found_inf.device: found_inf.float() } + + # Get inv_scale from GradScaler if provided + if grad_scaler is not None and grad_scaler._enabled: + grad_scaler_state = grad_scaler._per_optimizer_states[id(self)] + GradScalerOptState = torch.cuda.amp.grad_scaler.OptState + if grad_scaler_state["stage"] is GradScalerOptState.UNSCALED: + raise RuntimeError( + "unscale_grads has already been called since the last GradScaler update" + ) + if grad_scaler_state["stage"] is GradScalerOptState.STEPPED: + raise RuntimeError( + "unscale_grads is being called after optimizer step" + ) + if grad_scaler._scale is None: + raise RuntimeError( + "Attempted unscale_grads with GradScaler that is missing _scale" + ) + if inv_scale is not None: + raise ValueError( + "unscale_grads is being called with both scale_inv and grad_scaler" + ) + inv_scale = grad_scaler._scale.double().reciprocal() + inv_scale = inv_scale.to(dtype=torch.float32, device=self.device) + grad_scaler_state["found_inf_per_device"] = found_inf_per_device + grad_scaler_state["stage"] = GradScalerOptState.UNSCALED + + # Apply inv_scale to grad_scale + if inv_scale is None: + raise ValueError( + "unscale_grads is being called with neither scale_inv and grad_scaler" + ) + self._grad_scale *= inv_scale.view([]) + return found_inf_per_device + + def step( + self, + closure: Optional[Callable] = None, + *, + grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, + ): """Apply Adam optimizer step Arguments: @@ -927,202 +2446,562 @@ def step(self, closure=None, *, grad_scaler=None): if closure is not None: loss = closure() - # Make sure that gradients have been reduced + # Make sure params are initialized + self.init_params() + + # Make sure that parameters and gradients are synchronized + self.param_sync() self.grad_sync() # Apply gradient scaler if provided - # Note: We compute gradient norm to check for non-finite - # values. This is more conservative and compute intensive than - # directly checking, but it avoids extra communication if we - # have already computed gradient norm e.g. for gradient - # clipping. - if grad_scaler is not None: - grad_norm = self.grad_norm() - found_inf = torch.logical_not(torch.isfinite(grad_norm)) - scaler_state = grad_scaler._per_optimizer_states[id(self)] - scaler_state['found_inf_per_device'] = {found_inf.device: found_inf.float()} - if found_inf.item(): + if grad_scaler is not None and grad_scaler._enabled: + grad_scaler_state = grad_scaler._per_optimizer_states[id(self)] + GradScalerOptState = torch.cuda.amp.grad_scaler.OptState + if grad_scaler_state["stage"] is GradScalerOptState.READY: + self.unscale_grads(grad_scaler=grad_scaler) + found_inf = grad_scaler_state["found_inf_per_device"][self.device] + if self.capturable: + self._dummy_overflow_buf.copy_(found_inf) + elif found_inf.item(): return + self._grad_scale = self._grad_scale.to(dtype=torch.float32, device=self.device) + + # Initialize buffers for param syncs + num_buckets = len(self.state["buckets"]) + for bucket_id in reversed(range(num_buckets)): + self._params_buckets[bucket_id] = self.ParameterBucket() + params_bucket = self._params_buckets[bucket_id] + state_bucket = self.state["buckets"][bucket_id] + shard_size = state_bucket.shard_size + dtype = state_bucket.dtype + param_sync_dtype = state_bucket.param_sync_dtype + + if self.contiguous_param_buffer: + # Construct views into contiguous param buffer + if not self._param_buffers: + self.init_param_buffer() + bucket_size = state_bucket.bucket_size + buffer_start = state_bucket.contiguous_buffer_offset + buffer_end = buffer_start + bucket_size + param_buffer = self._param_buffers[state_bucket.dtypes()] + params_bucket.params_bucket = param_buffer[buffer_start:buffer_end] + bucket_start = self.distributed_rank * shard_size + bucket_end = bucket_start + shard_size + params_bucket.params_shard = params_bucket.params_bucket[ + bucket_start:bucket_end + ] + + # Initialize param shard buffer + if self.with_scaled_states: + # Use FP32 workspace buffer with scaled optimizer state + params_bucket.params_shard = None + elif not param_sync_dtype.is_floating_point: + # Make sure param shard buffer is floating-point + if ( + state_bucket.params_shard is not None + and dtype.is_floating_point + ): + params_bucket.params_shard = state_bucket.params_shard + else: + params_bucket.params_shard = torch.empty( + [shard_size], + dtype=self.dtype, + device=self.device, + ) else: - assert grad_scaler._scale is not None - self._inv_grad_scale *= grad_scaler._scale - inv_grad_scale = self._inv_grad_scale.item() + # Allocate param shard buffer if needed + if params_bucket.params_shard is not None: + pass + elif ( + state_bucket.params_shard is not None + and dtype == param_sync_dtype + ): + params_bucket.params_shard = state_bucket.params_shard + else: + params_bucket.params_shard = torch.empty( + [shard_size], + dtype=param_sync_dtype, + device=self.device, + ) - # Construct workspace buffers - params_bucket_buffers = [ - torch.empty( - [self.bucket_size], - dtype=self.param_sync_dtype, - device=self.device, + # Apply optimizer step + self.state["step"] += 1 if not self.capturable else \ + (self._dummy_overflow_buf != 1).to(torch.int) + overlap_first_bucket = ( + self.distributed_size > 1 + and self.overlap_param_sync + and self.state["buckets"] + ) + if overlap_first_bucket: + # Local step and non-blocking param sync + # Note: Overlap param sync of first buckets with optimizer + # step of remaining buckets. + + # Get buckets containing "first" parameter + first_param = self.parameter( + self.state["buckets"][-1].fragments[-1] ) - for _ in range(self.pipeline_size) - ] - if self.grad_sync_dtype == self.param_sync_dtype: - shard_start = self.distributed_rank * self.shard_size - shard_end = shard_start + self.shard_size - params_copy_buffers = [ - params_bucket[shard_start:shard_end] - for params_bucket in params_bucket_buffers - ] + first_bucket_ids = sorted( + fragment.bucket_id + for fragment in self.state[first_param]["fragments"] + ) + + # Local step and launch param sync for first buckets + self._local_step(first_bucket_ids) + self._start_bucket_param_sync( + self._params_buckets[bucket_id] for bucket_id in first_bucket_ids + ) + + # Local step for remaining buckets + first_bucket_ids = set(first_bucket_ids) + self._local_step( + [ + bucket_id + for bucket_id in range(num_buckets) + if bucket_id not in first_bucket_ids + ] + ) + else: - params_copy_buffers = [ - torch.empty( - [self.shard_size], - dtype=self.grad_sync_dtype, - device=self.device, - ) - for _ in range(self.pipeline_size) - ] + # Local step + self._local_step(list(range(num_buckets))) + + # Synchronize params + if self.distributed_size > 1 and self.overlap_param_sync: + # Asynchronous param sync + self._try_start_bucket_param_sync() + for param in self.parameters(): + param._pre_forward_hook_is_enabled = True + else: + # Blocking param sync + self.param_sync() - # Apply optimizer step to each bucket and synchronize params - self.state['step'] += 1 - main_stream = torch.cuda.current_stream() - for stream in self._pipeline_streams: - stream.wait_stream(main_stream) - for bucket_id in range(len(self.state['buckets'])): - stream_id = bucket_id % self.pipeline_size + return loss - # Bucket buffers - fragments = self.state['buckets'][bucket_id].fragments - shard_start = self.distributed_rank * self.shard_size - shard_end = shard_start + self.shard_size - params_bucket = params_bucket_buffers[stream_id] - params_bucket_shard = params_bucket[shard_start:shard_end] - params_shard = self.state['buckets'][bucket_id].params_shard - params_copy = params_copy_buffers[stream_id] - exp_avg = self.state['buckets'][bucket_id].exp_avg_shard - exp_avg_sq = self.state['buckets'][bucket_id].exp_avg_sq_shard - grads = self._grads_buckets[bucket_id].grads_shard - - # Perform compute on parallel stream - stream = self._pipeline_streams[stream_id] - with torch.cuda.stream(stream): + def _local_step(self, bucket_ids: List[int]) -> None: + """Apply optimizer step to local shard of parameter buckets + + Arguments: + bucket_ids (list): bucket indices + + """ + + # Implementation with scaled optimizer state + if self.with_scaled_states: + self._local_step_with_scaled_states(bucket_ids) + return + + # Optimized implementation with BF16 params and 16-bit param + # remainders + if self.store_param_remainders: + bf16_rem_buckets = set() + for bucket_id in bucket_ids: + state_bucket = self.state["buckets"][bucket_id] + if state_bucket.param_remainders_shard is not None: + bf16_rem_buckets.add(bucket_id) + if bf16_rem_buckets: + self._local_step_with_param_remainders(sorted(bf16_rem_buckets)) + bucket_ids = [ + bucket_id + for bucket_id in bucket_ids + if bucket_id not in bf16_rem_buckets + ] + if not bucket_ids: + return - # Find param fragments in local shard - buffers = collections.defaultdict(list) # p, m, v, g, p_copy - for fragment in fragments: - if fragment.in_local_shard: - param_group_id = fragment.param_group_id - shard_start, shard_end = fragment.shard_range - buffers[param_group_id].append([ - params_shard[shard_start:shard_end], - exp_avg[shard_start:shard_end], - exp_avg_sq[shard_start:shard_end], - grads[shard_start:shard_end], - params_copy[shard_start:shard_end], - ]) - - # Fuse param fragments if possible - if len(buffers) == 1: - group_id = list(buffers.keys())[0] - buffers[group_id] = [( - params_shard, - exp_avg, - exp_avg_sq, - grads, - params_copy, - )] - - # Apply optimizer step to each param group - for group_id, group_buffers in buffers.items(): - - # Get param group configs - group = self.param_groups[group_id] - beta1, beta2 = group['betas'] - bias_correction = 1 if group['bias_correction'] else 0 - eps = group['eps'] - weight_decay = group['weight_decay'] - - # Copy param group configs to GPU - num_fragments = len(group_buffers) - beta1 = torch.full([num_fragments], beta1, dtype=self.dtype, device='cuda') - beta2 = torch.full([num_fragments], beta2, dtype=self.dtype, device='cuda') - bias_correction = torch.full([num_fragments], bias_correction, dtype=torch.int32, device='cuda') - eps = torch.full([num_fragments], eps, dtype=self.dtype, device='cuda') - weight_decay = torch.full([num_fragments], weight_decay, dtype=self.dtype, device='cuda') - - # Apply Adam step - dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') - multi_tensor_applier( - distributed_adam_cuda.multi_tensor_fused_adam, - dummy_overflow_buf, - list(zip(*group_buffers)), - beta1, - beta2, - bias_correction, - eps, - weight_decay, - group['lr'], - inv_grad_scale, - self.state['step'], - 1, # Set to 0 to apply eps inside sqrt + # Find param fragments for each bucket + buffers = collections.defaultdict(list) # p_in, m, v, g, p_out + for bucket_id in bucket_ids: + state_bucket = self.state["buckets"][bucket_id] + grads_bucket = self._grads_buckets[bucket_id] + params_bucket = self._params_buckets[bucket_id] + + # Optimizer state buffers for local shard + fragments = state_bucket.fragments + exp_avg = state_bucket.exp_avg_shard + exp_avg_sq = state_bucket.exp_avg_sq_shard + grads = grads_bucket.grads_shard + params_out = params_bucket.params_shard + + # Find param fragments in local shard + for fragment in fragments: + if not fragment.in_local_shard: + continue + shard_start, shard_end = fragment.shard_range + if shard_end <= shard_start: + continue + shard_range = slice(shard_start, shard_end) + if state_bucket.params_shard is None: + param = self.parameter(fragment) + if param.is_contiguous(memory_format=torch.channels_last): + param = param.permute(0, 2, 3, 1) + param_range = slice(*fragment.shard_param_range) + param_fragment = param.detach().view(-1)[param_range] + param_fragment = param_fragment.to( + dtype=state_bucket.dtype, device=self.device ) + else: + params_shard = state_bucket.params_shard + param_fragment = params_shard[shard_range] + buffers_key = ( + fragment.param_group_id, + state_bucket.dtype, + state_bucket.grad_sync_dtype, + state_bucket.param_sync_dtype, + ) + buffers[buffers_key].append( + [ + param_fragment, + exp_avg[shard_range], + exp_avg_sq[shard_range], + grads[shard_range], + params_out[shard_range], + ] + ) + + # Apply optimizer step to each param group + adam_func = distributed_adam_cuda.multi_tensor_fused_adam_capturable \ + if self.capturable else distributed_adam_cuda.multi_tensor_fused_adam + for (group_id, _, _, _), group_buffers in buffers.items(): + group = self.param_groups[group_id] + beta1, beta2 = group["betas"] + multi_tensor_applier( + adam_func, + self._dummy_overflow_buf, + list(zip(*group_buffers)), + self._grad_scale, + group["lr"], + beta1, + beta2, + group["eps"], + self.state["step"], + 1 if self.adam_w_mode else 0, + 1 if group["bias_correction"] else 0, + group["weight_decay"], + ) - # Cast parameter dtype if needed - if params_copy.data_ptr() != params_bucket_shard.data_ptr(): - params_bucket_shard.copy_(params_copy) + # Make sure param sync buffer has correct dtype + self._check_params_shard_dtypes( + { + bucket_id: self._params_buckets[bucket_id] + for bucket_id in bucket_ids + } + ) - # Allgather updated parameters - if self.distributed_size > 1: - all_params_bucket_shards = [ - params_bucket[i*self.shard_size:(i+1)*self.shard_size] - for i in range(self.distributed_size) + def _local_step_with_param_remainders( + self, + bucket_ids: List[int], + ) -> None: + """Apply optimizer step to local shard of parameter bucket + + This is an experimental implementation that expects + store_params=False and store_param_remainders=True. The + optimizer dtype must be FP32 and the params must all be BF16 + and GPU. + + Arguments: + bucket_ids (list): bucket indices + + """ + + # Find param fragments for each bucket + buffers = collections.defaultdict(list) # p_in, p_rem, m, v, g, p_out + for bucket_id in bucket_ids: + state_bucket = self.state["buckets"][bucket_id] + grads_bucket = self._grads_buckets[bucket_id] + params_bucket = self._params_buckets[bucket_id] + + # State buffers for local shard + fragments = state_bucket.fragments + param_remainders_shard = state_bucket.param_remainders_shard + exp_avg = state_bucket.exp_avg_shard + exp_avg_sq = state_bucket.exp_avg_sq_shard + grads = grads_bucket.grads_shard + params_out = params_bucket.params_shard + + # Find param fragments in local shard + for fragment in fragments: + if not fragment.in_local_shard: + continue + shard_start, shard_end = fragment.shard_range + if shard_end <= shard_start: + continue + shard_range = slice(shard_start, shard_end) + buffers_key = ( + fragment.param_group_id, + state_bucket.grad_sync_dtype, + ) + param = self.parameter(fragment) + param_range = slice(*fragment.shard_param_range) + param_fragment = param.detach().view(-1)[param_range] + param_fragment = param_fragment.to( + dtype=torch.bfloat16, device=self.device + ) + buffers[buffers_key].append( + [ + param_fragment, + param_remainders_shard[shard_range], + exp_avg[shard_range], + exp_avg_sq[shard_range], + grads[shard_range], + params_out[shard_range], ] - if self._all_gather_no_copy: - no_copy_kwarg = { 'no_copy': True } - else: - no_copy_kwarg = {} - torch.distributed.all_gather( - all_params_bucket_shards, - params_bucket_shard, - group=self.distributed_process_group, - **no_copy_kwarg, - ) + ) - # Copy values to param buffers - buffers = collections.defaultdict(list) # param_in, param_out - for fragment in fragments: - param_group_id = fragment.param_group_id - param_id = fragment.param_id - param = self.param_groups[param_group_id]['params'][param_id] - bucket_start, bucket_end = fragment.bucket_range - param_start, param_end = fragment.param_range - param_in = params_bucket[bucket_start:bucket_end] - param_out = param.detach().view(-1)[param_start:param_end] - if param_in.dtype == param_out.dtype: - # Just copy bytes if buffers have same type - param_in = param_in.view(torch.uint8) - param_out = param_out.view(torch.uint8) - buffers[(param.is_cuda, param.dtype)].append( - (param_in, param_out) - ) - for (is_cuda, dtype), dtype_buffers in buffers.items(): - fused_kernel_dtypes = ( - self.param_sync_dtype, - torch.float32, - torch.float16, - torch.uint8, - ) - if is_cuda and dtype in fused_kernel_dtypes: - dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') - multi_tensor_applier( - fused_adam_cuda.maybe_cast_mt, - dummy_overflow_buf, - list(zip(*dtype_buffers)), - ) - else: - for param_in, param_out in dtype_buffers: - param_out.copy_(param_in) + # Apply optimizer step to each param group + for (group_id, _), group_buffers in buffers.items(): + group = self.param_groups[group_id] + beta1, beta2 = group["betas"] + multi_tensor_applier( + distributed_adam_cuda.multi_tensor_fused_adam_with_param_remainders, + self._dummy_overflow_buf, + list(zip(*group_buffers)), + self._grad_scale, + group["lr"], + beta1, + beta2, + group["eps"], + self.state["step"], + 1 if self.adam_w_mode else 0, + 1 if group["bias_correction"] else 0, + group["weight_decay"], + ) - # Synchronize pipeline streams - for stream in self._pipeline_streams: - main_stream.wait_stream(stream) + # Make sure param sync buffer has correct dtype + self._check_params_shard_dtypes( + { + bucket_id: self._params_buckets[bucket_id] + for bucket_id in bucket_ids + } + ) - return loss + @torch.no_grad() + def _local_step_with_scaled_states( + self, + bucket_ids: List[int], + ) -> None: + for bucket_id in bucket_ids: + state_bucket = self.state["buckets"][bucket_id] + grads_bucket = self._grads_buckets[bucket_id] + params_bucket = self._params_buckets[bucket_id] + params_bucket.params_shard = torch.empty_like( + state_bucket.params_shard, + dtype=torch.float32, + ) + + # Find param fragments in local shard + group_buffers = collections.defaultdict(list) # p_in, m, v, g, p_out + scaled_buffers = [] + unscaled_buffers = [] + buffer_scales = [] + for fragment in state_bucket.fragments: + if not fragment.in_local_shard: + continue + shard_start, shard_end = fragment.shard_range + if shard_end <= shard_start: + continue + shard_range = slice(shard_start, shard_end) + param_group_id = fragment.param_group_id + param_id = fragment.param_id + scaled_param = state_bucket.params_shard[shard_range] + scaled_exp_avg = state_bucket.exp_avg_shard[shard_range] + scaled_exp_avg_sq = state_bucket.exp_avg_sq_shard[shard_range] + grads = grads_bucket.grads_shard[shard_range] + param = params_bucket.params_shard[shard_range] + exp_avg = torch.empty_like(scaled_exp_avg, dtype=torch.float32) + exp_avg_sq = torch.empty_like(scaled_exp_avg_sq, dtype=torch.float32) + scales = self._state_scales[(param_group_id, param_id, bucket_id)] + group_buffers[param_group_id].append( + (param, exp_avg, exp_avg_sq, grads, param) + ) + scaled_buffers.extend( + (scaled_param, scaled_exp_avg, scaled_exp_avg_sq) + ) + unscaled_buffers.extend((param, exp_avg, exp_avg_sq)) + buffer_scales.extend( + (scales["param"], scales["exp_avg"], scales["exp_avg_sq"]) + ) + + # Unscale optimizer state + _multi_tensor_copy( + scaled_buffers, + unscaled_buffers, + dummy_overflow_buf=self._dummy_overflow_buf, + ) + for buf, scale in zip(unscaled_buffers, buffer_scales): + buf.mul_(scale) + + # Apply optimizer step to each param group + for group_id, buffers in group_buffers.items(): + group = self.param_groups[group_id] + beta1, beta2 = group["betas"] + multi_tensor_applier( + distributed_adam_cuda.multi_tensor_fused_adam, + self._dummy_overflow_buf, + list(zip(*buffers)), + self._grad_scale, + group["lr"], + beta1, + beta2, + group["eps"], + self.state["step"], + 1 if self.adam_w_mode else 0, + 1 if group["bias_correction"] else 0, + group["weight_decay"], + ) + del group_buffers + + # Make sure param sync buffer has correct dtype + self._check_params_shard_dtypes({bucket_id: params_bucket}) + + # Scale optimizer state + for buf, scale in zip(unscaled_buffers, buffer_scales): + self._apply_state_scale(buf, scale) + _multi_tensor_copy( + unscaled_buffers, + scaled_buffers, + dummy_overflow_buf=self._dummy_overflow_buf, + ) + del scaled_buffers, unscaled_buffers, buffer_scales + + @torch.no_grad() + def _check_params_shard_dtypes( + self, + params_buckets: Dict[int, ParameterBucket], + ) -> None: + """Make sure local shards of parameters are in expected datatypes + + The Adam kernel only supports floating-point datatypes. If we + want to perform parameter synchronization with + non-floating-point dtypes, we need to allocate temporary + buffers that can accommodate the Adam kernel. This function is + responsible for converting these temporary buffers to the + parameter synchronization datatype. + + """ + + # Find param shards that require dtype conversion + buffers_in = [] + buffers_out = [] + for bucket_id, param_bucket in params_buckets.items(): + + # Check if param shard is already in expected dtype + state_bucket = self.state["buckets"][bucket_id] + param_sync_dtype = state_bucket.param_sync_dtype + if param_bucket.params_shard.dtype == param_sync_dtype: + continue + + # Allocate buffer with required dtype + buffer_in = param_bucket.params_shard + buffer_out = torch.empty_like( + param_bucket.params_shard, + dtype=param_sync_dtype, + ) + param_bucket.params_shard = buffer_out + + if ( + torch.is_floating_point(buffer_in) + and torch.is_floating_point(buffer_out) + ): + # Cast between floating-point dtypes + buffers_in.append(buffer_in) + buffers_out.append(buffer_out) + else: + # Copy most significant bytes for non-floating-point + # dtypes + # Note: Assume dtypes are little-endian + in_bytes = buffer_in.unsqueeze(-1).view(torch.uint8) + out_bytes = buffer_out.unsqueeze(-1).view(torch.uint8) + copy_size = min(in_bytes.size(-1), out_bytes.size(-1)) + buffers_in.append(in_bytes[..., -copy_size:]) + buffers_out.append(out_bytes[..., -copy_size:]) + if copy_size < out_bytes.size(-1): + out_bytes[..., :-copy_size].zero_() + + # Perform dtype conversions + _multi_tensor_copy( + buffers_in, + buffers_out, + dummy_overflow_buf=self._dummy_overflow_buf, + ) + + @torch.no_grad() + def _apply_state_scale( + self, + tensor: torch.Tensor, + scale: torch.Tensor, + ) -> None: + """Compute and apply scaling factor for scaled optimizer state - def state_dict(self, gather_on_root=True): + The scaling factor is chosen to maximize the dynamic range + while avoiding numerical overflows. The returned tensors are + the scale (used to unscale the optimizer state) and the + scale-reciprocal (used to generate the scaled optimizer + state). The input tensors are updated in-place. + + """ + if not hasattr(self, "_max_scaled_state"): + self._max_scaled_state = torch.full( + [1], + torch.finfo(self.dtype).max / 2, + dtype=torch.float32, + device=self.device, + ) + min_val, max_val = torch.aminmax(tensor) + absmax = torch.maximum(-min_val, max_val) + absmax = absmax.to(dtype=torch.float32, device=self.device) + torch.div(absmax, self._max_scaled_state, out=scale) + rscale = torch.where(scale > 0, scale.reciprocal(), 0.0) + tensor.mul_(rscale) + + def state_dict( + self, + *, + state_dict_format: Optional[int] = None, + gather_on_root: Optional[bool] = None, + ) -> Optional[dict]: """Get dictionary containing optimizer state + All ranks in the process group must call this function since + it performs communication. The same optimizer state is + returned on all ranks. + + Arguments: + state_dict_format (int, optional): Tag for custom or + deprecated state dict format. + gather_on_root (bool, optional): Option for deprecated v1 + format. + + """ + + # Default state dict format + if state_dict_format is None: + state_dict_format = 2 + + # Construct state dict + state_dict = None + if state_dict_format == 1: + # Deprecated v1 format + kwargs = {} + if gather_on_root is not None: + kwargs["gather_on_root"] = gather_on_root + state_dict = self._state_dict_v1(**kwargs) + elif state_dict_format == 2: + # Default v2 format + state_dict = self._state_dict_v2() + else: + # Unrecognized format + raise ValueError(f"Unrecognized state dict format ({state_dict_format})") + + # Add format tag to state dict + if state_dict is not None: + state_dict["format"] = state_dict_format + + return state_dict + + def _state_dict_v1(self, gather_on_root: bool = True) -> Optional[dict]: + """Get dictionary containing optimizer state (deprecated v1 format) + Default behavior is to perform communication so that the entire optimizer state is returned on the root rank in the process group. In this case, all ranks in the process group @@ -1134,10 +3013,23 @@ def state_dict(self, gather_on_root=True): ranks on the root rank (default: True) """ + warnings.warn( + "Making optimizer state dictionary in deprecated v1 format. " + "Future support is not guaranteed." + ) + if self.with_scaled_states: + raise NotImplementedError( + "Deprecated v1 format does not support scaled state" + ) + state_dict = super().state_dict() if not gather_on_root: return state_dict + # Finish any asynchronous communication + self.grad_sync() + self.param_sync() + # Export local state to byte string state_bytes = io.BytesIO() torch.save(state_dict, state_bytes) @@ -1155,10 +3047,17 @@ def state_dict(self, gather_on_root=True): max_state_size = max(state_sizes) # Construct workspace buffers - chunk_size = self.shard_size * torch.finfo(self.grad_sync_dtype).bits // 8 + chunk_size = ( + self.default_shard_size * torch.finfo(self.grad_sync_dtype).bits // 8 + ) if self.distributed_rank == 0: - gathered_state_bytes = [state_bytes.getvalue()] - gathered_state_bytes.extend(bytearray(size) for size in state_sizes[1:]) + gathered_state_bytes = [ + torch.empty([size], dtype=torch.uint8, device="cpu") + for size in state_sizes + ] + gathered_state_bytes[0].copy_( + torch.frombuffer(state_bytes_view, dtype=torch.uint8) + ) gathered_chunks_buffers = [ torch.empty( [chunk_size * self.distributed_size], @@ -1180,31 +3079,28 @@ def state_dict(self, gather_on_root=True): # Split data into chunks and gather on root rank # Note: Assuming we are using the NCCL backend, communication # must happen on the GPU. We split the data into fixed-size - # chunks so that the GPU memory usage is limited to - # (chunk_size * distributed_size) bytes. - # TODO: Avoid chunking with direct communication between CPUs + # chunks to limit GPU memory usage. main_stream = torch.cuda.current_stream() for stream in self._pipeline_streams: stream.wait_stream(main_stream) for stream_id, offset in enumerate(range(0, max_state_size, chunk_size)): stream_id %= self.pipeline_size - - # Buffers for chunk - if self.distributed_rank == 0: - gathered_chunks = [ - gathered_chunks_buffers[stream_id][i*chunk_size:(i+1)*chunk_size] - for i in range(self.distributed_size) - ] - else: - chunk = chunk_buffers[stream_id] - - # Perform communication on parallel stream stream = self._pipeline_streams[stream_id] with torch.cuda.stream(stream): + # Buffers for chunk + if self.distributed_rank == 0: + gathered_chunks = [ + gathered_chunks_buffers[stream_id][ + i * chunk_size : (i + 1) * chunk_size + ] + for i in range(self.distributed_size) + ] + else: + chunk = chunk_buffers[stream_id] # Copy to GPU if self.distributed_rank != 0 and offset < local_state_size: - local_chunk_size = min(chunk_size, local_state_size-offset) + local_chunk_size = min(chunk_size, local_state_size - offset) chunk[:local_chunk_size].copy_( torch.frombuffer( state_bytes_view, @@ -1216,39 +3112,43 @@ def state_dict(self, gather_on_root=True): ) # Gather on root - if self.distributed_rank == 0: - if self._gather_no_copy: - no_copy_kwarg = { 'no_copy': True } + # Note: Call in main stream to avoid memory pool + # overheads from internal memory allocations in + # gather. + main_stream.wait_stream(stream) + with torch.cuda.stream(main_stream): + if self.distributed_rank == 0: + if self._gather_no_copy: + no_copy_kwarg = {"no_copy": True} + else: + no_copy_kwarg = {} + torch.distributed.gather( + gathered_chunks[0], + gathered_chunks, + dst=self.process_group_root, + group=self.process_group, + **no_copy_kwarg, + ) else: - no_copy_kwarg = {} - torch.distributed.gather( - gathered_chunks[0], - gathered_chunks, - dst=self._process_group_ranks[0], - group=self.process_group, - **no_copy_kwarg, - ) - else: - torch.distributed.gather( - chunk, - dst=self._process_group_ranks[0], - group=self.process_group, - ) + torch.distributed.gather( + chunk, + dst=self.process_group_root, + group=self.process_group, + ) + stream.wait_stream(main_stream) # Copy back to CPU if self.distributed_rank == 0: for rank in range(1, self.distributed_size): - if offset < state_sizes[rank]: - rank_chunk_size = min(chunk_size, state_sizes[rank]-offset) - torch.frombuffer( - gathered_state_bytes[rank], - dtype=torch.uint8, - count=rank_chunk_size, - offset=offset, - ).copy_( - gathered_chunks[rank][:rank_chunk_size], - non_blocking=True, - ) + rank_chunk_start = offset + rank_chunk_end = min(offset + chunk_size, state_sizes[rank]) + rank_chunk_size = rank_chunk_end - rank_chunk_start + if rank_chunk_size > 0: + src = gathered_chunks[rank][:rank_chunk_size] + dst = gathered_state_bytes[rank][ + rank_chunk_start:rank_chunk_end + ] + dst.copy_(src, non_blocking=True) # Synchronize GPU for stream in self._pipeline_streams: @@ -1257,24 +3157,443 @@ def state_dict(self, gather_on_root=True): # Return gathered state data on root rank if self.distributed_rank == 0: - return {'gathered_states': gathered_state_bytes} + return {"gathered_states": gathered_state_bytes} else: return None - def load_state_dict(self, state_dict): + @torch.no_grad() + def _state_dict_v2(self) -> Optional[dict]: + """Get dictionary containing optimizer state (default v2 format) + + All ranks in the process group must call this function since + it performs communication. The same optimizer state is + returned on all ranks. + + """ + + # Make sure params are initialized + self.init_params() + + # Finish any asynchronous communication + self.grad_sync() + self.param_sync() + + # Output tensor format + dtype = torch.float32 if self.with_scaled_states else self.dtype + device = torch.device("cpu") + + # Get state dict from base class + state_dict = super().state_dict() + state_dict["state"] = {"step": state_dict["state"]["step"]} + + # Initialize state dict with CPU buffers + for param in self.parameters(): + # Get param index in state dict + fragment = self.state[param]["fragments"][0] + param_group_id = fragment.param_group_id + param_id = fragment.param_id + index = state_dict["param_groups"][param_group_id]["params"][param_id] + + # Construct CPU buffers with optimizer state + state_dict["state"][index] = dict( + param=torch.zeros_like(param, dtype=dtype, device=device), + exp_avg=torch.zeros_like(param, dtype=dtype, device=device), + exp_avg_sq=torch.zeros_like(param, dtype=dtype, device=device), + ) + + # Workspace buffers for gathering shards on root rank + num_buckets = len(self.state["buckets"]) + max_bucket_size = max(bucket.bucket_size for bucket in self.state["buckets"]) + bucket_buffers = [ + torch.empty( + [max_bucket_size], + dtype=dtype, + device=self.device, + ) + for _ in range(self.pipeline_size) + ] + if self.store_param_remainders: + max_shard_size = max(bucket.shard_size for bucket in self.state["buckets"]) + shard_bf16_buffers = [ + torch.empty([max_shard_size], dtype=torch.bfloat16, device=self.device) + for _ in range(self.pipeline_size) + ] + + # Synchronize streams + main_stream = torch.cuda.current_stream() + for stream in self._pipeline_streams: + stream.wait_stream(main_stream) + + def get_workspace_shard(bucket_id: int) -> torch.Tensor: + """Workspace buffer for local shard""" + bucket = self.state["buckets"][bucket_id] + shard_size = bucket.shard_size + stream_id = bucket_id % self.pipeline_size + shard_range = slice( + shard_size * self.distributed_rank, + shard_size * (self.distributed_rank + 1), + ) + return bucket_buffers[stream_id][shard_range] + + def unscale_shard( + bucket_id: int, + shard: torch.Tensor, + state_key: str, + ) -> torch.Tensor: + """Unscale local shard if needed + + If state buffers are scaled, then the shard is unscaled + and output to a workspace buffer. Otherwise, the shard is + immediately returned. + + """ + if not self.with_scaled_states: + return shard + out = get_workspace_shard(bucket_id) + bucket = self.state["buckets"][bucket_id] + stream_id = bucket_id % self.pipeline_size + stream = self._pipeline_streams[stream_id] + with torch.cuda.stream(stream): + for fragment in bucket.fragments: + if not fragment.in_local_shard: + continue + param_group_id = fragment.param_group_id + param_id = fragment.param_id + shard_range = slice(*fragment.shard_range) + scale = self._state_scales[(param_group_id, param_id, bucket_id)][state_key] + out[shard_range].copy_(shard[shard_range]).mul_(scale) + return out + + def pack_param_shard(bucket_id: int) -> torch.Tensor: + """Pack local shard of param values into contiguous buffer""" + + # Stream objects + stream_id = bucket_id % self.pipeline_size + stream = self._pipeline_streams[stream_id] + + # Bucket objects + bucket = self.state["buckets"][bucket_id] + shard_size = bucket.shard_size + + # Case 1: Param state is already packed + if bucket.params_shard is not None: + return unscale_shard(bucket_id, bucket.params_shard, "param") + + # Case 2: Pack BF16 model params with 16-bit remainders + if bucket.param_remainders_shard is not None: + with torch.cuda.stream(stream): + # Pack bf16 param values + shard_bf16 = shard_bf16_buffers[stream_id][:shard_size] + buffers_in = [] + buffers_out = [] + for fragment in bucket.fragments: + if not fragment.in_local_shard: + continue + param_range = slice(*fragment.shard_param_range) + shard_range = slice(*fragment.shard_range) + param = self.parameter(fragment) + buffers_in.append(param.view(-1)[param_range]) + buffers_out.append(shard_bf16[shard_range]) + _multi_tensor_copy( + buffers_in, + buffers_out, + dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Reconstruct fp32 from bf16 and remainders + shard_fp32 = get_workspace_shard(bucket_id) + _bf16_rem_to_fp32( + shard_bf16, + bucket.param_remainders_shard, + shard_fp32, + ) + return shard_fp32 + + # Case 3: Pack model params + with torch.cuda.stream(stream): + shard = get_workspace_shard(bucket_id) + buffers_in = [] + buffers_out = [] + for fragment in bucket.fragments: + if not fragment.in_local_shard: + continue + param_range = slice(*fragment.shard_param_range) + shard_range = slice(*fragment.shard_range) + param = self.parameter(fragment) + buffers_in.append(param.view(-1)[param_range]) + buffers_out.append(shard[shard_range]) + _multi_tensor_copy( + buffers_in, + buffers_out, + dummy_overflow_buf=self._dummy_overflow_buf, + ) + return shard + + def start_all_gather(bucket_id: int, shard: torch.Tensor) -> None: + """Launch all-gather on bucket shards + + Communication is done on main stream to ensure consistent + ordering. + + """ + + # Stream objects + stream_id = bucket_id % self.pipeline_size + stream = self._pipeline_streams[stream_id] + + # Workspace buffer + bucket = self.state["buckets"][bucket_id] + bucket_size = bucket.bucket_size + bucket_buffer = bucket_buffers[stream_id][:bucket_size] + + # All-gather shards + main_stream.wait_stream(stream) + all_gather_into_tensor( + bucket_buffer, + shard, + group=self.distributed_process_group, + ) + stream.wait_stream(main_stream) + + def finish_all_gather(bucket_id: int, state_dict_key: str) -> None: + """Finish all-gather on bucket shards + + Data is copied into state dict CPU buffers. + + Splitting the NCCL all-gather and the CPU memcpys into + separate stages helps achieve good overlap when kernel + launches are serialized with + CUDA_DEVICE_MAX_CONNECTIONS=1. In particular, the pipeline + calls start_all_gather(bucket_id+1) before + finish_all_gather(bucket_id). + + """ + + # Stream objects + stream_id = bucket_id % self.pipeline_size + stream = self._pipeline_streams[stream_id] + + # Bucket objects + bucket = self.state["buckets"][bucket_id] + bucket_size = bucket.bucket_size + bucket_buffer = bucket_buffers[stream_id][:bucket_size] + + # Update state dict + with torch.cuda.stream(stream): + for fragment in bucket.fragments: + param_range = slice(*fragment.param_range) + bucket_range = slice(*fragment.bucket_range) + param_group_id = fragment.param_group_id + param_id = fragment.param_id + index = state_dict["param_groups"][param_group_id]["params"][ + param_id + ] + state_buffer = state_dict["state"][index][state_dict_key] + state_fragment = state_buffer.view(-1)[param_range] + bucket_fragment = bucket_buffer[bucket_range] + state_fragment.copy_(bucket_fragment, non_blocking=True) + + # All-gather param state + for bucket_id in range(num_buckets): + shard = pack_param_shard(bucket_id) + start_all_gather(bucket_id, shard) + if bucket_id > 0: + finish_all_gather(bucket_id - 1, "param") + if bucket_id == num_buckets - 1: + finish_all_gather(bucket_id, "param") + + # All-gather exp_avg state + for bucket_id in range(num_buckets): + shard = unscale_shard( + bucket_id, + self.state["buckets"][bucket_id].exp_avg_shard, + "exp_avg", + ) + start_all_gather(bucket_id, shard) + if bucket_id > 0: + finish_all_gather(bucket_id - 1, "exp_avg") + if bucket_id == num_buckets - 1: + finish_all_gather(bucket_id, "exp_avg") + + # All-gather exp_avg_sq state + for bucket_id in range(num_buckets): + shard = unscale_shard( + bucket_id, + self.state["buckets"][bucket_id].exp_avg_sq_shard, + "exp_avg_sq", + ) + start_all_gather(bucket_id, shard) + if bucket_id > 0: + finish_all_gather(bucket_id - 1, "exp_avg_sq") + if bucket_id == num_buckets - 1: + finish_all_gather(bucket_id, "exp_avg_sq") + + # Synchronize GPU and return + for stream in self._pipeline_streams: + main_stream.wait_stream(stream) + main_stream.synchronize() + return state_dict + + def load_state_dict(self, state_dict: dict) -> None: """Load optimizer state""" - # State dict contains state for all ranks - if 'gathered_states' in state_dict: + # Figure out state dict format + state_dict_format = state_dict.pop("format", None) + if state_dict_format is None: + if "buckets" in state_dict or "gathered_states" in state_dict: + state_dict_format = 1 + else: + state_dict_format = 2 + + # Load state dict + if state_dict_format == 1: + # Deprecated v1 format + self._load_state_dict_v1(state_dict) + elif state_dict_format == 2: + # Default v2 format + self._load_state_dict_v2(state_dict) + else: + # Unrecognized format + raise ValueError(f"Unrecognized state dict format ({state_dict_format})") + + def _load_state_dict_v1(self, state_dict: dict) -> None: + """Load optimizer state (deprecated v1 format) + + Parallel configuration (e.g. process group sizes) and + optimizer options must match between saving and loading the + optimizer state. + """ + warnings.warn( + "Loading checkpoint in deprecated v1 format. " + "Future support is not guaranteed." + ) + if self.with_scaled_states: + raise NotImplementedError( + "Deprecated v1 format does not support scaled state" + ) + + # Get state dict for current rank + if "gathered_states" in state_dict: # Deallocate distributed optimizer state to reduce GPU # memory usage - if 'buckets' in self.state: - del self.state['buckets'] + if "buckets" in self.state: + del self.state["buckets"] # Get state for current rank and parse byte string - state_bytes = state_dict['gathered_states'][self.distributed_rank] - state_bytes = io.BytesIO(state_bytes) + state_bytes = state_dict["gathered_states"][self.distributed_rank] + state_bytes = io.BytesIO(state_bytes.numpy()) state_dict = torch.load(state_bytes) - return super().load_state_dict(state_dict) + # Load state dict + super().load_state_dict(state_dict) + + # Handle old state dicts without per-bucket dtypes + for bucket in self.state["buckets"]: + if getattr(bucket, "dtype", None) is None: + bucket.dtype = self.dtype + if getattr(bucket, "grad_sync_dtype", None) is None: + bucket.grad_sync_dtype = self.grad_sync_dtype + if getattr(bucket, "param_sync_dtype", None) is None: + bucket.param_sync_dtype = self.param_sync_dtype + + if bucket.params_shard is not None: + bucket.params_shard = bucket.params_shard.to(self.device) + if bucket.param_remainders_shard is not None: + bucket.param_remainders_shard = bucket.param_remainders_shard.to(self.device) + bucket.exp_avg_shard = bucket.exp_avg_shard.to(self.device) + bucket.exp_avg_sq_shard = bucket.exp_avg_sq_shard.to(self.device) + + @torch.no_grad() + def _load_state_dict_v2(self, state_dict: dict) -> None: + """Load optimizer state (default v2 format) + + The parallel configuration and optimizer options are allowed + to differ between saving and loading the model. + + """ + + # Make sure params are initialized + self.init_params() + + # Finish any asynchronous communication + self.grad_sync() + self.param_sync() + + # Load general state + # Note: State includes bucketing scheme (e.g. + # self.state["buckets"] and self.state[param]["fragments"]). + # This was needed for v1 checkpoints, but not for v2. As a + # kludge, we temporarily set state to dummy dict to avoid + # messing up the bucketing scheme. + state = self.state + self.state = {} + super().load_state_dict( + { + "state": {}, + "param_groups": state_dict["param_groups"], + } + ) + self.state = state + self.state["step"] = state_dict["state"]["step"] + + # Load state for each param + for param in self.parameters(): + # Get param index in state dict + fragment = self.state[param]["fragments"][0] + param_id = fragment.param_id + param_group_id = fragment.param_group_id + index = state_dict["param_groups"][param_group_id]["params"][param_id] + + # Buffers in state dict + param_state = state_dict["state"][index]["param"].view(-1) + exp_avg = state_dict["state"][index]["exp_avg"].view(-1) + exp_avg_sq = state_dict["state"][index]["exp_avg_sq"].view(-1) + + # Copy to local shard of state buckets + for fragment in self.state[param]["fragments"]: + if not fragment.in_local_shard: + continue + bucket_id = fragment.bucket_id + bucket = self.state["buckets"][bucket_id] + param_range = slice(*fragment.shard_param_range) + shard_range = slice(*fragment.shard_range) + if self.with_scaled_states: + scales = self._state_scales[(param_group_id, param_id, bucket_id)] + temp = torch.empty_like( + param_state[param_range], + dtype=torch.float32, + device=self.device, + ) + temp.copy_(param_state[param_range], non_blocking=True) + self._apply_state_scale(temp, scales["param"]) + bucket.params_shard[shard_range].copy_(temp) + temp.copy_(exp_avg[param_range], non_blocking=True) + self._apply_state_scale(temp, scales["exp_avg"]) + bucket.exp_avg_shard[shard_range].copy_(temp) + temp.copy_(exp_avg_sq[param_range], non_blocking=True) + self._apply_state_scale(temp, scales["exp_avg_sq"]) + bucket.exp_avg_sq_shard[shard_range].copy_(temp) + else: + if bucket.params_shard is not None: + bucket.params_shard[shard_range].copy_( + param_state[param_range], + non_blocking=True, + ) + if bucket.param_remainders_shard is not None: + param_state_int16 = param_state.unsqueeze(-1).view(torch.int16) + bucket.param_remainders_shard[shard_range].copy_( + param_state_int16[param_range, 0], + non_blocking=True, + ) + bucket.exp_avg_shard[shard_range].copy_( + exp_avg[param_range], + non_blocking=True, + ) + bucket.exp_avg_sq_shard[shard_range].copy_( + exp_avg_sq[param_range], + non_blocking=True, + ) + + # Synchronize GPU + torch.cuda.current_stream().synchronize() \ No newline at end of file diff --git a/apex/contrib/optimizers/distributed_fused_lamb.py b/apex/contrib/optimizers/distributed_fused_lamb.py index b5ec47b1a..0925bd04a 100644 --- a/apex/contrib/optimizers/distributed_fused_lamb.py +++ b/apex/contrib/optimizers/distributed_fused_lamb.py @@ -1,5 +1,6 @@ import os import math +import inspect import torch import importlib import amp_C @@ -7,36 +8,49 @@ import torch.distributed.distributed_c10d as c10d +# Fallback to private fields if using older PyTorch version +try: + import torch.distributed.distributed_c10d.get_process_group_ranks +except ImportError: + def get_process_group_ranks(group): + return list(c10d._pg_group_ranks[group].keys()) + +_make_nccl_premul_sum = getattr(torch.distributed, "_make_nccl_premul_sum", None) +# Ref: https://github.com/pytorch/pytorch/pull/81272 +if _make_nccl_premul_sum is None: + if hasattr(torch.distributed, "make_nccl_premul_sum"): + _make_nccl_premul_sum = torch.distributed.make_nccl_premul_sum + class DistributedFusedLAMB(torch.optim.Optimizer): """Implements LAMB algorithm. - + Currently GPU-only. Requires Apex to be installed via ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``. - + This version of fused LAMB implements 2 fusions. - + * Fusion of the LAMB update's elementwise operations * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. - + :class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer:: - + opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....) ... opt.step() - + :class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp, you may choose any ``opt_level``:: - + opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....) model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") ... opt.step() - + In general, ``opt_level="O1"`` is recommended. - + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. - + Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups. @@ -61,7 +75,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): weight decay parameter (default: False) step_supports_amp_scaling(boolean, optional): whether to use customized gradient unscaling logic (default: True) - + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962 .. _On the Convergence of Adam and Beyond: @@ -82,8 +96,8 @@ def add(self, idx): def __init__(self, params, lr=1e-3, bias_correction = True, grad_averaging=True, - betas=(0.9, 0.999), eps=1e-8, - weight_decay=0., max_grad_norm=0., + betas=(0.9, 0.999), eps=1e-8, + weight_decay=0., max_grad_norm=0., adam_w_mode=True, use_nvlamb=False, step_supports_amp_scaling=True, overlap_reductions=True, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, @@ -123,13 +137,13 @@ def __init__(self, params, self._verbose = verbose self._clip_after_ar = clip_after_ar self._full_ar = full_ar - self._fuse_scale = fuse_scale + self._fuse_scale = fuse_scale self._L2_grad_norm = None self._set_flat_param_view = set_param_views_to_flat_buffer self._skip_ag = skip_allgather self._fused_norm = fused_norm if not clip_after_ar else False self._current_process_group = c10d._get_default_group() - self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys()) + self._available_ranks = get_process_group_ranks(self._current_process_group) self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size self._world_size = torch.distributed.get_world_size() self._num_groups = self._world_size // self._group_size @@ -143,8 +157,18 @@ def __init__(self, params, # Master weight, moment, gradient buffers self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None - import inspect - assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option" + # Check if collectives have no_copy option + self._reduce_scatter_no_copy = ( + 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args + ) + self._all_gather_no_copy = ( + 'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args + ) + + if "reduce_scatter_tensor" not in dir(torch.distributed): + torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base + if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base self._num_rs_pg = dwu_num_rs_pg self._num_ar_pg = dwu_num_ar_pg @@ -376,14 +400,17 @@ def __shardify(p): list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks] list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks] return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards - def _flat_split_no_shards(p): - def __blockify(p): - return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)] - def __chunkify(p): - return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)] - list_of_blocks = __blockify(self._flat_grads) - list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks] - return list_of_blocks, list_of_list_of_chunks + + # note(crcrpar): the function below doesn't seem to be used at all. + # def _flat_split_no_shards(p): + # def __blockify(p): + # return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)] + # def __chunkify(p): + # return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)] + # list_of_blocks = __blockify(self._flat_grads) + # list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks] + # return list_of_blocks, list_of_list_of_chunks + def _full_packed_split(p): def __shardify(p): return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)] @@ -439,7 +466,7 @@ def _split_assign(shards): def _lazy_init_stage2(self): if self._lazy_init_stage2_done: return - if not self._set_flat_param_view: + if not self._set_flat_param_view: # reversing is needed for overlapping allreduce and backprop, but currently not supported for flat param view self._param_order.order.reverse() @@ -576,7 +603,7 @@ def set_is_accumulation_step(self, is_accumulation_step): def set_last_step(self, last_step): self._last_step = last_step - + def _get_flush_block(self): flush_block = [] if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]: @@ -595,19 +622,19 @@ def _get_flush_block(self): def _full_all_reduce_scale(self, block_id, scale): works = [None]*self._num_chunks - if self._clip_after_ar: + if self._clip_after_ar: for chunk_id in range(self._num_chunks): glob_chunk_id = block_id * self._num_chunks + chunk_id ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg] ar_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(ar_stream): - works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=torch.distributed.make_nccl_premul_sum((scale,))) + works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=_make_nccl_premul_sum(scale)) else: glob_chunk_id = block_id ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg] ar_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(ar_stream): - works0 = torch.distributed.all_reduce(self._flat_grads_blocks[block_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=torch.distributed.make_nccl_premul_sum((scale,))) + works0 = torch.distributed.all_reduce(self._flat_grads_blocks[block_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=_make_nccl_premul_sum(scale)) for i in range(self._num_chunks): works[i]=works0 self._reductions_works[block_id] = works @@ -634,7 +661,23 @@ def _reduce_scatter_and_all_reduce_scale(self, block_id, scale): rs_stream.wait_stream(torch.cuda.current_stream()) rs_stream.wait_stream(self._l2_grad_norm_st) with torch.cuda.stream(rs_stream): - works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True,op=torch.distributed.make_nccl_premul_sum((scale,))) + if self._reduce_scatter_no_copy: + works[chunk_id] = torch.distributed.reduce_scatter( + output=self._fp16_g_chunks[block_id][chunk_id], + input_list=self._flat_grads_shards[block_id][chunk_id], + group=self._rs_pg[glob_chunk_id%self._num_rs_pg], + async_op=True, + no_copy=True, + op=_make_nccl_premul_sum(scale), + ) + else: + works[chunk_id] = torch.distributed.reduce_scatter_tensor( + output=self._fp16_g_chunks[block_id][chunk_id], + input=self._flat_grads_chunks[block_id][chunk_id], + group=self._rs_pg[glob_chunk_id%self._num_rs_pg], + async_op=True, + op=_make_nccl_premul_sum(scale), + ) # Reduction across nodes for each rank if self._num_groups > 1: @@ -656,7 +699,21 @@ def _reduce_scatter_and_all_reduce(self, block_id): rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg] rs_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(rs_stream): - works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True) + if self._reduce_scatter_no_copy: + works[chunk_id] = torch.distributed.reduce_scatter( + output=self._fp16_g_chunks[block_id][chunk_id], + input_list=self._flat_grads_shards[block_id][chunk_id], + group=self._rs_pg[glob_chunk_id%self._num_rs_pg], + async_op=True, + no_copy=True, + ) + else: + works[chunk_id] = torch.distributed.reduce_scatter_tensor( + output = self._fp16_g_chunks[block_id][chunk_id], + input = self._flat_grads_chunks[block_id][chunk_id], + group = self._rs_pg[glob_chunk_id%self._num_rs_pg], + async_op = True, + ) # Reduction across nodes for each rank if self._num_groups > 1: @@ -719,7 +776,7 @@ def _pipeline_block_reductions(self, block_id): if self._fuse_scale: self._full_all_reduce_scale(block_id, scale) else: - self._full_all_reduce(block_id) + self._full_all_reduce(block_id) else: if self._fuse_scale: self._reduce_scatter_and_all_reduce_scale(block_id, scale) @@ -809,13 +866,37 @@ def _pipeline_step(self): global_grad_norm, self._use_nvlamb) if not self._skip_ag: - # allgather chunking is currently not supported for clip after allreduce + # allgather chunking is currently not supported for clip after allreduce if not self._clip_after_ar: for block in range(self._num_blocks): for chunk in range(self._num_chunks): - torch.distributed.all_gather(self._new_params2_shards[block][chunk], self._fp16_p_chunks[block][chunk], group=self._ag_pg[0], no_copy=True) + if self._all_gather_no_copy: + torch.distributed.all_gather( + tensor_list = self._new_params2_shards[block][chunk], + tensor = self._fp16_p_chunks[block][chunk], + group = self._ag_pg[0], + no_copy = True, + ) + else: + torch.distributed.all_gather_into_tensor( + output_tensor = self._new_params2_blocks[block], + input_tensor = self._fp16_p_chunks[block][chunk], + group = self._ag_pg[0], + ) else: - torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True) + if self._all_gather_no_copy: + torch.distributed.all_gather( + tensor_list = self._new_params_mega_shards, + tensor = self._fp16_p, + group = self._ag_pg[0], + no_copy = True, + ) + else: + torch.distributed.all_gather_into_tensor( + output_tensor = self._new_params, + input_tensor = self._fp16_p, + group = self._ag_pg[0], + ) def _flatten_grad_mt(self, scale): if len(self._grads_fp16) > 0: @@ -931,7 +1012,7 @@ def step(self, closure=None, grad_scaler=None): fused_adam_cuda.maybe_cast_mt, self._overflow_buf, self._packed_flat_to_model_params_fp32) - + torch.cuda.current_stream().wait_stream(self._completion_st) self._reductions_works = [None]*self._num_blocks @@ -977,4 +1058,4 @@ def load_state_dict(self, state_dict): self._fp32_p = state_dict['fp32_p'].to(device="cuda") self._fp32_m = state_dict['fp32_m'].to(device="cuda") self._fp32_v = state_dict['fp32_v'].to(device="cuda") - self._resume_from_checkpoint = True + self._resume_from_checkpoint = True \ No newline at end of file diff --git a/apex/contrib/peer_memory/__init__.py b/apex/contrib/peer_memory/__init__.py index 367dc5854..8d6fa5480 100644 --- a/apex/contrib/peer_memory/__init__.py +++ b/apex/contrib/peer_memory/__init__.py @@ -1,2 +1,3 @@ from .peer_memory import PeerMemoryPool from .peer_halo_exchanger_1d import PeerHaloExchanger1d + diff --git a/apex/contrib/peer_memory/peer_halo_exchange_module_tests.py b/apex/contrib/peer_memory/peer_halo_exchange_module_tests.py index dd77856e3..bd85354af 100644 --- a/apex/contrib/peer_memory/peer_halo_exchange_module_tests.py +++ b/apex/contrib/peer_memory/peer_halo_exchange_module_tests.py @@ -40,8 +40,9 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli btm_out_halo = y[:,:,:,W:W+half_halo] btm_inp_halo = y[:,:,:,W+half_halo:W+2*half_halo] - top_out_halo = top_out_halo.clone(memory_format=torch.preserve_format) - btm_out_halo = btm_out_halo.clone(memory_format=torch.preserve_format) + mf = torch.channels_last if y.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format + top_out_halo = top_out_halo.contiguous() + btm_out_halo = btm_out_halo.contiguous() top_inp_halos = [torch.empty_like(top_out_halo) for _ in range(peer_group_size)] torch.distributed.all_gather(top_inp_halos, top_out_halo) @@ -49,8 +50,14 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli torch.distributed.all_gather(btm_inp_halos, btm_out_halo) top_rank = (peer_rank + peer_group_size - 1) % peer_group_size btm_rank = (peer_rank + 1) % peer_group_size - top_inp_halo.copy_(btm_inp_halos[top_rank]) - btm_inp_halo.copy_(top_inp_halos[btm_rank]) + if peer_rank == 0: + top_inp_halo.zero_() + else: + top_inp_halo.copy_(btm_inp_halos[top_rank].to(memory_format=mf)) + if peer_rank == peer_group_size-1: + btm_inp_halo.zero_() + else: + btm_inp_halo.copy_(top_inp_halos[btm_rank].to(memory_format=mf)) def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, memory_format, H_split, num_steps, numSM=1): @@ -141,12 +148,13 @@ def main(): rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() torch.cuda.set_device(rank) - pool = PeerMemoryPool(rank, world_size, world_size, 64*1024, 2*1024*1024) + peer_ranks = [i for i in range(world_size)] + pool = PeerMemoryPool(64*1024, 2*1024*1024, peer_ranks) num_steps = 100 half_halo = 1 - halo_ex = PeerHaloExchanger1d(rank, world_size, pool, half_halo) + halo_ex = PeerHaloExchanger1d(peer_ranks, rank, pool, half_halo) H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex,num_steps) W_split_tests(1,64,200,336, half_halo,rank,world_size,halo_ex,num_steps) diff --git a/apex/contrib/peer_memory/peer_halo_exchanger_1d.py b/apex/contrib/peer_memory/peer_halo_exchanger_1d.py index 33db83c06..cc25693ce 100644 --- a/apex/contrib/peer_memory/peer_halo_exchanger_1d.py +++ b/apex/contrib/peer_memory/peer_halo_exchanger_1d.py @@ -3,9 +3,15 @@ import peer_memory_cuda as pm class PeerHaloExchanger1d: - def __init__(self, rank, peer_group_size, peer_pool, half_halo): - self.peer_group_size = peer_group_size - self.peer_rank = rank % peer_group_size + def __init__(self, ranks, rank_in_group, peer_pool, half_halo): + self.peer_group_size = len(ranks) + self.ranks = ranks + self.peer_rank = rank_in_group + self.low_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size + self.high_neighbor = (self.peer_rank + 1) % self.peer_group_size + self.low_zero = True if self.peer_rank == 0 else False + self.high_zero = True if self.peer_rank == self.peer_group_size - 1 else False + self.peer_pool = peer_pool self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False) self.signals[self.peer_rank].zero_() @@ -17,45 +23,43 @@ def __call__(self, y, H_split=True, explicit_nhwc=False, numSM=1, diagnostics=Fa if explicit_nhwc: _, Hs, _, _ = list(y.shape) H = Hs - 2*self.half_halo - top_out_halo = y[:,self.half_halo:2*self.half_halo,:,:] - top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True) - top_inp_halo = y[:,:self.half_halo,:,:] - btm_out_halo = y[:,H:H+self.half_halo,:,:] - btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True) - btm_inp_halo = y[:,H+self.half_halo:H+2*self.half_halo,:,:] + low_out_halo = y[:,self.half_halo:2*self.half_halo,:,:] + low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, False, True) + low_inp_halo = y[:,:self.half_halo,:,:] + high_out_halo = y[:,H:H+self.half_halo,:,:] + high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, False, True) + high_inp_halo = y[:,H+self.half_halo:H+2*self.half_halo,:,:] else: _, _, Hs, _ = list(y.shape) H = Hs - 2*self.half_halo - top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:] - top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True) - top_inp_halo = y[:,:,:self.half_halo,:] - btm_out_halo = y[:,:,H:H+self.half_halo,:] - btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True) - btm_inp_halo = y[:,:,H+self.half_halo:H+2*self.half_halo,:] + low_out_halo = y[:,:,self.half_halo:2*self.half_halo,:] + low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, channels_last, True) + low_inp_halo = y[:,:,:self.half_halo,:] + high_out_halo = y[:,:,H:H+self.half_halo,:] + high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, channels_last, True) + high_inp_halo = y[:,:,H+self.half_halo:H+2*self.half_halo,:] else: if explicit_nhwc: _, _, Ws, _ = list(y.shape) W = Ws - 2*self.half_halo - top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:] - top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True) - top_inp_halo = y[:,:,:self.half_halo,:] - btm_out_halo = y[:,:,W:W+self.half_halo,:] - btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True) - btm_inp_halo = y[:,:,W+self.half_halo:W+2*self.half_halo,:] + low_out_halo = y[:,:,self.half_halo:2*self.half_halo,:] + low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, False, True) + low_inp_halo = y[:,:,:self.half_halo,:] + high_out_halo = y[:,:,W:W+self.half_halo,:] + high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, False, True) + high_inp_halo = y[:,:,W+self.half_halo:W+2*self.half_halo,:] else: _, _, _, Ws = list(y.shape) W = Ws - 2*self.half_halo - top_out_halo = y[:,:,:,self.half_halo:2*self.half_halo] - top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True) - top_inp_halo = y[:,:,:,:self.half_halo] - btm_out_halo = y[:,:,:,W:W+self.half_halo] - btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True) - btm_inp_halo = y[:,:,:,W+self.half_halo:W+2*self.half_halo] - top_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size - btm_neighbor = (self.peer_rank + 1) % self.peer_group_size + low_out_halo = y[:,:,:,self.half_halo:2*self.half_halo] + low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, channels_last, True) + low_inp_halo = y[:,:,:,:self.half_halo] + high_out_halo = y[:,:,:,W:W+self.half_halo] + high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, channels_last, True) + high_inp_halo = y[:,:,:,W+self.half_halo:W+2*self.half_halo] pm.push_pull_halos_1d( diagnostics, explicit_nhwc, numSM, - top_out_halo, top_tx[self.peer_rank], btm_tx[top_neighbor], top_inp_halo, - btm_out_halo, btm_tx[self.peer_rank], top_tx[btm_neighbor], btm_inp_halo, - self.signals[top_neighbor], self.signals[btm_neighbor], self.signals[self.peer_rank] + self.low_zero, low_out_halo, low_tx[self.peer_rank], high_tx[self.low_neighbor], low_inp_halo, + self.high_zero, high_out_halo, high_tx[self.peer_rank], low_tx[self.high_neighbor], high_inp_halo, + self.signals[self.low_neighbor], self.signals[self.high_neighbor], self.signals[self.peer_rank] ) diff --git a/apex/contrib/sparsity/sparse_masklib.py b/apex/contrib/sparsity/sparse_masklib.py index ed42d0456..48deb633c 100644 --- a/apex/contrib/sparsity/sparse_masklib.py +++ b/apex/contrib/sparsity/sparse_masklib.py @@ -29,8 +29,8 @@ def compute_valid_1d_patterns(m,n): if m==4 and n==2 and valid_m4n2_1d_patterns is not None: return valid_m4n2_1d_patterns patterns = torch.zeros(m) patterns[:n] = 1 - valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist())))) - if m == 4 and n == 2: valid_m4n2_1d_patterns = valid_patterns + valid_patterns = torch.tensor(list(set(permutations(patterns.tolist())))) + if m == 4 and n == 2: valid_m4n2_1d_patterns = valid_patterns return valid_patterns """ m:n 1d structured best """ @@ -109,10 +109,10 @@ def compute_valid_2d_patterns(m,n): patterns[:n] = 1 patterns = list(set(permutations(patterns.tolist()))) patterns = patterns + patterns - patterns = torch.Tensor(list(set(permutations(patterns,m)))) + patterns = torch.empty(list(set(permutations(patterns,m)))) valid = ((patterns.sum(dim=1) <= n).sum(dim=1) == m).nonzero().view(-1) - valid_patterns = torch.Tensor(valid.shape[0],m,m) + valid_patterns = torch.empty(valid.shape[0],m,m) valid_patterns[:] = patterns[valid[:]] if m == 4 and n == 2: valid_m4n2_2d_patterns = valid_patterns diff --git a/apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py b/apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py index 350257c5c..f2a4492d2 100644 --- a/apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py +++ b/apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py @@ -54,11 +54,11 @@ def setUp(self, seed=0): self.conv_stride, self.conv_pad)) def test_conv_bias_relu(self): - with torch.cuda.amp.autocast(dtype=torch.half): + with torch.amp.autocast(device_type="cuda", dtype=torch.half): out = ConvBiasReLU(self.x, self.conv1.weight, self.conv1.bias.reshape(1, -1, 1, 1), self.conv_pad, self.conv_stride) loss = (out.float()**2).sum() / out.numel() loss.backward() - with torch.cuda.amp.autocast(dtype=torch.half): + with torch.amp.autocast(device_type="cuda", dtype=torch.half): out_ = F.relu(self.conv1_(self.x_)) loss_ = (out_**2).sum() / out_.numel() loss_.backward() @@ -69,12 +69,12 @@ def test_conv_bias_relu(self): self.assertTrue(torch.allclose(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) def test_conv_bias(self): - with torch.cuda.amp.autocast(dtype=torch.half): + with torch.amp.autocast(device_type="cuda", dtype=torch.half): out = ConvBias(self.x, self.conv1.weight, self.conv1.bias.reshape(1, -1, 1, 1), self.conv_pad, self.conv_stride) loss = (out.float()**2).sum() / out.numel() loss.backward() - with torch.cuda.amp.autocast(dtype=torch.half): + with torch.amp.autocast(device_type="cuda", dtype=torch.half): out_ = self.conv1_(self.x_) loss_ = (out_**2).sum() / out_.numel() loss_.backward() @@ -85,11 +85,11 @@ def test_conv_bias(self): self.assertTrue(torch.allclose(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) def test_conv_bias_mask_relu(self): - with torch.cuda.amp.autocast(dtype=torch.half): + with torch.amp.autocast(device_type="cuda", dtype=torch.half): out = ConvBiasMaskReLU(self.x, self.conv1.weight, self.conv1.bias.reshape(1, -1, 1, 1), self.mask, self.conv_pad, self.conv_stride) loss = (out.float()**2).sum() / out.numel() loss.backward() - with torch.cuda.amp.autocast(dtype=torch.half): + with torch.amp.autocast(device_type="cuda", dtype=torch.half): out_ = F.relu(self.conv1_(self.x_) * self.mask_) loss_ = (out_**2).sum() / out_.numel() loss_.backward() @@ -100,6 +100,41 @@ def test_conv_bias_mask_relu(self): self.assertTrue(torch.allclose(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) + def test_conv_bias_retinanet(self): + # RetinaNet configuration + batch_size = 32 + in_channels = 256 + out_channels = 2376 + h, w = 100, 100 + + # Input in NHWC format with HALF precision + x = torch.randn(batch_size, in_channels, h, w).cuda()\ + .to(memory_format=torch.channels_last).half() + x_ = x.clone() + x.requires_grad_() + x_.requires_grad_() + + # Conv layer + conv = torch.nn.Conv2d(in_channels, out_channels, 3, + stride=1, padding=1).cuda()\ + .to(memory_format=torch.channels_last) + conv_ = copy.deepcopy(conv) + + # Test with FP16 + with torch.amp.autocast(device_type="cuda", dtype=torch.half): + out = ConvBias(x, conv.weight, conv.bias.reshape(1, -1, 1, 1), 1, 1) + loss = (out.float()**2).sum() / out.numel() + loss.backward() + + # Reference with FP16 + with torch.amp.autocast(device_type="cuda", dtype=torch.half): + out_ = conv_(x_) + loss_ = (out_**2).sum() / out_.numel() + loss_.backward() + + self.assertTrue(torch.allclose(out, out_, atol=1e-2, rtol=1e-2)) + + if __name__ == '__main__': unittest.main() diff --git a/apex/contrib/test/fused_dense/test_gelu.py b/apex/contrib/test/fused_dense/test_gelu.py new file mode 100644 index 000000000..9ff36d5ca --- /dev/null +++ b/apex/contrib/test/fused_dense/test_gelu.py @@ -0,0 +1,46 @@ +from apex import FusedDenseGeluDense +import torch +import torch.nn.functional as F + +batch_size = 4 +in_features = 3 +intermediate_features = 3 +out_features = 2 + +#tst_dtype = torch.float8_e4m3 +# tst_dtype = torch.float8_e5m2 +tst_dtype = torch.float16 + +# I = torch.randn(batch_size, in_features, dtype=tst_dtype, device='cuda') +I = torch.tensor([[1., 2. , 3., 4.], + [1., 2. , 3., 4.], + [1., 2. , 3., 4.], + [1., 2. , 3., 4.], + [1., 2. , 3., 4.]],dtype=tst_dtype, device='cuda') + +# W = torch.randn(out_features, in_features, dtype=tst_dtype, device='cuda') +W = torch.tensor([[1., 1. , 1. , 1. ], + [2., 2. , 2. , 2. ], + [3., 3. , 3. , 3. ]],dtype=tst_dtype, device='cuda') + +# b = torch.randn(in_features, dtype=tst_dtype, device='cuda') +b = torch.tensor([1, 1, 1], dtype=tst_dtype, device='cuda') + +print("Torch-A:\n", I) +print("Torch-B:\n", W) +print("Torch-b:\n", b) + +C = torch.matmul(I, W.t())+b +gelu_output = F.gelu(C) +print("Torch-C:\n", C) +print("Torch-Geli:\n", gelu_output) + +denseGlue = FusedDenseGeluDense.fused_dense_gelu_dense_function(in_features, intermediate_features, out_features) +denseGlue.to(dtype=tst_dtype) +denseGlue.cuda() +y_tst = denseGlue(I) + +print("Torch-aC:\n", aC) +print("GELU tensor:\n", gelu_output) + + diff --git a/apex/contrib/test/fused_dense/test_half.py b/apex/contrib/test/fused_dense/test_half.py new file mode 100644 index 000000000..1f67d2c6e --- /dev/null +++ b/apex/contrib/test/fused_dense/test_half.py @@ -0,0 +1,23 @@ +from apex import fused_dense +import torch + +batch_size = 5 +in_features = 4 +out_features = 3 + +tst_dtype = torch.float8_e5m2 + +I = torch.randn(batch_size, in_features, dtype=tst_dtype, device='cuda') + +W = torch.randn(in_features, out_features, dtype=tst_dtype, device='cuda') + +b = torch.randn(out_features, dtype=tst_dtype, device='cuda') + +print("Torch-A:\n", I) +print("Torch-B:\n", W) +print("Torch-b:\n", b) + + +aC = fused_dense.fused_dense_function(I, W, b) +print("Torch-aC:\n", aC) +torch.testing.assert_close(C, aC, atol=1e-3, rtol=1e-3, equal_nan=True) diff --git a/apex/contrib/test/groupbn/test_groupbn.py b/apex/contrib/test/groupbn/test_groupbn.py new file mode 100644 index 000000000..3df79175b --- /dev/null +++ b/apex/contrib/test/groupbn/test_groupbn.py @@ -0,0 +1,185 @@ +import torch +import unittest +import numpy as np +import random +from apex.contrib.groupbn.batch_norm import BatchNorm2d_NHWC + +def generate_uniform_tensor(size, np_dtype, pyt_dtype, device): + array = None + while array is None or np.isnan(array).any(): + array = np.random.uniform(low=-1.0, high=1.0, size=size).astype(np_dtype) + return torch.from_numpy(array).to(device).to(pyt_dtype) + +def to_channels_last(tensor): + return tensor.permute(0, 2, 3, 1).contiguous() + +def to_channels_first(tensor): + return tensor.permute(0, 3, 1, 2).contiguous() + +class Bn(torch.nn.BatchNorm2d): + def __init__(self, planes, mode): + super(Bn, self).__init__(planes, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + self.mode = mode + + def forward(self, x, z=None): + out = super().forward(x) + if self.mode == 'bn_add_relu': + out = out.add_(z) + if self.mode != 'bn': + out = out.relu_() + return out + +def bn_nhwc_bwd_ref(grad_y, x, mu, ivar, gamma): + sum_dim_c = (0, 1, 2) + grad_y_f32 = grad_y.float() + x_f32 = x.float() + N = x.shape[0] * x.shape[1] * x.shape[2] # nhw + ones = torch.ones(x.shape, dtype=torch.float32, device='cuda') + + xmu = x_f32 - mu + xhat = xmu * ivar + + dbias = torch.sum(grad_y_f32, dim=sum_dim_c) + + dscale = torch.sum(grad_y_f32 * xhat, dim=sum_dim_c) + + dx1 = (gamma * ivar) / N + dx2 = (N * grad_y_f32) - (dbias * ones) + dx3 = -xhat * dscale + dx = dx1 * (dx2 + dx3) + dx = dx.half() + return dx, dscale, dbias + +class TestGroupBN(unittest.TestCase): + + def setUp(self, seed=5, verbose=False): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + self.verbose = verbose + + def test_bn(self): + self.run_group_bn('bn') + + def test_bn_relu(self): + self.run_group_bn('bn_relu') + + def test_bn_add_relu(self): + self.run_group_bn('bn_add_relu') + + def run_group_bn(self, mode): + if self.verbose: + print('Running {}'.format(mode)) + + tensor_sizes = [ + (120, 64, 75, 75), + (120, 128, 38, 38)] + + for i in range(len(tensor_sizes)): + tensor_size = tensor_sizes[i] + num_channels = tensor_size[1] + + # Create input data + input_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda') + np.save('input.npy', input_data.detach().cpu().numpy()) + input_data.requires_grad = True + + gbn_input = torch.from_numpy(np.load('input.npy')).cuda().half() + gbn_input.requires_grad = True + + residual_data = None + gbn_residual_data = None + if mode == 'bn': + fuse_relu = False + else: + fuse_relu = True + if mode == 'bn_add_relu': + residual_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda') + gbn_residual_data = to_channels_last(residual_data) + + bn_grad = generate_uniform_tensor(input_data.shape, np.float16, torch.half, 'cuda') + + # Create models + batchnorm_model = Bn(num_channels, mode).cuda() + group_batchnorm = BatchNorm2d_NHWC(num_channels, fuse_relu=fuse_relu, bn_group=1,torch_channels_last=False).cuda() + + # Run reference forward + bn_output = batchnorm_model(input_data, residual_data) + + # Run GBN forward + gbn_input_data = to_channels_last(gbn_input) + gbn_output = group_batchnorm(gbn_input_data, gbn_residual_data) + + torch.cuda.synchronize() + + # Run reference backward + # (Use the same input and parameters as GBN) + gbn_grad = to_channels_last(bn_grad) + grad = gbn_grad.clone().detach() + input_data = torch.from_numpy(np.load('input.npy')).cuda().half() + input_data = to_channels_last(input_data) + if mode != 'bn': + grad[gbn_output <= 0] = 0 + bn_output_grad, _, _ = bn_nhwc_bwd_ref( \ + grad, + input_data, + group_batchnorm.minibatch_mean, + group_batchnorm.minibatch_riv, + group_batchnorm.weight) + bn_output_grad = to_channels_first(bn_output_grad) + + # Run GBN backward + gbn_output.backward(gbn_grad) + torch.cuda.synchronize() + + gbn_output = to_channels_first(gbn_output) + gbn_output_grad = gbn_input.grad.detach().clone().cpu() + + ########################## Validate results ########################## + if self.verbose: + print('Validate activation') + self.validate(bn_output.shape, bn_output, gbn_output) + if self.verbose: + print('Validate grad') + self.validate(bn_output_grad.shape, bn_output_grad, gbn_output_grad, is_grad=True) + + def validate(self, tensors, output_ref, output_test, is_grad=False): + output_ref = output_ref.detach().cpu().numpy() + output_test = output_test.detach().cpu().numpy() + + if self.verbose: + print('>>> tensor_size\t{}'.format(tensors)) + print("sum_output_ref {}, isnan {}, max {}, min {}".format( + np.sum(output_ref, dtype=float), np.isnan(output_ref).any(), np.max(output_ref), np.min(output_ref))) + print("sum_output_test {}, isnan {}, max {}, min {}".format( + np.sum(output_test, dtype=float), np.isnan(output_test).any(), np.max(output_test), np.min(output_test))) + + ret = np.array_equal(output_ref, output_test) + if not ret: + ret_allclose = np.allclose( + output_ref, output_test, rtol=1e-3, atol=1e-3, equal_nan=True) + if self.verbose: + print('{}\tshape {}\tidentical {}\tclose {}'.format('cpu/gpu', tensors, ret, ret_allclose)) + output_ref = output_ref.flatten() + output_test = output_test.flatten() + if not ret: + sub = np.absolute(output_ref - output_test) + norm_diff = np.average(sub) + rel = np.divide(sub, np.absolute(output_ref)) + rel[rel == np.inf] = 0 + max_abs_idx = np.argmax(sub) + max_rel_idx = np.argmax(rel) + if self.verbose: + print('max_diff {}, max_rel_diff {}, norm_diff {}'.format(np.max(sub), np.max(rel), np.average(sub))) + print('max_abs pair [{}] {} {}'.format(max_abs_idx, output_ref[max_abs_idx], output_test[max_abs_idx])) + print('max_rel pair [{}] {} {}'.format(max_rel_idx, output_ref[max_rel_idx], output_test[max_rel_idx])) + + result = ret or ret_allclose or (is_grad and norm_diff < 1e-4) + + if self.verbose: + print("Result {}".format("PASS" if result else "FAIL")) + + self.assertTrue(result) + +if __name__ == '__main__': + unittest.main() diff --git a/apex/contrib/test/groupbn/test_groupbn_channel_last.py b/apex/contrib/test/groupbn/test_groupbn_channel_last.py new file mode 100644 index 000000000..5ae36e33a --- /dev/null +++ b/apex/contrib/test/groupbn/test_groupbn_channel_last.py @@ -0,0 +1,194 @@ +import torch +import unittest +import numpy as np +import random +from apex.contrib.groupbn.batch_norm import BatchNorm2d_NHWC + +def generate_uniform_tensor(size, np_dtype, pyt_dtype, device): + array = None + while array is None or np.isnan(array).any(): + array = np.random.uniform(low=-1.0, high=1.0, size=size).astype(np_dtype) + return torch.from_numpy(array).to(device).to(pyt_dtype) + +def to_channels_last(tensor): + #return tensor.permute(0, 2, 3, 1).contiguous() + return tensor.to(memory_format = torch.channels_last) + +def to_channels_first(tensor): + #return tensor.permute(0, 3, 1, 2).contiguous() + return tensor.to(memory_format = torch.contiguous_format) + +class Bn(torch.nn.BatchNorm2d): + def __init__(self, planes, mode): + super(Bn, self).__init__(planes, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + self.mode = mode + + def forward(self, x, z=None): + out = super().forward(x) + if self.mode == 'bn_add_relu': + out = out.add_(z) + if self.mode != 'bn': + out = out.relu_() + return out + +def bn_nhwc_bwd_ref(grad_y, x, mu, ivar, gamma): + grad_y = grad_y.permute(0, 2, 3, 1).contiguous() + x = x.permute(0, 2, 3, 1).contiguous() + sum_dim_c = (0, 1, 2) + grad_y_f32 = grad_y.float() + x_f32 = x.float() + N = x.shape[0] * x.shape[1] * x.shape[2] # nhw + ones = torch.ones(x.shape, dtype=torch.float32, device='cuda') + + xmu = x_f32 - mu + + xhat = xmu * ivar + dbias = torch.sum(grad_y_f32, dim=sum_dim_c) + + dscale = torch.sum(grad_y_f32 * xhat, dim=sum_dim_c) + + dx1 = (gamma * ivar) / N + dx2 = (N * grad_y_f32) - (dbias * ones) + dx3 = -xhat * dscale + dx23 = dx2 + dx3 + dx = dx1 * (dx23) + dx = dx.half() + dx = dx.permute(0, 3, 1, 2).contiguous() + return dx, dscale, dbias + +class TestGroupBNChannelLast(unittest.TestCase): + + def setUp(self, seed=5, verbose=False): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + self.verbose = verbose + + def test_bn_channel_last(self): + self.run_group_bn_channel_last('bn') + + def test_bn_relu_channel_last(self): + self.run_group_bn_channel_last('bn_relu') + + def test_bn_add_relu_channel_last(self): + self.run_group_bn_channel_last('bn_add_relu') + + def run_group_bn_channel_last(self, mode): + if self.verbose: + print('Running {}'.format(mode)) + + tensor_sizes = [ + (120, 64, 75, 75), + (120, 128, 38, 38)] + + for i in range(len(tensor_sizes)): + tensor_size = tensor_sizes[i] + num_channels = tensor_size[1] + + # Create input data + input_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda') + np.save('input.npy', input_data.detach().cpu().numpy()) + input_data.requires_grad = True + + gbn_input = torch.from_numpy(np.load('input.npy')).cuda().half() + gbn_input.requires_grad = True + + residual_data = None + gbn_residual_data = None + if mode == 'bn': + fuse_relu = False + else: + fuse_relu = True + if mode == 'bn_add_relu': + residual_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda') + gbn_residual_data = to_channels_last(residual_data) + + bn_grad = generate_uniform_tensor(input_data.shape, np.float16, torch.half, 'cuda') + + # Create models + batchnorm_model = Bn(num_channels, mode).cuda() + group_batchnorm = BatchNorm2d_NHWC(num_channels, fuse_relu=fuse_relu, bn_group=1, torch_channels_last=True).cuda() + + # Run reference forward + bn_output = batchnorm_model(input_data, residual_data) + + # Run GBN forward + gbn_input_data = to_channels_last(gbn_input) + #gbn_input_data = gbn_input + gbn_output = group_batchnorm(gbn_input_data, gbn_residual_data) + + torch.cuda.synchronize() + + # Run reference backward + # (Use the same input and parameters as GBN) + gbn_grad = to_channels_last(bn_grad) + #gbn_grad = bn_grad + grad = gbn_grad.clone().detach() + input_data = torch.from_numpy(np.load('input.npy')).cuda().half() + input_data = to_channels_last(input_data) + if mode != 'bn': + grad[gbn_output <= 0] = 0 + bn_output_grad, _, _ = bn_nhwc_bwd_ref( \ + grad, + input_data, + group_batchnorm.minibatch_mean, + group_batchnorm.minibatch_riv, + group_batchnorm.weight) + bn_output_grad = to_channels_first(bn_output_grad) + + # Run GBN backward + gbn_output.backward(gbn_grad) + torch.cuda.synchronize() + + gbn_output = to_channels_first(gbn_output) + gbn_output_grad = gbn_input.grad.detach().clone().cpu() + + ########################## Validate results ########################## + if self.verbose: + print('Validate activation') + self.validate(bn_output.shape, bn_output, gbn_output) + if self.verbose: + print('Validate grad') + self.validate(bn_output_grad.shape, bn_output_grad, gbn_output_grad, is_grad=True) + + def validate(self, tensors, output_ref, output_test, is_grad=False): + output_ref = output_ref.detach().cpu().numpy() + output_test = output_test.detach().cpu().numpy() + + if self.verbose: + print('>>> tensor_size\t{}'.format(tensors)) + print("sum_output_ref {}, isnan {}, max {}, min {}".format( + np.sum(output_ref, dtype=float), np.isnan(output_ref).any(), np.max(output_ref), np.min(output_ref))) + print("sum_output_test {}, isnan {}, max {}, min {}".format( + np.sum(output_test, dtype=float), np.isnan(output_test).any(), np.max(output_test), np.min(output_test))) + + ret = np.array_equal(output_ref, output_test) + if not ret: + ret_allclose = np.allclose( + output_ref, output_test, rtol=1e-3, atol=1e-3, equal_nan=True) + if self.verbose: + print('{}\tshape {}\tidentical {}\tclose {}'.format('cpu/gpu', tensors, ret, ret_allclose)) + output_ref = output_ref.flatten() + output_test = output_test.flatten() + if not ret: + sub = np.absolute(output_ref - output_test) + norm_diff = np.average(sub) + rel = np.divide(sub, np.absolute(output_ref)) + rel[rel == np.inf] = 0 + max_abs_idx = np.argmax(sub) + max_rel_idx = np.argmax(rel) + if self.verbose: + print('max_diff {}, max_rel_diff {}, norm_diff {}'.format(np.max(sub), np.max(rel), np.average(sub))) + print('max_abs pair [{}] {} {}'.format(max_abs_idx, output_ref[max_abs_idx], output_test[max_abs_idx])) + print('max_rel pair [{}] {} {}'.format(max_rel_idx, output_ref[max_rel_idx], output_test[max_rel_idx])) + + result = ret or ret_allclose or (is_grad and norm_diff < 1e-4) + + if self.verbose: + print("Result {}".format("PASS" if result else "FAIL")) + + self.assertTrue(result) + +if __name__ == '__main__': + unittest.main() + diff --git a/apex/contrib/test/index_mul_2d/test_index_mul_2d.py b/apex/contrib/test/index_mul_2d/test_index_mul_2d.py new file mode 100644 index 000000000..d8f37ea3c --- /dev/null +++ b/apex/contrib/test/index_mul_2d/test_index_mul_2d.py @@ -0,0 +1,106 @@ +import random +import unittest + +import torch +import torch.nn.functional as F + +HAS_INDEX_MUL_2D_RELU = None +try: + from apex.contrib.index_mul_2d import index_mul_2d +except ImportError as e: + HAS_INDEX_MUL_2D_RELU = False +else: + HAS_INDEX_MUL_2D_RELU = True + + +@unittest.skipIf(not HAS_INDEX_MUL_2D_RELU, "`apex.contrib.index_mul_2d` is not found.") +class IndexMul2dTest(unittest.TestCase): + def setUp(self, seed=0): + torch.manual_seed(seed) + + self.input1_size = random.randint(1, 1000) + self.input2_size = random.randint(1, 100000) + self.feature_size = random.randint(1, 256) + + self.input1_float = torch.randn(size=(self.input1_size, self.feature_size),).cuda() + self.input2_float = torch.randn(size=(self.input2_size, self.feature_size),).cuda() + self.index1 = torch.randint(low=0, high=self.input1_size, size=(self.input2_size,)).cuda() + + self.input1_float_ = self.input1_float.clone() + self.input2_float_ = self.input2_float.clone() + + self.input1_float.requires_grad_() + self.input1_float_.requires_grad_() + self.input2_float.requires_grad_() + self.input2_float_.requires_grad_() + + self.input1_half = torch.randn(size=(self.input1_size, self.feature_size),).cuda().half() + self.input2_half = torch.randn(size=(self.input2_size, self.feature_size),).cuda().half() + + self.input1_half_ = self.input1_half.clone() + self.input2_half_ = self.input2_half.clone() + + self.input1_half.requires_grad_() + self.input2_half.requires_grad_() + self.input1_half_.requires_grad_() + self.input2_half_.requires_grad_() + + def test_index_mul_float(self): + out = index_mul_2d(self.input1_float, self.input2_float, self.index1) + energy = (out.float()**2).sum() / out.numel() + force = torch.autograd.grad( + energy, + self.input1_float, + grad_outputs=torch.ones_like(energy), + create_graph=True, + )[0] + loss = (out.float()**2).sum() / out.numel() + (force.float()**2).sum() + loss.backward() + + out_ = self.input1_float_[self.index1] * self.input2_float_ + energy_ = (out_.float()**2).sum() / out.numel() + force_ = torch.autograd.grad( + energy_, + self.input1_float_, + grad_outputs=torch.ones_like(energy), + create_graph=True, + )[0] + loss = (out_.float()**2).sum() / out_.numel() + (force_.float()**2).sum() + loss.backward() + + self.assertTrue(torch.allclose(self.input1_float, self.input1_float_, atol=1e-3, rtol=1e-3, equal_nan=True)) + self.assertTrue(torch.allclose(self.input2_float, self.input2_float_, atol=1e-3, rtol=1e-3, equal_nan=True)) + self.assertTrue(torch.allclose(self.input1_float.grad, self.input1_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) + self.assertTrue(torch.allclose(self.input2_float.grad, self.input2_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) + + def test_index_mul_half(self): + out = index_mul_2d(self.input1_half, self.input2_half, self.index1) + energy = (out.float()**2).sum() / out.numel() + force = torch.autograd.grad( + energy, + self.input1_half, + grad_outputs=torch.ones_like(energy), + create_graph=True, + )[0] + loss = (out.float()**2).sum() / out.numel() + (force.float()**2).sum() + loss.backward() + + out_ = self.input1_half_[self.index1] * self.input2_half_ + energy_ = (out_.float()**2).sum() / out.numel() + force_ = torch.autograd.grad( + energy_, + self.input1_half_, + grad_outputs=torch.ones_like(energy), + create_graph=True, + )[0] + loss = (out_.float()**2).sum() / out_.numel() + (force_.float()**2).sum() + loss.backward() + + self.assertTrue(torch.allclose(self.input1_half, self.input1_half_, atol=1e-3, rtol=1e-3, equal_nan=True)) + self.assertTrue(torch.allclose(self.input2_half, self.input2_half_, atol=1e-3, rtol=1e-3, equal_nan=True)) + self.assertTrue(torch.allclose(self.input1_half.grad, self.input1_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) + self.assertTrue(torch.allclose(self.input2_half.grad, self.input2_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) + +if __name__ == '__main__': + unittest.main() + diff --git a/apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py b/apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py index f37e5005f..836fe8433 100644 --- a/apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py +++ b/apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py @@ -40,37 +40,37 @@ def setUp(self, seed=1234): impl='fast') self.tst_layer.cuda().half() self.tst_layer.reset_parameters() - + self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) def test_encdec_multihead_attn(self) : + grads = torch.randn_like(self.tst_inputs_q) + ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, self.ref_inputs_k, self.ref_inputs_k, - key_padding_mask=None, - need_weights=False, + key_padding_mask=None, + need_weights=False, attn_mask=None, is_training=True) - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, - self.tst_inputs_k, + tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, self.tst_inputs_k, - key_padding_mask=None, - need_weights=False, + self.tst_inputs_k, + key_padding_mask=None, + need_weights=False, attn_mask=None, is_training=True) + + self.ref_inputs_q.backward(grads) + self.tst_inputs_q.backward(grads) + self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) - - with torch.no_grad(): - ref_grads = torch.randn_like(ref_outputs) - tst_grads = ref_grads.clone() - ref_outputs.backward(ref_grads) - tst_outputs.backward(tst_grads) self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3)) def test_encdec_multihead_attn_time_mask(self) : diff --git a/apex/contrib/test/multihead_attn/test_self_multihead_attn.py b/apex/contrib/test/multihead_attn/test_self_multihead_attn.py index b1b9f96f5..10d779feb 100644 --- a/apex/contrib/test/multihead_attn/test_self_multihead_attn.py +++ b/apex/contrib/test/multihead_attn/test_self_multihead_attn.py @@ -15,34 +15,36 @@ def setUp(self, seed=1234): self.heads = 16 self.dropout_prob = 0.0 - self.ref_layer = SelfMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=False, + self.ref_layer = SelfMultiheadAttn(self.hidden_dim, + self.heads, + dropout=self.dropout_prob, + bias=False, + include_norm_add=False, impl='default') self.ref_layer.cuda().half() self.ref_layer.reset_parameters() - self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) # Reset seed so parameters are identical torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - - self.tst_layer = SelfMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=False, + + self.tst_layer = SelfMultiheadAttn(self.hidden_dim, + self.heads, + dropout=self.dropout_prob, + bias=False, + include_norm_add=False, impl='fast') self.tst_layer.cuda().half() self.tst_layer.reset_parameters() - - self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + + self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - def test_self_multihead_attn(self) : + def test_self_multihead_attn(self): + grads = torch.randn_like(self.tst_inputs) + ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, self.ref_inputs, self.ref_inputs, @@ -59,15 +61,11 @@ def test_self_multihead_attn(self) : attn_mask=None, is_training=True) + self.ref_inputs.backward(grads) + self.tst_inputs.backward(grads) + self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) - - with torch.no_grad(): - ref_grads = torch.randn_like(self.tst_inputs) - tst_grads = ref_grads.clone() - - ref_outputs.backward(ref_grads) - tst_outputs.backward(tst_grads) self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)) def test_self_multihead_attn_time_mask(self) : @@ -75,23 +73,23 @@ def test_self_multihead_attn_time_mask(self) : time_mask_byte= torch.triu(torch.ones(self.tst_inputs.size(0), self.tst_inputs.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1) time_mask_bool= time_mask_byte.to(torch.bool) - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, - self.ref_inputs, + ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, self.ref_inputs, - key_padding_mask=None, - need_weights=False, + self.ref_inputs, + key_padding_mask=None, + need_weights=False, attn_mask=time_mask_bool, is_training=True) - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, - self.tst_inputs, + tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, + self.tst_inputs, self.tst_inputs, - key_padding_mask=None, - need_weights=False, + key_padding_mask=None, + need_weights=False, attn_mask=time_mask_byte, is_training=True) - + self.ref_inputs.backward(grads) self.tst_inputs.backward(grads) @@ -104,23 +102,23 @@ def test_self_multihead_attn_pad_mask(self) : pad_mask_byte = torch.tril(torch.ones(self.tst_inputs.size(1), self.tst_inputs.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1) pad_mask_bool = pad_mask_byte.to(torch.bool) - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, - self.ref_inputs, + ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, self.ref_inputs, - key_padding_mask=pad_mask_bool, - need_weights=False, + self.ref_inputs, + key_padding_mask=pad_mask_bool, + need_weights=False, attn_mask=None, is_training=True) - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, - self.tst_inputs, + tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, self.tst_inputs, - key_padding_mask=pad_mask_byte, - need_weights=False, + self.tst_inputs, + key_padding_mask=pad_mask_byte, + need_weights=False, attn_mask=None, is_training=True) - + self.ref_inputs.backward(grads) self.tst_inputs.backward(grads) diff --git a/apex/contrib/test/optimizers/test_dist_adam.py b/apex/contrib/test/optimizers/test_dist_adam.py index bd23ce2ae..531dce502 100644 --- a/apex/contrib/test/optimizers/test_dist_adam.py +++ b/apex/contrib/test/optimizers/test_dist_adam.py @@ -1,39 +1,63 @@ from contextlib import contextmanager import io -import os +from typing import Callable, Optional, Tuple +import unittest +import warnings +from contextlib import nullcontext import torch from torch.testing._internal import common_utils -from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam +from torch.testing._internal.common_utils import skipIfRocm + + +SKIP_TEST = None +try: + from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam +except ImportError as e: + SKIP_TEST = e from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -class SimpleModel(torch.nn.Module): +class SimpleModel(torch.nn.Module): def __init__(self, num_layers, size): super().__init__() - self.layers = torch.nn.ModuleList([ - torch.nn.Linear(size, size, bias=(i%3==0)) - for i in range(num_layers) + self.params = torch.nn.ParameterList([ + torch.nn.Parameter(torch.rand(1, size) + 1) + for _ in range(num_layers) ]) - def forward(self, x): y = 0 - for i, l in enumerate(self.layers): - y += (i+1) * l(x) + for i, param in enumerate(self.params): + y += (i+1) * param * x return y + def make_models( - num_layers, - size, - dtype=torch.float32, - param_sync_dtype=None, - device='cuda', - overlap_communication=True, + num_layers: int, + size: int, + *, + lr: float = 0.1, + adam_w_mode: bool = True, + model_dtype: torch.dtype = torch.float32, + optim_dtype: Optional[torch.dtype] = None, + grad_sync_dtype: Optional[torch.dtype] = None, + param_sync_dtype: Optional[torch.dtype] = None, + device: torch.device = 'cuda', + process_group: Optional[torch.distributed.ProcessGroup] = None, + average_grad_sync: bool = True, + overlap_communication: bool = True, + bucket_cap_mb: float = 71/(4*1024*1024), + contiguous_buffers: bool = False, + store_params: bool = False, + store_param_remainders: bool = False, + with_scaled_states: bool = False, + nccl_ub: bool = False, + with_cuda_graph: bool = False, ): # Construct models with same parameters - ref_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device) - dist_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device) + ref_model = SimpleModel(num_layers, size).to(dtype=model_dtype, device=device) + dist_model = SimpleModel(num_layers, size).to(dtype=model_dtype, device=device) with torch.no_grad(): for ref_param, dist_param in zip(dist_model.parameters(), ref_model.parameters()): @@ -45,31 +69,48 @@ def make_models( ref_model, device_ids=[rank] if device=='cuda' else None, output_device=rank if device=='cuda' else None, + process_group=process_group, ) # Construct optimizers with same hyperparameters - optim_args = dict(lr=0.1, betas=(0.1,0.2), eps=0.25, weight_decay=0.1) - ref_optim = torch.optim.AdamW( + if optim_dtype is None: + optim_dtype = model_dtype + optim_args = dict(lr=lr, betas=(0.1,0.2), eps=0.25, weight_decay=0.1) + ref_optim_class = torch.optim.AdamW if adam_w_mode else torch.optim.Adam + ref_optim = ref_optim_class( [ - {'params': list(ref_model.parameters())[1::2], 'lr': 0.2}, + {'params': list(ref_model.parameters())[1::2], 'lr': lr*2}, {'params': list(ref_model.parameters())[0::2]}, ], **optim_args, ) dist_optim = DistributedFusedAdam( [ - {'params': list(dist_model.parameters())[1::2], 'lr': 0.2}, + {'params': list(dist_model.parameters())[1::2], 'lr': lr*2}, {'params': list(dist_model.parameters())[0::2]}, ], + adam_w_mode=adam_w_mode, overlap_grad_sync=overlap_communication, - bucket_cap_mb=71/(4*1024*1024), - dtype=torch.float32, + overlap_param_sync=overlap_communication, + bucket_cap_mb=bucket_cap_mb, + dtype=optim_dtype, + grad_sync_dtype=grad_sync_dtype, param_sync_dtype=param_sync_dtype, + process_group=process_group, + average_grad_sync=average_grad_sync, + contiguous_param_buffer=contiguous_buffers, + contiguous_grad_buffer=contiguous_buffers, + store_params=store_params, + store_param_remainders=store_param_remainders, + with_scaled_states=with_scaled_states, + nccl_ub=nccl_ub, + capturable=with_cuda_graph, **optim_args, ) return ref_model, ref_optim, dist_model, dist_optim + @contextmanager def dummy_context(): try: @@ -77,83 +118,163 @@ def dummy_context(): finally: pass + +@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") class TestDistributedFusedAdam(NcclDistributedTestBase): seed = 1234 def test_matches_pytorch( self, - num_layers=11, - layer_size=7, - batch_size=3, - num_steps=3, - micro_batch_steps=3, - overlap_communication=True, - use_nosync=True, - dtype=torch.float32, - param_sync_dtype=None, - device='cuda', - rtol=None, - atol=None, + rtol: Optional[float] = None, + atol: Optional[float] = None, + num_layers: int = 11, + layer_size: int = 7, + batch_size: int = 3, + num_steps: int = 3, + micro_batch_steps: int = 3, + adam_w_mode: bool = True, + overlap_communication: bool = True, + use_nosync: bool = True, + model_dtype: torch.dtype = torch.float32, + optim_dtype: Optional[torch.dtype] = None, + grad_sync_dtype: Optional[torch.dtype] = None, + param_sync_dtype: Optional[torch.dtype] = None, + device: torch.device = 'cuda', + bucket_cap_mb: float = 71/(4*1024*1024), + contiguous_buffers: bool = False, + store_params: bool = False, + store_param_remainders: bool = False, + with_scaled_states: bool = False, + nccl_ub: bool = False, + init_optim_func: Optional[Callable[[DistributedFusedAdam], None]] = None, + with_cuda_graph: bool = False, ): torch.manual_seed(self.seed + self.rank) # Identical models with data-parallel and ZeRO - ref_model, ref_optim, dist_model, dist_optim = make_models( - num_layers, - layer_size, - dtype=dtype, - param_sync_dtype=param_sync_dtype, - device=device, - overlap_communication=overlap_communication, - ) + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + ref_model, ref_optim, dist_model, dist_optim = make_models( + num_layers, + layer_size, + adam_w_mode=adam_w_mode, + model_dtype=model_dtype, + optim_dtype=optim_dtype, + grad_sync_dtype=grad_sync_dtype, + param_sync_dtype=param_sync_dtype, + device=device, + overlap_communication=overlap_communication, + bucket_cap_mb=bucket_cap_mb, + contiguous_buffers=contiguous_buffers, + store_params=store_params, + store_param_remainders=store_param_remainders, + with_scaled_states=with_scaled_states, + nccl_ub=nccl_ub, + with_cuda_graph=with_cuda_graph, + ) - # Training loop - for step in range(num_steps): + # Initialize distributed optimizer + if init_optim_func is not None: + with torch.cuda.stream(stream): + init_optim_func(dist_optim) - # Reset gradients - ref_optim.zero_grad() - dist_optim.zero_grad() + # Static data + static_xs, static_dys = [], [] + ys_ref, grad_xs_ref = [], [] + ys_dist, grad_xs_dist = [], [] - # Forward and backward passes - for micro_step in range(micro_batch_steps): + graph = torch.cuda.CUDAGraph() if with_cuda_graph else None + CAPTURE_ITERATION = 11 + if with_cuda_graph: + assert num_steps > CAPTURE_ITERATION + 3, \ + "Not enough iterations for CUDA graph test." + # Training loop + with torch.cuda.stream(stream): + for step in range(num_steps): # Synthetic data - x = torch.rand(batch_size, layer_size) - 0.5 - dy = torch.rand_like(x) - 0.5 - x = x.to(dtype=dtype, device=device) - dy = dy.to(dtype=dtype, device=device) + for micro_step in range(micro_batch_steps): + x = torch.rand(batch_size, layer_size) - 0.5 + dy = torch.rand_like(x) - 0.5 + x = x.to(dtype=model_dtype, device=device) + dy = dy.to(dtype=model_dtype, device=device) + if step == 0: + static_xs.append(x) + static_dys.append(dy) + else: + static_xs[micro_step].copy_(x) + static_dys[micro_step].copy_(dy) # Reference implementation - x_ref = x.detach().clone().requires_grad_(True) - y_ref = ref_model(x_ref) - y_ref.backward(dy) + ref_optim.zero_grad() + for micro_step in range(micro_batch_steps): + x, dy = static_xs[micro_step], static_dys[micro_step] + + x_ref = x.detach().clone().requires_grad_(True) + y_ref = ref_model(x_ref) + y_ref.backward(dy) + + if step == 0: + ys_ref.append(y_ref) + grad_xs_ref.append(x_ref.grad) + else: + with torch.no_grad(): + ys_ref[micro_step].copy_(y_ref) + grad_xs_ref[micro_step].copy_(x_ref.grad) + ref_optim.step() # Distributed implementation - x_dist = x.detach().clone().requires_grad_(True) - y_dist = dist_model(x_dist) - backward_context = dummy_context - if use_nosync and micro_step < micro_batch_steps-1: - backward_context = dist_optim.no_sync - with backward_context(): - y_dist.backward(dy) + if not with_cuda_graph or step <= CAPTURE_ITERATION: + if with_cuda_graph and step == CAPTURE_ITERATION: + ctx = torch.cuda.graph(graph) + torch.cuda.synchronize() + else: + ctx = nullcontext() + + with ctx: + dist_optim.zero_grad() + for micro_step in range(micro_batch_steps): + x, dy = static_xs[micro_step], static_dys[micro_step] + + x_dist = x.detach().clone().requires_grad_(True) + y_dist = dist_model(x_dist) + backward_context = dummy_context + if use_nosync and micro_step < micro_batch_steps-1: + backward_context = dist_optim.no_sync + with backward_context(): + y_dist.backward(dy) + + if step == 0: + ys_dist.append(y_dist) + grad_xs_dist.append(x_dist.grad) + else: + with torch.no_grad(): + ys_dist[micro_step].copy_(y_dist) + grad_xs_dist[micro_step].copy_(x_dist.grad) + dist_optim.step() + + if with_cuda_graph and step == CAPTURE_ITERATION: + graph.replay() + else: + graph.replay() # Check that data tensors match - torch.testing.assert_close( - y_dist, y_ref, rtol=rtol, atol=atol) - torch.testing.assert_close( - x_dist.grad, x_ref.grad, rtol=rtol, atol=atol) + for mbs in range(micro_batch_steps): + torch.testing.assert_close( + ys_dist[mbs], ys_ref[mbs], rtol=rtol, atol=atol) + torch.testing.assert_close( + grad_xs_dist[mbs], grad_xs_ref[mbs], rtol=rtol, atol=atol) - # Optimization step - ref_optim.step() - dist_optim.step() + # Check that parameters match + for ref_param, dist_param in zip(ref_model.parameters(), + dist_model.parameters()): + torch.testing.assert_close( + dist_param, ref_param, rtol=rtol, atol=atol) - # Check that parameters match - for ref_param, dist_param in zip(ref_model.parameters(), - dist_model.parameters()): - torch.testing.assert_close( - dist_param, ref_param, rtol=rtol, atol=atol) + def test_matches_pytorch_l2_reg(self): + self.test_matches_pytorch(adam_w_mode=False) def test_matches_pytorch_no_overlap(self): self.test_matches_pytorch( @@ -164,28 +285,119 @@ def test_matches_pytorch_no_overlap(self): def test_matches_pytorch_sync_every_step(self): self.test_matches_pytorch(use_nosync=False) + def test_matches_pytorch_contiguous_buffers(self): + self.test_matches_pytorch(contiguous_buffers=True) + def test_matches_pytorch_fp64(self): self.test_matches_pytorch( - dtype=torch.float64, rtol=1.3e-6, atol=1e-5, + model_dtype=torch.float64, + optim_dtype=torch.float32, ) def test_matches_pytorch_fp16(self): self.test_matches_pytorch( - dtype=torch.float16, - rtol=1e-2, - atol=1e-2, + rtol=5e-3, + atol=1e-5, + micro_batch_steps=1, + model_dtype=torch.float16, + optim_dtype=torch.float16, ) - def test_matches_pytorch_allgather_fp16(self): + def test_matches_pytorch_bf16(self): self.test_matches_pytorch( - dtype=torch.float32, + rtol=5e-2, + atol=1e-5, + micro_batch_steps=1, + model_dtype=torch.bfloat16, + optim_dtype=torch.bfloat16, + ) + + def test_matches_pytorch_fp16_params(self): + self.test_matches_pytorch( + rtol=5e-3, + atol=1e-5, + micro_batch_steps=1, + model_dtype=torch.float16, + optim_dtype=torch.float32, param_sync_dtype=torch.float16, - rtol=1e-2, - atol=1e-2, + store_params=True, + ) + + def test_matches_pytorch_bf16_grads(self): + self.test_matches_pytorch( + rtol=5e-2, + atol=1e-5, + micro_batch_steps=1, + model_dtype=torch.float32, + optim_dtype=torch.float32, + grad_sync_dtype=torch.bfloat16, + ) + + def test_matches_pytorch_bf16_param_remainders(self): + self.test_matches_pytorch( + rtol=5e-2, + atol=1e-5, + micro_batch_steps=1, + model_dtype=torch.bfloat16, + optim_dtype=torch.float32, + param_sync_dtype=torch.bfloat16, + store_params=False, + store_param_remainders=True, + ) + + def test_matches_pytorch_multi_dtypes(self): + def init_optim(optim: DistributedFusedAdam): + params = list(optim.parameters()) + optim.init_params(params[0::3], grad_sync_dtype=torch.bfloat16) + optim.init_params(params[1::3], param_sync_dtype=torch.bfloat16) + self.test_matches_pytorch( + rtol=5e-2, + atol=1e-5, + init_optim_func=init_optim, + ) + + def test_matches_pytorch_int64_param_sync(self): + self.test_matches_pytorch( + param_sync_dtype=torch.int64, + ) + + def test_matches_pytorch_int32_param_sync_contiguous_buffers(self): + self.test_matches_pytorch( + param_sync_dtype=torch.int32, + contiguous_buffers=True, ) + def test_matches_pytorch_uint8_param_sync(self): + self.test_matches_pytorch( + rtol=0.5, + atol=0.05, + model_dtype=torch.float16, + optim_dtype=torch.float16, + micro_batch_steps=1, + param_sync_dtype=torch.uint8, + ) + + def test_matches_pytorch_scaled_state(self): + self.test_matches_pytorch( + rtol=5e-2, + atol=1e-5, + micro_batch_steps=1, + model_dtype=torch.bfloat16, + optim_dtype=torch.float16, + param_sync_dtype=torch.int, + store_params=True, + with_scaled_states=True, + ) + + def test_matches_pytorch_nccl_ub(self): + self.test_matches_pytorch( + contiguous_buffers=True, + nccl_ub=True, + ) + + def test_raises_on_mismatch(self): torch.manual_seed(self.seed + self.rank) @@ -200,9 +412,9 @@ def test_raises_on_mismatch(self): # Only perform training step with distributed model dist_optim.zero_grad() - x = torch.rand(3, layer_size) + 0.5 + x = torch.rand(3, layer_size) - 0.5 x = x.to(dtype=torch.float32, device='cuda') - dy = torch.rand_like(x) + 0.5 + dy = torch.rand_like(x) - 0.5 y = dist_model(x) y.backward(dy) dist_optim.step() @@ -227,8 +439,8 @@ def test_clip_grad_norm(self): xs = [3, 1, 4, 1, 5, 9] dys = [1, -1, 1, -1, 1, -1] for x, dy in zip(xs, dys): - x = torch.tensor([x], dtype=torch.float32, device='cuda') - dy = torch.tensor([dy], dtype=torch.float32, device='cuda') + x = torch.tensor([[x]], dtype=torch.float32, device='cuda') + dy = torch.tensor([[dy]], dtype=torch.float32, device='cuda') # Reference implementation ref_optim.zero_grad() @@ -262,15 +474,15 @@ def test_grad_scaler(self): backoff_factor=0.876, growth_interval=1, ) - ref_scaler = torch.cuda.amp.GradScaler(**grad_scaler_args) - dist_scaler = torch.cuda.amp.GradScaler(**grad_scaler_args) + ref_scaler = torch.amp.GradScaler('cuda', **grad_scaler_args) + dist_scaler = torch.amp.GradScaler('cuda', **grad_scaler_args) # Training steps with pre-determined gradients xs = [3, 1, 4, 1, 5, 9] dys = [1, float('inf'), 1, 1, float('nan'), -1] for x, dy in zip(xs, dys): - x = torch.tensor([x], dtype=torch.float32, device='cuda') - dy = torch.tensor([dy], dtype=torch.float32, device='cuda') + x = torch.tensor([[x]], dtype=torch.float32, device='cuda') + dy = torch.tensor([[dy]], dtype=torch.float32, device='cuda') # Reference implementation ref_optim.zero_grad() @@ -291,101 +503,375 @@ def test_grad_scaler(self): dist_model.parameters()): torch.testing.assert_close(dist_param, ref_param) - def test_checkpoint(self): + def test_checkpoint( + self, + rtol: Optional[float] = None, + atol: Optional[float] = None, + num_layers: int = 2, + layer_size: int = 2, + num_steps: int = 3, + save_group_size: Optional[int] = None, + load_group_size: Optional[int] = None, + save_model_kwargs: Optional[dict] = None, + load_model_kwargs: Optional[dict] = None, + ): + """Test state_dict and load_state_dict functions + + Two models are constructed, possibly on different process + groups. One of the models is trained for a few steps, a + checkpoint is saved, and the checkpoint is loaded on the other + model. Both models are then trained for a few steps and + checked to make sure that they produce identical results. + + Arguments: + rtol (float): Relative tolerance for numerical checks (see + torch.allclose). + atol (float): Absolute tolerance for numerical checks (see + torch.allclose). + num_layers (int): Number of layers in test model. + layer_size (int): Number of features in model layers. + num_steps (int): Number of training steps to perform + before and after checkpointing. + save_group_size (int): Process group size for model that + saves the checkpoint. Uses the default process group + by default. + load_group_size (int): Process group size for model that + loads the checkpoint. Uses the default process group + by default. + save_model_kwargs (dict): keyword arguments passed to + make_models when constructing the model that saves the + checkpoint. + load_model_kwargs (dict): keyword arguments passed to + make_models when constructing the model that loads the + checkpoint. + + """ + + # Initialize process groups + world_size = torch.distributed.get_world_size() + if save_group_size is None: + save_group_size = world_size + save_group = None + else: + if save_group_size > world_size: + self.skipTest( + f"Requires {save_group_size} ranks, found {world_size}" + ) + save_ranks = list(range(save_group_size)) + save_group = torch.distributed.new_group(ranks=save_ranks) + if load_group_size is None: + load_group_size = world_size + load_group = None + else: + if load_group_size > world_size: + self.skipTest( + f"Requires {load_group_size} ranks, found {world_size}" + ) + load_ranks = list(range(load_group_size)) + load_group = torch.distributed.new_group(ranks=load_ranks) # Construct two models with same config and different params - num_layers = 5 - layer_size = 2 - torch.manual_seed(self.seed + self.rank) - _, _, model_save, optim_save = make_models(num_layers, layer_size) - _, _, model_load, optim_load = make_models(num_layers, layer_size) + torch.manual_seed(self.seed) + if self.rank < save_group_size: + if not save_model_kwargs: + save_model_kwargs = {} + _, _, model_save, optim_save = make_models( + num_layers, + layer_size, + lr=0.1, + process_group=save_group, + average_grad_sync=False, + overlap_communication=False, + **save_model_kwargs, + ) + optim_save.init_params(reversed(list(model_save.parameters()))) + torch.manual_seed(self.seed+1) + if self.rank < load_group_size: + if not load_model_kwargs: + load_model_kwargs = {} + _, _, model_load, optim_load = make_models( + num_layers, + layer_size, + lr=1234., + process_group=load_group, + average_grad_sync=False, + overlap_communication=False, + **load_model_kwargs, + ) + optim_load.init_params(list(model_load.parameters())) + + batch_size = 2 * save_group_size * load_group_size + def make_global_batch() -> torch.Tensor: + """Generate random tensor on root rank and broadcast""" + x = torch.empty(batch_size, layer_size, device='cuda') + if self.rank == 0: + torch.rand(x.size(), out=x) + x -= 0.5 + torch.distributed.broadcast(x, src=0) + return x + + def to_local_batch( + global_batch: torch.Tensor, + group: Optional[torch.distributed.ProcessGroup], + ) -> Optional[torch.Tensor]: + """Get local portion of tensor that is replicated across all ranks""" + group_size = torch.distributed.get_world_size(group) + if group_size < 0: + return None + local_batch_size = batch_size // group_size + batch_start = self.rank * local_batch_size + batch_end = (self.rank + 1) * local_batch_size + return global_batch[batch_start:batch_end, ...] + + def to_global_batch( + local_batch: torch.Tensor, + group: Optional[torch.distributed.ProcessGroup], + ) -> torch.Tensor: + """Gather distributed tensor and broadcast to all ranks""" + + # Allocate buffer + global_batch = torch.empty(batch_size, layer_size, device='cuda') + + # Gather data on root rank + group_size = torch.distributed.get_world_size(group) + if group_size > 0: + local_batches = None + if self.rank == 0: + local_batch_size = batch_size // group_size + local_batches = [ + global_batch[rank*local_batch_size:(rank+1)*local_batch_size, ...] + for rank in range(group_size) + ] + torch.distributed.gather( + local_batch, + local_batches, + dst=0, + group=group, + ) + + # Broadcast data to all ranks + torch.distributed.broadcast(global_batch, src=0) + return global_batch # Train one of the models - num_steps = 3 - micro_batch_steps = 2 - batch_size = 4 + torch.manual_seed(self.seed+2) for step in range(num_steps): - optim_save.zero_grad() - for micro_step in range(micro_batch_steps): - x = torch.rand(batch_size, layer_size) - 0.5 - dy = torch.rand_like(x) - 0.5 - x = x.cuda() - dy = dy.cuda() + if self.rank < save_group_size: + optim_save.zero_grad() + x = make_global_batch() + dy = make_global_batch() + if self.rank < save_group_size: + x = to_local_batch(x, save_group) + dy = to_local_batch(dy, save_group) y = model_save(x) y.backward(dy) - optim_save.step() + optim_save.step() # Make sure models are different - for param_save, param_load in zip(model_save.parameters(), - model_load.parameters()): - self.assertRaises( - AssertionError, - torch.testing.assert_close, - param_load, param_save, - ) - - # Save state on root rank and load on all ranks - state_dict = { - 'model': model_save.state_dict(), - 'optim': optim_save.state_dict(), - } - if self.rank == 0: - state_bytes = io.BytesIO() - torch.save(state_dict, state_bytes) - state_bytes = [state_bytes.getvalue()] - else: - state_bytes = [None] - torch.distributed.broadcast_object_list(state_bytes, src=0) - state_bytes = io.BytesIO(state_bytes[0]) - state_dict = torch.load(state_bytes, map_location='cuda') - model_load.load_state_dict(state_dict['model']) - optim_load.load_state_dict(state_dict['optim']) + if self.rank < min(save_group_size, load_group_size): + for param_save, param_load in zip(model_save.parameters(), + model_load.parameters()): + self.assertRaises( + AssertionError, + torch.testing.assert_close, + param_load, + param_save, + rtol=rtol, + atol=atol, + ) + + # Save state + state_bytes = None + if self.rank < save_group_size: + state_dict = { + 'model': model_save.state_dict(), + 'optim': optim_save.state_dict(), + } + byte_stream = io.BytesIO() + torch.save(state_dict, byte_stream) + state_bytes = byte_stream.getvalue() + + # Broadcast state from root rank and load + if self.rank < load_group_size: + if load_group_size != save_group_size: + if self.rank != 0: + state_bytes = None + state_bytes = [state_bytes] + torch.distributed.broadcast_object_list( + state_bytes, + src=0, + group=load_group, + ) + state_bytes = state_bytes[0] + state_dict = torch.load(io.BytesIO(state_bytes)) + model_load.load_state_dict(state_dict['model']) + optim_load.load_state_dict(state_dict['optim']) # Make sure models are identical - for param_save, param_load in zip(model_save.parameters(), - model_load.parameters()): - torch.testing.assert_close(param_load, param_save) + if self.rank < min(save_group_size, load_group_size): + for param_save, param_load in zip(model_save.parameters(), + model_load.parameters()): + torch.testing.assert_close( + param_load, + param_save, + rtol=rtol, + atol=atol + ) # Train both models - num_steps = 3 - micro_batch_steps = 3 - batch_size = 5 + torch.manual_seed(self.seed+3) for step in range(num_steps): - # Reset gradients - optim_save.zero_grad() - optim_load.zero_grad() - - # Forward and backward passes - for micro_step in range(micro_batch_steps): - - # Synthetic data - x = torch.rand(batch_size, layer_size) - 0.5 - dy = torch.rand_like(x) - 0.5 - x = x.cuda() - dy = dy.cuda() - - # Forward and backward pass - x_save = x.detach().clone().requires_grad_(True) + # Reset grads + if self.rank < save_group_size: + optim_save.zero_grad() + if self.rank < load_group_size: + optim_load.zero_grad() + + # Synthetic data + x = make_global_batch() + dy = make_global_batch() + + # Training step for model that saved checkpoint + y_save = None + dx_save = None + if self.rank < save_group_size: + x_save = to_local_batch(x, save_group) + x_save = x_save.detach().clone().requires_grad_(True) + dy_save = to_local_batch(dy, save_group) y_save = model_save(x_save) - y_save.backward(dy) - x_load = x.detach().clone().requires_grad_(True) + y_save.backward(dy_save) + dx_save = x_save.grad + y_save = to_global_batch(y_save, save_group) + dx_save = to_global_batch(dx_save, save_group) + + # Training step for model that loaded checkpoint + y_load = None + dx_load = None + if self.rank < load_group_size: + x_load = to_local_batch(x, load_group) + x_load = x_load.detach().clone().requires_grad_(True) + dy_load = to_local_batch(dy, load_group) y_load = model_load(x_load) - y_load.backward(dy) + y_load.backward(dy_load) + dx_load = x_load.grad + y_load = to_global_batch(y_load, load_group) + dx_load = to_global_batch(dx_load, load_group) - # Check that data tensors match - torch.testing.assert_close(y_load, y_save) - torch.testing.assert_close(x_load.grad, x_save.grad) + # Check that data tensors match + torch.testing.assert_close(y_load, y_save, rtol=rtol, atol=atol) + torch.testing.assert_close(dx_load, dx_save, rtol=rtol, atol=atol) # Optimizer step - optim_save.step() - optim_load.step() + if self.rank < save_group_size: + optim_save.step() + if self.rank < load_group_size: + optim_load.step() # Check that parameters match - for param_save, param_load in zip(model_save.parameters(), - model_load.parameters()): - torch.testing.assert_close(param_load, param_save) + if self.rank < min(save_group_size, load_group_size): + for param_save, param_load in zip(model_save.parameters(), + model_load.parameters()): + torch.testing.assert_close( + param_load, + param_save, + rtol=rtol, + atol=atol, + ) + + def test_checkpoint_save_1gpu(self): + """Test loading checkpoint with one GPU""" + self.test_checkpoint(save_group_size=1) + + def test_checkpoint_load_1gpu(self): + """Test saving checkpoint with one GPU""" + self.test_checkpoint(load_group_size=1) + + def test_checkpoint_bf16(self): + """Test checkpoint with BF16 model""" + self.test_checkpoint( + rtol=5e-2, + atol=1e-5, + save_model_kwargs=dict( + model_dtype=torch.bfloat16, + optim_dtype=torch.float32, + param_sync_dtype=torch.bfloat16, + store_params=False, + store_param_remainders=True, + ), + load_model_kwargs=dict( + model_dtype=torch.bfloat16, + optim_dtype=torch.float32, + param_sync_dtype=torch.bfloat16, + store_params=False, + store_param_remainders=True, + ), + ) + + def test_checkpoint_scaled_state(self): + """Test checkpoint with scaled FP16 state""" + self.test_checkpoint( + rtol=5e-2, + atol=1e-5, + save_model_kwargs=dict( + model_dtype=torch.bfloat16, + optim_dtype=torch.float16, + param_sync_dtype=torch.int, + store_params=True, + with_scaled_states=True, + ), + load_model_kwargs=dict( + model_dtype=torch.bfloat16, + optim_dtype=torch.float16, + param_sync_dtype=torch.int, + store_params=True, + with_scaled_states=True, + ), + ) + + def test_bucket_low_utilization_warning(self): + """Test warning when bucket utilization is low""" + layer_size = 2*1024*1024 + num_layers = 4 + fairish_bucket_cap_mb = 4*num_layers*layer_size/(1024*1024) + + # Check that warning is raised when bucket utilization is low + with self.assertWarnsRegex(Warning, ".*Consider decreasing the bucket_cap_mb argument."): + self.test_matches_pytorch( + num_layers=num_layers, + layer_size=layer_size, + overlap_communication=False, + bucket_cap_mb=fairish_bucket_cap_mb * 2, + ) + + # Check that warning is not raised when bucket utilization is high + with warnings.catch_warnings(record=True) as warns: + self.test_matches_pytorch( + num_layers=num_layers, + layer_size=layer_size, + overlap_communication=False, + bucket_cap_mb=fairish_bucket_cap_mb, + ) + for w in warns: + self.assertNotRegex(str(w.message), ".*Consider decreasing the bucket_cap_mb argument.") + + + def test_cuda_graph(self): + """Test distributed adam with CUDA graph""" + if self.world_size < 8: + self.skipTest(f"{self.world_size=} is expected to be >= 8") + self.test_matches_pytorch( + rtol=5e-3, + atol=1e-5, + num_steps=15, + micro_batch_steps=1, + model_dtype=torch.float16, + optim_dtype=torch.float16, + contiguous_buffers=True, + with_cuda_graph=True, + ) + if __name__ == "__main__": # Assume script has been run with torchrun - common_utils.run_tests() + common_utils.run_tests() \ No newline at end of file diff --git a/apex/contrib/test/optimizers/test_distributed_fused_lamb.py b/apex/contrib/test/optimizers/test_distributed_fused_lamb.py new file mode 100644 index 000000000..d8f56117a --- /dev/null +++ b/apex/contrib/test/optimizers/test_distributed_fused_lamb.py @@ -0,0 +1,124 @@ +import os +import inspect +import torch +from torch.cuda.amp import GradScaler +from torch.testing._internal import common_utils +from apex.parallel.distributed import flat_dist_call +from apex.contrib.optimizers.distributed_fused_lamb import DistributedFusedLAMB +from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase + +def get_init_weights_func(): + @torch.no_grad() + def init_weights(m): + if isinstance(m, torch.nn.Linear): + m.weight.fill_(1.0) + return init_weights + +class ModelFoo(torch.nn.Module): + def __init__(self): + super(ModelFoo, self).__init__() + self.linear = torch.nn.Linear(128, 128, bias = False) + self.loss = torch.nn.MSELoss() + + def forward(self, input_tensor, gt): + y = self.linear(input_tensor) + loss = self.loss(y, gt) + return loss + +# A test for distributed fused Lamb optimizer: run several iterations and see if loss decreases +# There are two instances of the same test because based on `world_size` the optimizer decides what collectives operation to use. +# If torch.distributed.get_world_size() == torch.cuda.device_count() it uses only `all_gather`. +# If torch.distributed.get_world_size() < torch.cuda.device_count() it uses both `all_gather` and `reduce_scatter`. +class NcclDistributedFusedLAMB(NcclDistributedTestBase): + @property + def world_size(self) -> int: + return torch.cuda.device_count() + + @common_utils.parametrize("no_copy", [False, True]) + @common_utils.parametrize("opt_kwargs", [ + dict(overlap_reductions=True, dwu_num_blocks=2, dwu_num_chunks=2, + fused_norm=False, fuse_scale=False, clip_after_ar=True, + full_ar=False), + dict(overlap_reductions=False, dwu_num_blocks=1, dwu_num_chunks=1, + fused_norm=True, fuse_scale=True, clip_after_ar=False), + ]) + def test_distributed_fused_lamb(self, no_copy, opt_kwargs): + if no_copy and 'no_copy' not in inspect.getfullargspec(torch.distributed.reduce_scatter).args: + self.skipTest("does not support no_copy") + if no_copy and 'no_copy' not in inspect.getfullargspec(torch.distributed.all_gather).args: + self.skipTest("does not support no_copy") + + assert torch.distributed.is_initialized() + gpu_count = torch.distributed.get_world_size() + + init_scale = 100 + lr = torch.tensor(0.1).cuda() + grad_scaler = GradScaler(init_scale=init_scale, growth_interval=1000) + + model = ModelFoo() + model = model.cuda().half() + model.apply(get_init_weights_func()) + + param_optimizer = list(model.named_parameters()) + no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, + {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + + if 'full_ar' not in opt_kwargs: + opt_kwargs['full_ar'] = gpu_count == torch.cuda.device_count() + + # Aidyn-A: not sure what parameters are the best for testing purposes, + # setting up whatever I think appropriate. + optimizer = DistributedFusedLAMB( + optimizer_grouped_parameters, + lr=0.1, + betas=(0.9, 0.9), + eps=1e-6, + max_grad_norm=1.0, + dwu_group_size=gpu_count, + dwu_num_rs_pg=1, + dwu_num_ar_pg=1, + dwu_num_ag_pg=1, + use_nvlamb=False, + set_param_views_to_flat_buffer=False, + e5m2_allgather=False, + **opt_kwargs + ) + optimizer.set_global_scale(init_scale) + + optimizer._reduce_scatter_no_copy = no_copy + optimizer._all_gather_no_copy = no_copy + + flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) ) + + x = torch.randn(4096, 128, dtype=torch.float16).cuda() + y = torch.randn(4096, 128, dtype=torch.float16).cuda() + + losses = [] + for _ in range(10): + loss = model(x, y) + optimizer._lazy_init_stage1() + grad_scaler.scale(loss).backward() + optimizer._lazy_init_stage2() + optimizer._lr = lr + optimizer.complete_reductions() + optimizer.set_global_scale(grad_scaler._get_scale_async()) + grad_scaler.step(optimizer) + grad_scaler.update() + optimizer.zero_grad(set_to_none=True) + + losses.append(loss.item()) + + self.assertTrue(losses == sorted(losses, reverse=True)) + +common_utils.instantiate_parametrized_tests(NcclDistributedFusedLAMB) + +class NcclDistributedFusedLAMB_partial_ar(NcclDistributedFusedLAMB): + @property + def world_size(self) -> int: + return max(torch.cuda.device_count()-1, 1) + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/apex/contrib/test/run_rocm_extensions.py b/apex/contrib/test/run_rocm_extensions.py new file mode 100644 index 000000000..1c9add5d8 --- /dev/null +++ b/apex/contrib/test/run_rocm_extensions.py @@ -0,0 +1,28 @@ +import unittest +import sys + + +test_dirs = ["groupbn", "layer_norm", "multihead_attn", "transducer", "focal_loss", "index_mul_2d", ".", \ + "optimizers", "clip_grad"] # "." for test_label_smoothing.py +ROCM_BLACKLIST = [ + "layer_norm" +] + +if __name__ == '__main__': + runner = unittest.TextTestRunner(verbosity=2) + + errcode = 0 + + for test_dir in test_dirs: + if test_dir in ROCM_BLACKLIST: + continue + suite = unittest.TestLoader().discover(test_dir) + + print("\nExecuting tests from " + test_dir) + + result = runner.run(suite) + + if not result.wasSuccessful(): + errcode = 1 + + sys.exit(errcode) diff --git a/apex/contrib/test/transducer/test_transducer_joint.py b/apex/contrib/test/transducer/test_transducer_joint.py index c1c8dd1e7..3a19482db 100755 --- a/apex/contrib/test/transducer/test_transducer_joint.py +++ b/apex/contrib/test/transducer/test_transducer_joint.py @@ -121,6 +121,7 @@ def test_transducer_joint(self): def test_transducer_joint_vec(self): self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False) + # @unittest.skip("Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89") def test_transducer_joint_pack(self): self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False) @@ -133,25 +134,30 @@ def test_transducer_joint_relu(self): def test_transducer_joint_vec_relu(self): self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False) + # @unittest.skip("Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89") def test_transducer_joint_pack_relu(self): self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False) def test_transducer_joint_vec_pack_relu(self): self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False) + @unittest.skip("Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89") def test_transducer_joint_relu_dropout(self): self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) + @unittest.skip("Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89") def test_transducer_joint_vec_relu_dropout(self): self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True) + @unittest.skip("Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89") def test_transducer_joint_pack_relu_dropout(self): self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True) + @unittest.skip("Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89") def test_transducer_joint_vec_pack_relu_dropout(self): self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/apex/csrc b/apex/csrc new file mode 120000 index 000000000..e96d28eb5 --- /dev/null +++ b/apex/csrc @@ -0,0 +1 @@ +../csrc \ No newline at end of file diff --git a/apex/fused_dense/fused_dense.py b/apex/fused_dense/fused_dense.py index d36078c94..97377a423 100644 --- a/apex/fused_dense/fused_dense.py +++ b/apex/fused_dense/fused_dense.py @@ -1,13 +1,15 @@ import torch from torch import nn import fused_dense_cuda -from .. import amp +from apex._autocast_utils import _cast_if_autocast_enabled +import math + #implements fused GEMM+bias in forward pass using mlp_cuda from apex class FusedDenseFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, bias): ctx.save_for_backward(input, weight) - output = fused_dense_cuda.linear_bias_forward(input, weight, bias) + output = fused_dense_cuda.linear_bias_forward(input, weight, bias.t()) return output @staticmethod @@ -33,52 +35,114 @@ def backward(ctx, grad_output): class FusedDenseGeluDenseFunc(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight1, bias1, weight2, bias2): - ctx.save_for_backward(input, weight1, weight2) - output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(input, weight1, bias1, weight2, bias2) - ctx.save_for_backward(input, weight1, weight2, gelu_in, output1) + def forward(ctx, input, weight, bias, weight2, bias2): + ''' + The forward method of the FusedDenseGELUDense layer performs the following operations: + Applies the first dense layer (dense1) to the input tensor. + Applies the GELU activation function (act) to the result. + Applies the second dense layer (dense2) to the GELU-activated output. + ''' + ctx.save_for_backward(input, weight, weight2) + output, output2, gelu = fused_dense_cuda.linear_gelu_linear_forward(input, weight, bias, weight2, bias2) + ctx.save_for_backward(input, weight, weight2, gelu, output) return output2 @staticmethod def backward(ctx, grad_output): - input, weight1, weight2, gelu_in, output1 = ctx.saved_tensors - grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_gelu_linear_backward(input, gelu_in, output1, weight1, weight2, grad_output) - return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 + input, weight, weight2, gelu, output = ctx.saved_tensors + grad_input, grad_weight, grad_bias, grad_weight2, grad_bias2 = fused_dense_cuda.linear_gelu_linear_backward(input, gelu, output, weight, weight2, grad_output) + return grad_input, grad_weight, grad_bias, grad_weight2, grad_bias2 + +def fused_dense_function(input, weight, bias): + args = _cast_if_autocast_enabled(input, weight, bias) + with torch.amp.autocast('cuda', enabled=False): + return FusedDenseFunc.apply(*args) +def dense_no_bias_function(input, weight): + args = _cast_if_autocast_enabled(input, weight) + with torch.amp.autocast('cuda', enabled=False): + return DenseNoBiasFunc.apply(*args) -fused_dense_function = amp.half_function(FusedDenseFunc.apply) -dense_no_bias_function = amp.half_function(DenseNoBiasFunc.apply) -fused_dense_gelu_dense_function = amp.half_function(FusedDenseGeluDenseFunc.apply) +def fused_dense_gelu_dense_function(input, weight1, bias1, weight2, bias2): + args = _cast_if_autocast_enabled(input, weight1, bias1, weight2, bias2) + with torch.amp.autocast('cuda', enabled=False): + return FusedDenseGeluDenseFunc.apply(*args) class FusedDense(nn.Module): def __init__(self, in_features, out_features, bias=True): super(FusedDense, self).__init__() self.in_features = in_features self.out_features = out_features - self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) + self.weight = nn.Parameter(torch.randn(out_features, in_features)) if bias: - self.bias = nn.Parameter(torch.Tensor(out_features)) + self.bias = nn.Parameter(torch.randn(out_features)) else: #assert False, "no-bias option not added yet" self.register_parameter('bias', None) + self.reset_parameters() + def forward(self, input): if self.bias is not None: return fused_dense_function(input, self.weight, self.bias) else: return dense_no_bias_function(input, self.weight) + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + +#======================================================================================= +# +#======================================================================================= class FusedDenseGeluDense(nn.Module): + ''' + https://zeta.apac.ai/en/latest/zeta/nn/modules/fused_gelu_dense/ + module combines dense layers with GELU activations in a single neural network layer. + layer consists of two dense sub-layers, each followed by a GELU activation function. + It takes an input tensor and passes it through these sub-layers to produce the final output. + Parameters: + dim (int): Input dimension. + dim_out (int): Output dimension. + bias (bool, optional): Whether to include bias terms. Defaults to True. + has_fp16_weights (bool, optional): Whether to use fp16 weights. Defaults to False. + threshold (float, optional): Threshold for quantization. Defaults to 6.0. + + layer consists of the following internal layers: + dense1: The first dense layer. + act: The GELU activation function. + dense2: The second dense layer. + + ''' def __init__(self, in_features, intermediate_features, out_features, bias=True): super(FusedDenseGeluDense, self).__init__() assert bias == True, "DenseGeluDense module without bias is currently not supported" self.in_features = in_features self.intermediate_features = intermediate_features self.out_features = out_features - self.weight1 = nn.Parameter(torch.Tensor(intermediate_features, in_features)) - self.bias1 = nn.Parameter(torch.Tensor(intermediate_features)) - self.weight2 = nn.Parameter(torch.Tensor(out_features, intermediate_features)) - self.bias2 = nn.Parameter(torch.Tensor(out_features)) + self.weight1 = nn.Parameter(torch.randn(intermediate_features, in_features)) + self.bias1 = nn.Parameter(torch.randn(intermediate_features)) + self.weight2 = nn.Parameter(torch.randn(out_features, intermediate_features)) + self.bias2 = nn.Parameter(torch.randn(out_features)) + self.reset_parameters() + + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) + if self.bias1 is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias1, -bound, bound) + if self.bias2 is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias2, -bound, bound) + def forward(self, input): return fused_dense_gelu_dense_function(input, self.weight1, self.bias1, self.weight2, self.bias2) diff --git a/apex/git_version_info.py b/apex/git_version_info.py new file mode 100644 index 000000000..ee9e7c6c7 --- /dev/null +++ b/apex/git_version_info.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Portions of this code were adapted from DeepSpeed: +# https://github.com/microsoft/DeepSpeed +# Modified for ROCm Apex + +try: + # This is populated by setup.py + from .git_version_info_installed import * # noqa: F401 # type: ignore +except ModuleNotFoundError: + import os + if os.path.isfile('version.txt'): + # Will be missing from checkouts that haven't been installed (e.g., readthedocs) + version = open('version.txt', 'r').read().strip() + else: + version = "0.0.0" + git_hash = '[none]' + git_branch = '[none]' + + from .op_builder.all_ops import ALL_OPS + installed_ops = dict.fromkeys(ALL_OPS.keys(), False) + torch_info = {'version': "0.0", "cuda_version": "0.0", "hip_version": "0.0"} + +# compatible_ops list is recreated for each launch +from .op_builder.all_ops import ALL_OPS + +compatible_ops = dict.fromkeys(ALL_OPS.keys(), False) +for op_name, builder in ALL_OPS.items(): + op_compatible = builder.is_compatible() + compatible_ops[op_name] = op_compatible + compatible_ops["apex_not_implemented"] = False \ No newline at end of file diff --git a/apex/mlp/mlp.py b/apex/mlp/mlp.py index bae38f3f8..31b853292 100644 --- a/apex/mlp/mlp.py +++ b/apex/mlp/mlp.py @@ -1,9 +1,12 @@ from copy import copy import math + import torch from torch import nn + +from apex._autocast_utils import _cast_if_autocast_enabled import mlp_cuda -from .. import amp + class MlpFunction(torch.autograd.Function): @staticmethod @@ -21,7 +24,11 @@ def backward(ctx, grad_o): del ctx.outputs return (None, None, *grads) -mlp_function = amp.half_function(MlpFunction.apply) + +def mlp_function(bias, activation, *args): + autocast_args = _cast_if_autocast_enabled(bias, activation, *args) + return MlpFunction.apply(*autocast_args) + class MLP(torch.nn.Module): """Launch MLP in C++ diff --git a/apex/multi_tensor_apply/__init__.py b/apex/multi_tensor_apply/__init__.py index 0a80e3c54..31e2a53de 100644 --- a/apex/multi_tensor_apply/__init__.py +++ b/apex/multi_tensor_apply/__init__.py @@ -1,4 +1,5 @@ from .multi_tensor_apply import MultiTensorApply -multi_tensor_applier = MultiTensorApply(2048*32) +multi_tensor_applier = MultiTensorApply(256*32) +multi_tensor_applier_l2norm = MultiTensorApply(2048*32) diff --git a/apex/normalization/fused_layer_norm.py b/apex/normalization/fused_layer_norm.py index d873969f4..d5485cd9d 100644 --- a/apex/normalization/fused_layer_norm.py +++ b/apex/normalization/fused_layer_norm.py @@ -5,6 +5,7 @@ from torch.nn.parameter import Parameter from torch.nn import init from torch.nn import functional as F +from typing import List, Tuple from apex._autocast_utils import _cast_if_autocast_enabled @@ -12,6 +13,11 @@ fused_layer_norm_cuda = None +# PyTorch supports `torch.library.custom_op` since 2.4.0. +def supports_custom_op() -> bool: + return hasattr(torch.library, "custom_op") + + # Reference implementation from Huggingface def manual_rms_norm(input, normalized_shape, weight, eps): # layer norm should always be calculated in float32 @@ -24,181 +30,850 @@ def manual_rms_norm(input, normalized_shape, weight, eps): # convert into half-precision if necessary if weight.dtype in [torch.float16, torch.bfloat16]: - input = input.to(self.weight.dtype) + input = input.to(weight.dtype) return weight * input class FusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, bias, normalized_shape, eps): + def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() output, mean, invvar = fused_layer_norm_cuda.forward_affine( input_, ctx.normalized_shape, weight_, bias_, ctx.eps ) - ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, bias_, None, invvar) + else: + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output @staticmethod def backward(ctx, grad_output): - input_, weight_, bias_, mean, invvar = ctx.saved_tensors + input_or_output, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + grad_output.contiguous(), mean, invvar, input_or_output, + ctx.normalized_shape, weight_, bias_, ctx.eps, ctx.memory_efficient + ) + return grad_input, grad_weight, grad_bias, None, None, None + + +if supports_custom_op(): + + @torch.library.custom_op("apex::fused_layer_norm_affine_fwd", mutates_args=()) + def fused_layer_norm_affine_fwd( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_layer_norm_cuda.forward_affine( + input_, normalized_shape, weight_, bias_, eps + ) + return output, mean, invvar + + @fused_layer_norm_affine_fwd.register_fake + def fused_layer_norm_affine_fwd_fake( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input = input.contiguous() + weight = weight.contiguous() + bias = bias.contiguous() + idiff = input.ndim - len(normalized_shape) + n = 1 + for i in range(idiff): + n *= input.shape[i] + if input.dtype in [torch.float16, torch.bfloat16]: + dtype = torch.float32 + else: + dtype = input.dtype + mean = torch.empty([n], dtype=dtype, device=input.device) + invvar = torch.empty_like(mean) + return torch.empty_like(input), mean, invvar + + @torch.library.custom_op("apex::fused_layer_norm_affine_bwd", mutates_args=()) + def fused_layer_norm_affine_bwd( + grad_output: torch.Tensor, + mean: torch.Tensor, + invvar: torch.Tensor, + input_or_output: torch.Tensor, + normalized_shape: List[int], + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( + grad_output.contiguous(), + mean, + invvar, + input_or_output, + normalized_shape, + weight, + bias, + eps, + memory_efficient, + ) + return grad_input, grad_weight, grad_bias + + @fused_layer_norm_affine_bwd.register_fake + def fused_layer_norm_affine_bwd_fake( + grad_output: torch.Tensor, + mean: torch.Tensor, + invvar: torch.Tensor, + input_or_output: torch.Tensor, + normalized_shape: List[int], + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + grad_input = torch.empty_like(input_or_output) + grad_weight = torch.empty_like(weight) + grad_bias = torch.empty_like(bias) + return grad_input, grad_weight, grad_bias + + def _fused_layer_norm_affine_backward(ctx, grad_output, grad_mean, grad_invvar): + input_or_output, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias = fused_layer_norm_affine_bwd( + grad_output, + mean, + invvar, + input_or_output, + ctx.normalized_shape, + weight_, + bias_, + ctx.eps, + ctx.memory_efficient, ) - return grad_input, grad_weight, grad_bias, None, None + return grad_input, grad_weight, grad_bias, None, None, None + + def _fused_layer_norm_affine_setup_context(ctx, inputs, output): + input, weight, bias, normalized_shape, eps, memory_efficient = inputs + output, mean, invvar = output + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + if memory_efficient: + ctx.save_for_backward(output, weight_, bias_, None, invvar) + else: + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + + fused_layer_norm_affine_fwd.register_autograd( + _fused_layer_norm_affine_backward, + setup_context=_fused_layer_norm_affine_setup_context, + ) class FusedRMSNormAffineFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, normalized_shape, eps): + def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() weight_ = weight.contiguous() output, invvar = fused_layer_norm_cuda.rms_forward_affine( input_, ctx.normalized_shape, weight_, ctx.eps) - ctx.save_for_backward(input_, weight_, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, invvar) + else: + ctx.save_for_backward(input_, weight_, invvar) return output @staticmethod def backward(ctx, grad_output): - input_, weight_, invvar = ctx.saved_tensors + input_or_output, weight_, invvar = ctx.saved_tensors grad_input = grad_weight = None grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine( - grad_output.contiguous(), invvar, input_, ctx.normalized_shape, weight_, ctx.eps + grad_output.contiguous(), invvar, input_or_output, + ctx.normalized_shape, weight_, ctx.eps, ctx.memory_efficient + ) + return grad_input, grad_weight, None, None, None + +if supports_custom_op(): + @torch.library.custom_op("apex::fused_rms_norm_affine_fwd", mutates_args=()) + def fused_rms_norm_affine_fwd( + input: torch.Tensor, + weight: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + + input_ = input.contiguous() + weight_ = weight.contiguous() + output, invvar = fused_layer_norm_cuda.rms_forward_affine( + input_, normalized_shape, weight_, eps + ) + return output, invvar + + + @fused_rms_norm_affine_fwd.register_fake + def fused_rms_norm_affine_fwd_fake( + input: torch.Tensor, + weight: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input = input.contiguous() + weight = weight.contiguous() + idiff = input.ndim - len(normalized_shape) + n = 1 + for i in range(idiff): + n *= input.shape[i] + if input.dtype in [torch.float16, torch.bfloat16]: + dtype = torch.float32 + else: + dtype = input.dtype + return ( + torch.empty_like(input), + torch.empty( + [n], + dtype=dtype, + device=input.device, + requires_grad=input.requires_grad, + memory_format=torch.contiguous_format, + ), + ) + + + @torch.library.custom_op("apex::fused_rms_norm_affine_bwd", mutates_args=()) + def fused_rms_norm_affine_bwd( + grad_output: torch.Tensor, + invvar: torch.Tensor, + input_or_output: torch.Tensor, + normalized_shape: List[int], + weight: torch.Tensor, + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine( + grad_output.contiguous(), + invvar, + input_or_output, + normalized_shape, + weight, + eps, + memory_efficient, ) - return grad_input, grad_weight, None, None + return grad_input, grad_weight + + + @fused_rms_norm_affine_bwd.register_fake + def fused_rms_norm_affine_bwd_fake( + grad_output: torch.Tensor, + invvar: torch.Tensor, + input_or_output: torch.Tensor, + normalized_shape: List[int], + weight: torch.Tensor, + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grad_input = torch.empty_like(input_or_output) + grad_weight = torch.empty_like(weight) + return grad_input, grad_weight + + + def _fused_rms_norm_affine_backward(ctx, grad_output, grad_invvar): + input_or_output, weight_, invvar = ctx.saved_tensors + grad_input = grad_weight = None + grad_input, grad_weight = fused_rms_norm_affine_bwd( + grad_output, + invvar, + input_or_output, + ctx.normalized_shape, + weight_, + ctx.eps, + ctx.memory_efficient, + ) + return grad_input, grad_weight, None, None, None + + + def _fused_rms_norm_affine_setup_context(ctx, inputs, output): + input_, weight_, normalized_shape, eps, memory_efficient = inputs + output_, invvar = output + input_ = input_.contiguous() + weight_ = weight_.contiguous() + if memory_efficient: + ctx.save_for_backward(output_, weight_, invvar) + else: + ctx.save_for_backward(input_, weight_, invvar) + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + + + fused_rms_norm_affine_fwd.register_autograd( + _fused_rms_norm_affine_backward, + setup_context=_fused_rms_norm_affine_setup_context + ) class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): @staticmethod - def forward(ctx, input, weight, bias, normalized_shape, eps): + def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() output, mean, invvar = fused_layer_norm_cuda.forward_affine_mixed_dtypes( input_, ctx.normalized_shape, weight_, bias_, ctx.eps ) - ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, bias_, None, invvar) + else: + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction): @staticmethod - def forward(ctx, input, weight, normalized_shape, eps): + def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() weight_ = weight.contiguous() output, invvar = fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes( input_, ctx.normalized_shape, weight_, ctx.eps ) - - ctx.save_for_backward(input_, weight_, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, invvar) + else: + ctx.save_for_backward(input_, weight_, invvar) return output +if supports_custom_op(): + + @torch.library.custom_op("apex::fused_layer_norm_affine_mixed_dtypes_fwd", mutates_args=()) + def fused_layer_norm_affine_mixed_dtypes_fwd( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_layer_norm_cuda.forward_affine_mixed_dtypes( + input_, normalized_shape, weight_, bias_, eps + ) + return output, mean, invvar + + @fused_layer_norm_affine_mixed_dtypes_fwd.register_fake + def fused_layer_norm_affine_mixed_dtypes_fwd_fake( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input = input.contiguous() + weight = weight.contiguous() + bias = bias.contiguous() + idiff = input.ndim - len(normalized_shape) + n = 1 + for i in range(idiff): + n *= input.shape[i] + if input.dtype in [torch.float16, torch.bfloat16]: + stat_dtype = torch.float32 + else: + stat_dtype = input.dtype + mean = torch.empty([n], dtype=stat_dtype, device=input.device) + invvar = torch.empty_like(mean) + output = torch.empty_like(input, dtype=weight.dtype) + return output, mean, invvar + + def _fused_layer_norm_affine_mixed_dtypes_backward(ctx, grad_output, grad_mean, grad_invvar): + input_or_output, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input, grad_weight, grad_bias = fused_layer_norm_affine_bwd( + grad_output, + mean, + invvar, + input_or_output, + ctx.normalized_shape, + weight_, + bias_, + ctx.eps, + ctx.memory_efficient, + ) + return grad_input, grad_weight, grad_bias, None, None, None + + def _fused_layer_norm_affine_mixed_dtypes_setup_context(ctx, inputs, output): + input, weight, bias, normalized_shape, eps, memory_efficient = inputs + output, mean, invvar = output + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + if memory_efficient: + ctx.save_for_backward(output, weight_, bias_, None, invvar) + else: + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + + fused_layer_norm_affine_mixed_dtypes_fwd.register_autograd( + _fused_layer_norm_affine_mixed_dtypes_backward, + setup_context=_fused_layer_norm_affine_mixed_dtypes_setup_context, + ) + + @torch.library.custom_op("apex::fused_rms_norm_affine_mixed_dtypes_fwd", mutates_args=()) + def fused_rms_norm_affine_mixed_dtypes_fwd( + input: torch.Tensor, + weight: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + + input_ = input.contiguous() + weight_ = weight.contiguous() + output, invvar = fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes( + input_, normalized_shape, weight_, eps + ) + return output, invvar + + @fused_rms_norm_affine_mixed_dtypes_fwd.register_fake + def fused_rms_norm_affine_mixed_dtypes_fwd_fake( + input: torch.Tensor, + weight: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input = input.contiguous() + weight = weight.contiguous() + idiff = input.ndim - len(normalized_shape) + n = 1 + for i in range(idiff): + n *= input.shape[i] + if input.dtype in [torch.float16, torch.bfloat16]: + stat_dtype = torch.float32 + else: + stat_dtype = input.dtype + output = torch.empty_like(input, dtype=weight.dtype) + invvar = torch.empty( + [n], + dtype=stat_dtype, + device=input.device, + requires_grad=input.requires_grad, + memory_format=torch.contiguous_format, + ) + return output, invvar + + def _fused_rms_norm_affine_mixed_dtypes_backward(ctx, grad_output, grad_invvar): + input_or_output, weight_, invvar = ctx.saved_tensors + grad_input, grad_weight = fused_rms_norm_affine_bwd( + grad_output, + invvar, + input_or_output, + ctx.normalized_shape, + weight_, + ctx.eps, + ctx.memory_efficient, + ) + return grad_input, grad_weight, None, None, None + + def _fused_rms_norm_affine_mixed_dtypes_setup_context(ctx, inputs, output): + input_, weight_, normalized_shape, eps, memory_efficient = inputs + output_, invvar = output + input_ = input_.contiguous() + weight_ = weight_.contiguous() + if memory_efficient: + ctx.save_for_backward(output_, weight_, invvar) + else: + ctx.save_for_backward(input_, weight_, invvar) + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + + fused_rms_norm_affine_mixed_dtypes_fwd.register_autograd( + _fused_rms_norm_affine_mixed_dtypes_backward, + setup_context=_fused_rms_norm_affine_mixed_dtypes_setup_context, + ) + + class FusedLayerNormFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, normalized_shape, eps): + def forward(ctx, input, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() output, mean, invvar = fused_layer_norm_cuda.forward(input_, ctx.normalized_shape, ctx.eps) - ctx.save_for_backward(input_, mean, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, None, invvar) + else: + ctx.save_for_backward(input_, mean, invvar) return output @staticmethod def backward(ctx, grad_output): - input_, mean, invvar = ctx.saved_tensors - grad_input = None + input_or_output, mean, invvar = ctx.saved_tensors + grad_input = fused_layer_norm_cuda.backward( + grad_output.contiguous(), mean, invvar, input_or_output, + ctx.normalized_shape, ctx.eps, ctx.memory_efficient + ) + return grad_input, None, None, None + + +if supports_custom_op(): + + @torch.library.custom_op("apex::fused_layer_norm_fwd", mutates_args=()) + def fused_layer_norm_fwd( + input: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + + input_ = input.contiguous() + output, mean, invvar = fused_layer_norm_cuda.forward( + input_, normalized_shape, eps + ) + return output, mean, invvar + + @fused_layer_norm_fwd.register_fake + def fused_layer_norm_fwd_fake( + input: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input = input.contiguous() + idiff = input.ndim - len(normalized_shape) + n = 1 + for i in range(idiff): + n *= input.shape[i] + if input.dtype in [torch.float16, torch.bfloat16]: + dtype = torch.float32 + else: + dtype = input.dtype + mean = torch.empty([n], dtype=dtype, device=input.device) + invvar = torch.empty_like(mean) + return torch.empty_like(input), mean, invvar + + @torch.library.custom_op("apex::fused_layer_norm_bwd", mutates_args=()) + def fused_layer_norm_bwd( + grad_output: torch.Tensor, + mean: torch.Tensor, + invvar: torch.Tensor, + input_or_output: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> torch.Tensor: grad_input = fused_layer_norm_cuda.backward( - grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, ctx.eps + grad_output.contiguous(), + mean, + invvar, + input_or_output, + normalized_shape, + eps, + memory_efficient, ) - return grad_input, None, None + return grad_input + + @fused_layer_norm_bwd.register_fake + def fused_layer_norm_bwd_fake( + grad_output: torch.Tensor, + mean: torch.Tensor, + invvar: torch.Tensor, + input_or_output: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> torch.Tensor: + grad_input = torch.empty_like(input_or_output) + return grad_input + + def _fused_layer_norm_backward(ctx, grad_output, grad_mean, grad_invvar): + input_or_output, mean, invvar = ctx.saved_tensors + grad_input = fused_layer_norm_bwd( + grad_output, + mean, + invvar, + input_or_output, + ctx.normalized_shape, + ctx.eps, + ctx.memory_efficient, + ) + return grad_input, None, None, None + + def _fused_layer_norm_setup_context(ctx, inputs, output): + input, normalized_shape, eps, memory_efficient = inputs + output, mean, invvar = output + input_ = input.contiguous() + if memory_efficient: + ctx.save_for_backward(output, None, invvar) + else: + ctx.save_for_backward(input_, mean, invvar) + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + + fused_layer_norm_fwd.register_autograd( + _fused_layer_norm_backward, + setup_context=_fused_layer_norm_setup_context, + ) class FusedRMSNormFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, normalized_shape, eps): + def forward(ctx, input, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() output, invvar = fused_layer_norm_cuda.rms_forward(input_, ctx.normalized_shape, ctx.eps) - ctx.save_for_backward(input_, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, invvar) + else: + ctx.save_for_backward(input_, invvar) return output @staticmethod def backward(ctx, grad_output): - input_, invvar = ctx.saved_tensors + input_or_output, invvar = ctx.saved_tensors grad_input = None grad_input = fused_layer_norm_cuda.rms_backward( - grad_output.contiguous(), invvar, input_, ctx.normalized_shape, ctx.eps + grad_output.contiguous(), invvar, input_or_output, + ctx.normalized_shape, ctx.eps, ctx.memory_efficient ) - return grad_input, None, None + return grad_input, None, None, None -def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) - with torch.cuda.amp.autocast(enabled=False): - return FusedLayerNormAffineFunction.apply(*args) +if supports_custom_op(): + @torch.library.custom_op("apex::fused_rms_norm_fwd", mutates_args=()) + def fused_rms_norm_fwd( + input: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + input_ = input.contiguous() + output, invvar = fused_layer_norm_cuda.rms_forward( + input_, normalized_shape, eps + ) + return output, invvar + + + @fused_rms_norm_fwd.register_fake + def fused_rms_norm_fwd_fake( + input: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input = input.contiguous() + idiff = input.ndim - len(normalized_shape) + n = 1 + for i in range(idiff): + n *= input.shape[i] + if input.dtype in [torch.float16, torch.bfloat16]: + dtype = torch.float32 + else: + dtype = input.dtype + return ( + torch.empty_like(input), + torch.empty( + [n], + dtype=dtype, + device=input.device, + requires_grad=input.requires_grad, + memory_format=torch.contiguous_format, + ), + ) -def fused_layer_norm(input, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, normalized_shape, eps) - with torch.cuda.amp.autocast(enabled=False): - return FusedLayerNormFunction.apply(*args) + @torch.library.custom_op("apex::fused_rms_norm_bwd", mutates_args=()) + def fused_rms_norm_bwd( + grad_output: torch.Tensor, + invvar: torch.Tensor, + input_or_output: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> torch.Tensor: + grad_input = fused_layer_norm_cuda.rms_backward( + grad_output.contiguous(), + invvar, + input_or_output, + normalized_shape, + eps, + memory_efficient, + ) + return grad_input -def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) - with torch.cuda.amp.autocast(enabled=False): - return FusedLayerNormAffineMixedDtypesFunction.apply(*args) + @fused_rms_norm_bwd.register_fake + def fused_rms_norm_bwd_fake( + grad_output: torch.Tensor, + invvar: torch.Tensor, + input_or_output: torch.Tensor, + normalized_shape: List[int], + eps: float, + memory_efficient: bool = False, + ) -> torch.Tensor: + grad_input = torch.empty_like(input_or_output) + return grad_input -def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps) - with torch.cuda.amp.autocast(enabled=False): - return FusedRMSNormAffineFunction.apply(*args) + def _fused_rms_norm_backward(ctx, grad_output, grad_invvar): + input_or_output, invvar = ctx.saved_tensors + grad_input = None + grad_input = fused_rms_norm_bwd( + grad_output, + invvar, + input_or_output, + ctx.normalized_shape, + ctx.eps, + ctx.memory_efficient, + ) + return grad_input, None, None, None -def fused_rms_norm(input, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, normalized_shape, eps) - with torch.cuda.amp.autocast(enabled=False): - return FusedRMSNormFunction.apply(*args) + def _fused_rms_norm_setup_context(ctx, inputs, output): + input_, normalized_shape, eps, memory_efficient = inputs + output_, invvar = output + input_ = input_.contiguous() + if memory_efficient: + ctx.save_for_backward(output_, invvar) + else: + ctx.save_for_backward(input_, invvar) + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + + + fused_rms_norm_fwd.register_autograd( + _fused_rms_norm_backward, + setup_context=_fused_rms_norm_setup_context + ) -def mixed_dtype_fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps) - with torch.cuda.amp.autocast(enabled=False): - return FusedRMSNormAffineMixedDtypesFunction.apply(*args) + +def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps, memory_efficient) + with torch.amp.autocast('cuda', enabled=False): + if supports_custom_op(): + return fused_layer_norm_affine_fwd(*args)[0] + else: + return FusedLayerNormAffineFunction.apply(*args) + + +def fused_layer_norm(input, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, normalized_shape, eps, memory_efficient) + with torch.amp.autocast('cuda', enabled=False): + if supports_custom_op(): + return fused_layer_norm_fwd(*args)[0] + else: + return FusedLayerNormFunction.apply(*args) + + +def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps, memory_efficient) + with torch.amp.autocast('cuda', enabled=False): + if supports_custom_op(): + return fused_layer_norm_affine_mixed_dtypes_fwd(*args)[0] + else: + return FusedLayerNormAffineMixedDtypesFunction.apply(*args) + + +def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps, memory_efficient) + with torch.amp.autocast('cuda', enabled=False): + if supports_custom_op(): + return fused_rms_norm_affine_fwd(*args)[0] + else: + return FusedRMSNormAffineFunction.apply(*args) + + +def fused_rms_norm(input, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, normalized_shape, eps, memory_efficient) + with torch.amp.autocast('cuda', enabled=False): + if supports_custom_op(): + return fused_rms_norm_fwd(*args)[0] + else: + return FusedRMSNormFunction.apply(*args) + + +def mixed_dtype_fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps, memory_efficient) + with torch.amp.autocast('cuda', enabled=False): + if supports_custom_op(): + return fused_rms_norm_affine_mixed_dtypes_fwd(*args)[0] + else: + return FusedRMSNormAffineMixedDtypesFunction.apply(*args) class FusedLayerNorm(torch.nn.Module): @@ -261,7 +936,7 @@ class FusedLayerNorm(torch.nn.Module): .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 """ - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, memory_efficient=False): super().__init__() global fused_layer_norm_cuda @@ -272,9 +947,10 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): self.normalized_shape = torch.Size(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine + self.memory_efficient = memory_efficient if self.elementwise_affine: - self.weight = Parameter(torch.Tensor(*normalized_shape)) - self.bias = Parameter(torch.Tensor(*normalized_shape)) + self.weight = Parameter(torch.empty(*normalized_shape)) + self.bias = Parameter(torch.empty(*normalized_shape)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) @@ -286,12 +962,14 @@ def reset_parameters(self): init.zeros_(self.bias) def forward(self, input): - if not input.is_cuda: + if torch.jit.is_tracing() or torch.jit.is_scripting() or torch.compiler.is_compiling() or not input.is_cuda: return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) if self.elementwise_affine: - return fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) + return fused_layer_norm_affine( + input, self.weight, self.bias, self.normalized_shape, self.eps, self.memory_efficient + ) else: - return fused_layer_norm(input, self.normalized_shape, self.eps) + return fused_layer_norm(input, self.normalized_shape, self.eps, self.memory_efficient) def extra_repr(self): return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) @@ -357,7 +1035,7 @@ class FusedRMSNorm(torch.nn.Module): .. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf """ - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, memory_efficient=False): super().__init__() global fused_layer_norm_cuda @@ -368,8 +1046,9 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): self.normalized_shape = torch.Size(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine + self.memory_efficient = memory_efficient if self.elementwise_affine: - self.weight = Parameter(torch.Tensor(*normalized_shape)) + self.weight = Parameter(torch.empty(*normalized_shape)) else: self.register_parameter("weight", None) self.reset_parameters() @@ -379,13 +1058,15 @@ def reset_parameters(self): init.ones_(self.weight) def forward(self, input): - if not input.is_cuda: + if torch.jit.is_tracing() or torch.jit.is_scripting() or torch.compiler.is_compiling() or not input.is_cuda: return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) if self.elementwise_affine: - return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps) + return fused_rms_norm_affine( + input, self.weight, self.normalized_shape, self.eps, self.memory_efficient + ) else: - return fused_rms_norm(input, self.normalized_shape, self.eps) + return fused_rms_norm(input, self.normalized_shape, self.eps, self.memory_efficient) def extra_repr(self): return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) @@ -397,7 +1078,7 @@ def extra_repr(self): # See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" class MixedFusedLayerNorm(FusedLayerNorm): - def __init__(self, normalized_shape, eps=1e-5, **kwargs): + def __init__(self, normalized_shape, eps=1e-5, *, memory_efficient=False, **kwargs): if "elementwise_affine" in kwargs: import warnings warnings.warn("MixedFusedLayerNorm does not support `elementwise_affine` argument") @@ -405,13 +1086,16 @@ def __init__(self, normalized_shape, eps=1e-5, **kwargs): if not elementwise_affine: raise RuntimeError("MixedFusedLayerNorm does not support `elementwise_affine = False`") - super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True) - + super().__init__( + normalized_shape=normalized_shape, eps=eps, elementwise_affine=True, memory_efficient=memory_efficient + ) def forward(self, input: torch.Tensor): # NOTE (mkozuki): CPU path is here mainly for unittest sake. - if not input.is_cuda: + if torch.jit.is_tracing() or torch.jit.is_scripting() or torch.compiler.is_compiling() or not input.is_cuda: return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) - return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) + return mixed_dtype_fused_layer_norm_affine( + input, self.weight, self.bias, self.normalized_shape, self.eps, self.memory_efficient + ) # MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype @@ -419,7 +1103,7 @@ def forward(self, input: torch.Tensor): # See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" class MixedFusedRMSNorm(FusedRMSNorm): - def __init__(self, normalized_shape, eps=1e-5, **kwargs): + def __init__(self, normalized_shape, eps=1e-5, *, memory_efficient=False, **kwargs): if "elementwise_affine" in kwargs: import warnings warnings.warn("MixedFusedRMSNorm does not support `elementwise_affine` argument") @@ -427,11 +1111,13 @@ def __init__(self, normalized_shape, eps=1e-5, **kwargs): if not elementwise_affine: raise RuntimeError("MixedFusedRMSNorm does not support `elementwise_affine = False`") - super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True) - + super().__init__( + normalized_shape=normalized_shape, eps=eps, elementwise_affine=True, memory_efficient=memory_efficient + ) def forward(self, input: torch.Tensor): # NOTE (mkozuki): CPU path is here mainly for unittest sake. - # TODO Manual RMS Norm Implementation Here - if not input.is_cuda: + if torch.jit.is_tracing() or torch.jit.is_scripting() or torch.compiler.is_compiling() or not input.is_cuda: return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) - return mixed_dtype_fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps) + return mixed_dtype_fused_rms_norm_affine( + input, self.weight, self.normalized_shape, self.eps, self.memory_efficient + ) \ No newline at end of file diff --git a/apex/op_builder b/apex/op_builder new file mode 120000 index 000000000..1e19f3e8d --- /dev/null +++ b/apex/op_builder @@ -0,0 +1 @@ +../op_builder \ No newline at end of file diff --git a/apex/optimizers/__init__.py b/apex/optimizers/__init__.py index 25c178c5f..888a4af08 100644 --- a/apex/optimizers/__init__.py +++ b/apex/optimizers/__init__.py @@ -4,3 +4,4 @@ from .fused_lamb import FusedLAMB from .fused_adagrad import FusedAdagrad from .fused_mixed_precision_lamb import FusedMixedPrecisionLamb +from .fused_lars import FusedLARS diff --git a/apex/optimizers/fused_adagrad.py b/apex/optimizers/fused_adagrad.py index d72a68c5d..8d1ef6f32 100644 --- a/apex/optimizers/fused_adagrad.py +++ b/apex/optimizers/fused_adagrad.py @@ -91,7 +91,7 @@ def step(self, closure=None): if len(state) == 0: # Exponential moving average of gradient values state['sum'] = torch.zeros_like(p.data) - if p.dtype == torch.float16: + if p.dtype in {torch.float16, torch.bfloat16}: g_16.append(p.grad.data) p_16.append(p.data) h_16.append(state['sum']) @@ -100,7 +100,7 @@ def step(self, closure=None): p_32.append(p.data) h_32.append(state['sum']) else: - raise RuntimeError('FusedAdagrad only support fp16 and fp32.') + raise RuntimeError('FusedAdagrad only support fp16, bfloat16 and fp32.') if(len(g_16) > 0): multi_tensor_applier(self.multi_tensor_adagrad, diff --git a/apex/optimizers/fused_adam.py b/apex/optimizers/fused_adam.py index c7c135b0a..2ecfc077d 100644 --- a/apex/optimizers/fused_adam.py +++ b/apex/optimizers/fused_adam.py @@ -53,6 +53,11 @@ class FusedAdam(torch.optim.Optimizer): True for decoupled weight decay(also known as AdamW) (default: True) set_grad_none (bool, optional): whether set grad to None when zero_grad() method is called. (default: True) + capturable (bool, optional): whether to use the version of the optimizer + that can be used with CUDA Graphs. (default: False) + master_weights (bool, optional): whether to maintain FP32 master weights + in the optimizer with FP16 mixed precision training, currently can + only be used with capturable set to True. (default: False) .. _Adam - A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -62,20 +67,52 @@ class FusedAdam(torch.optim.Optimizer): def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, adam_w_mode=True, - weight_decay=0., amsgrad=False, set_grad_none=True): + weight_decay=0., amsgrad=False, set_grad_none=True, + capturable=False, master_weights=False): if amsgrad: raise RuntimeError('FusedAdam does not support the AMSGrad variant.') + if master_weights and not capturable: + raise RuntimeError('Master weights is currently only supported with the capturable version.') + # If the optimizer is capturable then LR should be a tensor (on GPU) + lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) super(FusedAdam, self).__init__(params, defaults) self.adam_w_mode = 1 if adam_w_mode else 0 self.set_grad_none = set_grad_none + + self.capturable = capturable + self.master_weights = master_weights + + # Create full precision master weights + self.param_groups_master = [] + for i, pg in enumerate(self.param_groups): + param_list = pg['params'] + self.param_groups_master.append({ + 'params': [ + p.clone().detach().float() if self.master_weights else None + for p in param_list + ], + }) + + if capturable: + for idx, group in enumerate(self.param_groups): + if len(group['params']) == 0: + continue + device = group['params'][0].device + for item in ['lr']: + self.param_groups[idx][item] = group[item].to(device=device) + + self._step_supports_amp_scaling = True + if multi_tensor_applier.available: import amp_C # Skip buffer - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda') self.multi_tensor_adam = amp_C.multi_tensor_adam + self.multi_tensor_adam_capturable = amp_C.multi_tensor_adam_capturable + self.multi_tensor_adam_capturable_master = amp_C.multi_tensor_adam_capturable_master else: raise RuntimeError('apex.optimizers.FusedAdam requires cuda extensions') @@ -87,7 +124,7 @@ def zero_grad(self): else: super(FusedAdam, self).zero_grad() - def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None): + def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None): """Performs a single optimization step. Arguments: @@ -102,23 +139,28 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no if closure is not None: loss = closure() - for group in self.param_groups: + for group, group_master in zip(self.param_groups, self.param_groups_master): + if len(group['params']) == 0: + continue + device = group['params'][0].device bias_correction = 1 if group['bias_correction'] else 0 beta1, beta2 = group['betas'] # assume same step across group now to simplify things # per parameter step can be easily support by making it tensor, or pass list into kernel if 'step' in group: - group['step'] += 1 + group['step'] += 1 if not self.capturable else (self._dummy_overflow_buf != 1).to(torch.int) else: - group['step'] = 1 + group['step'] = 1 if not self.capturable else torch.tensor([1], dtype=torch.int, device=device) # create lists for multi-tensor apply g_16, p_16, m_16, v_16 = [], [], [], [] g_bf, p_bf, m_bf, v_bf = [], [], [], [] g_32, p_32, m_32, v_32 = [], [], [], [] + p_16_master = [] + p_32_master = [] - for p in group['params']: + for p, p_master in zip(group['params'], group_master['params']): if p.grad is None: continue if p.grad.data.is_sparse: @@ -128,11 +170,13 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no # State initialization if len(state) == 0: # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) + state['exp_avg'] = torch.zeros_like(p.data).float() # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data) + state['exp_avg_sq'] = torch.zeros_like(p.data).float() - if p.dtype == torch.float16: + if p.dtype in {torch.float16, torch.bfloat16}: + if self.master_weights: + p_16_master.append(p_master.data) g_16.append(p.grad.data) p_16.append(p.data) m_16.append(state['exp_avg']) @@ -143,51 +187,119 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no m_bf.append(state['exp_avg']) v_bf.append(state['exp_avg_sq']) elif p.dtype == torch.float32: + if self.master_weights: + p_32_master.append(p_master.data) g_32.append(p.grad.data) p_32.append(p.data) m_32.append(state['exp_avg']) v_32.append(state['exp_avg_sq']) else: - raise RuntimeError('FusedAdam only support fp16 and fp32.') - - if(len(g_16) > 0): - multi_tensor_applier(self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_16, p_16, m_16, v_16], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - self.adam_w_mode, - bias_correction, - group['weight_decay']) - if g_bf: - multi_tensor_applier( - self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_bf, p_bf, m_bf, v_bf], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - self.adam_w_mode, - bias_correction, - group['weight_decay'], + raise RuntimeError('FusedAdam only support fp16, bfloat16 and fp32.') + + # If the optimizer is capturable, then if there's a grad scaler it works + # on the GPU + a different multi_tensor_applier should be called + if self.capturable: + # overflow check of gradients + found_inf = ( + grad_scaler._check_inf_per_device(self)[device] + if grad_scaler is not None else torch.zeros((1,), device=device) ) - if(len(g_32) > 0): - multi_tensor_applier(self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_32, p_32, m_32, v_32], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - self.adam_w_mode, - bias_correction, - group['weight_decay']) + self._dummy_overflow_buf.copy_(found_inf) + + # get unscale scale factor + scale, inv_scale = None, None + if grad_scaler: + scale = grad_scaler._get_scale_async() + inv_scale = scale.double().reciprocal().float() + else: + scale = torch.ones((1,), device=device) + inv_scale = torch.ones((1,), device=device) + + if len(g_16) > 0: + multi_tensor_applier(self.multi_tensor_adam_capturable_master if self.master_weights + else self.multi_tensor_adam_capturable, + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16, p_16_master] if self.master_weights + else [g_16, p_16, m_16, v_16], + group['lr'], + beta1, + beta2, + group['eps'], + group['step'], + self.adam_w_mode, + bias_correction, + group['weight_decay'], + inv_scale) + + if len(g_bf) > 0: + multi_tensor_applier( + self.multi_tensor_adam_capturable, + self._dummy_overflow_buf, + [g_bf, p_bf, m_bf, v_bf], + group['lr'], + beta1, + beta2, + group['eps'], + group['step'], + self.adam_w_mode, + bias_correction, + group['weight_decay'], + inv_scale) + + if len(g_32) > 0: + multi_tensor_applier(self.multi_tensor_adam_capturable_master if self.master_weights + else self.multi_tensor_adam_capturable, + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32, p_32_master] if self.master_weights + else [g_32, p_32, m_32, v_32], + group['lr'], + beta1, + beta2, + group['eps'], + group['step'], + self.adam_w_mode, + bias_correction, + group['weight_decay'], + inv_scale) + else: + if len(g_16) > 0: + multi_tensor_applier(self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16], + group['lr'], + beta1, + beta2, + group['eps'], + group['step'], + self.adam_w_mode, + bias_correction, + group['weight_decay']) + + if len(g_bf) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_bf, p_bf, m_bf, v_bf], + group['lr'], + beta1, + beta2, + group['eps'], + group['step'], + self.adam_w_mode, + bias_correction, + group['weight_decay']) + if len(g_32) > 0: + multi_tensor_applier(self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32], + group['lr'], + beta1, + beta2, + group['eps'], + group['step'], + self.adam_w_mode, + bias_correction, + group['weight_decay']) return loss diff --git a/apex/optimizers/fused_lamb.py b/apex/optimizers/fused_lamb.py index 854525dcf..a77e0cd54 100644 --- a/apex/optimizers/fused_lamb.py +++ b/apex/optimizers/fused_lamb.py @@ -1,5 +1,5 @@ import torch -from apex.multi_tensor_apply import multi_tensor_applier +from apex.multi_tensor_apply import multi_tensor_applier, multi_tensor_applier_l2norm class FusedLAMB(torch.optim.Optimizer): @@ -72,7 +72,7 @@ def __init__(self, params, lr=1e-3, bias_correction=True, grad_averaging=grad_averaging, max_grad_norm=max_grad_norm) super(FusedLAMB, self).__init__(params, defaults) - if multi_tensor_applier.available: + if multi_tensor_applier.available and multi_tensor_applier_l2norm.available: import amp_C self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm # Skip buffer @@ -121,16 +121,16 @@ def step(self, closure=None): g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device) # compute grad norm for two lists if len(g_all_32) > 0: - g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, + g_norm_32 = multi_tensor_applier_l2norm(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_32], False)[0] if len(g_all_16) > 0: - g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, + g_norm_16 = multi_tensor_applier_l2norm(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_16], False)[0] # blend two grad norms to get global grad norm - global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, + global_grad_norm = multi_tensor_applier_l2norm(self.multi_tensor_l2norm, self._dummy_overflow_buf, [[g_norm_32, g_norm_16]], False)[0] @@ -166,7 +166,7 @@ def step(self, closure=None): # Exponential moving average of gradient values state['exp_avg_sq'] = torch.zeros_like(p.data) - if p.dtype == torch.float16: + if p.dtype in {torch.float16, torch.bfloat16}: g_16.append(p.grad.data) p_16.append(p.data) m_16.append(state['exp_avg']) @@ -177,7 +177,7 @@ def step(self, closure=None): m_32.append(state['exp_avg']) v_32.append(state['exp_avg_sq']) else: - raise RuntimeError('FusedLAMB only support fp16 and fp32.') + raise RuntimeError('FusedLAMB only support fp16, bfloat16 and fp32.') if(len(g_16) > 0): multi_tensor_applier(self.multi_tensor_lamb, diff --git a/apex/optimizers/fused_lars.py b/apex/optimizers/fused_lars.py new file mode 100644 index 000000000..3e60b2cce --- /dev/null +++ b/apex/optimizers/fused_lars.py @@ -0,0 +1,224 @@ +import torch +from torch.optim.optimizer import Optimizer, required +from torch import nn +from torch.nn.parameter import Parameter +from apex.multi_tensor_apply import multi_tensor_applier + +class FusedLARS(Optimizer): + def __init__(self, params, lr=required, momentum=0, dampening=0, + weight_decay=0, trust_coefficient=0.001, eps=0.0, + nesterov=False, wd_after_momentum=False, + materialize_master_grads=True, set_grad_none=False): + + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov, trust_coefficient=trust_coefficient, eps=eps, is_skipped=False) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super(FusedLARS, self).__init__(params, defaults) + + self.wd_after_momentum = wd_after_momentum + self.materialize_master_grads = materialize_master_grads + self.most_recent_scale = 1.0 + self.scale_set_by_backward = False + self.set_grad_none = set_grad_none + self.trust_coefficient = trust_coefficient + self.eps = eps + + if multi_tensor_applier.available: + import amp_C + # Skip buffer + self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) + self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm + self.multi_tensor_lars = amp_C.multi_tensor_lars + self._dummy_overflow_buf = torch.cuda.IntTensor(1).zero_() + else: + raise RuntimeError('apex.optimizers.FusedLARS requires cuda extensions') + + def __setstate__(self, state): + super(FusedLARS, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + def zero_grad(self): + if self.set_grad_none: + for group in self.param_groups: + for p in group['params']: + p.grad = None + else: + super(FusedLARS, self).zero_grad() + + def get_momentums(self, params): + momentums = [] + first_run = True + for p in params: + if p.grad is None: + continue + + param_state = self.state[p] + d_p = p.grad.data + # torch.optim.SGD initializes momentum in the main loop, we have + # to do it here, and track whether or not we've done so, so that + # momentum application can be skipped in the main kernel. + if 'momentum_buffer' not in param_state: + first_run = True + buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) + momentums.append(buf) + else: + first_run = False + momentums.append(param_state['momentum_buffer']) + return momentums, first_run + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + explicit_master_params = (hasattr(self, "_amp_stash") and + hasattr(self._amp_stash, "fp32_from_fp16_groups")) + explicit_master_params = False + + for gid, group in enumerate(self.param_groups): + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + lr = group['lr'] + is_skipped = group['is_skipped'] + + # For each group, there are 3 possible combinations we need to consider: + # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy + # 1. fp16, fp16, fp16, No + # 2. fp32, fp32, fp32, No + # 3. fp16, fp32, fp32, Yes + + first_runs = [True, True] + g_norms_grp = [] + w_norms_grp = [] + + + # I think a bit of code divergence in exchange for naming clarity is worthwhile + if explicit_master_params: + print('explicit_master_params') + stash = self._amp_stash + + fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None] + fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None] + fp32_momentums, first_runs[1] = self.get_momentums(fp32_params) + + if self.materialize_master_grads: + fp16_model_params = [p for i, p in enumerate( + stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None] + fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None] + fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None] + fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params) + + fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params, + fp32_from_fp16_momentums, fp16_model_params] + else: + fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None] + fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None] + fp32_from_fp16_params = [p for i, p in enumerate( + stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None] + fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params) + + fp16_set = [fp16_model_grads, fp32_from_fp16_params, + fp32_from_fp16_momentums, fp16_model_params] + + launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]] + + else: + fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)] + #fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)] + fp16_grads = [] + for p in fp16_params: + if p.is_contiguous(): + fp16_grads.append(p.grad) + elif p.is_contiguous(memory_format=torch.channels_last): + fp16_grads.append(p.grad.to(memory_format=torch.channels_last)) + fp16_momentums, first_runs[0] = self.get_momentums(fp16_params) + # Compute L2 norms + if len(fp16_params) > 0: + w_norms = multi_tensor_applier( + self.multi_tensor_l2norm, + self._dummy_overflow_buf, + [[p.data for p in fp16_params]], + True)[1] + g_norms = multi_tensor_applier( + self.multi_tensor_l2norm, + self._dummy_overflow_buf, + [[p.data for p in fp16_grads]], + True)[1] + else: + w_norms = [] + g_norms = [] + w_norms_grp.append(w_norms) + g_norms_grp.append(g_norms) + + fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)] + fp32_grads = [] + for p in fp32_params: + if p.is_contiguous(): + fp32_grads.append(p.grad) + elif p.is_contiguous(memory_format=torch.channels_last): + fp32_grads.append(p.grad.to(memory_format=torch.channels_last)) + fp32_momentums, first_runs[1] = self.get_momentums(fp32_params) + # Compute L2 norms + if len(fp32_params) > 0: + w_norms = multi_tensor_applier( + self.multi_tensor_l2norm, + self._dummy_overflow_buf, + [[p.data for p in fp32_params]], + True)[1] + g_norms = multi_tensor_applier( + self.multi_tensor_l2norm, + self._dummy_overflow_buf, + [[p.data for p in fp32_grads]], + True)[1] + else: + w_norms = [] + g_norms = [] + w_norms_grp.append(w_norms) + g_norms_grp.append(g_norms) + + launch_sets = [[fp16_grads, fp16_params, fp16_momentums], + [fp32_grads, fp32_params, fp32_momentums]] + + for s, (launch_set, first_run, g_norms, w_norms) in enumerate(zip(launch_sets, first_runs, g_norms_grp, w_norms_grp)): + assert len(launch_set[0]) == len(launch_set[1]) + assert len(launch_set[0]) == len(launch_set[2]) + if len(launch_set[0]) > 0: + multi_tensor_applier( + self.multi_tensor_lars, + self._dummy_overflow_buf, + launch_set, + g_norms, + w_norms, + group['lr'], + group['trust_coefficient'], + self.eps, + weight_decay, + momentum, + dampening, + nesterov, + first_run, + self.wd_after_momentum, + 1.0/self.most_recent_scale, + group['is_skipped']) + + self.most_recent_scale = 1.0 + self.scale_set_by_backward = False + + return loss diff --git a/apex/optimizers/fused_mixed_precision_lamb.py b/apex/optimizers/fused_mixed_precision_lamb.py index f1b2902ca..7ecda4f51 100644 --- a/apex/optimizers/fused_mixed_precision_lamb.py +++ b/apex/optimizers/fused_mixed_precision_lamb.py @@ -3,7 +3,7 @@ from itertools import chain from collections import defaultdict, abc as container_abcs -from apex.multi_tensor_apply import multi_tensor_applier +from apex.multi_tensor_apply import multi_tensor_applier, multi_tensor_applier_l2norm class FusedMixedPrecisionLamb(torch.optim.Optimizer): @@ -32,7 +32,7 @@ def __init__(self, params, lr=1e-3, step=0, bias_correction=True, for item in tensor_state: self.param_groups[idx][item] = group[item].to(device=device) - if multi_tensor_applier.available: + if multi_tensor_applier.available and multi_tensor_applier_l2norm.available: import amp_C self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm_mp # Skip buffer @@ -180,7 +180,7 @@ def step(self, closure=None, grad_scaler=None): # grad_norm is of scaled gradients. # So, multiply `max_grad_norm` by scale. max_grad_norm = self.defaults['max_grad_norm'] * scale - grad_norm = multi_tensor_applier( + grad_norm = multi_tensor_applier_l2norm( self.multi_tensor_l2norm, self._dummy_overflow_buf, [grad_list], diff --git a/apex/optimizers/fused_novograd.py b/apex/optimizers/fused_novograd.py index 2820ae36c..b3ec5acb9 100644 --- a/apex/optimizers/fused_novograd.py +++ b/apex/optimizers/fused_novograd.py @@ -144,7 +144,7 @@ def step(self, closure=None): # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p.data) - if p.dtype == torch.float16: + if p.dtype in {torch.float16, torch.bfloat16}: g_16.append(p.grad.data) p_16.append(p.data) m_16.append(state['exp_avg']) @@ -153,7 +153,7 @@ def step(self, closure=None): p_32.append(p.data) m_32.append(state['exp_avg']) else: - raise RuntimeError('FusedNovoGrad only support fp16 and fp32.') + raise RuntimeError('FusedNovoGrad only support fp16, bfloat16 and fp32.') # we store per weight norm as one tensor for one group/precision combination # different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types diff --git a/apex/optimizers/fused_sgd.py b/apex/optimizers/fused_sgd.py index e7bdcb2b9..88f26f27a 100644 --- a/apex/optimizers/fused_sgd.py +++ b/apex/optimizers/fused_sgd.py @@ -175,15 +175,33 @@ def step(self, closure=None): if self.materialize_master_grads: fp16_model_params = [p for i, p in enumerate( stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None] - fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None] fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None] + fp32_from_fp16_grads = [] + for p in fp32_from_fp16_params: + if p.is_contiguous(memory_format=torch.contiguous_format): + fp32_from_fp16_grads.append(p.grad) + elif p.is_contiguous(memory_format=torch.channels_last): + fp32_from_fp16_grads.append(p.grad.to(memory_format=torch.channels_last)) + elif p.is_contiguous(memory_format=torch.channel_last_3d): + fp32_from_fp16_grads.append(p.grad.to(memory_format=torch.channel_last_3d)) + else: + assert(False), "Unsupported memory format. Supports only contiguous_format, channels_last, or channel_last_3d." fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params) fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params, fp32_from_fp16_momentums, fp16_model_params] else: fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None] - fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None] + fp16_model_grads = [] + for p in fp16_model_params: + if p.is_contiguous(memory_format=torch.contiguous_format): + fp16_model_grads.append(p.grad) + elif p.is_contiguous(memory_format=torch.channels_last): + fp16_model_grads.append(p.grad.to(memory_format=torch.channels_last)) + elif p.is_contiguous(memory_format=torch.channel_last_3d): + fp16_model_grads.append(p.grad.to(memory_format=torch.channel_last_3d)) + else: + assert(False), "Unsupported memory format. Supports only contiguous_format, channels_last, or channel_last_3d." fp32_from_fp16_params = [p for i, p in enumerate( stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None] fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params) @@ -194,11 +212,29 @@ def step(self, closure=None): launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]] else: fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)] - fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)] + fp16_grads = [] + for p in fp16_params: + if p.is_contiguous(memory_format=torch.contiguous_format): + fp16_grads.append(p.grad) + elif p.is_contiguous(memory_format=torch.channels_last): + fp16_grads.append(p.grad.to(memory_format=torch.channels_last)) + elif p.is_contiguous(memory_format=torch.channel_last_3d): + fp16_grads.append(p.grad.to(memory_format=torch.channel_last_3d)) + else: + assert(False), "Unsupported memory format. Supports only contiguous_format, channels_last, or channel_last_3d." fp16_momentums, first_runs[0] = self.get_momentums(fp16_params) fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)] - fp32_grads = [p.grad for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)] + fp32_grads = [] + for p in fp32_params: + if p.is_contiguous(memory_format=torch.contiguous_format): + fp32_grads.append(p.grad) + elif p.is_contiguous(memory_format=torch.channels_last): + fp32_grads.append(p.grad.to(memory_format=torch.channels_last)) + elif p.is_contiguous(memory_format=torch.channel_last_3d): + fp32_grads.append(p.grad.to(memory_format=torch.channel_last_3d)) + else: + assert(False), "Unsupported memory format. Supports only contiguous_format, channels_last, or channel_last_3d." fp32_momentums, first_runs[1] = self.get_momentums(fp32_params) launch_sets = [[fp16_grads, fp16_params, fp16_momentums], @@ -208,6 +244,7 @@ def step(self, closure=None): assert len(launch_set[0]) == len(launch_set[1]) assert len(launch_set[0]) == len(launch_set[2]) if len(launch_set[0]) > 0: + # multi_tensor_applier has nhwc support: https://github.com/NVIDIA/apex/pull/732 multi_tensor_applier( self.multi_tensor_sgd, self._dummy_overflow_buf, diff --git a/apex/parallel/distributed.py b/apex/parallel/distributed.py index 5267c834a..6aa6a6e8a 100644 --- a/apex/parallel/distributed.py +++ b/apex/parallel/distributed.py @@ -48,8 +48,8 @@ def apply_flat_dist_call(bucket, call, extra_args=None): for buf, synced in zip(bucket, unflatten(coalesced, bucket)): buf.copy_(synced) -def split_half_float_double(tensors): - dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor"] +def split_half_float_double_bfloat16(tensors): + dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"] buckets = [] for i, dtype in enumerate(dtypes): bucket = [t for t in tensors if t.type() == dtype] @@ -240,7 +240,8 @@ def __init__(self, self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0, "torch.cuda.FloatTensor" : 1, - "torch.cuda.DoubleTensor" : 2} + "torch.cuda.DoubleTensor" : 2, + "torch.cuda.BFloat16Tensor" : 3} if multi_tensor_applier.available: # TODO: I really need to centralize the C++ backed imports @@ -498,7 +499,7 @@ def allreduce_fallback(self): else: grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] - split_buckets = split_half_float_double(grads) + split_buckets = split_half_float_double_bfloat16(grads) # If retain_allreduce_buffers is True and delay_allreduce is False, # this will only be done during the first backward pass, ignored by the @@ -578,8 +579,8 @@ def forward(self, *inputs, **kwargs): if self.needs_refresh: self.active_i_buckets = [] self.buckets = [] - self.tmp_buckets = [[], [], []] # [running half, float, double buckets] - self.tmp_numels = [0, 0, 0] + self.tmp_buckets = [[], [], [], []] # [running half, float, double, bfloat16 buckets] + self.tmp_numels = [0, 0, 0, 0] self.bucket_sizes = [] self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)} self.param_id_to_bucket = {} diff --git a/apex/testing/__init__.py b/apex/testing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apex/testing/common_utils.py b/apex/testing/common_utils.py new file mode 100644 index 000000000..82b660f9b --- /dev/null +++ b/apex/testing/common_utils.py @@ -0,0 +1,33 @@ +''' +This file contains common utility functions for running the unit tests on ROCM. +''' + +import torch +import os +import sys +from functools import wraps +import unittest + + +TEST_WITH_ROCM = os.getenv('APEX_TEST_WITH_ROCM', '0') == '1' +SKIP_FLAKY_TEST = os.getenv('APEX_SKIP_FLAKY_TEST', '0') == '1' + +## Wrapper to skip the unit tests. +def skipIfRocm(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_WITH_ROCM: + raise unittest.SkipTest("test doesn't currently work on ROCm stack.") + else: + fn(*args, **kwargs) + return wrapper + +## Wrapper to skip the flaky unit tests. +def skipFlakyTest(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if SKIP_FLAKY_TEST: + raise unittest.SkipTest("Test is flaky.") + else: + fn(*args, **kwargs) + return wrapper diff --git a/apex/transformer/amp/grad_scaler.py b/apex/transformer/amp/grad_scaler.py index 5bcd061d9..931110afc 100644 --- a/apex/transformer/amp/grad_scaler.py +++ b/apex/transformer/amp/grad_scaler.py @@ -35,6 +35,12 @@ def __init__( enabled=enabled, ) + def _unscale_grads_(self, optimizer, *args): + if getattr(optimizer, "_custom_amp_unscale_grads", False): + return optimizer.unscale_grads(*args) + else: + return super()._unscale_grads_(optimizer, *args) + def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): retval = None found_inf = torch.cuda.FloatTensor([sum(v.item() for v in optimizer_state["found_inf_per_device"].values())]) diff --git a/apex/transformer/functional/__init__.py b/apex/transformer/functional/__init__.py index d770c8859..f307df79f 100644 --- a/apex/transformer/functional/__init__.py +++ b/apex/transformer/functional/__init__.py @@ -1,5 +1,15 @@ +from apex.transformer.functional.fused_rope import ( + fused_apply_rotary_pos_emb, + fused_apply_rotary_pos_emb_cached, + fused_apply_rotary_pos_emb_thd, + fused_apply_rotary_pos_emb_2d, +) from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax __all__ = [ "FusedScaleMaskSoftmax", + "fused_apply_rotary_pos_emb", + "fused_apply_rotary_pos_emb_cached", + "fused_apply_rotary_pos_emb_thd", + "fused_apply_rotary_pos_emb_2d", ] diff --git a/apex/transformer/functional/fused_rope.py b/apex/transformer/functional/fused_rope.py new file mode 100644 index 000000000..e74906151 --- /dev/null +++ b/apex/transformer/functional/fused_rope.py @@ -0,0 +1,565 @@ +# coding=utf-8 +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union +import torch +import os +from torch.utils.cpp_extension import ROCM_HOME +import warnings + +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) + +def check_if_rocm_pytorch(): + is_rocm_pytorch = False + if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5): + is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False + return is_rocm_pytorch + +IS_ROCM_PYTORCH = check_if_rocm_pytorch() + +# an envrionment variable to explicitly switch on/off aiter backend +# by default it is 1, which means aiter backend is enabled +USE_ROCM_AITER_ROPE_BACKEND = int(os.environ.get("USE_ROCM_AITER_ROPE_BACKEND", 1)) == 1 + +# a flag to switch between the native apex kernel and native aiter kernel +# by default it is False +AITER_ROPE_BACKEND = False +''' +False - native kernel in apex repo +True - aiter native kernel +''' + +# switch on aiter backend if it is rocm and aiter is enabled from the user +if IS_ROCM_PYTORCH and USE_ROCM_AITER_ROPE_BACKEND: + try: + import aiter + AITER_ROPE_BACKEND = True + warnings.warn("Aiter backend is selected for fused RoPE. This has lower precision. To disable aiter, export USE_ROCM_AITER_ROPE_BACKEND=0", UserWarning) + except ImportError: + AITER_ROPE_BACKEND = False +if not AITER_ROPE_BACKEND: + import fused_rotary_positional_embedding + warnings.warn("Using the native apex kernel for RoPE.", UserWarning) + + +class FusedRoPEFunc(torch.autograd.Function): + """ + Fused RoPE function + + This implementation assumes the input tensor to be in `sbhd` format and the RoPE tensor to be + of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive + `.contiguous()` calls, thus it may not achieve the best memory access pattern. + """ + + @staticmethod + def forward( + ctx, + t: torch.Tensor, + freqs: torch.Tensor, + transpose_output_memory: bool = False, + ) -> torch.Tensor: + raise ValueError("Invalid forward implementation.") + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + raise ValueError("Invalid backward implementation.") + +class FusedRoPEFuncApex(FusedRoPEFunc): + @staticmethod + def forward( + ctx, + t: torch.Tensor, + freqs: torch.Tensor, + transpose_output_memory: bool = False, + ) -> torch.Tensor: + output = fused_rotary_positional_embedding.forward( + t, freqs, transpose_output_memory + ) + ctx.save_for_backward(freqs) + ctx.transpose_output_memory = transpose_output_memory + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + (freqs,) = ctx.saved_tensors + grad_input = fused_rotary_positional_embedding.backward( + grad_output, freqs, ctx.transpose_output_memory + ) + return grad_input, None, None + +class FusedRoPEFuncAiter(FusedRoPEFunc): + @staticmethod + def forward( + ctx, + t: torch.Tensor, + freqs: torch.Tensor, + transpose_output_memory: bool = False, + ) -> torch.Tensor: + s = t.shape[0] + b = t.shape[1] + h = t.shape[2] + d = t.shape[3] + # t is of shape [s, b, h, d] + # freqs is of shape [s, 1, 1, d] + + act_options = {'dtype': t.dtype, 'device': t.device, 'requires_grad': False} + if transpose_output_memory: + output = torch.empty((b, s, h, d), **act_options).transpose(0, 1) + else: + output = torch.empty((s, b, h, d), **act_options) + aiter.rope_fwd_impl(output, t, freqs, 0, False, False) + + ctx.save_for_backward(freqs) + ctx.transpose_output_memory = transpose_output_memory + + return output + + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + (freqs,) = ctx.saved_tensors + + s = grad_output.shape[0] + b = grad_output.shape[1] + h = grad_output.shape[2] + d = grad_output.shape[3] + + act_options = {'dtype': grad_output.dtype, 'device': grad_output.device, 'requires_grad': False} + if ctx.transpose_output_memory: + grad_input = torch.empty((b, s, h, d), **act_options).transpose(0, 1) + else: + grad_input = torch.empty((s, b, h, d), **act_options) + aiter.rope_bwd_impl(grad_input, grad_output, freqs, 0, False, False) + + return grad_input, None, None + + +def fused_apply_rotary_pos_emb( + t: torch.Tensor, + freqs: torch.Tensor, + transpose_output_memory: bool = False, +) -> torch.Tensor: + """Apply rotary positional embedding to input tensor T in `sbhd` format, where + s: sequence length + b: batch size + h: head num + d: dim of each head + + Args: + t (Tensor): Input tensor T is of shape [s, b, h, d] + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [s, 1, 1, d] and + `float` dtype + transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b' + dimension of the output's underlying memory format. This is very helpful when you want to + get a contiguous tensor after calling `output.transpose(0, 1)`. + + Returns: + Tensor: The input tensor after applying RoPE + """ + FusedRoPEFunc = FusedRoPEFuncAiter if AITER_ROPE_BACKEND else FusedRoPEFuncApex + return FusedRoPEFunc.apply(t, freqs, transpose_output_memory) + +class FusedRoPECachedFunc(torch.autograd.Function): + """ + Fused RoPE function + + This implementation assumes the input tensor to be in `sbhd` format and the RoPE tensor to be + of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive + `.contiguous()` calls, thus it may not achieve the best memory access pattern. + """ + + @staticmethod + def forward( + ctx, + t: torch.Tensor, + cos_: torch.Tensor, + sin_: torch.Tensor, + transpose_output_memory: bool = False, + ) -> torch.Tensor: + raise ValueError("Invalid forward implementation.") + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + raise ValueError("Invalid backward implementation.") + +class FusedRoPECachedFuncApex(FusedRoPECachedFunc): + @staticmethod + def forward( + ctx, + t: torch.Tensor, + cos_: torch.Tensor, + sin_: torch.Tensor, + transpose_output_memory: bool = False, + ) -> torch.Tensor: + output = fused_rotary_positional_embedding.forward_cached( + t, cos_, sin_, transpose_output_memory + ) + ctx.save_for_backward(cos_, sin_) + ctx.transpose_output_memory = transpose_output_memory + + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + cos_, sin_ = ctx.saved_tensors + grad_input = fused_rotary_positional_embedding.backward_cached( + grad_output, cos_, sin_, ctx.transpose_output_memory + ) + return grad_input, None, None, None + +class FusedRoPECachedFuncAiter(FusedRoPECachedFunc): + @staticmethod + def forward( + ctx, + t: torch.Tensor, + cos_: torch.Tensor, + sin_: torch.Tensor, + transpose_output_memory: bool = False, + ) -> torch.Tensor: + s = t.shape[0] + b = t.shape[1] + h = t.shape[2] + d = t.shape[3] + # t is of shape [s, b, h, d] + # freqs is of shape [s, 1, 1, d] + + act_options = {'dtype': t.dtype, 'device': t.device, 'requires_grad': False} + if transpose_output_memory: + output = torch.empty((b, s, h, d), **act_options).transpose(0, 1) + else: + output = torch.empty((s, b, h, d), **act_options) + aiter.rope_cached_fwd_impl(output, t, cos_, sin_, 0, False, False) + + ctx.save_for_backward(cos_, sin_) + ctx.transpose_output_memory = transpose_output_memory + + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + cos_, sin_ = ctx.saved_tensors + + s = grad_output.shape[0] + b = grad_output.shape[1] + h = grad_output.shape[2] + d = grad_output.shape[3] + + act_options = {'dtype': grad_output.dtype, 'device': grad_output.device, 'requires_grad': False} + if ctx.transpose_output_memory: + grad_input = torch.empty((b, s, h, d), **act_options).transpose(0, 1) + else: + grad_input = torch.empty((s, b, h, d), **act_options) + aiter.rope_cached_bwd_impl(grad_input, grad_output, cos_, sin_, 0, False, False) + return grad_input, None, None, None + +def fused_apply_rotary_pos_emb_cached( + t: torch.Tensor, + cos_: torch.Tensor, + sin_: torch.Tensor, + transpose_output_memory: bool = False, +) -> torch.Tensor: + """Apply rotary positional embedding to input tensor T in `sbhd` format, where + s: sequence length + b: batch size + h: head num + d: dim of each head + + Args: + t (Tensor): Input tensor T is of shape [s, b, h, d] + cos_ (Tensor): Cached cosine of the rotary positional embedding tensor is of + shape [s, 1, 1, d] and dtype either `float` or the same as `t`. + sin_ (Tensor): Cached sine of the rotary positional embedding tensor is of + shape [s, 1, 1, d] and dtype either `float` or the same as `t`. + transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b' + dimension of the output's underlying memory format. This is very helpful when you want to + get a contiguous tensor after calling `output.transpose(0, 1)`. + + Returns: + Tensor: The input tensor after applying RoPE + """ + FusedRoPEFunc = FusedRoPECachedFuncAiter if AITER_ROPE_BACKEND else FusedRoPECachedFuncApex + return FusedRoPEFunc.apply(t, cos_, sin_, transpose_output_memory) + +class FusedRoPETHDFunc(torch.autograd.Function): + """ + Fused RoPE function for `thd` format. + + This implementation accepts arbitrary memory layouts to avoid the expensive + `.contiguous()` calls, thus it may not achieve the best memory access pattern. + """ + + @staticmethod + def forward( + ctx, + t: torch.Tensor, + cu_seqlens: torch.Tensor, + freqs: torch.Tensor, + ) -> torch.Tensor: + raise ValueError("Invalid forward implementation.") + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + raise ValueError("Invalid backward implementation.") + +class FusedRoPETHDFuncApex(FusedRoPETHDFunc): + @staticmethod + def forward( + ctx, + t: torch.Tensor, + cu_seqlens: torch.Tensor, + freqs: torch.Tensor, + ) -> torch.Tensor: + output = fused_rotary_positional_embedding.forward_thd( + t, cu_seqlens, freqs + ) + ctx.save_for_backward(cu_seqlens, freqs) + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + cu_seqlens, freqs = ctx.saved_tensors + grad_input = fused_rotary_positional_embedding.backward_thd( + grad_output, cu_seqlens, freqs + ) + return grad_input, None, None + +class FusedRoPETHDFuncAiter(FusedRoPETHDFunc): + + @staticmethod + def forward( + ctx, + t: torch.Tensor, + cu_seqlens: torch.Tensor, + freqs: torch.Tensor, + ) -> torch.Tensor: + t1 = t.shape[0] + h = t.shape[1] + d = t.shape[2] + # t is of shape [t, h, d] + + act_options = {'dtype': t.dtype, 'device': t.device, 'requires_grad': False} + output = torch.empty((t1, h, d), **act_options) + aiter.rope_thd_fwd_impl(output, t, cu_seqlens, freqs, 0, False, False) + + ctx.save_for_backward(cu_seqlens, freqs) + + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + cu_seqlens, freqs = ctx.saved_tensors + + t = grad_output.shape[0] + h = grad_output.shape[1] + d = grad_output.shape[2] + # t is of shape [t, h, d] + + act_options = {'dtype': grad_output.dtype, 'device': grad_output.device, 'requires_grad': False} + grad_input = torch.empty((t, h, d), **act_options) + aiter.rope_thd_bwd_impl(grad_input, grad_output, cu_seqlens, freqs, 0, False, False) + + return grad_input, None, None + +def fused_apply_rotary_pos_emb_thd( + t: torch.Tensor, + cu_seqlens: torch.Tensor, + freqs: torch.Tensor, +) -> torch.Tensor: + """Apply rotary positional embedding to input tensor T in `thd` format, where + t: cumulative sum of sequence lengths + h: head num + d: dim of each head + + Args: + t (Tensor): Input tensor T is of shape [t, h, d] + cu_seqlens (Tensor): Cumulative sum of sequence lengths in a batch for `t`, + with shape [b + 1] and dtype torch.int32. + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] and + `float` dtype + + Returns: + Tensor: The input tensor after applying RoPE + """ + FusedRoPEFunc = FusedRoPETHDFuncAiter if AITER_ROPE_BACKEND else FusedRoPETHDFuncApex + return FusedRoPEFunc.apply(t, cu_seqlens, freqs) + +class FusedRoPE2DFunc(torch.autograd.Function): + """ + Fused 2D RoPE function + """ + @staticmethod + def forward( + ctx, + t: torch.Tensor, + img_h: int, + img_w: int, + cos_h: torch.Tensor, + sin_h: torch.Tensor, + cos_w: torch.Tensor, + sin_w: torch.Tensor, + ) -> torch.Tensor: + raise ValueError("Invalid forward implementation.") + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + raise ValueError("Invalid backward implementation.") + +class FusedRoPE2DFuncApex(FusedRoPE2DFunc): + @staticmethod + def forward( + ctx, + t: torch.Tensor, + img_h: int, + img_w: int, + cos_h: torch.Tensor, + sin_h: torch.Tensor, + cos_w: torch.Tensor, + sin_w: torch.Tensor, + ) -> torch.Tensor: + t = t.view(t.shape[0], img_h, img_w, t.shape[2], t.shape[3]) + output = fused_rotary_positional_embedding.forward_2d( + t, cos_h, sin_h, cos_w, sin_w + ) + ctx.save_for_backward(cos_h, sin_h, cos_w, sin_w) + ctx.img_h = img_h + ctx.img_w = img_w + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + + cos_h, sin_h, cos_w, sin_w = ctx.saved_tensors + + grad_output = grad_output.view( + grad_output.shape[0], + ctx.img_h, + ctx.img_w, + grad_output.shape[2], + grad_output.shape[3], + ) + grad_input = fused_rotary_positional_embedding.backward_2d( + grad_output, cos_h, sin_h, cos_w, sin_w + ) + return grad_input, None, None, None, None, None, None + +class FusedRoPE2DFuncAiter(FusedRoPE2DFunc): + @staticmethod + def forward( + ctx, + t: torch.Tensor, + img_h: int, + img_w: int, + cos_h: torch.Tensor, + sin_h: torch.Tensor, + cos_w: torch.Tensor, + sin_w: torch.Tensor, + ) -> torch.Tensor: + + s = t.shape[0] + h = t.shape[2] + d = t.shape[3] + # t is of shape [s, ih*iw, h, d] + + act_options = {'dtype': t.dtype, 'device': t.device, 'requires_grad': False} + output = torch.empty((s, img_h * img_w, h, d), **act_options) + aiter.rope_2d_fwd_impl(output, t, cos_h, sin_h, cos_w, sin_w, img_h, img_w, 0, False, False) + ctx.save_for_backward(cos_h, sin_h, cos_w, sin_w) + ctx.img_h = img_h + ctx.img_w = img_w + + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + + cos_h, sin_h, cos_w, sin_w = ctx.saved_tensors + + s = grad_output.shape[0] + h = grad_output.shape[2] + d = grad_output.shape[3] + # t is of shape [s, ih* iw, h, d] + + act_options = {'dtype': grad_output.dtype, 'device': grad_output.device, 'requires_grad': False} + grad_input = torch.empty((s, ctx.img_h * ctx.img_w, h, d), **act_options) + aiter.rope_2d_bwd_impl(grad_input, grad_output, cos_h, sin_h, cos_w, sin_w, ctx.img_h, ctx.img_w, 0, False, False) + + return grad_input, None, None, None, None, None, None + +def fused_apply_rotary_pos_emb_2d( + t: torch.Tensor, + img_h: int, + img_w: int, + cos_h: torch.Tensor, + sin_h: torch.Tensor, + cos_w: torch.Tensor, + sin_w: torch.Tensor, +) -> torch.Tensor: + """Apply rotary positional embedding to input tensor T in `bshd` format, where + b: batch size + s: sequence length + h: head num + d: dim of each head + + Args: + t (Tensor): Input tensor T is of shape [b, s, h, d] + img_h (int): s == img_h * img_w + img_w (int): s == img_h * img_w + cos_h (Tensor): shape [1, H, 1, d // 2] and dtype either `float` or + the same as `t`. H >= img_h. + sin_h (Tensor): shape [1, H, 1, d // 2] and dtype either `float` or + the same as `t`. H >= img_h. + cos_w (Tensor): shape [1, W, 1, d // 2] and dtype either `float` or + the same as `t`. W >= img_w. + sin_w (Tensor): shape [1, W, 1, d // 2] and dtype either `float` or + the same as `t`. W >= img_w. + + Returns: + Tensor: The input tensor after applying RoPE + """ + assert ( + t.size(1) == img_h * img_w + ), "The sequence length should be equal to img_h * img_w" + assert ( + cos_h.size() == sin_h.size() + ), "The shape of cos_h and sin_h should be the same" + assert ( + cos_w.size() == sin_w.size() + ), "The shape of cos_w and sin_w should be the same" + FusedRoPEFunc = FusedRoPE2DFuncAiter if AITER_ROPE_BACKEND else FusedRoPE2DFuncApex + return FusedRoPEFunc.apply(t, img_h, img_w, cos_h, sin_h, cos_w, sin_w) \ No newline at end of file diff --git a/apex/transformer/functional/fused_softmax.py b/apex/transformer/functional/fused_softmax.py index 8ceaffef9..83243ef7b 100644 --- a/apex/transformer/functional/fused_softmax.py +++ b/apex/transformer/functional/fused_softmax.py @@ -92,10 +92,73 @@ def backward(ctx, output_grads): def scaled_masked_softmax(inputs, mask, scale): + # input is 4D tensor (b, np, sq, sk) + if mask is not None: + args = _cast_if_autocast_enabled(inputs, mask, scale) + with torch.cuda.amp.autocast(enabled=False): + return ScaledMaskedSoftmax.apply(*args) + else: + args = _cast_if_autocast_enabled(inputs, scale) + with torch.cuda.amp.autocast(enabled=False): + return ScaledSoftmax.apply(*args) + + +class GenericScaledMaskedSoftmax(torch.autograd.Function): + @staticmethod + def forward(ctx, inputs, mask, scale): + import generic_scaled_masked_softmax_cuda + + scale_t = torch.tensor([scale]) + softmax_results = generic_scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import generic_scaled_masked_softmax_cuda_new + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = generic_scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None + + +def generic_scaled_masked_softmax(inputs, mask, scale): # input is 4D tensor (b, np, sq, sk) args = _cast_if_autocast_enabled(inputs, mask, scale) - with torch.cuda.amp.autocast(enabled=False): - return ScaledMaskedSoftmax.apply(*args) + with torch.amp.autocast('cuda', enabled=False): + return GenericScaledMaskedSoftmax.apply(*args) + + +class ScaledSoftmax(torch.autograd.Function): + """ + Fused operation which performs following two operations in sequence + 1. Scale the tensor. + 2. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + import scaled_softmax_cuda + + scale_t = torch.tensor([scale]) + + softmax_results = scaled_softmax_cuda.forward( + inputs, scale_t[0] + ) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_softmax_cuda.backward( + output_grads, softmax_results, scale_t[0] + ) + return input_grads, None, None class FusedScaleMaskSoftmax(torch.nn.Module): @@ -166,12 +229,12 @@ def is_kernel_available(self, mask, b, np, sq, sk): self.attn_mask_type == AttnMaskType.causal or (self.attn_mask_type == AttnMaskType.padding and mask is not None) ) - and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and 16 < sk <= 16384 # sk must be 16 ~ 16384 and sq % 4 == 0 # sq must be divisor of 4 and sk % 4 == 0 # sk must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4 ): - if 0 <= sk <= 2048: + if 0 <= sk <= 16384: batch_per_block = self.get_batch_per_block(sq, sk, b, np) if self.attn_mask_type == AttnMaskType.causal: diff --git a/apex/transformer/parallel_state.py b/apex/transformer/parallel_state.py index a8d16bfd3..ed834fed6 100644 --- a/apex/transformer/parallel_state.py +++ b/apex/transformer/parallel_state.py @@ -575,6 +575,12 @@ def get_virtual_pipeline_model_parallel_world_size(): return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE +def set_virtual_pipeline_model_parallel_world_size(size): + """Return the virtual pipeline-parallel world size.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = size + + def get_tensor_model_parallel_src_rank(): """Calculate the global rank corresponding to the first local rank in the tensor model parallel group.""" diff --git a/apex/transformer/pipeline_parallel/p2p_communication.py b/apex/transformer/pipeline_parallel/p2p_communication.py index 6c4b0d93d..0399be2b8 100644 --- a/apex/transformer/pipeline_parallel/p2p_communication.py +++ b/apex/transformer/pipeline_parallel/p2p_communication.py @@ -96,11 +96,18 @@ def _run_p2pops( reqs = torch.distributed.batch_isend_irecv(ops) if async_comm: - assert len(reqs) == len(ops) - tensor_send_prev_req = None if tensor_send_prev is None else reqs.pop(0) - tensor_recv_prev_req = None if tensor_recv_prev is None else reqs.pop(0) - tensor_send_next_req = None if tensor_send_next is None else reqs.pop(0) - tensor_recv_next_req = None if tensor_recv_next is None else reqs.pop(0) + if len(ops) == 0 or len(reqs) == len(ops): + tensor_send_prev_req = None if tensor_send_prev is None else reqs.pop(0) + tensor_recv_prev_req = None if tensor_recv_prev is None else reqs.pop(0) + tensor_send_next_req = None if tensor_send_next is None else reqs.pop(0) + tensor_recv_next_req = None if tensor_recv_next is None else reqs.pop(0) + elif len(reqs) == 1: + tensor_send_prev_req = None if tensor_send_prev is None else reqs[0] + tensor_recv_prev_req = None if tensor_recv_prev is None else reqs[0] + tensor_send_next_req = None if tensor_send_next is None else reqs[0] + tensor_recv_next_req = None if tensor_recv_next is None else reqs[0] + else: + assert False, "failed to manage p2p requests and handles" return (tensor_send_prev_req, tensor_recv_prev_req, tensor_send_next_req, tensor_recv_next_req) else: for req in reqs: diff --git a/apex/transformer/tensor_parallel/layers.py b/apex/transformer/tensor_parallel/layers.py index e2d7e524c..346dfaa7a 100644 --- a/apex/transformer/tensor_parallel/layers.py +++ b/apex/transformer/tensor_parallel/layers.py @@ -401,7 +401,7 @@ def linear_with_grad_accumulation_and_async_allreduce( sequence_parallel_enabled, False, # use_16bit_in_wgrad_accum_fusion ) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast('cuda',enabled=False): return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) @@ -422,7 +422,7 @@ def linear_with_grad_accumulation_and_async_allreduce_in16bit( sequence_parallel_enabled, True, # use_16bit_in_wgrad_accum_fusion ) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast('cuda',enabled=False): return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) diff --git a/apex/transformer/testing/distributed_test_base.py b/apex/transformer/testing/distributed_test_base.py index b01ca2c5d..4ab93762e 100644 --- a/apex/transformer/testing/distributed_test_base.py +++ b/apex/transformer/testing/distributed_test_base.py @@ -20,7 +20,10 @@ _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = Version("470.42.01") _driver_version = None if torch.cuda.is_available(): - _driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run)) + if collect_env.get_nvidia_driver_version(collect_env.run) != None: + _driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run)) + else: + _driver_version = None HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION @@ -45,8 +48,13 @@ def world_size(self) -> int: def init_method(self): return f"{common_utils.FILE_SCHEMA}{self.file_name}" + @property + def destroy_pg_upon_exit(self) -> bool: + # Overriding base test class: do not auto destroy PG upon exit. + return False + @classmethod - def _run(cls, rank, test_name, file_name, pipe): + def _run(cls, rank, test_name, file_name, pipe, **kwargs): self = cls(test_name) self.assertTrue(torch.cuda.is_available()) self.assertTrue(hasattr(self, "DISTRIBUTED_BACKEND")) diff --git a/build.sh b/build.sh new file mode 100755 index 000000000..54ed12093 --- /dev/null +++ b/build.sh @@ -0,0 +1,16 @@ +#!/bin/bash -x + +export PYTORCH_ROCM_ARCH=gfx942 +# export TENSILE_DB=0x40 +# export HIPBLASLT_LOG_MASK=0xff + + +python setup.py develop --cuda_ext --cpp_ext +cp build/lib.linux-x86_64-cpython-39/fused_dense_cuda.cpython-39-x86_64-linux-gnu.so /opt/conda/envs/py_3.9/lib/python3.9/site-packages/. + +# export HIPBLASLT_LOG_FILE=hipblaslt_bgrad.log + +# python apex/contrib/test/fused_dense/test_fused_dense_1.py + +# python apex/contrib/test/fused_dense/test_half_T.py +# python apex/contrib/test/fused_dense/test_half_NT.py diff --git a/compatibility/__init__.py b/compatibility/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/compatibility/_apex_nccl_allocator.py b/compatibility/_apex_nccl_allocator.py new file mode 100644 index 000000000..6a029d1ee --- /dev/null +++ b/compatibility/_apex_nccl_allocator.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _ApexNcclAllocatorModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'NCCLAllocatorBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load _apex_nccl_allocator : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_") and name != "__class__": + raise AttributeError(f"module _apex_nccl_allocator has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _ApexNcclAllocatorModule() \ No newline at end of file diff --git a/compatibility/amp_C.py b/compatibility/amp_C.py new file mode 100644 index 000000000..f9257c596 --- /dev/null +++ b/compatibility/amp_C.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _AmpCModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'AmpCBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load amp_C : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module amp_C has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _AmpCModule() \ No newline at end of file diff --git a/compatibility/apex_C.py b/compatibility/apex_C.py new file mode 100644 index 000000000..39bac5264 --- /dev/null +++ b/compatibility/apex_C.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _ApexCModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'ApexCBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load apex_C : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module apex_C has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _ApexCModule() \ No newline at end of file diff --git a/compatibility/bnp.py b/compatibility/bnp.py new file mode 100644 index 000000000..b03ba798c --- /dev/null +++ b/compatibility/bnp.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _BnpModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'BnpBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load bnp : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module bnp has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _BnpModule() \ No newline at end of file diff --git a/compatibility/distributed_adam_cuda.py b/compatibility/distributed_adam_cuda.py new file mode 100644 index 000000000..2566dce11 --- /dev/null +++ b/compatibility/distributed_adam_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _DistributedAdamCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'DistributedAdamBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load distributed_adam_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module distributed_adam_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _DistributedAdamCudaModule() \ No newline at end of file diff --git a/compatibility/distributed_lamb_cuda.py b/compatibility/distributed_lamb_cuda.py new file mode 100644 index 000000000..7f0b64f3e --- /dev/null +++ b/compatibility/distributed_lamb_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _DistributedLambCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'DistributedLambBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load distributed_lamb_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module distributed_lamb_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _DistributedLambCudaModule() \ No newline at end of file diff --git a/compatibility/fast_multihead_attn.py b/compatibility/fast_multihead_attn.py new file mode 100644 index 000000000..a9e060b87 --- /dev/null +++ b/compatibility/fast_multihead_attn.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _FastMultiheadAttnModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'FastMultiheadAttnBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load fast_multihead_attn : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module fast_multihead_attn has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _FastMultiheadAttnModule() \ No newline at end of file diff --git a/compatibility/focal_loss_cuda.py b/compatibility/focal_loss_cuda.py new file mode 100644 index 000000000..c7b364faf --- /dev/null +++ b/compatibility/focal_loss_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _FocalLossCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'FocalLossBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load focal_loss_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module focal_loss_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _FocalLossCudaModule() \ No newline at end of file diff --git a/compatibility/fused_adam_cuda.py b/compatibility/fused_adam_cuda.py new file mode 100644 index 000000000..bf31ca739 --- /dev/null +++ b/compatibility/fused_adam_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _FusedAdamCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'FusedAdamBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load fused_adam_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module fused_adam_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _FusedAdamCudaModule() \ No newline at end of file diff --git a/compatibility/fused_bias_swiglu.py b/compatibility/fused_bias_swiglu.py new file mode 100644 index 000000000..e9f066f4a --- /dev/null +++ b/compatibility/fused_bias_swiglu.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _FusedBiasSwiGLUModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'FusedBiasSwiGLUBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load fused_bias_swiglu : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module fused_bias_swiglu has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _FusedBiasSwiGLUModule() \ No newline at end of file diff --git a/compatibility/fused_conv_bias_relu.py b/compatibility/fused_conv_bias_relu.py new file mode 100644 index 000000000..32668b797 --- /dev/null +++ b/compatibility/fused_conv_bias_relu.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _FusedConvBiasReluModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'FusedConvBiasReluBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load fused_conv_bias_relu : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module fused_conv_bias_relu has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _FusedConvBiasReluModule() \ No newline at end of file diff --git a/compatibility/fused_dense_cuda.py b/compatibility/fused_dense_cuda.py new file mode 100644 index 000000000..0d28badb2 --- /dev/null +++ b/compatibility/fused_dense_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _FusedDenseCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'FusedDenseBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load fused_dense_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module fused_dense_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _FusedDenseCudaModule() \ No newline at end of file diff --git a/compatibility/fused_index_mul_2d.py b/compatibility/fused_index_mul_2d.py new file mode 100644 index 000000000..c036877df --- /dev/null +++ b/compatibility/fused_index_mul_2d.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _FusedIndexMul2dModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'FusedIndexMul2dBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load fused_index_mul_2d : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module fused_index_mul_2d has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _FusedIndexMul2dModule() \ No newline at end of file diff --git a/compatibility/fused_lamb_cuda.py b/compatibility/fused_lamb_cuda.py new file mode 100644 index 000000000..3ab88d443 --- /dev/null +++ b/compatibility/fused_lamb_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _FusedLambCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'FusedLambBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load fused_lamb_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module fused_lamb_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _FusedLambCudaModule() \ No newline at end of file diff --git a/compatibility/fused_layer_norm_cuda.py b/compatibility/fused_layer_norm_cuda.py new file mode 100644 index 000000000..2722e0252 --- /dev/null +++ b/compatibility/fused_layer_norm_cuda.py @@ -0,0 +1,44 @@ +import sys +import importlib + +class _FusedLayerCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + #import the builder + apex_op_builder = importlib.import_module('apex.op_builder') + mlp_builder = getattr(apex_op_builder, 'FusedLayerNormBuilder') + + #load the module + self._loaded_module = mlp_builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load fused_layer_norm_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module fused_layer_norm_cuda has no attribute '{name}'") + + module = self._load_module() + return getattr(module, name) + + def __dir__(self): + try: + module = self._load_module() + return dir(module) + except: + return [] + + def __repr__(self): + return "" + +#replace module with lazy loader +sys.modules[__name__] = _FusedLayerCudaModule() \ No newline at end of file diff --git a/compatibility/fused_rotary_positional_embedding.py b/compatibility/fused_rotary_positional_embedding.py new file mode 100644 index 000000000..d4f87bd33 --- /dev/null +++ b/compatibility/fused_rotary_positional_embedding.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _FusedRotaryPositionalEmbeddingModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'FusedRopeBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load fused_rotary_positional_embedding : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module fused_rotary_positional_embedding has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _FusedRotaryPositionalEmbeddingModule() \ No newline at end of file diff --git a/compatibility/fused_weight_gradient_mlp_cuda.py b/compatibility/fused_weight_gradient_mlp_cuda.py new file mode 100644 index 000000000..219d9355b --- /dev/null +++ b/compatibility/fused_weight_gradient_mlp_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _FusedWeightGradientMlpCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'FusedWeightGradientMlpCudaBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load fused_weight_gradient_mlp_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module fused_weight_gradient_mlp_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _FusedWeightGradientMlpCudaModule() \ No newline at end of file diff --git a/compatibility/generic_scaled_masked_softmax_cuda.py b/compatibility/generic_scaled_masked_softmax_cuda.py new file mode 100644 index 000000000..fa50ca52c --- /dev/null +++ b/compatibility/generic_scaled_masked_softmax_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _GenericScaledMaskedSoftmaxCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'GenericScaledMaskedSoftmaxCudaBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load generic_scaled_masked_softmax_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module generic_scaled_masked_softmax_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _GenericScaledMaskedSoftmaxCudaModule() \ No newline at end of file diff --git a/compatibility/mlp_cuda.py b/compatibility/mlp_cuda.py new file mode 100644 index 000000000..4c873d560 --- /dev/null +++ b/compatibility/mlp_cuda.py @@ -0,0 +1,44 @@ +import sys +import importlib + +class _MLPCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + #import the builder + apex_op_builder = importlib.import_module('apex.op_builder') + mlp_builder = getattr(apex_op_builder, 'MlpBuilder') + + #load the module + self._loaded_module = mlp_builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load mlp_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module mlp_cuda has no attribute '{name}'") + + module = self._load_module() + return getattr(module, name) + + def __dir__(self): + try: + module = self._load_module() + return dir(module) + except: + return [] + + def __repr__(self): + return "" + +#replace module with lazy loader +sys.modules[__name__] = _MLPCudaModule() \ No newline at end of file diff --git a/compatibility/nccl_p2p_cuda.py b/compatibility/nccl_p2p_cuda.py new file mode 100644 index 000000000..d937cb95e --- /dev/null +++ b/compatibility/nccl_p2p_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _NcclP2pCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'NCCLP2PBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load nccl_p2p_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module nccl_p2p_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _NcclP2pCudaModule() \ No newline at end of file diff --git a/compatibility/peer_memory_cuda.py b/compatibility/peer_memory_cuda.py new file mode 100644 index 000000000..d909ec1b9 --- /dev/null +++ b/compatibility/peer_memory_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _PeerMemoryCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'PeerMemoryBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load peer_memory_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module peer_memory_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _PeerMemoryCudaModule() \ No newline at end of file diff --git a/compatibility/scaled_masked_softmax_cuda.py b/compatibility/scaled_masked_softmax_cuda.py new file mode 100644 index 000000000..77ed74e47 --- /dev/null +++ b/compatibility/scaled_masked_softmax_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _ScaledMaskedSoftmaxCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'ScaledMaskedSoftmaxCudaBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load scaled_masked_softmax_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module scaled_masked_softmax_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _ScaledMaskedSoftmaxCudaModule() \ No newline at end of file diff --git a/compatibility/scaled_softmax_cuda.py b/compatibility/scaled_softmax_cuda.py new file mode 100644 index 000000000..d7a4427e3 --- /dev/null +++ b/compatibility/scaled_softmax_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _ScaledSoftmaxCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'ScaledSoftmaxCudaBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load scaled_softmax_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module scaled_softmax_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _ScaledSoftmaxCudaModule() \ No newline at end of file diff --git a/compatibility/scaled_upper_triang_masked_softmax_cuda.py b/compatibility/scaled_upper_triang_masked_softmax_cuda.py new file mode 100644 index 000000000..8da9b5c67 --- /dev/null +++ b/compatibility/scaled_upper_triang_masked_softmax_cuda.py @@ -0,0 +1,38 @@ +import sys +import importlib + +class _ScaledUpperTriangMaskedSoftmaxCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + name = 'ScaledUpperTriangMaskedSoftmaxCudaBuilder' + builder = getattr(apex_op_builder, name) + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load scaled_upper_triang_masked_softmax_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name_attr): + if name_attr.startswith("_"): + raise AttributeError(f"module scaled_upper_triang_masked_softmax_cuda has no attribute '{name_attr}'") + return getattr(self._load_module(), name_attr) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _ScaledUpperTriangMaskedSoftmaxCudaModule() \ No newline at end of file diff --git a/compatibility/syncbn.py b/compatibility/syncbn.py new file mode 100644 index 000000000..b619575dc --- /dev/null +++ b/compatibility/syncbn.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _SyncbnModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'SyncBnBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load syncbn : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module syncbn has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _SyncbnModule() \ No newline at end of file diff --git a/compatibility/transducer_joint_cuda.py b/compatibility/transducer_joint_cuda.py new file mode 100644 index 000000000..e06705fde --- /dev/null +++ b/compatibility/transducer_joint_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _TransducerJointCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'TransducerJointBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load transducer_joint_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module transducer_joint_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _TransducerJointCudaModule() \ No newline at end of file diff --git a/compatibility/transducer_loss_cuda.py b/compatibility/transducer_loss_cuda.py new file mode 100644 index 000000000..d5a2c0f36 --- /dev/null +++ b/compatibility/transducer_loss_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _TransducerLossCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'TransducerLossBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load transducer_loss_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module transducer_loss_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _TransducerLossCudaModule() \ No newline at end of file diff --git a/compatibility/xentropy_cuda.py b/compatibility/xentropy_cuda.py new file mode 100644 index 000000000..ff4dc9733 --- /dev/null +++ b/compatibility/xentropy_cuda.py @@ -0,0 +1,37 @@ +import sys +import importlib + +class _XentropyCudaModule: + def __init__(self): + self._loaded_module = None + self._loading = False + + def _load_module(self): + if self._loaded_module is None and not self._loading: + self._loading = True + try: + apex_op_builder = importlib.import_module('apex.op_builder') + builder = getattr(apex_op_builder, 'XentropyBuilder') + self._loaded_module = builder().load() + except Exception as e: + self._loading = False + raise ImportError(f"Failed to load xentropy_cuda : {e}") + finally: + self._loading = False + return self._loaded_module + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(f"module xentropy_cuda has no attribute '{name}'") + return getattr(self._load_module(), name) + + def __dir__(self): + try: + return dir(self._load_module()) + except: + return [] + + def __repr__(self): + return "" + +sys.modules[__name__] = _XentropyCudaModule() \ No newline at end of file diff --git a/contrib/csrc b/contrib/csrc new file mode 120000 index 000000000..4e941d8b2 --- /dev/null +++ b/contrib/csrc @@ -0,0 +1 @@ +../apex/contrib/csrc \ No newline at end of file diff --git a/csrc/amp_C_frontend.cpp b/csrc/amp_C_frontend.cpp index 36a88aa6e..d9da549b1 100644 --- a/csrc/amp_C_frontend.cpp +++ b/csrc/amp_C_frontend.cpp @@ -81,6 +81,33 @@ void multi_tensor_adam_cuda( const int bias_correction, const float weight_decay); +void multi_tensor_adam_capturable_cuda( + int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor lr, + const float beta1, + const float beta2, + const float epsilon, + at::Tensor step, + const int mode, + const int bias_correction, + const float weight_decay, + at::Tensor inv_scale); + +void multi_tensor_adam_capturable_master_cuda( + int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor lr, + const float beta1, + const float beta2, + const float epsilon, + at::Tensor step, + const int mode, + const int bias_correction, + const float weight_decay, + at::Tensor inv_scale); void multi_tensor_adagrad_cuda( int chunk_size, @@ -144,6 +171,24 @@ void multi_tensor_lamb_mp_cuda( at::Tensor found_inf, at::Tensor inv_scale); +void multi_tensor_lars_cuda( + int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor grad_norms, + at::Tensor param_norms, + float lr, + float trust_coefficient, + float epsilon, + float weight_decay, + float momentum, + float dampening, + bool nesterov, + bool first_run, + bool wd_after_momentum, + float scale, + const bool is_skipped); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_scale", &multi_tensor_scale_cuda, "Fused overflow check + scale for a list of contiguous tensors"); @@ -162,7 +207,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda, "Completes application of gradient to parameters for LAMB optimizer"); m.def("multi_tensor_adam", &multi_tensor_adam_cuda, - "Compute and apply gradient update to parameters for Adam optimizer"); + "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); + m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda, + "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support and LR scheduling", + py::call_guard()); + m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda, + "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support, LR scheduling and FP32 master weights", + py::call_guard()); m.def("multi_tensor_adagrad", &multi_tensor_adagrad_cuda, "Compute and apply gradient update to parameters for Adam optimizer"); m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda, @@ -171,4 +222,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Computes and apply update for LAMB optimizer"); m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda, "Computes and apply update for LAMB optimizer"); + m.def("multi_tensor_lars", &multi_tensor_lars_cuda, + "Fused LARS optimizer for list of contiguous tensors"); } diff --git a/csrc/fused_dense.cpp b/csrc/fused_dense.cpp deleted file mode 100644 index db5bd0d59..000000000 --- a/csrc/fused_dense.cpp +++ /dev/null @@ -1,192 +0,0 @@ -#include -#include -#include - -#include - - -template -int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace); - -template -int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, void *lt_workspace); - -template -int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2, T *bias2, int in_features, int hidden_features, int batch_size, int out_features, T *output1, T *output2, T *gelu_in, void *lt_workspace) ; - -template -int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, void *lt_workspace); - -at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) { - - auto batch_size = input.size(0); - auto in_features = input.size(1); - - int out_features = weight.size(0); - - //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); - - // create output/workspace tensor - auto out = at::empty({batch_size, out_features}, input.type()); - //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); - // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB - auto lt_workspace = at::empty({1 << 22}, input.type()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_bias_forward", [&] { - scalar_t* w_ptr = weight.data_ptr(); - scalar_t* b_ptr = bias.data_ptr(); - auto result = linear_bias_forward_cuda( - input, - w_ptr, - bias, - in_features, - batch_size, - out_features, - out, - //out.data_ptr(), - // reserved_space.data_ptr(), - (void*) (lt_workspace.data_ptr())); - }); - - return {out}; -} - -std::vector linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) { - - auto batch_size = input.size(0); - auto in_features = input.size(1); - - int out_features = weight.size(0); - - //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); - - // create output/workspace tensor - auto d_weight = at::empty({out_features, in_features}, input.type()); -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600 - auto d_bias = d_output.view({-1, out_features}).sum(0, false); -#else - auto d_bias = at::empty({out_features}, input.type()); -#endif - auto d_input = at::empty({batch_size, in_features}, input.type()); - //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); - // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB - auto lt_workspace = at::empty({1 << 22}, input.type()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_bias_backward", [&] { - scalar_t* w_ptr = weight.data_ptr(); - scalar_t* d_b_ptr = d_bias.data_ptr(); - auto result = linear_bias_backward_cuda( - input.data_ptr(), - w_ptr, - d_output.data_ptr(), - in_features, - batch_size, - out_features, - d_weight.data_ptr(), - d_bias.data_ptr(), - d_input.data_ptr(), - // reserved_space.data_ptr(), - (void*) (lt_workspace.data_ptr())); - }); - - return {d_input, d_weight, d_bias}; -} - -std::vector linear_gelu_linear_forward(at::Tensor input, at::Tensor weight1, at::Tensor bias1, at::Tensor weight2, at::Tensor bias2) { - - auto batch_size = input.size(0); - auto in_features = input.size(1); - - int hidden_features = weight1.size(0); - int out_features = weight2.size(0); - - //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); - - // create output/workspace tensor - auto output1 = at::empty({batch_size, hidden_features}, input.type()); - auto gelu_in = at::empty({batch_size, hidden_features}, input.type()); - auto output2 = at::empty({batch_size, out_features}, input.type()); - //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); - // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB - auto lt_workspace = at::empty({1 << 22}, input.type()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_gelu_linear_forward", [&] { - scalar_t* w1_ptr = weight1.data_ptr(); - scalar_t* b1_ptr = bias1.data_ptr(); - scalar_t* w2_ptr = weight2.data_ptr(); - scalar_t* b2_ptr = bias2.data_ptr(); - auto result = linear_gelu_linear_forward_cuda( - input.data_ptr(), - w1_ptr, - b1_ptr, - w2_ptr, - b2_ptr, - in_features, - hidden_features, - batch_size, - out_features, - output1.data_ptr(), - output2.data_ptr(), - gelu_in.data_ptr(), - // reserved_space.data_ptr(), - (void*) (lt_workspace.data_ptr())); - }); - - return {output1, output2, gelu_in}; -} - -std::vector linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2) { - - auto batch_size = input.size(0); - auto in_features = input.size(1); - - int hidden_features = weight1.size(0); - int out_features = weight2.size(0); - - //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); - - // create output/workspace tensor - auto d_weight1 = at::empty({hidden_features, in_features}, input.type()); - auto d_weight2 = at::empty({out_features, hidden_features}, input.type()); - auto d_bias1 = at::empty({hidden_features}, input.type()); - auto d_bias2 = at::empty({out_features}, input.type()); - auto d_input = at::empty({batch_size, in_features}, input.type()); - auto d_output1 = at::empty({batch_size, hidden_features}, input.type()); - //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); - // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB - auto lt_workspace = at::empty({1 << 22}, input.type()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_bias_backward", [&] { - //scalar_t* w_ptr = weight.data_ptr(); - //scalar_t* d_b_ptr = d_bias.data_ptr(); - auto result = linear_gelu_linear_backward_cuda( - input.data_ptr(), - gelu_in.data_ptr(), - output1.data_ptr(), - weight1.data_ptr(), - weight2.data_ptr(), - d_output1.data_ptr(), - d_output2.data_ptr(), - in_features, - batch_size, - hidden_features, - out_features, - d_weight1.data_ptr(), - d_weight2.data_ptr(), - d_bias1.data_ptr(), - d_bias2.data_ptr(), - d_input.data_ptr(), - // reserved_space.data_ptr(), - (void*) (lt_workspace.data_ptr())); - }); - - return {d_input, d_weight1, d_bias1, d_weight2, d_bias2}; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward"); - m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward"); - m.def("linear_gelu_linear_forward", &linear_gelu_linear_forward, "linear gelu linear forward"); - m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward"); -} - diff --git a/csrc/fused_dense_base.cpp b/csrc/fused_dense_base.cpp new file mode 100644 index 000000000..0e62768b0 --- /dev/null +++ b/csrc/fused_dense_base.cpp @@ -0,0 +1,21 @@ +#include +#include +#include +#include +#include + +at::Tensor linear_bias_forward( at::Tensor input, at::Tensor weight, at::Tensor bias); + +std::vector linear_bias_backward( at::Tensor input, at::Tensor weight, at::Tensor d_output); + +std::vector linear_gelu_linear_forward( at::Tensor input, at::Tensor weight1, at::Tensor bias1, at::Tensor weight2, at::Tensor bias2); + +std::vector linear_gelu_linear_backward( at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward"); + m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward"); + m.def("linear_gelu_linear_forward", &linear_gelu_linear_forward, "linear gelu linear forward"); + m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward"); +} + diff --git a/csrc/fused_dense_cuda.cu b/csrc/fused_dense_cuda.cu index c12d264a1..15c076f68 100644 --- a/csrc/fused_dense_cuda.cu +++ b/csrc/fused_dense_cuda.cu @@ -1,1437 +1,594 @@ -#include -#include + #include #include #include #include -#include +#include -/* Includes, cuda */ -#include -#include +#include +#include +#include -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 -// includes cublaslt -#include +#include +#include +#include +#include +#include + +#define DEBUG 0 + +#include "type_shim.h" + +#ifndef CHECK_HIP_ERROR +#define CHECK_HIP_ERROR(error) \ + if (error != hipSuccess) \ + { \ + fprintf(stderr, \ + "Hip error: '%s'(%d) at %s:%d\n", \ + hipGetErrorString(error), \ + error, \ + __FILE__, \ + __LINE__); \ + exit(EXIT_FAILURE); \ + } #endif -// FP64 Wrapper around cublas GEMMEx -cublasStatus_t gemm_bias( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - double* A, - int lda, - double* B, - int ldb, - const float* beta, - double* C, - int ldc) { - return cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_64F, - lda, - B, - CUDA_R_64F, - ldb, - beta, - C, - CUDA_R_64F, - ldc, - CUDA_R_64F, - CUBLAS_GEMM_DEFAULT); -} - -// FP32 Wrapper around cublas GEMMEx -cublasStatus_t gemm_bias( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - float* A, - int lda, - float* B, - int ldb, - const float* beta, - float* C, - int ldc) { - return cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_32F, - lda, - B, - CUDA_R_32F, - ldb, - beta, - C, - CUDA_R_32F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT); -} - -// FP16 Tensor core wrapper around cublas GEMMEx -cublasStatus_t gemm_bias( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float* beta, - at::Half* C, - int ldc) { - return cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_16F, - lda, - B, - CUDA_R_16F, - ldb, - beta, - C, - CUDA_R_16F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP); -} - - -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 - - -int gemm_bias_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float *beta, /* host pointer */ - at::Half* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* bias) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - epilogue = CUBLASLT_EPILOGUE_BIAS; - } - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; +#ifndef CHECK_HIPBLASLT_ERROR +#define CHECK_HIPBLASLT_ERROR(error) \ + if (error != HIPBLAS_STATUS_SUCCESS) \ + { \ + fprintf(stderr, "hipBLASLt error(Err=%d) at %s:%d\n", error, __FILE__, __LINE__); \ + fprintf(stderr, "\n"); \ + exit(EXIT_FAILURE); \ } +#endif - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; +#define DISPATCH_TYPES(TYPE, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + constexpr auto compute_t = CUBLAS_COMPUTE_32F; \ + constexpr auto compute_datatype_t = CUDA_R_32F; \ + constexpr auto datatype_t = CUDA_R_16F; \ + using scalar_t = at::Half; \ + __VA_ARGS__(); \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + constexpr auto compute_t = CUBLAS_COMPUTE_32F; \ + constexpr auto compute_datatype_t = CUDA_R_32F; \ + constexpr auto datatype_t = CUDA_R_16BF; \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__(); \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + constexpr auto compute_t = CUBLAS_COMPUTE_32F; \ + constexpr auto compute_datatype_t = CUDA_R_32F; \ + constexpr auto datatype_t = CUDA_R_32F; \ + using scalar_t = float; \ + __VA_ARGS__(); \ + break; \ + } \ + case at::ScalarType::Double: \ + { \ + constexpr auto compute_t = CUBLAS_COMPUTE_64F; \ + constexpr auto compute_datatype_t = CUDA_R_64F; \ + constexpr auto datatype_t = CUDA_R_64F; \ + using scalar_t = double; \ + __VA_ARGS__(); \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented type "); \ } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - NULL, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} - - - +hipDataType get_dtype(at::Tensor A) +{ + hipDataType dataType; - -int gemm_bias_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - double* A, - int lda, - double* B, - int ldb, - const float *beta, /* host pointer */ - double* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* bias) { - return 1; -} - -int gemm_bias_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - float *A, - int lda, - float *B, - int ldb, - const float *beta, /* host pointer */ - float *C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* bias) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - epilogue = CUBLASLT_EPILOGUE_BIAS; + if (A.scalar_type() == at::ScalarType::BFloat16) + { + dataType = HIP_R_16F; } - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; + if (A.scalar_type() == at::ScalarType::Half) + { + dataType = HIP_R_16F; } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; + if (A.scalar_type() == at::ScalarType::Float) + { + dataType = HIP_R_32F; + } + if (A.scalar_type() == at::ScalarType::Double) + { + dataType = HIP_R_64F; + } + // The E4M3 is mainly used for the weights, and the E5M2 is for the gradient. + if (A.scalar_type() == at::ScalarType::Float8_e5m2fnuz) + { + dataType = HIP_R_8F_E5M2_FNUZ; + } + if (A.scalar_type() == at::ScalarType::Float8_e4m3fnuz) + { + dataType = HIP_R_8F_E4M3_FNUZ; } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - &heuristicResult.algo, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; + return dataType; } - -int gemm_bias_gelu_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float *beta, /* host pointer */ - at::Half* C, - int64_t ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, +/******************************************************************************************************************************************************** + * + * D = Epilogue{ (alpha_s * (A * B) + beta_s * C) + bias_v } * scaleD_v + * + ******************************************************************************************************************************************************/ +int gemm_lt( + hipblasOperation_t trans_a, + hipblasOperation_t trans_b, + const float *alpha, + const float *beta, + at::Tensor A, + at::Tensor B, + at::Tensor C, + at::Tensor bias, + at::Tensor gelu, bool use_bias, - const void* gelu_in, - const void* bias) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_GELU_AUX; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); + bool use_grad, + bool use_gelu) +{ - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS; - } + hipStream_t stream; + hipblasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + hipblasGetStream(handle, &stream); - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; +#if DEBUG + std::cout << "gemm_lt " << std::endl; +#endif + if ((trans_a == HIPBLAS_OP_T) && (trans_b == HIPBLAS_OP_T)) + { + std::cout << "Both Transose is not supported"; + return HIPBLAS_STATUS_NOT_SUPPORTED; } - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; + /* ============================================================================================ + * Matrix layout: + * 1. Set the Data type of matrix elements. + * 3. Set the layout: Size/shape of the matrix. This depends if transpose is needed or not. + * 4. Set the leading dimentions + * + */ + hipblasLtMatrixLayout_t matA = nullptr, matB = nullptr, matC = nullptr; + + hipDataType dtype_a = get_dtype(A); + hipDataType dtype_b = get_dtype(B); + hipDataType dtype_c = get_dtype(C); + + int64_t m = trans_a == HIPBLAS_OP_T ? A.size(0) : A.size(1); + int64_t k = trans_a == HIPBLAS_OP_T ? A.size(1) : A.size(0); + int64_t n = trans_b == HIPBLAS_OP_T ? B.size(1) : B.size(0); + + int64_t lda = 0, ldb = 0, ldd = 0; + + if ((trans_a == HIPBLAS_OP_T) && (trans_b != HIPBLAS_OP_T)) + { + lda = k; + ldb = k; + } // TN + else if ((trans_a != HIPBLAS_OP_T) && (trans_b == HIPBLAS_OP_T)) + { + lda = m; + ldb = n; + } // NT + else if ((trans_a != HIPBLAS_OP_T) && (trans_b != HIPBLAS_OP_T)) + { + lda = m; + ldb = k; + } // NN + + CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype_a, trans_a == HIPBLAS_OP_T ? k : m, + trans_a == HIPBLAS_OP_T ? m : k, lda)); + + CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype_b, trans_b == HIPBLAS_OP_T ? n : k, + trans_b == HIPBLAS_OP_T ? k : n, ldb)); + + CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype_c, m, n, m)); + + /* ============================================================================================ + * Matmul desc: + * 1. Create operation descriptor with compute data type + * 2. Set transpose operation + */ + hipblasLtMatmulDesc_t matmulDesc = nullptr; + + hipblasComputeType_t desc_computeType = HIPBLAS_COMPUTE_32F; + hipDataType desc_dataType = HIP_R_32F; + + if (A.scalar_type() == at::ScalarType::Double) + { + desc_computeType = HIPBLAS_COMPUTE_64F; + desc_dataType = HIP_R_64F; } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - NULL, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmulDesc, desc_computeType, desc_dataType)); -int gemm_bias_gelu_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - double* A, - int lda, - double* B, - int ldb, - const float *beta, /* host pointer */ - double* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void *gelu_in, - const void* bias) { - return 1; -} + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, + &trans_a, sizeof(trans_a))); + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, + &trans_b, sizeof(trans_b))); -int gemm_bias_gelu_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - float *A, - int lda, - float *B, - int ldb, - const float *beta, /* host pointer */ - float *C, - int64_t ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* gelu_in, - const void* bias) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_GELU_AUX; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); - - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS; - } + /* ============================================================================================ + * Configure epilogue + * 1. Set mat-mul post-ops: bias, bgradb, gelu. + * 2. + */ - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - NULL, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} + hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT; + hipDataType dtype_bias = get_dtype(bias); + hipDataType dtype_gelu = get_dtype(gelu); + auto d_bias = static_cast(bias.data_ptr()); + auto d_gelu = static_cast(gelu.data_ptr()); + int64_t ld_gelu = (int64_t)gelu.size(0); -int gemm_bgradb_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float *beta, /* host pointer */ - at::Half* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* bgrad) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; + if (use_bias && use_gelu) + { + if (use_grad) + { + epilogue = HIPBLASLT_EPILOGUE_DGELU_BGRAD; } - epilogue = CUBLASLT_EPILOGUE_BGRADB; - } - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; + else + { + epilogue = HIPBLASLT_EPILOGUE_GELU_BIAS; + } + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, + &d_bias, sizeof(d_bias))); + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, + &dtype_bias, sizeof(dtype_bias))); + + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + &d_gelu, sizeof(d_gelu))); + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + &ld_gelu, sizeof(ld_gelu))); + // CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, + // &dtype_gelu, sizeof(dtype_gelu))); } + else if (use_bias) + { + if (use_grad) + { + epilogue = HIPBLASLT_EPILOGUE_BGRADB; + } + else + { + epilogue = HIPBLASLT_EPILOGUE_BIAS; + } + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, + &d_bias, sizeof(d_bias))); - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, + &dtype_bias, sizeof(dtype_bias))); + } + else if (use_gelu) + { + if (use_grad) + { + epilogue = HIPBLASLT_EPILOGUE_DGELU; + } + else + { + epilogue = HIPBLASLT_EPILOGUE_GELU_AUX; + } + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + &d_gelu, sizeof(d_gelu))); + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + &ld_gelu, sizeof(ld_gelu))); + // CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, + // &dtype_gelu, sizeof(dtype_gelu))); } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - NULL, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} - + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, sizeof(epilogue))); + /* ============================================================================================ + * Algo Get Heuristic + * 1. retrieves the possible algorithms for given input matrices A, B and C, and the output matrix D. + * decription/layout. In our case matrux C and D are same. search result is in heuristicResultsArray[]. + */ + hipblasLtMatmulPreference_t pref; + const int request_solutions = 1; + int returnedAlgoCount = 0; + uint64_t workspace_size = 0; + void *workspace = nullptr; + hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulPreferenceCreate(&pref)); + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmulDesc, matA, matB, matC, matC, + pref, request_solutions, heuristicResult, + &returnedAlgoCount)); -int gemm_bgradb_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - double* A, - int lda, - double* B, - int ldb, - const float *beta, /* host pointer */ - double* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* bgrad) { - return 1; -} - -int gemm_bgradb_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - float *A, - int lda, - float *B, - int ldb, - const float *beta, /* host pointer */ - float *C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* bgrad) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - epilogue = CUBLASLT_EPILOGUE_BGRADB; + if (returnedAlgoCount == 0) + { + std::cerr << "No valid solution found!" << std::endl; + return HIPBLAS_STATUS_NOT_SUPPORTED; } - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; + for (int i = 0; i < returnedAlgoCount; i++) + { + workspace_size = max(workspace_size, heuristicResult[i].workspaceSize); } - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } + hipMalloc(&workspace, workspace_size); + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulPreferenceSetAttribute(pref, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace, sizeof(workspace_size))); + + /* ============================================================================================ + * Matmul + */ + const void *d_a = static_cast(A.data_ptr()); + const void *d_b = static_cast(B.data_ptr()); + void *d_c = static_cast(C.data_ptr()); + + CHECK_HIPBLASLT_ERROR(hipblasLtMatmul(handle, matmulDesc, alpha, d_a, matA, + d_b, matB, beta, static_cast(d_c), + matC, d_c, matC, &heuristicResult[0].algo, + workspace, workspace_size, stream)); + +#if DEBUG + std::cout << "\nTensor-A:\n" << A + << "\nTensor-B:\n" << B + << "\nTensor-C:\n" << C + << "\nTensor-Bias:\n" << bias << std::endl; + std::cout << "\nSizes: A[" << A.size(0) << "," << A.size(1) << "]" << std::endl; + std::cout << "\nSizes: B[" << B.size(0) << "," << B.size(1) << "]" << std::endl; + std::cout << "\nSizes: C[" << C.size(0) << "," << C.size(1) << "]" << std::endl; + std::cout << "\nValues:: m:" << m << ", k:" << k << ", n:" << n << std::endl; + std::cout << "lda: " << lda << "\tldb: " << ldb << "\tldd: " << ldd << "\tm: " << m << "\tk: " << k << "\tn: " << n << std::endl; +#endif - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - &heuristicResult.algo, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} + CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matA)); + CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matB)); + CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matC)); + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescDestroy(matmulDesc)); + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulPreferenceDestroy(pref)); + return HIPBLAS_STATUS_SUCCESS; +} -int gemm_dgelu_bgradb_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float *beta, /* host pointer */ - at::Half* C, - int64_t ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - const void *gelu_in, - const void *bgrad) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } +template +hipblasStatus_t gemm_bias( hipblasOperation_t transa, hipblasOperation_t transb, + int64_t m, int64_t n, int64_t k, const float *alpha, const float *beta, + const TensorType *A, const TensorType *B, TensorType *C) +{ + hipblasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + int64_t lda = n; + int64_t ldb = k; + int64_t ldc = m; - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - NULL, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; +#if DEBUG + std::cout << "gemm_bias " << std::endl; +#endif + return hipblasGemmEx(handle, transa, transb, m, n, k, alpha, A, DataType, lda, B, DataType, + ldb, beta, C, DataType, ldc, ComputeType, CUBLAS_GEMM_DEFAULT); } -int gemm_dgelu_bgradb_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - double *A, - int lda, - double *B, - int ldb, - const float *beta, /* host pointer */ - double *C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - const void *gelu_in, - const void *bgrad) { - return 1; -} -int gemm_dgelu_bgradb_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - float *A, - int lda, - float *B, - int ldb, - const float *beta, /* host pointer */ - float *C, - int64_t ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - const void *gelu_in, - const void *bgrad) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); +/**************************************************************************** + * output[batch_size, out_features] = input[batch_size, in_features] * weight[out_features,in_features] + bias[out_features] + ****************************************************************************/ +at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) +{ + const float alpha = 1.0, beta = 0.0; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } + int64_t batch_size = input.size(0); // input[batch_size, in_features] + int64_t in_features = input.size(1); + int64_t out_features = weight.size(0); // weight[out_features,in_features] - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - NULL, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} + at::Tensor dummy_gelu = at::empty({0}, torch::device(torch::kCUDA).dtype(input.scalar_type())); + // ********************************************************************************** + // output[batch_size, out_features] = input[batch_size, in_features] * weight[out_features,in_features] + bias[out_features] + // ********************************************************************************** + auto output = at::zeros({batch_size, out_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); +#if DEBUG + std::cout << "linear_bias_forward " << std::endl; #endif + if (at::globalContext().blasPreferredBackend() == at::BlasBackend::Cublaslt) { + CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_T, HIPBLAS_OP_N, &alpha, &beta, weight, input, output, bias, dummy_gelu, true, false, false)); + } else { + DISPATCH_TYPES(input.scalar_type(), "linear_bias_forward", [&] { + auto result = gemm_bias( + HIPBLAS_OP_T, HIPBLAS_OP_N, out_features, batch_size, in_features, + &alpha, &beta, weight.data_ptr(), input.data_ptr(), output.data_ptr()); + if (result != 0) { fprintf(stderr, "INVALID RESULT for linear_bias_forward\n"); } + }); + } -template -int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - // Get the stream from cublas handle to reuse for biasReLU kernel. - cudaStream_t stream; - cublasGetStream(handle, &stream); - const float alpha = 1.0; - const float beta_zero = 0.0; - const float beta_one = 1.0; - int status = 1; -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 - status = gemm_bias_lt( - (cublasLtHandle_t)handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - out_features, - batch_size, - in_features, - &alpha, /* host pointer */ - weight, - in_features, - input.data_ptr(), - in_features, - &beta_zero, /* host pointer */ - output.data_ptr(), - out_features, - lt_workspace, - 1 << 22, - stream, - true, - static_cast(bias.data_ptr())); -#endif - if (status != 0){ - output.copy_(bias); - status = gemm_bias( - handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - out_features, - batch_size, - in_features, - &alpha, - weight, - in_features, - input.data_ptr(), - in_features, - &beta_one, - output.data_ptr(), - out_features); - } - return status; + return {output}; } - -template -int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, void *lt_workspace) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - // Get the stream from cublas handle to reuse for biasReLU kernel. - cudaStream_t stream; - cublasGetStream(handle, &stream); - const float alpha = 1.0; - const float beta_zero = 0.0; - int status = 1; -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 - status = gemm_bgradb_lt( - (cublasLtHandle_t)handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - in_features, - out_features, - batch_size, - &alpha, /* host pointer */ - input, - in_features, - d_output, - out_features, - &beta_zero, /* host pointer */ - d_weight, - in_features, - lt_workspace, - 1 << 22, - stream, - true, - static_cast(d_bias)); +/**************************************************************************** + * In the backward pass, we compute the gradients of the loss with respect to input, weight, and bias. + * The key matrix operations are: + * 1. Gradient of Input : grad_input[batch_size, in_features] = output[batch_size, out_features] * weight[out_features,in_features] + * 2. Gradient of Weights: grad_weight[out_features,in_features] = input[batch_size, in_features] * output[batch_size, out_features] + * 3. Gradient of Bias : grad_bias=sum(dY) + **************************************************************************/ +std::vector linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor output) +{ + const float alpha = 1.0, beta = 0.0; + + int64_t batch_size = input.size(0); // input[batch_size, in_features] + int64_t in_features = input.size(1); + int64_t out_features = weight.size(0); // weight[out_features,in_features] + + auto grad_bias = at::zeros(out_features, torch::device(torch::kCUDA).dtype(input.scalar_type())); + auto dummy_gelu = at::empty({0}, torch::device(torch::kCUDA).dtype(input.scalar_type())); + auto grad_weight = at::zeros({out_features,in_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); + auto grad_input = at::zeros({batch_size, in_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); + +#if DEBUG + std::cout << "linear_bias_backward " << std::endl; #endif - - - if (status != 0){ - - status = gemm_bias( - handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - in_features, - out_features, - batch_size, - &alpha, - input, - in_features, - d_output, - out_features, - &beta_zero, - d_weight, - in_features); - } - - status = gemm_bias( - handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - in_features, - batch_size, - out_features, - &alpha, - weight, - in_features, - d_output, - out_features, - &beta_zero, - d_input, - in_features); - return status; - + if (at::globalContext().blasPreferredBackend() == at::BlasBackend::Cublaslt) { + // ********************************************************************************** + // Gradient of Input : + // grad_input [batch_size, in_features] = output[batch_size, out_features] * Weight[out_features,in_features] + // ********************************************************************************** + CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_N, HIPBLAS_OP_N, &alpha, &beta, weight, output, grad_input, grad_bias, dummy_gelu, false, false, false)); + + // ********************************************************************************** + // Gradient of Weights: + // grad_weight[out_features,in_features] = input[batch_size, in_features](T) * output[batch_size, out_features] + // ********************************************************************************** + CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_N, HIPBLAS_OP_T, &alpha, &beta, input, output, grad_weight, grad_bias, dummy_gelu, true, false, false)); + + // ********************************************************************************** + // ToDo: Check why HipBLASLt fail to get bgrad above so this step is not needed. + // db=sum(dY) + // ********************************************************************************** + grad_bias = output.sum(0, false); + } else { + DISPATCH_TYPES(input.scalar_type(), "linear_bias_forward", [&] { + auto result = gemm_bias( + HIPBLAS_OP_N, HIPBLAS_OP_T, in_features, out_features, batch_size, + &alpha, &beta, input.data_ptr(), output.data_ptr(), grad_weight.data_ptr()); + if (result != 0) { fprintf(stderr, "INVALID RESULT for linear_bias_forward\n"); } + }); + + DISPATCH_TYPES(input.scalar_type(), "linear_bias_forward", [&] { + auto result = gemm_bias( + HIPBLAS_OP_N, HIPBLAS_OP_N, in_features, batch_size, out_features, + &alpha, &beta, weight.data_ptr(), output.data_ptr(), grad_input.data_ptr()); + if (result != 0) { fprintf(stderr, "INVALID RESULT for linear_bias_forward\n"); } + }); + } + return {grad_input, grad_weight, grad_bias}; } -template -int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2, T *bias2, int in_features, int hidden_features, int batch_size, int out_features, T *output1, T *output2, T *gelu_in, void *lt_workspace) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - // Get the stream from cublas handle to reuse for biasReLU kernel. - cudaStream_t stream; - cublasGetStream(handle, &stream); - const float alpha = 1.0; - const float beta_zero = 0.0; - int status = 1; -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 - status = gemm_bias_gelu_lt( - (cublasLtHandle_t)handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - hidden_features, - batch_size, - in_features, - &alpha, /* host pointer */ - weight1, - in_features, - input, - in_features, - &beta_zero, /* host pointer */ - output1, - hidden_features, - lt_workspace, - 1 << 22, - stream, - true, - static_cast(gelu_in), - static_cast(bias1)); - status = gemm_bias_lt( - (cublasLtHandle_t)handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - out_features, - batch_size, - hidden_features, - &alpha, /* host pointer */ - weight2, - hidden_features, - output1, - hidden_features, - &beta_zero, /* host pointer */ - output2, - out_features, - lt_workspace, - 1 << 22, - stream, - true, - static_cast(bias2)); - return status; -#else - return 1; +/**************************************************************************** + * + * [Linear] https://pytorch.org/docs/stable/generated/torch.nn.Linear.html + * [GELU] https://pytorch.org/docs/stable/generated/torch.nn.GELU.html + * + * module combines dense layers with GELU activations in a single neural network layer. + * layer consists of two dense sub-layers, each followed by a GELU activation function. + * It takes an input tensor and passes it through these sub-layers to produce the final output. + * + * layer consists of the following internal layers: + * dense1: The first dense layer. + * output[batch_size, hidden_features] = input[batch_size, in_features] * weight[hidden_features,in_features] + bias[hidden_features] + * activation: The GELU(Gaussian Error Linear Units) activation function. + * dense2: The second dense layer. + * output2[batch_size,out_features] = output[batch_size, hidden_features] * weight2[out_features, hidden_features] + bias2[out_features + * Parameters: + * input (torch.Tensor): (∗,Hin ) where ∗ is batch_size and Hin=in_features + * weight (torch.Tensor): the learnable weights of the module of shape(out_features,in_features). + * bias (torch.Tensor): the learnable bias of the module of shape(out_features) + * + * Output: (*,Hout ) where all but the last dimension are the same shape as the input and Hout = out_features. + * + **************************************************************************/ +std::vector linear_gelu_linear_forward(at::Tensor input, at::Tensor weight, at::Tensor bias, + at::Tensor weight2, at::Tensor bias2) +{ + const float alpha = 1.0, beta = 0.0; + + int64_t batch_size = input.size(0); // input[batch_size, in_features] + int64_t in_features = input.size(1); // bias[hidden_features] and bias2[out_features] + int64_t hidden_features = weight.size(0); // weight[hidden_features, in_features] + int64_t out_features = weight2.size(0); // weight2[out_features, hidden_features] + + + at::Tensor dummy_gelu = at::empty({0}, torch::device(torch::kCUDA).dtype(input.scalar_type())); + + // ********************************************************************************** + // output[batch_size, hidden_features] = input[batch_size, in_features] * weight[hidden_features,in_features] + bias[hidden_features] + // ********************************************************************************** + at::Tensor output = at::zeros({batch_size, hidden_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); + at::Tensor gelu = at::zeros({batch_size, hidden_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); + + // ********************************************************************************** + // output2[batch_size,out_features] = output[batch_size, hidden_features] * weight2[out_features, hidden_features] + bias2[out_features] + // ********************************************************************************** + at::Tensor output2 = at::zeros({batch_size,out_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); // output2[batch_size,out_features] + +#if DEBUG + std::cout << "linear_gelu_linear_forward " << std::endl; #endif + if (at::globalContext().blasPreferredBackend() == at::BlasBackend::Cublaslt) { + CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_T, HIPBLAS_OP_N, &alpha, &beta, weight, input, output, bias, gelu, true, false, true)); + CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_T, HIPBLAS_OP_N, &alpha, &beta, weight2, output, output2, bias2, dummy_gelu, true, false, false)); + } else { + std::cout << "linear_gelu_linear_forward not implimented for non-MI300 GPU" << std::endl; + } + return {output, output2, gelu}; } -template -int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, void *lt_workspace) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - // Get the stream from cublas handle to reuse for biasReLU kernel. - cudaStream_t stream; - cublasGetStream(handle, &stream); - const float alpha = 1.0; - const float beta_zero = 0.0; - int status = 1; -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 -//wgrad for first gemm - status = gemm_bgradb_lt( - (cublasLtHandle_t)handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - hidden_features, - out_features, - batch_size, - &alpha, /* host pointer */ - output1, - hidden_features, - d_output2, - out_features, - &beta_zero, /* host pointer */ - d_weight2, - hidden_features, - lt_workspace, - 1 << 22, - stream, - true, - static_cast(d_bias2)); -//dgrad for second GEMM - status = gemm_dgelu_bgradb_lt( - (cublasLtHandle_t)handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - hidden_features, - batch_size, - out_features, - &alpha, /* host pointer */ - weight2, - hidden_features, - d_output2, - out_features, - &beta_zero, /* host pointer */ - d_output1, - hidden_features, - lt_workspace, - 1 << 22, - stream, - static_cast(gelu_in), - static_cast(d_bias1)); -//wgrad for the first GEMM - status = gemm_bias( - handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - in_features, - hidden_features, - batch_size, - &alpha, - input, - in_features, - d_output1, - hidden_features, - &beta_zero, - d_weight1, - in_features); - -//dgrad for the first GEMM - status = gemm_bias( - handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - in_features, - batch_size, - hidden_features, - &alpha, - weight1, - in_features, - d_output1, - hidden_features, - &beta_zero, - d_input, - in_features); +/**************************************************************************** + * In the backward pass, we compute the gradients of the loss with respect to input, weight, and bias. + * The key matrix operations are: + * For second gemm + * 1. Gradient of Input (dX): grad_output[batch_size, hidden_features] = output2[batch_size,out_features] ⋅ weight2[out_features, hidden_features] + * 2. Gradient of Weights (dW): grad_weight[hidden_features, in_features] = output[batch_size, hidden_features](T) ⋅ output2[batch_size,out_features] + * For First gemm + * 1. Gradient of Input (dX): grad_input[batch_size, in_features] = output[batch_size, hidden_features] ⋅ weight[hidden_features,in_features](T) + * 2. Gradient of Weights (dW): grad_weight[hidden_features, in_features] = input[batch_size, in_features](T) ⋅ output[batch_size, hidden_features] + **************************************************************************/ +std::vector linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu, at::Tensor output, at::Tensor weight, + at::Tensor weight2, at::Tensor output2) +{ + const float alpha = 1.0, beta = 0.0; + + int64_t batch_size = input.size(0); + int64_t in_features = input.size(1); + int64_t hidden_features = weight.size(0); + int64_t out_features = weight2.size(0); + + hipblasStatus_t status = HIPBLAS_STATUS_NOT_INITIALIZED; + + hipblasOperation_t trans_a = HIPBLAS_OP_T; + hipblasOperation_t trans_b = HIPBLAS_OP_N; + + at::Tensor grad_weight = at::zeros({hidden_features, in_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); + at::Tensor grad_weight2 = at::zeros({out_features, hidden_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); + at::Tensor grad_bias = at::zeros({hidden_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); + at::Tensor grad_bias2 = at::zeros({out_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); + at::Tensor grad_input = at::zeros({batch_size, in_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); + at::Tensor grad_output = at::zeros({batch_size, hidden_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); + + at::Tensor dummy_gelu = at::empty({0}, torch::device(torch::kCUDA).dtype(input.scalar_type())); +#if DEBUG + std::cout << "linear_gelu_linear_backward " << std::endl; #endif - return status; - + if (at::globalContext().blasPreferredBackend() == at::BlasBackend::Cublaslt) { + // ********************************************************************************** + // Gradient For second gemm : + // grad_output[batch_size, hidden_features] = output2[batch_size,out_features] ⋅ weight2[out_features, hidden_features] + // grad_weight[out_features,in_features] = input[batch_size, in_features](T) * output[batch_size, out_features] + // ********************************************************************************** + CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_N, HIPBLAS_OP_N, &alpha, &beta, weight2, output2, grad_output, grad_bias2, dummy_gelu, false, false, false)); + CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_N, HIPBLAS_OP_T, &alpha, &beta, output2, output, grad_weight2, grad_bias2, dummy_gelu, true, false, false)); + grad_bias2 = output2.sum(0, false); // ToDo: Check why HipBLASLt fail to get bgrad above so this step is not needed. + + // ********************************************************************************** + // Gradient For First gemm : + // grad_input [batch_size, in_features] = output[batch_size, out_features] * Weight[out_features,in_features] + // grad_weight[out_features,in_features] = input[batch_size, in_features](T) * output[batch_size, out_features] + // ********************************************************************************** + CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_N, HIPBLAS_OP_N, &alpha, &beta, weight, output, grad_input, grad_bias2, dummy_gelu, false, false, false)); + CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_N, HIPBLAS_OP_T, &alpha, &beta, output, input, grad_weight, grad_bias2, dummy_gelu, true, false, false)); + grad_bias = output.sum(0, false); // ToDo: Check why HipBLASLt fail to get bgrad above so this step is not needed. + } else { + std::cout << "linear_gelu_linear_backward not implimented for non-MI300 GPU" << std::endl; + } + return {grad_input, grad_weight, grad_bias, grad_weight2, grad_bias2}; } - - -template int linear_bias_forward_cuda(at::Tensor input, at::Half *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace); - -template int linear_bias_forward_cuda(at::Tensor input, float *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace); - -template int linear_bias_forward_cuda(at::Tensor input, double *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace); - -template int linear_bias_backward_cuda(at::Half *input, at::Half *weight, at::Half *d_output, int in_features, int batch_size, int out_features, at::Half *d_weight, at::Half *d_bias, at::Half *d_input, void *lt_workspace) ; - -template int linear_bias_backward_cuda(float *input, float *weight, float *d_output, int in_features, int batch_size, int out_features, float *d_weight, float *d_bias, float *d_input, void *lt_workspace) ; - -template int linear_bias_backward_cuda(double *input, double *weight, double *d_output, int in_features, int batch_size, int out_features, double *d_weight, double *d_bias, double *d_input, void *lt_workspace) ; - - -template int linear_gelu_linear_forward_cuda(at::Half *input, at::Half *weight1, at::Half *bias1, at::Half *weight2, at::Half *bias2, int in_features, int hidden_features, int batch_size, int out_features, at::Half *output1, at::Half *output2, at::Half *gelu_in, void *lt_workspace) ; - -template int linear_gelu_linear_forward_cuda(float *input, float *weight1, float *bias1, float *weight2, float *bias2, int in_features, int hidden_features, int batch_size, int out_features, float *output1, float *output2, float *gelu_in, void *lt_workspace); - -template int linear_gelu_linear_forward_cuda(double *input, double *weight1, double *bias1, double *weight2, double *bias2, int in_features, int hidden_features, int batch_size, int out_features, double *output1, double *output2, double *gelu_in, void *lt_workspace) ; - -template int linear_gelu_linear_backward_cuda(at::Half *input, at::Half *gelu_in, at::Half *output1, at::Half *weight1, at::Half *weight2, at::Half *d_output1, at::Half *d_output2, int in_features, int batch_size, int hidden_features, int out_features, at::Half *d_weight1, at::Half *d_weight2, at::Half *d_bias1, at::Half *d_bias2, at::Half *d_input, void *lt_workspace); - -template int linear_gelu_linear_backward_cuda(float *input, float *gelu_in, float *output1, float *weight1, float *weight2, float *d_output1, float *d_output2, int in_features, int batch_size, int hidden_features, int out_features, float *d_weight1, float *d_weight2, float *d_bias1, float *d_bias2, float *d_input, void *lt_workspace); - -template int linear_gelu_linear_backward_cuda(double *input, double *gelu_in, double *output1, double *weight1, double *weight2, double *d_output1, double *d_output2, int in_features, int batch_size, int hidden_features, int out_features, double *d_weight1, double *d_weight2, double *d_bias1, double *d_bias2, double *d_input, void *lt_workspace); diff --git a/csrc/layer_norm_cuda.cpp b/csrc/layer_norm_cuda.cpp index 869870178..99037fb6b 100644 --- a/csrc/layer_norm_cuda.cpp +++ b/csrc/layer_norm_cuda.cpp @@ -142,7 +142,7 @@ void cuda_layer_norm( at::Tensor* beta, double epsilon); -#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) @@ -214,7 +214,7 @@ void cuda_layer_norm_gradient( at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, #ifdef VERSION_GE_1_1 @@ -227,38 +227,45 @@ void cuda_layer_norm_gradient( double epsilon, at::Tensor* grad_input, at::Tensor* grad_gamma, - at::Tensor* grad_beta + at::Tensor* grad_beta, + bool memory_efficient ); at::Tensor layer_norm_gradient( at::Tensor dout, - at::Tensor mean, + c10::optional mean_, at::Tensor invvar, - at::Tensor input, + at::Tensor input_or_output, #ifdef VERSION_GE_1_1 at::IntArrayRef normalized_shape, #else at::IntList normalized_shape, #endif - double epsilon) { + double epsilon, + bool memory_efficient) { CHECK_INPUT(dout); - CHECK_INPUT(mean); CHECK_INPUT(invvar); - CHECK_INPUT(input); + CHECK_INPUT(input_or_output); int n1,n2; - check_args(input,normalized_shape,n1,n2); - at::Tensor grad_input = at::empty_like(input); - cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2, - normalized_shape,NULL,NULL,epsilon, - &grad_input,NULL,NULL); + check_args(input_or_output,normalized_shape,n1,n2); + at::Tensor grad_input = at::empty_like(input_or_output); + if (mean_.has_value()) { + cuda_layer_norm_gradient(&dout,&mean_.value(),&invvar,&input_or_output,n1,n2, + normalized_shape,NULL,NULL,epsilon, + &grad_input,NULL,NULL,memory_efficient); + } else { + cuda_layer_norm_gradient(&dout,NULL,&invvar,&input_or_output,n1,n2, + normalized_shape,NULL,NULL,epsilon, + &grad_input,NULL,NULL,memory_efficient); + } return grad_input; } std::vector layer_norm_gradient_affine( at::Tensor dout, - at::Tensor mean, + c10::optional mean_, at::Tensor invvar, - at::Tensor input, + at::Tensor input_or_output, #ifdef VERSION_GE_1_1 at::IntArrayRef normalized_shape, #else @@ -266,21 +273,28 @@ std::vector layer_norm_gradient_affine( #endif at::Tensor gamma, at::Tensor beta, - double epsilon) { + double epsilon, + bool memory_efficient) { CHECK_INPUT(dout); - CHECK_INPUT(mean); CHECK_INPUT(invvar); - CHECK_INPUT(input); + CHECK_INPUT(input_or_output); CHECK_INPUT(gamma); CHECK_INPUT(beta); int n1,n2; - check_args(input,normalized_shape,gamma,beta,n1,n2); - at::Tensor grad_input = at::empty_like(input); + check_args(input_or_output,normalized_shape,gamma,beta,n1,n2); + at::Tensor grad_input = at::empty_like(input_or_output); at::Tensor grad_gamma = at::empty_like(gamma); at::Tensor grad_beta = at::empty_like(beta); - cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2, - normalized_shape,&gamma,&beta,epsilon, - &grad_input,&grad_gamma,&grad_beta); +// at::Tensor *mean = mean_.has_value() ? &mean_.value() : NULL; + if (mean_.has_value()) { + cuda_layer_norm_gradient(&dout,&mean_.value(),&invvar,&input_or_output,n1,n2, + normalized_shape,&gamma,&beta,epsilon, + &grad_input,&grad_gamma,&grad_beta,memory_efficient); + } else { + cuda_layer_norm_gradient(&dout,NULL,&invvar,&input_or_output,n1,n2, + normalized_shape,&gamma,&beta,epsilon, + &grad_input,&grad_gamma,&grad_beta,memory_efficient); + } return {grad_input, grad_gamma, grad_beta}; } @@ -298,7 +312,7 @@ void cuda_rms_norm( at::Tensor* gamma, double epsilon); -#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) @@ -364,7 +378,7 @@ std::vector rms_norm_affine_mixed_dtypes( void cuda_rms_norm_gradient( at::Tensor* dout, at::Tensor* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, #ifdef VERSION_GE_1_1 @@ -375,68 +389,71 @@ void cuda_rms_norm_gradient( at::Tensor* gamma, double epsilon, at::Tensor* grad_input, - at::Tensor* grad_gamma); + at::Tensor* grad_gamma, + bool memory_efficient); at::Tensor rms_norm_gradient( at::Tensor dout, at::Tensor invvar, - at::Tensor input, + at::Tensor input_or_output, #ifdef VERSION_GE_1_1 at::IntArrayRef normalized_shape, #else at::IntList normalized_shape, #endif - double epsilon) { + double epsilon, + bool memory_efficient) { CHECK_INPUT(dout); CHECK_INPUT(invvar); - CHECK_INPUT(input); + CHECK_INPUT(input_or_output); int n1,n2; - check_args(input,normalized_shape,n1,n2); - at::Tensor grad_input = at::empty_like(input); - cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2, + check_args(input_or_output,normalized_shape,n1,n2); + at::Tensor grad_input = at::empty_like(input_or_output); + cuda_rms_norm_gradient(&dout,&invvar,&input_or_output,n1,n2, normalized_shape,NULL,epsilon, - &grad_input,NULL); + &grad_input,NULL,memory_efficient); return grad_input; } std::vector rms_norm_gradient_affine( at::Tensor dout, at::Tensor invvar, - at::Tensor input, + at::Tensor input_or_output, #ifdef VERSION_GE_1_1 at::IntArrayRef normalized_shape, #else at::IntList normalized_shape, #endif at::Tensor gamma, - double epsilon) { + double epsilon, + bool memory_efficient) { CHECK_INPUT(dout); CHECK_INPUT(invvar); - CHECK_INPUT(input); + CHECK_INPUT(input_or_output); CHECK_INPUT(gamma); int n1,n2; - check_args(input,normalized_shape,gamma,n1,n2); - at::Tensor grad_input = at::empty_like(input); + check_args(input_or_output,normalized_shape,gamma,n1,n2); + at::Tensor grad_input = at::empty_like(input_or_output); at::Tensor grad_gamma = at::empty_like(gamma); - cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2, + cuda_rms_norm_gradient(&dout,&invvar,&input_or_output,n1,n2, normalized_shape,&gamma,epsilon, - &grad_input,&grad_gamma); + &grad_input,&grad_gamma,memory_efficient); return {grad_input, grad_gamma}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); - m.def("forward", &layer_norm, "LayerNorm forward (CUDA)"); - m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); - m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)"); + m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)", py::call_guard()); + m.def("forward", &layer_norm, "LayerNorm forward (CUDA)", py::call_guard()); + m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)", py::call_guard()); + m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)", py::call_guard()); - m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation"); + m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation", py::call_guard()); - m.def("rms_forward_affine", &rms_norm_affine, "RMSNorm forward (CUDA)"); - m.def("rms_forward", &rms_norm, "RMSNorm forward (CUDA)"); - m.def("rms_backward_affine", &rms_norm_gradient_affine, "RMSNorm backward (CUDA)"); - m.def("rms_backward", &rms_norm_gradient, "RMSNorm backward (CUDA)"); + m.def("rms_forward_affine", &rms_norm_affine, "RMSNorm forward (CUDA)", py::call_guard()); + m.def("rms_forward", &rms_norm, "RMSNorm forward (CUDA)", py::call_guard()); + m.def("rms_backward_affine", &rms_norm_gradient_affine, "RMSNorm backward (CUDA)", py::call_guard()); + m.def("rms_backward", &rms_norm_gradient, "RMSNorm backward (CUDA)", py::call_guard()); - m.def("rms_forward_affine_mixed_dtypes", &rms_norm_affine_mixed_dtypes, "RMSNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation"); -} + m.def("rms_forward_affine_mixed_dtypes", &rms_norm_affine_mixed_dtypes, "RMSNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation", py::call_guard()); +} \ No newline at end of file diff --git a/csrc/layer_norm_cuda_kernel.cu b/csrc/layer_norm_cuda_kernel.cu index 21366772c..706ec8162 100644 --- a/csrc/layer_norm_cuda_kernel.cu +++ b/csrc/layer_norm_cuda_kernel.cu @@ -7,6 +7,7 @@ #include #include "type_shim.h" +#include "static_switch.h" template __device__ void cuWelfordOnlineSum( @@ -73,7 +74,8 @@ void cuWelfordMuSigma2( const int i1, U& mu, U& sigma2, - U* buf, + U* buf, + const int GPU_WARP_SIZE, bool rms_only) { // Assumptions: @@ -85,6 +87,7 @@ void cuWelfordMuSigma2( U count = U(0); mu= U(0); sigma2 = U(0); + if (i1 < n1) { // one warp normalizes one n1 index, // synchronization is implicit @@ -103,6 +106,9 @@ void cuWelfordMuSigma2( } } } + + + for (; l < n2; ++l) { U curr = static_cast(lvals[l]); if (!rms_only) { @@ -111,16 +117,31 @@ void cuWelfordMuSigma2( cuRMSOnlineSum(curr, sigma2); } } + // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x+(1<(muB,sigma2B,countB,mu,sigma2,count); - } else { - cuChanRMSOnlineSum(sigma2B, sigma2); + if(USE_ROCM){ + #pragma unroll + for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { + U sigma2B = WARP_SHFL_DOWN(sigma2, stride); + if (!rms_only) { + U muB = WARP_SHFL_DOWN(mu, stride); + U countB = WARP_SHFL_DOWN(count, stride); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + }else{ + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1<(muB,sigma2B,countB,mu,sigma2,count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } } } // threadIdx.x == 0 has correct values for each warp @@ -183,6 +204,7 @@ void cuWelfordMuSigma2( float& mu, float& sigma2, float* buf, + const int GPU_WARP_SIZE, bool rms_only) { // Assumptions: @@ -238,15 +260,30 @@ void cuWelfordMuSigma2( } } // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x+(1< 0; stride /= 2) { + float sigma2B = WARP_SHFL_DOWN(sigma2, stride); + if (!rms_only) { + float muB = WARP_SHFL_DOWN(mu, stride); + float countB = WARP_SHFL_DOWN(count, stride); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } else { + cuChanRMSOnlineSum(sigma2B, sigma2); + } + } + } + else{ + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1< U rsqrt(U v) { return U(1) / sqrt(v); } template<> float rsqrt(float v) { - return rsqrtf(v); + #if defined (USE_ROCM) + return 1/sqrtf(v); + #else + return rsqrtf(v); + #endif } template<> double rsqrt(double v) { return rsqrt(v); @@ -360,6 +401,7 @@ void cuApplyLayerNorm_( const U epsilon, const V* __restrict__ gamma, const V* __restrict__ beta, + const int GPU_WARP_SIZE, bool rms_only ) { @@ -371,7 +413,7 @@ void cuApplyLayerNorm_( SharedMemory shared; U* buf = shared.getPointer(); U mu,sigma2; - cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf,rms_only); + cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf,GPU_WARP_SIZE,rms_only); const T* lvals = vals + i1*n2; V* ovals = output_vals + i1*n2; @@ -418,10 +460,11 @@ void cuApplyLayerNorm( const int n2, const U epsilon, const V* __restrict__ gamma, - const V* __restrict__ beta + const V* __restrict__ beta, + const int warp_size ) { - cuApplyLayerNorm_(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, false); + cuApplyLayerNorm_(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, warp_size, false); } template __global__ @@ -432,12 +475,35 @@ void cuApplyRMSNorm( const int n1, const int n2, const U epsilon, - const V* __restrict__ gamma) + const V* __restrict__ gamma, + const int warp_size + ) { - cuApplyLayerNorm_(output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, true); + cuApplyLayerNorm_(output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, warp_size, true); } -template __device__ + +template __device__ +V clamp_by_magnitude(V curr_gamma, double eps) +{ + const V kMinGamma = V(eps); + if (curr_gamma >= 0) { + if (curr_gamma < kMinGamma) { + return kMinGamma; + } else { + return curr_gamma; + } + } else { + if (curr_gamma > -kMinGamma) { + return -kMinGamma; + } else { + return curr_gamma; + } + } +} + + +template __device__ void cuLoadWriteStridedInputs( const int i1_block, const int thr_load_row_off, @@ -446,34 +512,41 @@ void cuLoadWriteStridedInputs( const int row_stride, U* warp_buf1, U* warp_buf2, - const T* input, + const T* input_or_output, const V* dout, const int i1_end, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, + const V* __restrict__ gamma, + const V* __restrict__ beta, + const double eps, bool rms_only ) { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; int load_idx = i1*n2+i2; int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; if (i2(input[load_idx]); + U c_h = static_cast(input_or_output[load_idx]); U curr_dout = static_cast(dout[load_idx]); if (!rms_only) { warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + if (MemoryEfficient) { + U curr_beta = static_cast(beta[i2]); + warp_buf2[write_idx] = curr_dout * (c_h - curr_beta) / static_cast(clamp_by_magnitude(gamma[i2], eps)); + } else { + warp_buf2[write_idx] = curr_dout * (c_h - mean[i1]) * invvar[i1]; + } } else { - warp_buf2[write_idx] = curr_dout * (curr_input) * curr_invvar; + if (MemoryEfficient) { + warp_buf2[write_idx] = curr_dout * (c_h) / static_cast(clamp_by_magnitude(gamma[i2], eps)); + } else { + warp_buf2[write_idx] = curr_dout * (c_h) * invvar[i1]; + } } } else { if (!rms_only) { @@ -493,7 +566,7 @@ void cuLoadWriteStridedInputs( } } -template __device__ +template __device__ void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, @@ -502,34 +575,41 @@ void cuLoadAddStridedInputs( const int row_stride, U* warp_buf1, U* warp_buf2, - const T* input, + const T* input_or_output, const V* dout, const int i1_end, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, + const V* __restrict__ gamma, + const V* __restrict__ beta, + const double eps, bool rms_only ) { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; int load_idx = i1*n2+i2; int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; if (i2(input[load_idx]); + U c_h = static_cast(input_or_output[load_idx]); U curr_dout = static_cast(dout[load_idx]); if (!rms_only) { + U curr_beta = static_cast(beta[i2]); warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + if (MemoryEfficient) { + warp_buf2[write_idx] += curr_dout * (c_h - curr_beta) / static_cast(clamp_by_magnitude(gamma[i2], eps)); + } else { + warp_buf2[write_idx] += curr_dout * (c_h - mean[i1]) * invvar[i1]; + } } else { - warp_buf2[write_idx] += curr_dout * (curr_input) * curr_invvar; + if (MemoryEfficient) { + warp_buf2[write_idx] += curr_dout * (c_h) / static_cast(clamp_by_magnitude(gamma[i2], eps)); + } else { + warp_buf2[write_idx] += curr_dout * (c_h) * invvar[i1]; + } } } } @@ -537,17 +617,20 @@ void cuLoadAddStridedInputs( } -template __global__ +template __global__ void cuComputePartGradGammaBeta( const V* __restrict__ dout, - const T* __restrict__ input, + const T* __restrict__ input_or_output, const int n1, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta, U* part_grad_gamma, U* part_grad_beta, + const double eps, bool rms_only) { const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); @@ -565,9 +648,9 @@ void cuComputePartGradGammaBeta( U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only); + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps, rms_only); for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { - cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only); + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps, rms_only); } __syncthreads(); // inter-warp reductions @@ -675,78 +758,108 @@ void cuComputeGradGammaBeta( } -template __global__ +template __global__ void cuComputeGradInput( const V* __restrict__ dout, - const T* __restrict__ input, + const T* __restrict__ input_or_output, const int n1, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, U epsilon, const V* gamma, + const V* beta, T* grad_input, + const double eps, bool rms_only) { for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { U sum_loss1 = U(0); U sum_loss2 = U(0); - U c_mean; - if (!rms_only) { - c_mean = mean[i1]; - } - const U c_invvar = invvar[i1]; - const T* k_input = input + i1*n2; + const T* k_h = input_or_output + i1*n2; const V* k_dout = dout + i1*n2; + const U c_invvar = invvar[i1]; + const U c_mean = !MemoryEfficient ? mean[i1] : 0.; const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; if (gamma != NULL) { int l = 4*thrx; for (; l+3 < n2; l+=4*numx) { for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l+k]); + const U c_h = static_cast(k_h[l+k]); const U c_loss = static_cast(k_dout[l+k]); if (!rms_only) { sum_loss1 += c_loss * gamma[l+k]; - sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * (c_h - beta[l+k]); + } else { + sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; + } } else { - sum_loss2 += c_loss * gamma[l+k] * (c_h) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * gamma[l+k] * (c_h) * c_invvar; + } } } } for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); + const U c_h = static_cast(k_h[l]); const U c_loss = static_cast(k_dout[l]); if (!rms_only) { sum_loss1 += c_loss * gamma[l]; - sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * (c_h - beta[l]); + } else { + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } } else { - sum_loss2 += c_loss * gamma[l] * (c_h) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * gamma[l] * (c_h) * c_invvar; + } } - } } else { int l = 4*thrx; for (; l+3 < n2; l+=4*numx) { for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l+k]); + const U c_h = static_cast(k_h[l+k]); const U c_loss = static_cast(k_dout[l+k]); if (!rms_only) { sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } } else { - sum_loss2 += c_loss * (c_h) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * (c_h) * c_invvar; + } } } } for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); + const U c_h = static_cast(k_h[l]); const U c_loss = static_cast(k_dout[l]); if (!rms_only) { sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } } else { - sum_loss2 += c_loss * (c_h) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * (c_h) * c_invvar; + } } } } @@ -801,28 +914,46 @@ void cuComputeGradInput( T* k_grad_input = grad_input + i1*n2; if (gamma != NULL) { for (int l = thrx; l < n2; l+=numx) { - const U c_h = static_cast(k_input[l]); + const U c_h = static_cast(k_h[l]); const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss * gamma[l]; + const U k_gamma = static_cast(clamp_by_magnitude(gamma[l], eps)); + U f_grad_input = fH * c_loss * k_gamma; if (!rms_only) { + const U k_beta = beta[l]; f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + if (MemoryEfficient) { + f_grad_input -= (c_h - k_beta) / k_gamma * sum_loss2; + } else { + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } } else { - f_grad_input -= (c_h) * c_invvar * sum_loss2; + if (MemoryEfficient) { + f_grad_input -= c_h / k_gamma * sum_loss2; + } else { + f_grad_input -= c_h * c_invvar * sum_loss2; + } } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } } else { for (int l = thrx; l < n2; l+=numx) { - const U c_h = static_cast(k_input[l]); + const U c_h = static_cast(k_h[l]); const U c_loss = static_cast(k_dout[l]); U f_grad_input = fH * c_loss; if (!rms_only) { f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + if (MemoryEfficient) { + f_grad_input -= c_h * sum_loss2; + } else { + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } } else { - f_grad_input -= (c_h) * c_invvar * sum_loss2; + if (MemoryEfficient) { + f_grad_input -= c_h * sum_loss2; + } else { + f_grad_input -= c_h * c_invvar * sum_loss2; + } } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); @@ -848,15 +979,22 @@ void HostApplyLayerNorm( ) { auto stream = at::cuda::getCurrentCUDAStream().stream(); - const dim3 threads(32,4,1); + const int warp_size = at::cuda::warp_size(); + dim3 threads(warp_size ,4, 1); // MI100 wavefront/warp = 64 + #ifdef USE_ROCM + // Optimization for ROCm MI100 + threads.y = 1; + #endif + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); int nshared = threads.y > 1 ? threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : 0; + cuApplyLayerNorm<<>>( - output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); + output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta, warp_size); } template @@ -870,15 +1008,20 @@ void HostApplyRMSNorm( const V* gamma) { auto stream = at::cuda::getCurrentCUDAStream().stream(); - const dim3 threads(32,4,1); + const int warp_size = at::cuda::warp_size(); const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + dim3 threads(warp_size,4,1); + #ifdef USE_ROCM + // Optimization for ROCm MI100 + threads.y = 2; + #endif int nshared = threads.y > 1 ? threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : 0; cuApplyRMSNorm<<>>( - output, invvar, input, n1, n2, U(epsilon), gamma); + output, invvar, input, n1, n2, U(epsilon), gamma, warp_size); } void cuda_layer_norm( @@ -947,7 +1090,7 @@ void HostLayerNormGradient( const V* dout, const U* mean, const U* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, const V* gamma, @@ -955,7 +1098,8 @@ void HostLayerNormGradient( double epsilon, T* grad_input, V* grad_gamma, - V* grad_beta + V* grad_beta, + bool memory_efficient ) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -971,21 +1115,27 @@ void HostLayerNormGradient( // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that // the `cuda_layer_norm_gradient` doesn't support double. const auto part_grad_dtype = - (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? + (input_or_output->scalar_type() == at::ScalarType::Half || input_or_output->scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : - input->scalar_type(); - at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype)); + input_or_output->scalar_type(); + at::Tensor part_grad_gamma = at::empty({part_size,n2}, input_or_output->options().dtype(part_grad_dtype)); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); - cuComputePartGradGammaBeta<<>>( - dout, - input->DATA_PTR(), - n1,n2, - mean, - invvar, - U(epsilon), - part_grad_gamma.DATA_PTR(), - part_grad_beta.DATA_PTR(), - false); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{ + auto kernel = &cuComputePartGradGammaBeta; + kernel<<>>( + dout, + input_or_output->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + beta, + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR(), + epsilon, + false); + }); const dim3 threads3(32,8,1); const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); @@ -1008,29 +1158,35 @@ void HostLayerNormGradient( threads1.y > 1 ? threads1.y*threads1.x*sizeof(U) : 0; - cuComputeGradInput<<>>( - dout, - input->DATA_PTR(), - n1,n2, - mean, - invvar, - U(epsilon), - gamma, - grad_input, - false); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&] { + auto kernel = cuComputeGradInput; + kernel<<>>( + dout, + input_or_output->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + beta, + grad_input, + epsilon, + false); + }); } template void HostRMSNormGradient( const V* dout, const U* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, const V* gamma, double epsilon, T* grad_input, - V* grad_gamma) + V* grad_gamma, + bool memory_efficient) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -1044,20 +1200,27 @@ void HostRMSNormGradient( // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that // the `cuda_layer_norm_gradient` doesn't support double. const auto part_grad_dtype = - (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? + (input_or_output->scalar_type() == at::ScalarType::Half || input_or_output->scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : - input->scalar_type(); - at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype)); - cuComputePartGradGammaBeta<<>>( - dout, - input->DATA_PTR(), - n1,n2, - invvar, // unused - invvar, - U(epsilon), - part_grad_gamma.DATA_PTR(), - part_grad_gamma.DATA_PTR(), /* unused */ - true); + input_or_output->scalar_type(); + at::Tensor part_grad_gamma = at::empty({part_size,n2}, input_or_output->options().dtype(part_grad_dtype)); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{ + auto kernel = &cuComputePartGradGammaBeta; + kernel<<>>( + dout, + input_or_output->DATA_PTR(), + n1,n2, + invvar, /* unused */ + invvar, + U(epsilon), + gamma, + gamma, /* unused */ + part_grad_gamma.DATA_PTR(), + part_grad_gamma.DATA_PTR(), /* unused */ + epsilon, + true); + }); + const dim3 threads3(32,8,1); const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); @@ -1080,23 +1243,28 @@ void HostRMSNormGradient( threads1.y > 1 ? threads1.y*threads1.x*sizeof(U) : 0; - cuComputeGradInput<<>>( - dout, - input->DATA_PTR(), - n1,n2, - invvar, /* unused */ - invvar, - U(epsilon), - gamma, - grad_input, - true); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&] { + auto kernel = cuComputeGradInput; + kernel<<>>( + dout, + input_or_output->DATA_PTR(), + n1,n2, + invvar, /* unused */ + invvar, + U(epsilon), + gamma, + gamma, /* unused */ + grad_input, + epsilon, + true); + }); } void cuda_layer_norm_gradient( at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, #ifdef VERSION_GE_1_1 @@ -1109,18 +1277,19 @@ void cuda_layer_norm_gradient( double epsilon, at::Tensor* grad_input, at::Tensor* grad_gamma, - at::Tensor* grad_beta) + at::Tensor* grad_beta, + bool memory_efficient) { using namespace at; // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInput", + input_or_output->scalar_type(), gamma == NULL ? input_or_output->scalar_type() : gamma->scalar_type(), "cuComputeGradInput", using accscalar_t = at::acc_type; HostLayerNormGradient( dout->DATA_PTR(), - mean->DATA_PTR(), + mean != NULL ? mean->DATA_PTR() : NULL, invvar->DATA_PTR(), - input, + input_or_output, n1,n2, // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // if gamma Tensor is NULL on input. @@ -1129,14 +1298,15 @@ void cuda_layer_norm_gradient( epsilon, grad_input->DATA_PTR(), gamma != NULL ? grad_gamma->DATA_PTR() : NULL, - gamma != NULL ? grad_beta->DATA_PTR() : NULL); + gamma != NULL ? grad_beta->DATA_PTR() : NULL, + memory_efficient); ) } void cuda_rms_norm_gradient( at::Tensor* dout, at::Tensor* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, #ifdef VERSION_GE_1_1 @@ -1147,24 +1317,26 @@ void cuda_rms_norm_gradient( at::Tensor* gamma, double epsilon, at::Tensor* grad_input, - at::Tensor* grad_gamma) + at::Tensor* grad_gamma, + bool memory_efficient) { using namespace at; // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 // DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInputRMS", + input_or_output->scalar_type(), gamma == NULL ? input_or_output->scalar_type() : gamma->scalar_type(), "cuComputeGradInputRMS", using accscalar_t = at::acc_type; HostRMSNormGradient( dout->DATA_PTR(), invvar->DATA_PTR(), - input, + input_or_output, n1,n2, // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // if gamma Tensor is NULL on input. gamma != NULL ? gamma->DATA_PTR() : NULL, epsilon, grad_input->DATA_PTR(), - gamma != NULL ? grad_gamma->DATA_PTR() : NULL); + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + memory_efficient); ) -} +} \ No newline at end of file diff --git a/csrc/megatron/fused_bias_swiglu.cpp b/csrc/megatron/fused_bias_swiglu.cpp new file mode 100644 index 000000000..0f1cb8d5f --- /dev/null +++ b/csrc/megatron/fused_bias_swiglu.cpp @@ -0,0 +1,11 @@ +#include + +// Function declarations +torch::Tensor fused_bias_swiglu_forward(torch::Tensor input, torch::Tensor bias); +torch::Tensor fused_bias_swiglu_backward(torch::Tensor grad_output, torch::Tensor input, torch::Tensor bias); + +// Register functions for PyTorch extension +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &fused_bias_swiglu_forward, "Fused Bias SwiGLU Forward (CUDA)"); + m.def("backward", &fused_bias_swiglu_backward, "Fused Bias SwiGLU Backward (CUDA)"); +} \ No newline at end of file diff --git a/csrc/megatron/fused_bias_swiglu_cuda.cu b/csrc/megatron/fused_bias_swiglu_cuda.cu new file mode 100644 index 000000000..6f5e54961 --- /dev/null +++ b/csrc/megatron/fused_bias_swiglu_cuda.cu @@ -0,0 +1,143 @@ +#include +#include + +// Swish (SiLU) activation function: SiLU(x) = x * sigmoid(x) +__device__ __forceinline__ float silu(float x) { + return x / (1.0f + expf(-x)); +} + +// CUDA kernel for Fused Bias SwiGLU with chunking +template +__global__ void fused_bias_swiglu_kernel(const T* __restrict__ input, + const T* __restrict__ bias, + T* __restrict__ output, + int half_dim, + int max_index) { + int output_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = output_idx / half_dim; + int input_idx = output_idx + row_idx * half_dim; + int col_idx = output_idx - row_idx * half_dim; + + if (output_idx < max_index) { + int other_chunk_idx = input_idx + half_dim; + int other_col_idx = col_idx + half_dim; + + T x1 = input[input_idx] + bias[col_idx]; + T x2 = input[other_chunk_idx] + bias[other_col_idx]; + output[output_idx] = silu(x1) * x2; + } +} + +// CUDA Kernel: Computes the backward pass for fused bias SwiGLU +template +__global__ void fused_bias_swiglu_backward_kernel( + const T* __restrict__ grad_output, + const T* __restrict__ input, + const T* __restrict__ bias, + T* __restrict__ grad_input, + int half_dim, int max_index) { + + int output_idx = blockIdx.x * blockDim.x + threadIdx.x; + int row_idx = output_idx / half_dim; + int input_idx = output_idx + row_idx * half_dim; + int col_idx = output_idx - row_idx * half_dim; + + if (output_idx < max_index) { + int other_chunk_idx = input_idx + half_dim; + int other_col_idx = col_idx + half_dim; + + T y1 = input[input_idx] + bias[col_idx]; + T y2 = input[other_chunk_idx] + bias[other_col_idx]; + + T sigmoid_y1 = 1.0f / (1.0f + expf(-y1)); + T silu_y1 = y1 * sigmoid_y1; + + T g = grad_output[output_idx]; + T d_y1 = g * sigmoid_y1 * (1.0f + y1 * (1.0f - sigmoid_y1)) * y2; + T d_y2 = g * silu_y1; + + grad_input[input_idx] += d_y1; + grad_input[other_chunk_idx] += d_y2; + } +} + +// PyTorch interface for CUDA kernel +torch::Tensor fused_bias_swiglu_forward(torch::Tensor input, torch::Tensor bias) { + int batch_size = input.size(0); + int hidden_dim = input.size(1); + int half_dim = hidden_dim / 2; + TORCH_CHECK(hidden_dim % 2 == 0, "Hidden dimension must be divisible by 2 for SwiGLU"); + TORCH_CHECK(input.is_cuda(), "Input must be on CUDA device"); + TORCH_CHECK(bias.is_cuda(), "Bias must be on CUDA device"); + + input = input.contiguous(); + bias = bias.contiguous(); + + auto output = torch::zeros({batch_size, hidden_dim / 2}, input.options()); + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + int threads = prop.maxThreadsPerBlock; + int blocks = (batch_size * half_dim + threads - 1) / threads; + blocks = min(blocks, prop.maxGridSize[0]); + + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fused_bias_swiglu_forward", [&] { + fused_bias_swiglu_kernel<<>>( + input.data_ptr(), bias.data_ptr(), output.data_ptr(), half_dim, half_dim * batch_size + ); + }); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "CUDA kernel launch error: " << cudaGetErrorString(err) << std::endl; + } + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "CUDA kernel execution error: " << cudaGetErrorString(err) << std::endl; + } + + return output; +} + +// PyTorch interface for backward pass +torch::Tensor fused_bias_swiglu_backward( + torch::Tensor grad_output, torch::Tensor input, torch::Tensor bias) { + + int batch_size = input.size(0); + int hidden_dim = input.size(1); + int half_dim = hidden_dim / 2; + + TORCH_CHECK(hidden_dim % 2 == 0, "Hidden dimension must be divisible by 2 for SwiGLU"); + + auto grad_input = torch::zeros_like(input); + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + int threads = prop.maxThreadsPerBlock; + int blocks = (batch_size * half_dim + threads - 1) / threads; + blocks = min(blocks, prop.maxGridSize[0]); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fused_bias_swiglu_backward", [&] { + fused_bias_swiglu_backward_kernel<<>>( + grad_output.data_ptr(), + input.data_ptr(), + bias.data_ptr(), + grad_input.data_ptr(), + half_dim, half_dim * batch_size + ); + }); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "CUDA kernel launch error: " << cudaGetErrorString(err) << std::endl; + } + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "CUDA kernel execution error: " << cudaGetErrorString(err) << std::endl; + } + + return grad_input; +} \ No newline at end of file diff --git a/csrc/megatron/fused_rotary_positional_embedding.cpp b/csrc/megatron/fused_rotary_positional_embedding.cpp new file mode 100644 index 000000000..782e4ec5d --- /dev/null +++ b/csrc/megatron/fused_rotary_positional_embedding.cpp @@ -0,0 +1,243 @@ +/* coding=utf-8 + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace fused_rope { + +torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &freqs, + const bool transpose_output); + +torch::Tensor bwd_cuda(const torch::Tensor &output_grads, + const torch::Tensor &freqs, const bool transpose_output); + +torch::Tensor fwd_cached_cuda(const torch::Tensor &input, + const torch::Tensor &cos, + const torch::Tensor &sin, + const bool transpose_output); + +torch::Tensor bwd_cached_cuda(const torch::Tensor &output_grads, + const torch::Tensor &cos, + const torch::Tensor &sin, + const bool transpose_output); + +torch::Tensor fwd_thd_cuda(const torch::Tensor &input, + const torch::Tensor &cu_seqlens, + const torch::Tensor &freqs); + +torch::Tensor bwd_thd_cuda(const torch::Tensor &output_grads, + const torch::Tensor &cu_seqlens, + const torch::Tensor &freqs); + +torch::Tensor fwd_2d_cuda(const torch::Tensor &input, + const torch::Tensor &cos_h, + const torch::Tensor &sin_h, + const torch::Tensor &cos_w, + const torch::Tensor &sin_w); + +torch::Tensor bwd_2d_cuda(const torch::Tensor &output_grads, + const torch::Tensor &cos_h, + const torch::Tensor &sin_h, + const torch::Tensor &cos_w, + const torch::Tensor &sin_w); + +torch::Tensor fwd(const at::Tensor &input, const at::Tensor &freqs, + const bool transpose_output) { + TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(input.size(0) == freqs.size(0), + "expected input and freqs tensor have the same sequence length"); + TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(input.size(3) >= freqs.size(3), + "expected the last dim of the input tensor equals or is " + "greater than the freqs tensor"); + TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + + return fwd_cuda(input, freqs, transpose_output); +} + +torch::Tensor bwd(const torch::Tensor &output_grads, const at::Tensor &freqs, + const bool transpose_output) { + TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK( + output_grads.size(0) == freqs.size(0), + "expected output_grads and freqs tensor have the same sequence length"); + TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(output_grads.size(3) >= freqs.size(3), + "expected the last dim of the output_grads tensor equals or is " + "greater than the freqs tensor"); + TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + + return bwd_cuda(output_grads, freqs, transpose_output); +} + +torch::Tensor fwd_cached(const at::Tensor &input, const at::Tensor &cos, + const at::Tensor &sin, const bool transpose_output) { + TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(input.size(0) == cos.size(0), + "expected input and cos tensor have the same sequence length"); + TORCH_CHECK(input.size(0) == sin.size(0), + "expected input and sin tensor have the same sequence length"); + TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1, + "expected the second and third dims of the cos tensor equal 1"); + TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1, + "expected the second and third dims of the sin tensor equal 1"); + TORCH_CHECK(cos.size(3) == sin.size(3), + "expected cos and sin tensor have the same last dim"); + TORCH_CHECK(input.size(3) >= cos.size(3), + "expected the last dim of the input tensor equals or is " + "greater than the cos tensor"); + TORCH_CHECK(cos.scalar_type() == sin.scalar_type(), + "expected cos and sin tensor have the same dtype"); + + return fwd_cached_cuda(input, cos, sin, transpose_output); +} + +torch::Tensor bwd_cached(const torch::Tensor &output_grads, + const at::Tensor &cos, const at::Tensor &sin, + const bool transpose_output) { + TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); + TORCH_CHECK( + output_grads.size(0) == cos.size(0), + "expected output_grads and cos tensor have the same sequence length"); + TORCH_CHECK( + output_grads.size(0) == sin.size(0), + "expected output_grads and sin tensor have the same sequence length"); + TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1, + "expected the second and third dims of the cos tensor equal 1"); + TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1, + "expected the second and third dims of the sin tensor equal 1"); + TORCH_CHECK(cos.size(3) == sin.size(3), + "expected cos and sin tensor have the same last dim"); + TORCH_CHECK(output_grads.size(3) >= cos.size(3), + "expected the last dim of the output_grads tensor equals or is " + "greater than the cos tensor"); + TORCH_CHECK(cos.scalar_type() == sin.scalar_type(), + "expected cos and sin tensor have the same dtype"); + + return bwd_cached_cuda(output_grads, cos, sin, transpose_output); +} + +torch::Tensor fwd_thd(const torch::Tensor &input, + const torch::Tensor &cu_seqlens, + const torch::Tensor &freqs) { + TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); + TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); + TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(input.size(2) >= freqs.size(3), + "expected the last dim of the input tensor equals or is " + "greater than the freqs tensor"); + TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + + return fwd_thd_cuda(input, cu_seqlens, freqs); +} + +torch::Tensor bwd_thd(const torch::Tensor &output_grads, + const torch::Tensor &cu_seqlens, + const torch::Tensor &freqs) { + TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); + TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); + TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(output_grads.size(2) >= freqs.size(3), + "expected the last dim of the output_grads tensor equals or is " + "greater than the freqs tensor"); + TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + + return bwd_thd_cuda(output_grads, cu_seqlens, freqs); +} + +torch::Tensor fwd_2d(const torch::Tensor &input, const torch::Tensor &cos_h, + const torch::Tensor &sin_h, const torch::Tensor &cos_w, + const torch::Tensor &sin_w) { + TORCH_CHECK(input.dim() == 5, "expected input to be 5D tensor"); + TORCH_CHECK(cos_h.dim() == 4, "expected cos_h to be 4D tensor"); + TORCH_CHECK(sin_h.dim() == 4, "expected sin_h to be 4D tensor"); + TORCH_CHECK(cos_w.dim() == 4, "expected cos_w to be 4D tensor"); + TORCH_CHECK(sin_w.dim() == 4, "expected sin_w to be 4D tensor"); + TORCH_CHECK(cos_h.size(2) == 1, "expected third dim of cos_h/sin_h equals 1"); + TORCH_CHECK(input.size(1) <= cos_h.size(1), + "expected input's height <= cos_h/sin_h's"); + TORCH_CHECK(input.size(4) / 2 == cos_h.size(3), + "expected cos_h/sin_h's head dim equals input's head dim / 2"); + TORCH_CHECK(cos_w.size(2) == 1, "expected third dim of cos_w/sin_w equals 1"); + TORCH_CHECK(input.size(2) <= cos_w.size(1), + "expected input's width <= cos_w/sin_w's"); + TORCH_CHECK(input.size(4) / 2 == cos_w.size(3), + "expected cos_w/sin_w's head dim equals input's head dim / 2"); + + return fwd_2d_cuda(input, cos_h, sin_h, cos_w, sin_w); +} + +torch::Tensor bwd_2d(const torch::Tensor &output_grads, + const torch::Tensor &cos_h, const torch::Tensor &sin_h, + const torch::Tensor &cos_w, const torch::Tensor &sin_w) { + TORCH_CHECK(output_grads.dim() == 5, "expected output_grads to be 5D tensor"); + TORCH_CHECK(cos_h.dim() == 4, "expected cos_h to be 4D tensor"); + TORCH_CHECK(sin_h.dim() == 4, "expected sin_h to be 4D tensor"); + TORCH_CHECK(cos_w.dim() == 4, "expected cos_w to be 4D tensor"); + TORCH_CHECK(sin_w.dim() == 4, "expected sin_w to be 4D tensor"); + TORCH_CHECK(cos_h.size(2) == 1, "expected third dim of cos_h/sin_h equals 1"); + TORCH_CHECK(output_grads.size(1) <= cos_h.size(1), + "expected output_grads' height <= cos_h/sin_h's"); + TORCH_CHECK(output_grads.size(4) / 2 == cos_h.size(3), + "expected cos_h/sin_h's head dim equals output_grads' head dim / 2"); + TORCH_CHECK(cos_w.size(2) == 1, "expected third dim of cos_w/sin_w equals 1"); + TORCH_CHECK(output_grads.size(2) <= cos_w.size(1), + "expected output_grads' width <= cos_w/sin_w's"); + TORCH_CHECK(output_grads.size(4) / 2 == cos_w.size(3), + "expected cos_w/sin_w's head dim equals output_grads' head dim / 2"); + + return bwd_2d_cuda(output_grads, cos_h, sin_h, cos_w, sin_w); +} + +} // end namespace fused_rope + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &fused_rope::fwd, + "Fused Rotary Positional Embedding -- Forward."); + m.def("backward", &fused_rope::bwd, + "Fused Rotary Positional Embedding -- Backward."); + // cache sin/cos + m.def("forward_cached", &fused_rope::fwd_cached, + "Fused Rotary Positional Embedding Cached -- Forward."); + m.def("backward_cached", &fused_rope::bwd_cached, + "Fused Rotary Positional Embedding Cached -- Backward."); + // thd + m.def("forward_thd", &fused_rope::fwd_thd, + "Fused Rotary Positional Embedding for thd layout -- Forward."); + m.def("backward_thd", &fused_rope::bwd_thd, + "Fused Rotary Positional Embedding for thd layout -- Backward."); + // 2d + m.def("forward_2d", &fused_rope::fwd_2d, + "2D Fused Rotary Positional Embedding -- Forward."); + m.def("backward_2d", &fused_rope::bwd_2d, + "2D Fused Rotary Positional Embedding -- Backward."); +} diff --git a/csrc/megatron/fused_rotary_positional_embedding.h b/csrc/megatron/fused_rotary_positional_embedding.h new file mode 100644 index 000000000..1f031c338 --- /dev/null +++ b/csrc/megatron/fused_rotary_positional_embedding.h @@ -0,0 +1,486 @@ +/* coding=utf-8 + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace { + +template +__device__ void fused_rope_block_forward( + const scalar_t* src, const float* freqs, scalar_t* dst, + const int offset_block, const int offset_block_dst, const int h, + const int d, const int d2, const int stride_h, const int stride_d, + const int o_stride_h, const int o_stride_d) { + int s_id = blockIdx.x; +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + float v_cos, v_sin; + sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos); +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + scalar_t v_src = src[offset_src]; + scalar_t v_src_rotate = (d_id + d2 / 2 < d2) + ? -src[offset_src + (d2 / 2) * stride_d] + : src[offset_src + (d2 / 2 - d2) * stride_d]; + dst[offset_dst] = + v_src * (scalar_t)v_cos + v_src_rotate * (scalar_t)v_sin; + } + } + + // copy the rest + if (d > d2) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_head = offset_block + h_id * stride_h; + int offset_head_dst = offset_block_dst + h_id * o_stride_h; +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + dst[offset_head_dst + d_id * o_stride_d] = + src[offset_head + d_id * stride_d]; + } + } + } +} + +template +__device__ void fused_rope_block_backward( + const scalar_t* src, const float* freqs, scalar_t* dst, + const int offset_block, const int offset_block_dst, const int h, + const int d, const int d2, const int stride_h, const int stride_d, + const int o_stride_h, const int o_stride_d) { + int s_id = blockIdx.x; +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + scalar_t v_cos = cosf(freqs[s_id * d2 + d_id]); + scalar_t v_sin = (d_id + d2 / 2 < d2) + ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) + : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + scalar_t v_src = src[offset_src]; + scalar_t v_src_rotate = (d_id + d2 / 2 < d2) + ? src[offset_src + (d2 / 2) * stride_d] + : src[offset_src + (d2 / 2 - d2) * stride_d]; + dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } + } + + // handle the tail + if (d > d2) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_head = offset_block + h_id * stride_h; + int offset_head_dst = offset_block_dst + h_id * o_stride_h; +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + dst[offset_head_dst + d_id * o_stride_d] = + src[offset_head + d_id * stride_d]; + } + } + } +} + +template +__global__ void fused_rope_forward(const int h, const int d, const int d2, + const int stride_s, const int stride_b, + const int stride_h, const int stride_d, + const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, + const scalar_t* src, const float* freqs, + scalar_t* dst) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int offset_block = s_id * stride_s + b_id * stride_b; + int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; + fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, + d, d2, stride_h, stride_d, o_stride_h, o_stride_d); +} + +template +__global__ void fused_rope_backward(const int h, const int d, const int d2, + const int stride_s, const int stride_b, + const int stride_h, const int stride_d, + const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, + const scalar_t* src, const float* freqs, + scalar_t* dst) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int offset_block = s_id * stride_s + b_id * stride_b; + int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; + fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, + d, d2, stride_h, stride_d, o_stride_h, o_stride_d); +} + +template +__device__ void fused_rope_cached_block_forward( + const scalar_t_0* src, const scalar_t_1* cos, const scalar_t_1* sin, + scalar_t_0* dst, const int s_id, const int offset_block, + const int offset_block_dst, const int h, const int d, const int d2, + const int stride_h, const int stride_d, const int o_stride_h, + const int o_stride_d) { +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + scalar_t_0 v_cos = cos[s_id * d2 + d_id]; + scalar_t_0 v_sin = sin[s_id * d2 + d_id]; +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + scalar_t_0 v_src = src[offset_src]; + scalar_t_0 v_src_rotate = + (d_id + d2 / 2 < d2) ? -src[offset_src + (d2 / 2) * stride_d] + : src[offset_src + (d2 / 2 - d2) * stride_d]; + dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } + } + + // copy the rest + if (d > d2) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_head = offset_block + h_id * stride_h; + int offset_head_dst = offset_block_dst + h_id * o_stride_h; +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + dst[offset_head_dst + d_id * o_stride_d] = + src[offset_head + d_id * stride_d]; + } + } + } +} + +template +__device__ void fused_rope_cached_block_backward( + const scalar_t_0* src, const scalar_t_1* cos, const scalar_t_1* sin, + scalar_t_0* dst, const int s_id, const int offset_block, + const int offset_block_dst, const int h, const int d, const int d2, + const int stride_h, const int stride_d, const int o_stride_h, + const int o_stride_d) { +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + scalar_t_0 v_cos = cos[s_id * d2 + d_id]; + scalar_t_0 v_sin = (d_id + d2 / 2 < d2) + ? sin[s_id * d2 + d_id + d2 / 2] + : -sin[s_id * d2 + d_id + d2 / 2 - d2]; +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + scalar_t_0 v_src = src[offset_src]; + scalar_t_0 v_src_rotate = + (d_id + d2 / 2 < d2) ? src[offset_src + (d2 / 2) * stride_d] + : src[offset_src + (d2 / 2 - d2) * stride_d]; + dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } + } + + // handle the tail + if (d > d2) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_head = offset_block + h_id * stride_h; + int offset_head_dst = offset_block_dst + h_id * o_stride_h; +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + dst[offset_head_dst + d_id * o_stride_d] = + src[offset_head + d_id * stride_d]; + } + } + } +} + +template +__global__ void fused_rope_cached_forward( + const int h, const int d, const int d2, const int stride_s, + const int stride_b, const int stride_h, const int stride_d, + const int o_stride_s, const int o_stride_b, const int o_stride_h, + const int o_stride_d, const scalar_t_0* src, const scalar_t_1* cos, + const scalar_t_1* sin, scalar_t_0* dst) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int offset_block = s_id * stride_s + b_id * stride_b; + int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; + fused_rope_cached_block_forward(src, cos, sin, dst, s_id, offset_block, + offset_block_dst, h, d, d2, stride_h, + stride_d, o_stride_h, o_stride_d); +} + +template +__global__ void fused_rope_cached_backward( + const int h, const int d, const int d2, const int stride_s, + const int stride_b, const int stride_h, const int stride_d, + const int o_stride_s, const int o_stride_b, const int o_stride_h, + const int o_stride_d, const scalar_t_0* src, const scalar_t_1* cos, + const scalar_t_1* sin, scalar_t_0* dst) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int offset_block = s_id * stride_s + b_id * stride_b; + int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; + fused_rope_cached_block_backward(src, cos, sin, dst, s_id, offset_block, + offset_block_dst, h, d, d2, stride_h, + stride_d, o_stride_h, o_stride_d); +} + +template +__global__ void fused_rope_thd_forward( + const int h, const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, const scalar_t* src, + const int* cu_seqlens, const float* freqs, scalar_t* dst) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int t_id = s_id + cu_seqlens[b_id]; + if (t_id >= cu_seqlens[b_id + 1]) return; + int offset_block = t_id * stride_t; + int offset_block_dst = t_id * o_stride_t; + fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, + d, d2, stride_h, stride_d, o_stride_h, o_stride_d); +} + +template +__global__ void fused_rope_thd_backward( + const int h, const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, const scalar_t* src, + const int* cu_seqlens, const float* freqs, scalar_t* dst) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int t_id = s_id + cu_seqlens[b_id]; + if (t_id >= cu_seqlens[b_id + 1]) return; + int offset_block = t_id * stride_t; + int offset_block_dst = t_id * o_stride_t; + fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, + d, d2, stride_h, stride_d, o_stride_h, o_stride_d); +} + +template +__global__ void fused_rope_2d_forward( + const int ih, const int iw, const int h, const int d, const int stride_b, + const int stride_ih, const int stride_iw, const int stride_h, + const int stride_d, const int o_stride_b, const int o_stride_s, + const int o_stride_h, const int o_stride_d, const scalar_t_0* src, + const scalar_t_1* cos_h, const scalar_t_1* sin_h, const scalar_t_1* cos_w, + const scalar_t_1* sin_w, scalar_t_0* dst) { + int ih_id = blockIdx.x, iw_id = blockIdx.y, b_id = blockIdx.z; + // apply to height + int offset_block = b_id * stride_b + ih_id * stride_ih + iw_id * stride_iw; + int offset_block_dst = b_id * o_stride_b + (ih_id * iw + iw_id) * o_stride_s; + int s_id = ih_id; // for cos_h and sin_h + fused_rope_cached_block_forward(src, cos_h, sin_h, dst, s_id, offset_block, + offset_block_dst, h, d / 2, d / 2, stride_h, + stride_d, o_stride_h, o_stride_d); + // apply to width + offset_block += d / 2 * stride_d; + offset_block_dst += d / 2 * o_stride_d; + s_id = iw_id; // for cos_w and sin_w + fused_rope_cached_block_forward(src, cos_w, sin_w, dst, s_id, offset_block, + offset_block_dst, h, d / 2, d / 2, stride_h, + stride_d, o_stride_h, o_stride_d); +} + +template +__global__ void fused_rope_2d_backward( + const int ih, const int iw, const int h, const int d, const int stride_b, + const int stride_ih, const int stride_iw, const int stride_h, + const int stride_d, const int o_stride_b, const int o_stride_s, + const int o_stride_h, const int o_stride_d, const scalar_t_0* src, + const scalar_t_1* cos_h, const scalar_t_1* sin_h, const scalar_t_1* cos_w, + const scalar_t_1* sin_w, scalar_t_0* dst) { + int ih_id = blockIdx.x, iw_id = blockIdx.y, b_id = blockIdx.z; + // apply to height + int offset_block = b_id * stride_b + ih_id * stride_ih + iw_id * stride_iw; + int offset_block_dst = b_id * o_stride_b + (ih_id * iw + iw_id) * o_stride_s; + int s_id = ih_id; // for cos_h and sin_h + fused_rope_cached_block_backward(src, cos_h, sin_h, dst, s_id, offset_block, + offset_block_dst, h, d / 2, d / 2, stride_h, + stride_d, o_stride_h, o_stride_d); + // apply to width + offset_block += d / 2 * stride_d; + offset_block_dst += d / 2 * o_stride_d; + s_id = iw_id; // for cos_w and sin_w + fused_rope_cached_block_backward(src, cos_w, sin_w, dst, s_id, offset_block, + offset_block_dst, h, d / 2, d / 2, stride_h, + stride_d, o_stride_h, o_stride_d); +} + +} // end of anonymous namespace + +template +void dispatch_fused_rope_forward(const int s, const int b, const int h, + const int d, const int d2, const int stride_s, + const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, + const int o_stride_b, const int o_stride_h, + const int o_stride_d, const scalar_t* input, + const float* freqs, scalar_t* output) { + auto stream = at::cuda::getCurrentCUDAStream(); + + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(s, b); + dim3 threads(at::cuda::warp_size(), warps_per_block); + + fused_rope_forward<<>>( + h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, input, freqs, output); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void dispatch_fused_rope_backward(const int s, const int b, const int h, + const int d, const int d2, const int stride_s, + const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, + const int o_stride_b, const int o_stride_h, + const int o_stride_d, + const scalar_t* output_grads, + const float* freqs, scalar_t* input_grads) { + auto stream = at::cuda::getCurrentCUDAStream(); + + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(s, b); + dim3 threads(at::cuda::warp_size(), warps_per_block); + + fused_rope_backward<<>>( + h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, output_grads, freqs, input_grads); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void dispatch_fused_rope_cached_forward( + const int s, const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, const scalar_t_0* input, + const scalar_t_1* cos, const scalar_t_1* sin, scalar_t_0* output) { + auto stream = at::cuda::getCurrentCUDAStream(); + + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(s, b); + dim3 threads(at::cuda::warp_size(), warps_per_block); + + fused_rope_cached_forward<<>>( + h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, input, cos, sin, output); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void dispatch_fused_rope_cached_backward( + const int s, const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, const scalar_t_0* output_grads, + const scalar_t_1* cos, const scalar_t_1* sin, scalar_t_0* input_grads) { + auto stream = at::cuda::getCurrentCUDAStream(); + + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(s, b); + dim3 threads(at::cuda::warp_size(), warps_per_block); + + fused_rope_cached_backward<<>>( + h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, output_grads, cos, sin, input_grads); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void dispatch_fused_rope_thd_forward(const int max_s, const int b, const int h, + const int d, const int d2, + const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, + const scalar_t* input, + const int* cu_seqlens, const float* freqs, + scalar_t* output) { + auto stream = at::cuda::getCurrentCUDAStream(); + + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(max_s, b); + dim3 threads(at::cuda::warp_size(), warps_per_block); + + fused_rope_thd_forward<<>>( + h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, + o_stride_d, input, cu_seqlens, freqs, output); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void dispatch_fused_rope_thd_backward( + const int max_s, const int b, const int h, const int d, const int d2, + const int stride_t, const int stride_h, const int stride_d, + const int o_stride_t, const int o_stride_h, const int o_stride_d, + const scalar_t* output_grads, const int* cu_seqlens, const float* freqs, + scalar_t* input_grads) { + auto stream = at::cuda::getCurrentCUDAStream(); + + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(max_s, b); + dim3 threads(at::cuda::warp_size(), warps_per_block); + + fused_rope_thd_backward<<>>( + h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, + o_stride_d, output_grads, cu_seqlens, freqs, input_grads); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void dispatch_fused_rope_2d_forward( + const int b, const int ih, const int iw, const int h, const int d, + const int stride_b, const int stride_ih, const int stride_iw, + const int stride_h, const int stride_d, const int o_stride_b, + const int o_stride_s, const int o_stride_h, const int o_stride_d, + const scalar_t_0* input, const scalar_t_1* cos_h, const scalar_t_1* sin_h, + const scalar_t_1* cos_w, const scalar_t_1* sin_w, scalar_t_0* output) { + auto stream = at::cuda::getCurrentCUDAStream(); + + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(ih, iw, b); + dim3 threads(at::cuda::warp_size(), warps_per_block); + + fused_rope_2d_forward<<>>( + ih, iw, h, d, stride_b, stride_ih, stride_iw, stride_h, stride_d, + o_stride_b, o_stride_s, o_stride_h, o_stride_d, input, cos_h, sin_h, + cos_w, sin_w, output); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void dispatch_fused_rope_2d_backward( + const int b, const int ih, const int iw, const int h, const int d, + const int stride_b, const int stride_ih, const int stride_iw, + const int stride_h, const int stride_d, const int o_stride_b, + const int o_stride_s, const int o_stride_h, const int o_stride_d, + const scalar_t_0* output_grads, const scalar_t_1* cos_h, + const scalar_t_1* sin_h, const scalar_t_1* cos_w, const scalar_t_1* sin_w, + scalar_t_0* input_grads) { + auto stream = at::cuda::getCurrentCUDAStream(); + + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(ih, iw, b); + dim3 threads(at::cuda::warp_size(), warps_per_block); + + fused_rope_2d_backward<<>>( + ih, iw, h, d, stride_b, stride_ih, stride_iw, stride_h, stride_d, + o_stride_b, o_stride_s, o_stride_h, o_stride_d, output_grads, cos_h, + sin_h, cos_w, sin_w, input_grads); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} diff --git a/csrc/megatron/fused_rotary_positional_embedding_cuda.cu b/csrc/megatron/fused_rotary_positional_embedding_cuda.cu new file mode 100644 index 000000000..8d1547ffe --- /dev/null +++ b/csrc/megatron/fused_rotary_positional_embedding_cuda.cu @@ -0,0 +1,421 @@ +/* coding=utf-8 + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "fused_rotary_positional_embedding.h" +#include "type_shim.h" + +namespace fused_rope { + +torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &freqs, + const bool transpose_output) { + // input sizes: (s, b, h, d) + // s: sequence length + // b: batch size + // h: head num + // d: dim of each head + const int s = input.size(0); + const int b = input.size(1); + const int h = input.size(2); + const int d = input.size(3); + // input strides + const int stride_s = input.stride(0); + const int stride_b = input.stride(1); + const int stride_h = input.stride(2); + const int stride_d = input.stride(3); + // freqs' shape is always (s, 1, 1, d2), so the strides are same under + // different memory formats + const int d2 = freqs.size(3); + + // output + auto act_options = input.options().requires_grad(false); + torch::Tensor output; + if (transpose_output) { + output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + } else { + output = torch::empty({s, b, h, d}, act_options); + } + // output strides + const int o_stride_s = output.stride(0); + const int o_stride_b = output.stride(1); + const int o_stride_h = output.stride(2); + const int o_stride_d = output.stride(3); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), 0, "dispatch_fused_rope_forward", + dispatch_fused_rope_forward( + s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d, input.data_ptr(), + freqs.data_ptr(), output.data_ptr());); + return output; +} + +torch::Tensor bwd_cuda(const torch::Tensor &output_grads, + const torch::Tensor &freqs, + const bool transpose_output) { + // output_grads sizes: (s, b, h, d) + // s: sequence length + // b: batch size + // h: head num + // d: dim of each head + const int s = output_grads.size(0); + const int b = output_grads.size(1); + const int h = output_grads.size(2); + const int d = output_grads.size(3); + // output_grads strides + const int stride_s = output_grads.stride(0); + const int stride_b = output_grads.stride(1); + const int stride_h = output_grads.stride(2); + const int stride_d = output_grads.stride(3); + // freqs' shape is always (s, 1, 1, d2), so the strides are same under + // different memory formats + const int d2 = freqs.size(3); + + auto act_options = output_grads.options().requires_grad(false); + torch::Tensor input_grads; + if (transpose_output) { + input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + } else { + input_grads = torch::empty({s, b, h, d}, act_options); + } + const int o_stride_s = input_grads.stride(0); + const int o_stride_b = input_grads.stride(1); + const int o_stride_h = input_grads.stride(2); + const int o_stride_d = input_grads.stride(3); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + output_grads.scalar_type(), 0, "dispatch_fused_rope_backward", + dispatch_fused_rope_backward( + s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d, + output_grads.data_ptr(), freqs.data_ptr(), + input_grads.data_ptr());); + return input_grads; +} + +#define DISPATCH_FUSED_ROPE_TYPES(TYPE1, TYPE2, NAME, ...) \ + switch (TYPE1) { \ + case at::ScalarType::Float: { \ + using scalar_t_0 = float; \ + switch (TYPE2) { \ + case at::ScalarType::Float: { \ + using scalar_t_1 = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \ + "' with '", toString(TYPE2), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_0 = at::Half; \ + switch (TYPE2) { \ + case at::ScalarType::Float: { \ + using scalar_t_1 = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_1 = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \ + "' with '", toString(TYPE2), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_0 = at::BFloat16; \ + switch (TYPE2) { \ + case at::ScalarType::Float: { \ + using scalar_t_1 = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_1 = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \ + "' with '", toString(TYPE2), "'"); \ + } \ + break; \ + } \ + default: \ + TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \ + "' with '", toString(TYPE2), "'"); \ + } + +torch::Tensor fwd_cached_cuda(const torch::Tensor &input, + const torch::Tensor &cos, + const torch::Tensor &sin, + const bool transpose_output) { + // input sizes: (s, b, h, d) + // s: sequence length + // b: batch size + // h: head num + // d: dim of each head + const int s = input.size(0); + const int b = input.size(1); + const int h = input.size(2); + const int d = input.size(3); + // input strides + const int stride_s = input.stride(0); + const int stride_b = input.stride(1); + const int stride_h = input.stride(2); + const int stride_d = input.stride(3); + // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under + // different memory formats + const int d2 = cos.size(3); + + // output + auto act_options = input.options().requires_grad(false); + torch::Tensor output; + if (transpose_output) { + output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + } else { + output = torch::empty({s, b, h, d}, act_options); + } + // output strides + const int o_stride_s = output.stride(0); + const int o_stride_b = output.stride(1); + const int o_stride_h = output.stride(2); + const int o_stride_d = output.stride(3); + + DISPATCH_FUSED_ROPE_TYPES( + input.scalar_type(), cos.scalar_type(), + "dispatch_fused_rope_cached_forward", + dispatch_fused_rope_cached_forward( + s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d, input.data_ptr(), + cos.data_ptr(), sin.data_ptr(), + output.data_ptr());); + return output; +} + +torch::Tensor bwd_cached_cuda(const torch::Tensor &output_grads, + const torch::Tensor &cos, + const torch::Tensor &sin, + const bool transpose_output) { + // output_grads sizes: (s, b, h, d) + // s: sequence length + // b: batch size + // h: head num + // d: dim of each head + const int s = output_grads.size(0); + const int b = output_grads.size(1); + const int h = output_grads.size(2); + const int d = output_grads.size(3); + // output_grads strides + const int stride_s = output_grads.stride(0); + const int stride_b = output_grads.stride(1); + const int stride_h = output_grads.stride(2); + const int stride_d = output_grads.stride(3); + // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under + // different memory formats + const int d2 = cos.size(3); + + auto act_options = output_grads.options().requires_grad(false); + torch::Tensor input_grads; + if (transpose_output) { + input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + } else { + input_grads = torch::empty({s, b, h, d}, act_options); + } + const int o_stride_s = input_grads.stride(0); + const int o_stride_b = input_grads.stride(1); + const int o_stride_h = input_grads.stride(2); + const int o_stride_d = input_grads.stride(3); + + DISPATCH_FUSED_ROPE_TYPES( + output_grads.scalar_type(), cos.scalar_type(), + "dispatch_fused_rope_cached_backward", + dispatch_fused_rope_cached_backward( + s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d, + output_grads.data_ptr(), cos.data_ptr(), + sin.data_ptr(), input_grads.data_ptr());); + return input_grads; +} + +torch::Tensor fwd_thd_cuda(const torch::Tensor &input, + const torch::Tensor &cu_seqlens, + const torch::Tensor &freqs) { + // input sizes: (t, h, d) + // t: cumulative sum of sequence lengths + // h: head num + // d: dim of each head + const int t = input.size(0); + const int h = input.size(1); + const int d = input.size(2); + // input strides + const int stride_t = input.stride(0); + const int stride_h = input.stride(1); + const int stride_d = input.stride(2); + // batch size + const int b = cu_seqlens.size(0) - 1; + // freqs' shape is (max_s, 1, 1, d2) + const int max_s = freqs.size(0); + const int d2 = freqs.size(3); + + // output + auto act_options = input.options().requires_grad(false); + auto output = torch::empty({t, h, d}, act_options); + // output strides + const int o_stride_t = output.stride(0); + const int o_stride_h = output.stride(1); + const int o_stride_d = output.stride(2); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), 0, "dispatch_fused_rope_thd_forward", + dispatch_fused_rope_thd_forward( + max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, + o_stride_h, o_stride_d, input.data_ptr(), + cu_seqlens.data_ptr(), freqs.data_ptr(), + output.data_ptr());); + return output; +} + +torch::Tensor bwd_thd_cuda(const torch::Tensor &output_grads, + const torch::Tensor &cu_seqlens, + const torch::Tensor &freqs) { + // output_grads sizes: (t, h, d) + // t: cumulative sum of sequence lengths + // h: head num + // d: dim of each head + const int t = output_grads.size(0); + const int h = output_grads.size(1); + const int d = output_grads.size(2); + // output_grads strides + const int stride_t = output_grads.stride(0); + const int stride_h = output_grads.stride(1); + const int stride_d = output_grads.stride(2); + // batch size + const int b = cu_seqlens.size(0) - 1; + // freqs' shape is (max_s, 1, 1, d2) + const int max_s = freqs.size(0); + const int d2 = freqs.size(3); + + auto act_options = output_grads.options().requires_grad(false); + auto input_grads = torch::empty({t, h, d}, act_options); + const int o_stride_t = input_grads.stride(0); + const int o_stride_h = input_grads.stride(1); + const int o_stride_d = input_grads.stride(2); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + output_grads.scalar_type(), 0, "dispatch_fused_rope_thd_backward", + dispatch_fused_rope_thd_backward( + max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, + o_stride_h, o_stride_d, output_grads.data_ptr(), + cu_seqlens.data_ptr(), freqs.data_ptr(), + input_grads.data_ptr());); + return input_grads; +} + +torch::Tensor fwd_2d_cuda(const torch::Tensor &input, + const torch::Tensor &cos_h, + const torch::Tensor &sin_h, + const torch::Tensor &cos_w, + const torch::Tensor &sin_w) { + // input sizes: (b, ih, iw, h, d) + // b: batch size + // ih: image height + // iw: image width + // h: head num + // d: dim of each head + const int b = input.size(0); + const int ih = input.size(1); + const int iw = input.size(2); + const int h = input.size(3); + const int d = input.size(4); + // input strides + const int stride_b = input.stride(0); + const int stride_ih = input.stride(1); + const int stride_iw = input.stride(2); + const int stride_h = input.stride(3); + const int stride_d = input.stride(4); + + // output + auto act_options = input.options().requires_grad(false); + auto output = torch::empty({b, ih * iw, h, d}, act_options); + // output strides + const int o_stride_b = output.stride(0); + const int o_stride_s = output.stride(1); + const int o_stride_h = output.stride(2); + const int o_stride_d = output.stride(3); + + DISPATCH_FUSED_ROPE_TYPES( + input.scalar_type(), cos_h.scalar_type(), + "dispatch_fused_rope_2d_forward", + dispatch_fused_rope_2d_forward( + b, ih, iw, h, d, stride_b, stride_ih, stride_iw, stride_h, stride_d, + o_stride_b, o_stride_s, o_stride_h, o_stride_d, + input.data_ptr(), cos_h.data_ptr(), + sin_h.data_ptr(), cos_w.data_ptr(), + sin_w.data_ptr(), output.data_ptr());); + return output; +} + +torch::Tensor bwd_2d_cuda(const torch::Tensor &output_grads, + const torch::Tensor &cos_h, + const torch::Tensor &sin_h, + const torch::Tensor &cos_w, + const torch::Tensor &sin_w) { + // output_grads sizes: (b, ih, iw, h, d) + // b: batch size + // ih: image height + // iw: image width + // h: head num + // d: dim of each head + const int b = output_grads.size(0); + const int ih = output_grads.size(1); + const int iw = output_grads.size(2); + const int h = output_grads.size(3); + const int d = output_grads.size(4); + // output_grads strides + const int stride_b = output_grads.stride(0); + const int stride_ih = output_grads.stride(1); + const int stride_iw = output_grads.stride(2); + const int stride_h = output_grads.stride(3); + const int stride_d = output_grads.stride(4); + + auto act_options = output_grads.options().requires_grad(false); + auto input_grads = torch::empty({b, ih * iw, h, d}, act_options); + const int o_stride_b = input_grads.stride(0); + const int o_stride_s = input_grads.stride(1); + const int o_stride_h = input_grads.stride(2); + const int o_stride_d = input_grads.stride(3); + + DISPATCH_FUSED_ROPE_TYPES( + output_grads.scalar_type(), cos_h.scalar_type(), + "dispatch_fused_rope_2d_backward", + dispatch_fused_rope_2d_backward( + b, ih, iw, h, d, stride_b, stride_ih, stride_iw, stride_h, stride_d, + o_stride_b, o_stride_s, o_stride_h, o_stride_d, + output_grads.data_ptr(), cos_h.data_ptr(), + sin_h.data_ptr(), cos_w.data_ptr(), + sin_w.data_ptr(), input_grads.data_ptr());); + return input_grads; +} + +} // end namespace fused_rope diff --git a/csrc/megatron/fused_weight_gradient_dense.cpp b/csrc/megatron/fused_weight_gradient_dense.cpp index a14c2b216..8be329081 100644 --- a/csrc/megatron/fused_weight_gradient_dense.cpp +++ b/csrc/megatron/fused_weight_gradient_dense.cpp @@ -16,6 +16,6 @@ void wgrad_gemm_accum_fp16_cuda_stub( ); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32_cuda_stub, "wgrad gemm accum in fp32"); - m.def("wgrad_gemm_accum_fp16", &wgrad_gemm_accum_fp16_cuda_stub, "wgrad gemm accum in fp16"); + m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32_cuda_stub, "wgrad gemm accum in fp32", py::call_guard()); + m.def("wgrad_gemm_accum_fp16", &wgrad_gemm_accum_fp16_cuda_stub, "wgrad gemm accum in fp16", py::call_guard()); } diff --git a/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu b/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu index 60d1e8d1f..24e5f0294 100644 --- a/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu +++ b/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu @@ -8,11 +8,34 @@ #include /* Includes, cuda */ -#include #include - #include "type_shim.h" +/* Includes, blaslt */ +#include + +#ifndef CHECK_CUDA_ERROR +#define CHECK_CUDA_ERROR(error) \ + if(error != cudaSuccess) \ + { \ + fprintf(stderr, \ + "Cuda error: '%s'(%d) at %s:%d\n", \ + cudaGetErrorString(error), \ + error, \ + __FILE__, \ + __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif +#ifndef CHECK_CUBLASLT_ERROR +#define CHECK_CUBLASLT_ERROR(error) \ + if(error != CUBLAS_STATUS_SUCCESS) \ + { \ + fprintf(stderr, "cuBLASLt error(Err=%d) at %s:%d\n", error, __FILE__, __LINE__); \ + fprintf(stderr, "\n"); \ + exit(EXIT_FAILURE); \ + } +#endif // BF16 inputs and BF16 accumulation void gemmex_wrapper_fp16( @@ -22,101 +45,221 @@ void gemmex_wrapper_fp16( int m, int n, int k, - const float* alpha, + int batch_count, + float& alpha, + float& beta, at::BFloat16* A, - int lda, at::BFloat16* B, - int ldb, - const float* beta, at::BFloat16* C, - int ldc) { - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_16BF, - lda, - B, - CUDA_R_16BF, - ldb, - beta, - C, - CUDA_R_16BF, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + at::BFloat16* D, + void* d_workspace, + int64_t max_workspace_size, + cudaStream_t stream) +{ + cublasLtMatrixLayout_t matA, matB, matC, matD; + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matA, CUDA_R_16BF, m, k, m)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matB, CUDA_R_16BF, n, k, n)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matC, CUDA_R_16BF, m, n, m)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matD, CUDA_R_16BF, m, n, m)); + + cublasLtMatmulDesc_t matmul; + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescCreate(&matmul, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb))); + + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + + // Set User Preference attributes + cublasLtMatmulPreference_t pref; + CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceCreate(&pref)); + CHECK_CUBLASLT_ERROR( + cublasLtMatmulPreferenceSetAttribute(pref, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_workspace_size, + sizeof(max_workspace_size))); + + const int request_solutions = 1; + cublasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; + int returnedAlgoCount = 0; + CHECK_CUBLASLT_ERROR(cublasLtMatmulAlgoGetHeuristic(handle, + matmul, + matA, + matB, + matC, + matD, + pref, + request_solutions, + heuristicResult, + &returnedAlgoCount)); + + if(returnedAlgoCount == 0) + { + std::cerr << "No valid solution found!" << std::endl; + return; + } + + uint64_t workspace_size = 0; + for(int i = 0; i < returnedAlgoCount; i++) + workspace_size = max(workspace_size, heuristicResult[i].workspaceSize); + + CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, + matmul, + &alpha, + A, + matA, + B, + matB, + &beta, + C, + matC, + D, + matD, + &heuristicResult[0].algo, + d_workspace, + workspace_size, + stream)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matA)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matB)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matC)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matD)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescDestroy(matmul)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceDestroy(pref)); + return; } // FP16 inputs and FP16 accumulation void gemmex_wrapper_fp16( - cublasHandle_t handle, + cublasLtHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, - const float* alpha, + int batch_count, + float& alpha, + float& beta, at::Half* A, - int lda, at::Half* B, - int ldb, - const float* beta, at::Half* C, - int ldc) { - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_16F, - lda, - B, - CUDA_R_16F, - ldb, - beta, - C, - CUDA_R_16F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + at::Half* D, + void* d_workspace, + int64_t max_workspace_size, + cudaStream_t stream) +{ + cublasLtMatrixLayout_t matA, matB, matC, matD; + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matA, CUDA_R_16F, m, k, m)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matB, CUDA_R_16F, n, k, n)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matC, CUDA_R_16F, m, n, m)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matD, CUDA_R_16F, m, n, m)); + + cublasLtMatmulDesc_t matmul; + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescCreate(&matmul, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb))); + + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + + // Set User Preference attributes + cublasLtMatmulPreference_t pref; + CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceCreate(&pref)); + CHECK_CUBLASLT_ERROR( + cublasLtMatmulPreferenceSetAttribute(pref, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_workspace_size, + sizeof(max_workspace_size))); + + const int request_solutions = 1; + cublasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; + int returnedAlgoCount = 0; + CHECK_CUBLASLT_ERROR(cublasLtMatmulAlgoGetHeuristic(handle, + matmul, + matA, + matB, + matC, + matD, + pref, + request_solutions, + heuristicResult, + &returnedAlgoCount)); + + if(returnedAlgoCount == 0) + { + std::cerr << "No valid solution found!" << std::endl; + return; + } + + uint64_t workspace_size = 0; + for(int i = 0; i < returnedAlgoCount; i++) + workspace_size = max(workspace_size, heuristicResult[i].workspaceSize); + + CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, + matmul, + &alpha, + A, + matA, + B, + matB, + &beta, + C, + matC, + D, + matD, + &heuristicResult[0].algo, + d_workspace, + workspace_size, + stream)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matA)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matB)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matC)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matD)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescDestroy(matmul)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceDestroy(pref)); + return; } template -void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *d_weight, int in_dim, int hidden_dim, int out_dim) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream; - cublasGetStream(handle, &stream); - const float alpha = 1.0; - const float beta = 1.0; +void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *d_weight,int in_dim, int hidden_dim, int out_dim) { + cublasLtHandle_t handle = at::cuda::getCurrentCUDABlasLtHandle(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + float alpha = 1.0; + float beta = 1.0; + const int batch_count = 1; + void* d_workspace = nullptr; + int64_t max_workspace_size = 32*1024*1024; + if (max_workspace_size > 0) { + at::Tensor workspace = at::empty({max_workspace_size}, at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + d_workspace = workspace.data_ptr(); + } gemmex_wrapper_fp16( handle, CUBLAS_OP_N, CUBLAS_OP_T, - in_dim, - out_dim, - hidden_dim, - &alpha, - input, - in_dim, - d_output, - out_dim, - &beta, - d_weight, - in_dim); + in_dim, //m + out_dim, //n + hidden_dim, //k + batch_count, + alpha, + beta, + input, //da + d_output, //db + d_weight, //dc + d_weight, //dd + d_workspace, + max_workspace_size, + stream); } template void wgrad_gemm_accum_fp16_cuda(at::Half *input, at::Half *d_output, at::Half *d_weight, int in_dim, int hidden_dim, int out_dim); -template void wgrad_gemm_accum_fp16_cuda(at::BFloat16 *input, at::BFloat16 *d_output, at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim); +template void wgrad_gemm_accum_fp16_cuda(at::BFloat16 *input, at::BFloat16 *d_output, at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim); void wgrad_gemm_accum_fp16_cuda_stub( at::Tensor &input, @@ -139,9 +282,9 @@ void wgrad_gemm_accum_fp16_cuda_stub( d_output_2d = d_output; } - const int hidden_dim = input_2d.size(0); - const int in_dim = input_2d.size(1); - const int out_dim = d_weight.size(0); + const int hidden_dim = input_2d.size(0); //k + const int in_dim = input_2d.size(1); //m + const int out_dim = d_weight.size(0); //n DISPATCH_HALF_AND_BFLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp16", wgrad_gemm_accum_fp16_cuda( diff --git a/csrc/megatron/fused_weight_gradient_dense_cuda.cu b/csrc/megatron/fused_weight_gradient_dense_cuda.cu index dfaa1345d..f2f762eb5 100644 --- a/csrc/megatron/fused_weight_gradient_dense_cuda.cu +++ b/csrc/megatron/fused_weight_gradient_dense_cuda.cu @@ -8,161 +8,361 @@ #include /* Includes, cuda */ -#include -#include +#include #include "type_shim.h" +/* Includes, blaslt */ +#include + +#ifndef CHECK_CUDA_ERROR +#define CHECK_CUDA_ERROR(error) \ + if(error != cudaSuccess) \ + { \ + fprintf(stderr, \ + "Cuda error: '%s'(%d) at %s:%d\n", \ + cudaGetErrorString(error), \ + error, \ + __FILE__, \ + __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif + +#ifndef CHECK_CUBLASLT_ERROR +#define CHECK_CUBLASLT_ERROR(error) \ + if(error != CUBLAS_STATUS_SUCCESS) \ + { \ + fprintf(stderr, "cudaBLASLt error(Err=%d) at %s:%d\n", error, __FILE__, __LINE__); \ + fprintf(stderr, "\n"); \ + exit(EXIT_FAILURE); \ + } +#endif // BF16 Tensor core wrapper around cublas GEMMEx void gemmex_wrapper( - cublasHandle_t handle, + cublasLtHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, - const float* alpha, + int batch_count, + float& alpha, + float& beta, at::BFloat16* A, - int lda, - at::BFloat16* B, - int ldb, - const float* beta, - float* C, - int ldc) { - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_16BF, - lda, - B, - CUDA_R_16BF, - ldb, - beta, - C, - CUDA_R_32F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + at::BFloat16* B, + float* C, + float* D, + void* d_workspace, + int64_t max_workspace_size, + cudaStream_t stream) { + + cublasLtMatrixLayout_t matA, matB, matC, matD; + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matA, CUDA_R_16BF, m, k, m)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matB, CUDA_R_16BF, n, k, n)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matC, CUDA_R_32F, m, n, m)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matD, CUDA_R_32F, m, n, m)); + + cublasLtMatmulDesc_t matmul; + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescCreate(&matmul, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb))); + + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + + // Set User Preference attributes + cublasLtMatmulPreference_t pref; + CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceCreate(&pref)); + CHECK_CUBLASLT_ERROR( + cublasLtMatmulPreferenceSetAttribute(pref, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_workspace_size, + sizeof(max_workspace_size))); + + const int request_solutions = 1; + cublasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; + int returnedAlgoCount = 0; + CHECK_CUBLASLT_ERROR(cublasLtMatmulAlgoGetHeuristic(handle, + matmul, + matA, + matB, + matC, + matD, + pref, + request_solutions, + heuristicResult, + &returnedAlgoCount)); + + if(returnedAlgoCount == 0) + { + std::cerr << "No valid solution found!" << std::endl; + return; + } + + uint64_t workspace_size = 0; + for(int i = 0; i < returnedAlgoCount; i++) + workspace_size = max(workspace_size, heuristicResult[i].workspaceSize); + + CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, + matmul, + &alpha, + A, + matA, + B, + matB, + &beta, + C, + matC, + D, + matD, + &heuristicResult[0].algo, + d_workspace, + workspace_size, + stream)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matA)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matB)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matC)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matD)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescDestroy(matmul)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceDestroy(pref)); + return; } // FP16 Tensor core wrapper around cublas GEMMEx void gemmex_wrapper( - cublasHandle_t handle, + cublasLtHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, - const float* alpha, + int batch_count, + float& alpha, + float& beta, at::Half* A, - int lda, at::Half* B, - int ldb, - const float* beta, - float* C, - int ldc) { - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_16F, - lda, - B, - CUDA_R_16F, - ldb, - beta, - C, - CUDA_R_32F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + float* C, + float* D, + void* d_workspace, + int64_t max_workspace_size, + cudaStream_t stream) { + cublasLtMatrixLayout_t matA, matB, matC, matD; + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matA, CUDA_R_16F, m, k, m)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matB, CUDA_R_16F, n, k, n)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matC, CUDA_R_32F, m, n, m)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matD, CUDA_R_32F, m, n, m)); + + cublasLtMatmulDesc_t matmul; + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescCreate(&matmul, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb))); + + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + + // Set User Preference attributes + cublasLtMatmulPreference_t pref; + CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceCreate(&pref)); + CHECK_CUBLASLT_ERROR( + cublasLtMatmulPreferenceSetAttribute(pref, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_workspace_size, + sizeof(max_workspace_size))); + + const int request_solutions = 1; + cublasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; + int returnedAlgoCount = 0; + CHECK_CUBLASLT_ERROR(cublasLtMatmulAlgoGetHeuristic(handle, + matmul, + matA, + matB, + matC, + matD, + pref, + request_solutions, + heuristicResult, + &returnedAlgoCount)); + + if(returnedAlgoCount == 0) + { + std::cerr << "No valid solution found!" << std::endl; + return; + } + + uint64_t workspace_size = 0; + for(int i = 0; i < returnedAlgoCount; i++) + workspace_size = max(workspace_size, heuristicResult[i].workspaceSize); + + CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, + matmul, + &alpha, + A, + matA, + B, + matB, + &beta, + C, + matC, + D, + matD, + &heuristicResult[0].algo, + d_workspace, + workspace_size, + stream)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matA)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matB)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matC)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matD)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescDestroy(matmul)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceDestroy(pref)); + return; } + // FP32 wrapper around cublas GEMMEx void gemmex_wrapper( - cublasHandle_t handle, + cublasLtHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, - const float *alpha, - float *A, - int lda, - float *B, - int ldb, - const float *beta, - float *C, - int ldc) { - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_32F, - lda, - B, - CUDA_R_32F, - ldb, - beta, - C, - CUDA_R_32F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + int batch_count, + float& alpha, + float& beta, + float* A, + float* B, + float* C, + float* D, + void* d_workspace, + int64_t max_workspace_size, + cudaStream_t stream) { + cublasLtMatrixLayout_t matA, matB, matC, matD; + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matA, CUDA_R_32F, m, k, m)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matB, CUDA_R_32F, n, k, n)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matC, CUDA_R_32F, m, n, m)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matD, CUDA_R_32F, m, n, m)); + + cublasLtMatmulDesc_t matmul; + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescCreate(&matmul, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb))); + + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescSetAttribute( + matmul, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + + // Set User Preference attributes + cublasLtMatmulPreference_t pref; + CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceCreate(&pref)); + CHECK_CUBLASLT_ERROR( + cublasLtMatmulPreferenceSetAttribute(pref, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_workspace_size, + sizeof(max_workspace_size))); + + const int request_solutions = 1; + cublasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; + int returnedAlgoCount = 0; + CHECK_CUBLASLT_ERROR(cublasLtMatmulAlgoGetHeuristic(handle, + matmul, + matA, + matB, + matC, + matD, + pref, + request_solutions, + heuristicResult, + &returnedAlgoCount)); + + if(returnedAlgoCount == 0) + { + std::cerr << "No valid solution found!" << std::endl; + return; + } + + uint64_t workspace_size = 0; + for(int i = 0; i < returnedAlgoCount; i++) + workspace_size = max(workspace_size, heuristicResult[i].workspaceSize); + + CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, + matmul, + &alpha, + A, + matA, + B, + matB, + &beta, + C, + matC, + D, + matD, + &heuristicResult[0].algo, + d_workspace, + workspace_size, + stream)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matA)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matB)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matC)); + CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutDestroy(matD)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulDescDestroy(matmul)); + CHECK_CUBLASLT_ERROR(cublasLtMatmulPreferenceDestroy(pref)); + return; } template void wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream; - cublasGetStream(handle, &stream); - const float alpha = 1.0; - const float beta = 1.0; + cublasLtHandle_t handle = at::cuda::getCurrentCUDABlasLtHandle(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + float alpha = 1.0; + float beta = 1.0; + const int batch_count = 1; + void* d_workspace = nullptr; + int64_t max_workspace_size = 32*1024*1024; + if(max_workspace_size > 0) { + at::Tensor workspace = at::empty({max_workspace_size}, at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + d_workspace = workspace.data_ptr(); + } gemmex_wrapper( handle, CUBLAS_OP_N, CUBLAS_OP_T, - in_dim, - out_dim, - hidden_dim, - &alpha, - input, - in_dim, - d_output, - out_dim, - &beta, - d_weight, - in_dim); + in_dim, //m + out_dim, //n + hidden_dim, //k + batch_count, + alpha, + beta, + input, //da + d_output, //db + d_weight, //dc + d_weight, //dd + d_workspace, + max_workspace_size, + stream); } template void wgrad_gemm_accum_fp32_cuda(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); -template void wgrad_gemm_accum_fp32_cuda(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); +template void wgrad_gemm_accum_fp32_cuda(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); template void wgrad_gemm_accum_fp32_cuda(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); void wgrad_gemm_accum_fp32_cuda_stub( at::Tensor &input, at::Tensor &d_output, - at::Tensor &d_weight -) { + at::Tensor &d_weight) +{ at::Tensor input_2d, d_output_2d; // input tensor: collapse to the first dim auto in_sizes = input.sizes(); @@ -179,9 +379,9 @@ void wgrad_gemm_accum_fp32_cuda_stub( d_output_2d = d_output; } - const int hidden_dim = input_2d.size(0); - const int in_dim = input_2d.size(1); - const int out_dim = d_weight.size(0); + const int hidden_dim = input_2d.size(0); //k + const int in_dim = input_2d.size(1); //m + const int out_dim = d_weight.size(0); //n DISPATCH_FLOAT_HALF_AND_BFLOAT(input_2d.scalar_type(), 0, "wgrad_gemm_accum_fp32", wgrad_gemm_accum_fp32_cuda( diff --git a/csrc/megatron/generic_scaled_masked_softmax.h b/csrc/megatron/generic_scaled_masked_softmax.h new file mode 100644 index 000000000..79fbc561d --- /dev/null +++ b/csrc/megatron/generic_scaled_masked_softmax.h @@ -0,0 +1,385 @@ +/* coding=utf-8 + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_DOWN_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_down_sync(mask, value, laneMask, width); +#else + return __shfl_down(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ acc_t warp_reduce_new(acc_t val) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) + { + val = r(val, WARP_SHFL_DOWN_NATIVE(val, offset, WARP_SIZE)); + } + return val; +} + + +template +__global__ void scaled_masked_softmax_warp_backward_new( + output_t *gradInput, //[batches, attn_heads, q_len, k_len] + input_t *grad, + const input_t *output, //[batches, attn_heads, q_len, k_len] + acc_t scale, + int element_count) +{ + int threads_per_block = blockDim.x; + //the first element_count*2 elements are used for cache, the last 128 is used for reduction + extern __shared__ acc_t shared_data[]; + input_t *local_data = (input_t *)shared_data; + input_t *output_data = &local_data[element_count]; + // maximum shared cached 128, enough for 4096 elements reduction into 4096/32= 128 elements + acc_t *shared = (acc_t *)(&(local_data[element_count*2])); + + int num_reductions = (element_count - 1) / threads_per_block + 1; + + int offset = blockIdx.x * element_count; + + int local_idx = threadIdx.x; + int lane = threadIdx.x % C10_WARP_SIZE; + int wid = threadIdx.x / C10_WARP_SIZE; + int warps_per_thread_block = threads_per_block / C10_WARP_SIZE; + + // load the data to local data + acc_t val = 0.0; + for (int i = 0; i < num_reductions; i++){ + if (i*threads_per_block + local_idx < element_count){ + val = output[offset + i*threads_per_block + local_idx]; + output_data[i*threads_per_block + local_idx] = val; + local_data[i*threads_per_block + local_idx] = val * grad[offset + i*threads_per_block + local_idx]; + } + __syncthreads(); + } + + // find the sum + for (int i = local_idx; i < (element_count - 1) / C10_WARP_SIZE + 1; i += threads_per_block){ + shared[i] = 0.0; + } + __syncthreads(); + + #pragma unroll + for (int i = 0; i < num_reductions; i++){ + if (i*threads_per_block + local_idx < element_count){ + val = local_data[i*threads_per_block + local_idx]; + } + else{ + val = 0.0; + } + __syncthreads(); + val = warp_reduce_new(val); + if (lane==0 && wid + warps_per_thread_block * i < (element_count - 1) / C10_WARP_SIZE + 1) { + shared[wid + warps_per_thread_block*i] = val; + } + __syncthreads(); + } + + // final shared reduction + + int shared_mem_len = (element_count - 1) / C10_WARP_SIZE + 1; + int num_warps = (shared_mem_len - 1) / C10_WARP_SIZE + 1; + while ( shared_mem_len > 1 ){ + #pragma unroll + for (int i = 0; i < num_reductions; i++){ + if (i*threads_per_block + local_idx < shared_mem_len){ + val = shared[i*threads_per_block + local_idx]; + } + else{ + val = 0.0; + } + __syncthreads(); + val = warp_reduce_new(val); + if (lane==0) { + shared[wid + warps_per_thread_block * i] = val; + } + __syncthreads(); + } + shared_mem_len = num_warps; + num_warps = (shared_mem_len - 1) / C10_WARP_SIZE + 1; + } + val = shared[0]; + #pragma unroll + for (int i = local_idx; i < element_count; i += threads_per_block){ + gradInput[offset + i] = (output_t)(scale*(local_data[i] - output_data[i]*val)); + } +} + +} // end of anonymous namespace + +template +void dispatch_scaled_masked_softmax_backward_new( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + if (key_seq_len == 0) + { + return; + } + else + { + int batch_count = batches * attn_heads * query_seq_len; + // use 128 threads per block to maximize gpu utilization + constexpr int threads_per_block = 128; + int num_warps = (key_seq_len - 1) / at::cuda::warp_size() + 1; + dim3 blocks(batch_count, 1, 1); + dim3 threads(threads_per_block, 1, 1); + + scaled_masked_softmax_warp_backward_new + <<>>(grad_input, grad, output, scale, key_seq_len); + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward_new( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const acc_t scale, + int query_len, // query_len + int attn_heads, + int element_count, // key_len + int pad_batches) // mask batch size +{ + // min threawds_per_block has to be bigger than 128 + int threads_per_block = blockDim.x; + // the first element_count is used for cache, the last 128 is used for reduction + extern __shared__ acc_t local_data[]; + // maximum shared cached 128, enough for 4096 elements reduction into 4096/32= 128 elements + acc_t *shared = &(local_data[element_count]); + // number of 1024 threads reductions + int num_reductions = (element_count - 1) / threads_per_block + 1; + + int offset = blockIdx.x * element_count; + int mask_offset; + int query_id = blockIdx.x % query_len; + if (pad_batches == 1){ + // broadcaste the mask tensor + mask_offset = query_id * element_count; + } + else{ + int mask_batch_id = blockIdx.x / attn_heads / query_len; + mask_offset = (mask_batch_id * query_len + query_id) * element_count; + } + + int local_idx = threadIdx.x; + int lane = threadIdx.x % C10_WARP_SIZE; + int wid = threadIdx.x / C10_WARP_SIZE; + int warps_per_thread_block = threads_per_block / C10_WARP_SIZE; + + // load the data to local data + for (int i = local_idx; i < element_count; i += threads_per_block) + { + // TODO, use the copy vector method + if (mask[mask_offset + i] == 1) + { + local_data[i] = -10000.0; + } + else + { + local_data[i] = src[offset + i] * scale; + } + } + + // first find the max value + for (int i = local_idx; i < (element_count - 1) / C10_WARP_SIZE + 1; i += threads_per_block){ + shared[i] = -10000.0; + } + __syncthreads(); + acc_t val = -10000.0; + #pragma unroll + for (int i = 0; i < num_reductions; i++){ + if (i*threads_per_block + local_idx < element_count){ + val = local_data[i*threads_per_block + local_idx]; + } + else{ + val = -10000.0; + } + __syncthreads(); + val = warp_reduce_new(val); + + if (lane==0 && wid + warps_per_thread_block * i < (element_count - 1) / C10_WARP_SIZE + 1) { + shared[wid + warps_per_thread_block*i] = val; + } + __syncthreads(); + } + + // final shared reduction + int shared_mem_len = (element_count - 1) / C10_WARP_SIZE + 1; + int num_warps = (shared_mem_len - 1) / C10_WARP_SIZE + 1; + while ( shared_mem_len > 1 ){ + #pragma unroll + for (int i = 0; i < num_reductions; i++){ + if (i*threads_per_block + local_idx < shared_mem_len){ + val = shared[i*threads_per_block + local_idx]; + } + else{ + val = -10000.0; + } + __syncthreads(); + val = warp_reduce_new(val); + if (lane==0) { + shared[wid + warps_per_thread_block * i] = val; + } + __syncthreads(); + } + shared_mem_len = num_warps; + num_warps = (shared_mem_len - 1) / C10_WARP_SIZE + 1; + } + + acc_t reduced_val = shared[0]; + if (reduced_val < -10000.0 + 0.1){ + // if everything is masked, pay attention to nothing + #pragma unroll + for (int i = local_idx; i < element_count; i += threads_per_block){ + dst[offset + i] = 0.0; + } + return; + } + + // update the values + #pragma unroll + for (int i = local_idx; i < element_count; i += threads_per_block){ + local_data[i] = std::exp(local_data[i] - reduced_val); + } + + // find the sum + for (int i = local_idx; i < (element_count - 1) / C10_WARP_SIZE + 1; i += threads_per_block){ + shared[i] = 0.0; + } + __syncthreads(); + + #pragma unroll + for (int i = 0; i < num_reductions; i++){ + if (i*threads_per_block + local_idx < element_count){ + val = local_data[i*threads_per_block + local_idx]; + } + else{ + val = 0.0; + } + __syncthreads(); + + val = warp_reduce_new(val); + if (lane==0 && wid + warps_per_thread_block * i < (element_count - 1) / C10_WARP_SIZE + 1) { + shared[wid + warps_per_thread_block*i] = val; + } + __syncthreads(); + } + + shared_mem_len = (element_count - 1) / C10_WARP_SIZE + 1; + num_warps = (shared_mem_len - 1) / C10_WARP_SIZE + 1; + while ( shared_mem_len > 1 ){ + #pragma unroll + for (int i = 0; i < num_reductions; i++){ + if (i*threads_per_block + local_idx < shared_mem_len){ + val = shared[i*threads_per_block + local_idx]; + } + else{ + val = 0.0; + } + __syncthreads(); + val = warp_reduce_new(val); + if (lane==0) { + shared[wid + warps_per_thread_block * i] = val; + } + __syncthreads(); + } + shared_mem_len = num_warps; + num_warps = (shared_mem_len - 1) / C10_WARP_SIZE + 1; + } + + reduced_val = shared[0]; + + #pragma unroll + for (int i = local_idx; i < element_count; i += threads_per_block){ + dst[offset + i] = local_data[i] / reduced_val; + } +} + + +template +void dispatch_scaled_masked_softmax_forward_new( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + int pad_batches) +{ + if (key_seq_len == 0) { + return; + } else { + int batch_count = batches * attn_heads * query_seq_len; + + // use 128 threads per block to maximize gpu utilization + constexpr int threads_per_block = 128; + + // calculate the needed shared memory + int num_warps = (key_seq_len - 1) / at::cuda::warp_size() + 1; + + dim3 blocks(batch_count, 1, 1); + dim3 threads(threads_per_block, 1, 1); + scaled_masked_softmax_warp_forward_new + <<>>(dst, src, mask, scale, query_seq_len, attn_heads, key_seq_len, pad_batches); + } +} diff --git a/csrc/megatron/generic_scaled_masked_softmax_cpu.cpp b/csrc/megatron/generic_scaled_masked_softmax_cpu.cpp new file mode 100644 index 000000000..87a04df91 --- /dev/null +++ b/csrc/megatron/generic_scaled_masked_softmax_cpu.cpp @@ -0,0 +1,83 @@ +/* coding=utf-8 + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace multihead_attn +{ + namespace fused_softmax + { + namespace generic_scaled_masked_softmax + { + + torch::Tensor fwd_cuda( + torch::Tensor const &input, + torch::Tensor const &mask, + float scale_factor); + + torch::Tensor bwd_cuda( + torch::Tensor const &output_grads, + torch::Tensor const &softmax_results, + float scale_factor); + + torch::Tensor fwd( + torch::Tensor const &input, + torch::Tensor const &mask, + float scale_factor) + { + TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + TORCH_CHECK(mask.dim() == 4, "expected 4D tensor"); + + return fwd_cuda(input, mask, scale_factor); + } + + torch::Tensor bwd( + torch::Tensor const &output_grads, + torch::Tensor const &softmax_results, + float scale_factor) + { + + TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); + TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); + + TORCH_CHECK((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); + } + + } // end namespace generic_scaled_masked_softmax + } // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::generic_scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward.", py::call_guard()); + + m.def("backward", + &multihead_attn::fused_softmax::generic_scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward.", py::call_guard()); +} diff --git a/csrc/megatron/generic_scaled_masked_softmax_cuda.cu b/csrc/megatron/generic_scaled_masked_softmax_cuda.cu new file mode 100644 index 000000000..93cd94b30 --- /dev/null +++ b/csrc/megatron/generic_scaled_masked_softmax_cuda.cu @@ -0,0 +1,114 @@ +/* coding=utf-8 + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "generic_scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace generic_scaled_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward_new( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + pad_batches); + ); + return softmax_results; +} + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + auto act_options = output_grads.options(); + torch::Tensor input_grad = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward_new( + reinterpret_cast(static_cast(input_grad.data_ptr())), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + + //backward pass is completely in-place + return input_grad; +} +} +} +} diff --git a/csrc/megatron/scaled_masked_softmax.h b/csrc/megatron/scaled_masked_softmax.h index 78a29cf3b..2674e1f54 100644 --- a/csrc/megatron/scaled_masked_softmax.h +++ b/csrc/megatron/scaled_masked_softmax.h @@ -23,6 +23,7 @@ #include #include #include +#include namespace { @@ -90,6 +91,118 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { } } + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + */ +template +__global__ void scaled_softmax_warp_forward( + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + long int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + long int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + src += thread_offset; + dst += thread_offset; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + + /* * Extended softmax (from native aten pytorch) with following additional features * 1) input scaling @@ -132,9 +245,11 @@ __global__ void scaled_masked_softmax_warp_forward( // there might be multiple batches per warp. compute the index within the batch int local_idx = threadIdx.x; - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + long int thread_offset_src_dst = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + long int thread_offset_mask = pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + src += thread_offset_src_dst; + dst += thread_offset_src_dst; + mask += thread_offset_mask; // load data from global memory acc_t elements[WARP_BATCH][WARP_ITERATIONS]; @@ -182,6 +297,13 @@ __global__ void scaled_masked_softmax_warp_forward( } warp_reduce(max_value); + // compute scale value to account for full mask + acc_t scale_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0; + } + acc_t sum[WARP_BATCH] { 0.0f }; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { @@ -316,7 +438,8 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att int log2_elements = log2_ceil(key_seq_len); const int next_power_of_two = 1 << log2_elements; - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; constexpr int threads_per_block = 128; @@ -326,6 +449,106 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att return batches_per_block; } +template +void dispatch_scaled_softmax_forward( + output_t *dst, + const input_t *src, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 16384 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 12: // 4096 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 13: // 8192 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 14: // 16384 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + default: + break; + } + } +} + template void dispatch_scaled_masked_softmax_forward( output_t *dst, @@ -338,7 +561,7 @@ void dispatch_scaled_masked_softmax_forward( int attn_heads, int pad_batches) { - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 ); if (key_seq_len == 0) { return; } else { @@ -347,7 +570,7 @@ void dispatch_scaled_masked_softmax_forward( int batch_count = batches * attn_heads * query_seq_len; // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -410,6 +633,10 @@ void dispatch_scaled_masked_softmax_forward( scaled_masked_softmax_warp_forward <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; + case 12: // 4096 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; default: break; } @@ -427,7 +654,7 @@ void dispatch_scaled_masked_softmax_backward( int batches, int attn_heads) { - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); + TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 ); if (key_seq_len == 0) { return; } else { @@ -436,7 +663,7 @@ void dispatch_scaled_masked_softmax_backward( int batch_count = batches * attn_heads * query_seq_len; // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -498,6 +725,10 @@ void dispatch_scaled_masked_softmax_backward( scaled_masked_softmax_warp_backward <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; + case 12: // 4096 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; default: break; } diff --git a/csrc/megatron/scaled_masked_softmax.cpp b/csrc/megatron/scaled_masked_softmax_cpu.cpp similarity index 99% rename from csrc/megatron/scaled_masked_softmax.cpp rename to csrc/megatron/scaled_masked_softmax_cpu.cpp index 6e5d35564..dd471a0bb 100644 --- a/csrc/megatron/scaled_masked_softmax.cpp +++ b/csrc/megatron/scaled_masked_softmax_cpu.cpp @@ -14,7 +14,6 @@ * limitations under the License. */ -#include #include #include diff --git a/csrc/megatron/scaled_masked_softmax_cuda.cu b/csrc/megatron/scaled_masked_softmax_cuda.cu index 12a364e44..053d071ed 100644 --- a/csrc/megatron/scaled_masked_softmax_cuda.cu +++ b/csrc/megatron/scaled_masked_softmax_cuda.cu @@ -18,7 +18,7 @@ #include #include #include -#include +//#include #include #include #include "scaled_masked_softmax.h" @@ -44,7 +44,7 @@ torch::Tensor fwd_cuda( const int attn_heads = input.size(1); const int query_seq_len = input.size(2); const int key_seq_len = input.size(3); - TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); + TORCH_INTERNAL_ASSERT(key_seq_len <= 16384); TORCH_INTERNAL_ASSERT(query_seq_len > 1); TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); TORCH_INTERNAL_ASSERT(mask.size(1) == 1); diff --git a/csrc/megatron/scaled_softmax_cpu.cpp b/csrc/megatron/scaled_softmax_cpu.cpp new file mode 100644 index 000000000..c8f6d28cc --- /dev/null +++ b/csrc/megatron/scaled_softmax_cpu.cpp @@ -0,0 +1,75 @@ +/* coding=utf-8 + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd( + torch::Tensor const& input, + float scale_factor) { + TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); + TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); + + TORCH_CHECK((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_softmax::fwd, + "Self Multihead Attention scaled, softmax -- Forward.", py::call_guard()); + m.def("backward", + &multihead_attn::fused_softmax::scaled_softmax::bwd, + "Self Multihead Attention scaled, softmax -- Backward.", py::call_guard()); +} + diff --git a/csrc/megatron/scaled_softmax_cuda.cu b/csrc/megatron/scaled_softmax_cuda.cu new file mode 100644 index 000000000..1bcaff36b --- /dev/null +++ b/csrc/megatron/scaled_softmax_cuda.cu @@ -0,0 +1,104 @@ +/* coding=utf-8 + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 16384); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_softmax_forward", + dispatch_scaled_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + return softmax_results; +} + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} + diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax.h b/csrc/megatron/scaled_upper_triang_masked_softmax.h index 445e0d88c..562350af2 100644 --- a/csrc/megatron/scaled_upper_triang_masked_softmax.h +++ b/csrc/megatron/scaled_upper_triang_masked_softmax.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace { @@ -340,7 +341,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int softmax_elements_stride, int attn_batches) { - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 16384 ); if (softmax_elements == 0) { return; } else { @@ -350,7 +351,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int batch_count = attn_batches * seq_len; // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -415,6 +416,18 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( scaled_upper_triang_masked_softmax_warp_forward <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; + case 12: // 4096 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 13: // 8192 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 14: // 16384 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; default: break; } @@ -431,7 +444,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int softmax_elements_stride, int attn_batches) { - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 16384 ); if (softmax_elements == 0) { return; } else { @@ -441,7 +454,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int batch_count = attn_batches * seq_len; // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -506,6 +519,18 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( scaled_upper_triang_masked_softmax_warp_backward <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; + case 12: // 4096 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 13: // 8192 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 14: // 16384 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; default: break; } diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax.cpp b/csrc/megatron/scaled_upper_triang_masked_softmax_cpu.cpp similarity index 99% rename from csrc/megatron/scaled_upper_triang_masked_softmax.cpp rename to csrc/megatron/scaled_upper_triang_masked_softmax_cpu.cpp index 29754fc59..12cec7f67 100644 --- a/csrc/megatron/scaled_upper_triang_masked_softmax.cpp +++ b/csrc/megatron/scaled_upper_triang_masked_softmax_cpu.cpp @@ -14,7 +14,6 @@ * limitations under the License. */ -#include #include #include diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu b/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu index a90a9344f..7cec7f8e3 100644 --- a/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu +++ b/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu @@ -18,7 +18,7 @@ #include #include #include -#include +//#include #include #include #include "scaled_upper_triang_masked_softmax.h" @@ -35,7 +35,7 @@ torch::Tensor fwd_cuda( // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] const int attn_batches = input.size(0); const int seq_len = input.size(1); - TORCH_INTERNAL_ASSERT(seq_len <= 2048); + TORCH_INTERNAL_ASSERT(seq_len <= 16384); // Output auto act_options = input.options().requires_grad(false); diff --git a/csrc/mlp.cpp b/csrc/mlp.cpp index 830d60628..adcd76e7a 100644 --- a/csrc/mlp.cpp +++ b/csrc/mlp.cpp @@ -66,7 +66,7 @@ std::vector mlp_forward(int use_bias, int activation, std::vector w_ptr; std::vector b_ptr; for (int i = 0; i < num_layers; i++) { @@ -121,7 +121,7 @@ std::vector mlp_backward( outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now } - AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_backward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].scalar_type(), "mlp_backward", [&] { std::vector w_ptr; for (int i = 0; i < num_layers; i++) { w_ptr.push_back(inputs[i + 1].data_ptr()); diff --git a/csrc/mlp_cuda.cu b/csrc/mlp_cuda.cu index f93f1df1a..1b67ad739 100644 --- a/csrc/mlp_cuda.cu +++ b/csrc/mlp_cuda.cu @@ -1,3 +1,5 @@ +// New MLP with denorm mitigation only for backprop + #include #include #include @@ -10,6 +12,9 @@ #include #include +#include +#include "type_shim.h" + #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 // includes cublaslt #include @@ -70,7 +75,8 @@ cublasStatus_t mlp_gemm( int ldb, const float* beta, double* C, - int ldc) { + int ldc, + int flag) { return cublasGemmEx( handle, transa, @@ -89,7 +95,7 @@ cublasStatus_t mlp_gemm( C, CUDA_R_64F, ldc, - CUDA_R_64F, + CUBLAS_COMPUTE_64F, CUBLAS_GEMM_DEFAULT); } @@ -108,7 +114,8 @@ cublasStatus_t mlp_gemm( int ldb, const float* beta, float* C, - int ldc) { + int ldc, + int flag) { return cublasGemmEx( handle, transa, @@ -127,7 +134,7 @@ cublasStatus_t mlp_gemm( C, CUDA_R_32F, ldc, - CUDA_R_32F, + CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT); } @@ -146,7 +153,8 @@ cublasStatus_t mlp_gemm( int ldb, float* beta, at::Half* C, - int ldc) { + int ldc, + int flag) { return cublasGemmEx( handle, transa, @@ -165,7 +173,7 @@ cublasStatus_t mlp_gemm( C, CUDA_R_16F, ldc, - CUDA_R_32F, + CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); } #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 @@ -1317,7 +1325,8 @@ int mlp_fp( ifeat, &zero, output, - ofeat); + ofeat, + int(0)); // Do nothing for forward prop if (cublas_status != CUBLAS_STATUS_SUCCESS) { printf("GEMM fprop failed with %d\n", cublas_status); @@ -1413,7 +1422,17 @@ int mlp_bp( // Get the stream from cublas handle to reuse for biasReLU kernel. cudaStream_t stream; cublasGetStream(handle, &stream); - + int flag = 0; + #ifdef USE_ROCM + #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) + #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) + #if USE_GEMM_FLAGS_FP16_ALT_IMPL + #ifdef BACKWARD_PASS_GUARD + flag = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; + #endif + #endif + #endif + int* y_offsets = (int*)malloc(num_layers * sizeof(int)); get_y_offsets(batch_size, num_layers, output_features, y_offsets); @@ -1532,7 +1551,8 @@ int mlp_bp( yfeat, &zero, dx, - xfeat); + xfeat, + flag); // if (cublas_status != CUBLAS_STATUS_SUCCESS) { printf("GEMM dgrad failed with %d\n", cublas_status); @@ -1555,7 +1575,8 @@ int mlp_bp( yfeat, &zero, dweight, - xfeat); + xfeat, + flag); // if (cublas_status != CUBLAS_STATUS_SUCCESS) { printf("GEMM wgrad failed with %d\n", cublas_status); @@ -1675,4 +1696,3 @@ template size_t get_mlp_bp_workspace_in_bytes( int batch_size, int num_layers, const int* output_features); - diff --git a/csrc/multi_tensor_adagrad.cu b/csrc/multi_tensor_adagrad.cu index 699681bce..7bdb621a0 100644 --- a/csrc/multi_tensor_adagrad.cu +++ b/csrc/multi_tensor_adagrad.cu @@ -90,7 +90,7 @@ void multi_tensor_adagrad_cuda( using namespace at; // Assume single type across p,g,h now - DISPATCH_DOUBLE_FLOAT_AND_HALF( + DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16( tensor_lists[0][0].scalar_type(), 0, "adagrad", multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdagradFunctor(), epsilon, lr, diff --git a/csrc/multi_tensor_adam.cu b/csrc/multi_tensor_adam.cu index 2a648c0dc..012e94458 100644 --- a/csrc/multi_tensor_adam.cu +++ b/csrc/multi_tensor_adam.cu @@ -10,7 +10,7 @@ #include "type_shim.h" #include "multi_tensor_apply.cuh" -#define BLOCK_SIZE 512 +#define BLOCK_SIZE 1024 #define ILP 4 typedef enum{ @@ -20,11 +20,11 @@ typedef enum{ using MATH_T = float; -template +template struct AdamFunctor { __device__ __forceinline__ void operator()( - int chunk_size, + index_t chunk_size, volatile int* noop_gmem, TensorListMetadata<4>& tl, const float beta1, @@ -40,13 +40,13 @@ struct AdamFunctor // if(*noop_gmem == 1) // return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; + index_t tensor_loc = tl.block_to_tensor[blockIdx.x]; // potentially use to pass in list of scalar // int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; + index_t n = tl.sizes[tensor_loc]; T* g = (T*)tl.addresses[0][tensor_loc]; g += chunk_idx*chunk_size; @@ -54,16 +54,16 @@ struct AdamFunctor T* p = (T*)tl.addresses[1][tensor_loc]; p += chunk_idx*chunk_size; - T* m = (T*)tl.addresses[2][tensor_loc]; + FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc]; m += chunk_idx*chunk_size; - T* v = (T*)tl.addresses[3][tensor_loc]; + FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc]; v += chunk_idx*chunk_size; n -= chunk_idx*chunk_size; // see note in multi_tensor_scale_kernel.cu - for(int i_start = 0; + for(index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) { @@ -126,6 +126,236 @@ struct AdamFunctor } }; +template +struct AdamCapturableFunctor +{ + __device__ __forceinline__ void operator()( + int chunk_size, + volatile int* noop_gmem, + TensorListMetadata<4>& tl, + const float beta1, + const float beta2, + const int* step, + const int bias_correction, + const float epsilon, + const float* lr, + adamMode_t mode, + const float decay, + const float* inv_scale) + { + if(*noop_gmem == 1) + return; + + float beta1_correction = 1.0f, beta2_correction = 1.0f; + if (bias_correction == 1) { + beta1_correction = 1 - pow(beta1, *step); + beta2_correction = 1 - pow(beta2, *step); + } + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T* g = (T*)tl.addresses[0][tensor_loc]; + g += chunk_idx*chunk_size; + + T* p = (T*)tl.addresses[1][tensor_loc]; + p += chunk_idx*chunk_size; + + FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc]; + m += chunk_idx*chunk_size; + + FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc]; + v += chunk_idx*chunk_size; + + n -= chunk_idx*chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for(int i_start = 0; + i_start < n && i_start < chunk_size; + i_start += blockDim.x*ILP) + { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for(int ii = 0; ii < ILP; ii++) + { + int i = i_start + threadIdx.x + ii*blockDim.x; + if(i < n && i < chunk_size) + { + r_g[ii] = static_cast(g[i]) * (*inv_scale); + g[i] = static_cast(r_g[ii]); + r_p[ii] = static_cast(p[i]); + r_m[ii] = static_cast(m[i]); + r_v[ii] = static_cast(v[i]); + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for(int ii = 0; ii < ILP; ii++) + { + if(mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (*lr * update); + } + else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (*lr * update); + } + } +#pragma unroll + for(int ii = 0; ii < ILP; ii++) + { + int i = i_start + threadIdx.x + ii*blockDim.x; + if(i < n && i < chunk_size) + { + p[i] = static_cast(r_p[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); + } + } + } + } +}; + +template +struct AdamCapturableMasterFunctor +{ + __device__ __forceinline__ void operator()( + int chunk_size, + volatile int* noop_gmem, + TensorListMetadata<5>& tl, + const float beta1, + const float beta2, + const int* step, + const int bias_correction, + const float epsilon, + const float* lr, + adamMode_t mode, + const float decay, + const float* inv_scale) + { + if(*noop_gmem == 1) + return; + + float beta1_correction = 1.0f, beta2_correction = 1.0f; + if (bias_correction == 1) { + beta1_correction = 1 - pow(beta1, *step); + beta2_correction = 1 - pow(beta2, *step); + } + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T* g = (T*)tl.addresses[0][tensor_loc]; + g += chunk_idx*chunk_size; + + T* p = (T*)tl.addresses[1][tensor_loc]; + p += chunk_idx*chunk_size; + + FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc]; + m += chunk_idx*chunk_size; + + FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc]; + v += chunk_idx*chunk_size; + + FULL_T* p_master = (FULL_T*)tl.addresses[4][tensor_loc]; + p_master += chunk_idx*chunk_size; + + n -= chunk_idx*chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for(int i_start = 0; + i_start < n && i_start < chunk_size; + i_start += blockDim.x*ILP) + { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for(int ii = 0; ii < ILP; ii++) + { + int i = i_start + threadIdx.x + ii*blockDim.x; + if(i < n && i < chunk_size) + { + r_g[ii] = static_cast(g[i]) * (*inv_scale); + g[i] = static_cast(r_g[ii]); + r_p[ii] = static_cast(p_master[i]); + r_m[ii] = static_cast(m[i]); + r_v[ii] = static_cast(v[i]); + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for(int ii = 0; ii < ILP; ii++) + { + if(mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (*lr * update); + } + else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (*lr * update); + } + } +#pragma unroll + for(int ii = 0; ii < ILP; ii++) + { + int i = i_start + threadIdx.x + ii*blockDim.x; + if(i < n && i < chunk_size) + { + p[i] = static_cast(r_p[ii]); + p_master[i] = static_cast(r_p[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); + } + } + } + } +}; + void multi_tensor_adam_cuda( int chunk_size, at::Tensor noop_flag, @@ -148,6 +378,42 @@ void multi_tensor_adam_cuda( bias_correction2 = 1 - std::pow(beta2, step); } + size_t max_size = 0; + bool requires_64bit_indexing = false; + for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { + for (auto it2 = it->begin(); it2 != it->end(); it2++) { + if (it2->numel() > max_size) { + max_size = it2->numel(); + if (max_size >= INT_MAX) { + requires_64bit_indexing = true; + break; + } + } + } + if (requires_64bit_indexing) { + break; + } + } + + if (requires_64bit_indexing) { + // Assume single type across p,g,m1,m2 now + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + tensor_lists[0][0].scalar_type(), 0, "adam", + multi_tensor_apply<4>( + (int64_t) BLOCK_SIZE, + (int64_t) chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + (adamMode_t) mode, + weight_decay); ) + } else { // Assume single type across p,g,m1,m2 now DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( tensor_lists[0][0].scalar_type(), 0, "adam", @@ -156,7 +422,7 @@ void multi_tensor_adam_cuda( chunk_size, noop_flag, tensor_lists, - AdamFunctor(), + AdamFunctor(), beta1, beta2, bias_correction1, @@ -165,7 +431,83 @@ void multi_tensor_adam_cuda( lr, (adamMode_t) mode, weight_decay); ) + } + AT_CUDA_CHECK(cudaGetLastError()); +} + +void multi_tensor_adam_capturable_cuda( + int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor lr, + const float beta1, + const float beta2, + const float epsilon, + at::Tensor step, + const int mode, + const int bias_correction, + const float weight_decay, + at::Tensor inv_scale) +{ + using namespace at; + + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + tensor_lists[0][0].scalar_type(), 0, "adam", + multi_tensor_apply<4>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + AdamCapturableFunctor(), + beta1, + beta2, + step.data_ptr(), + bias_correction, + epsilon, + lr.data_ptr(), + (adamMode_t) mode, + weight_decay, + inv_scale.data_ptr()); ) + + AT_CUDA_CHECK(cudaGetLastError()); + +} + +void multi_tensor_adam_capturable_master_cuda( + int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor lr, + const float beta1, + const float beta2, + const float epsilon, + at::Tensor step, + const int mode, + const int bias_correction, + const float weight_decay, + at::Tensor inv_scale) +{ + using namespace at; + + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + tensor_lists[0][0].scalar_type(), 0, "adam", + multi_tensor_apply<5>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + AdamCapturableMasterFunctor(), + beta1, + beta2, + step.data_ptr(), + bias_correction, + epsilon, + lr.data_ptr(), + (adamMode_t) mode, + weight_decay, + inv_scale.data_ptr()); ) AT_CUDA_CHECK(cudaGetLastError()); } + diff --git a/csrc/multi_tensor_apply.cuh b/csrc/multi_tensor_apply.cuh index 1e7a7d202..3bed46eae 100644 --- a/csrc/multi_tensor_apply.cuh +++ b/csrc/multi_tensor_apply.cuh @@ -13,13 +13,13 @@ // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) -constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; -constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; +constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; +constexpr int depth_to_max_blocks[6] = {2560, 2560, 2560, 2560, 2560, 2560}; template struct TensorListMetadata { void* addresses[n][depth_to_max_tensors[n-1]]; - int sizes[depth_to_max_tensors[n-1]]; + int64_t sizes[depth_to_max_tensors[n-1]]; unsigned char block_to_tensor[depth_to_max_blocks[n-1]]; int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int. int start_tensor_this_launch; @@ -27,8 +27,11 @@ template struct TensorListMetadata template +#ifdef USE_ROCM +__launch_bounds__(1024) +#endif __global__ void multi_tensor_apply_kernel( - int chunk_size, + int64_t chunk_size, volatile int* noop_flag, T tl, U callable, @@ -40,8 +43,8 @@ __global__ void multi_tensor_apply_kernel( template void multi_tensor_apply( - int block_size, - int chunk_size, + int64_t block_size, + int64_t chunk_size, const at::Tensor& noop_flag, const std::vector>& tensor_lists, T callable, @@ -85,9 +88,9 @@ void multi_tensor_apply( tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); loc_tensor_info++; - int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size; + auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size; - for(int chunk = 0; chunk < chunks_this_tensor; chunk++) + for(auto chunk = 0; chunk < chunks_this_tensor; chunk++) { // std::cout << chunks_this_tensor << std::endl; tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; @@ -130,4 +133,4 @@ void multi_tensor_apply( } } } -} +} \ No newline at end of file diff --git a/csrc/multi_tensor_apply_base.cuh b/csrc/multi_tensor_apply_base.cuh new file mode 100644 index 000000000..6a34c406e --- /dev/null +++ b/csrc/multi_tensor_apply_base.cuh @@ -0,0 +1,147 @@ +#include +#include +#include +#include +#include +#include "compat.h" + +#include + +// #include + +// This header is the one-stop shop for all your multi-tensor apply needs. + + +// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) +constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; + +template struct TensorListMetadata +{ + void* addresses[n][depth_to_max_tensors[n-1]]; + int sizes[depth_to_max_tensors[n-1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n-1]]; + int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int. + int start_tensor_this_launch; +}; + + +template +#ifdef USE_ROCM +__launch_bounds__(1024) +#endif +__global__ void multi_tensor_apply_kernel( + int chunk_size, + volatile int* noop_flag, + T tl, + U callable, + ArgTypes... args) +{ + // Hand the chunk information to the user-supplied functor to process however it likes. + callable(chunk_size, noop_flag, tl, args...); +} + +template +void multi_tensor_apply( + int block_size, + int chunk_size, + const at::Tensor& noop_flag, + const std::vector>& tensor_lists, + T callable, + ArgTypes... args) +{ + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + int len0 = tensor_lists[0].size(); + TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); + auto ref_device = tensor_lists[0][0].device(); + TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); + for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices + { + TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); + for(int t = 0; t < tensor_lists[l].size(); t++) + { + // TODO: Print which tensor fails. + bool contiguous_memory = (tensor_lists[l][t].is_sparse()) ? tensor_lists[l][t]._values().is_contiguous() : tensor_lists[l][t].is_contiguous(); +#ifdef VERSION_GE_1_5 + contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d)); +#endif + TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); + TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor"); + TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); + } + } + + int ntensors = tensor_lists[0].size(); + + TensorListMetadata tl; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); + auto stream = at::cuda::getCurrentCUDAStream(); + + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + for(int t = 0; t < ntensors; t++) + { + tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + // skip empty tensors + if (tl.sizes[loc_tensor_info] == 0) { + continue; + } + for(int d = 0; d < depth; d++) { + if (tensor_lists[d][t].is_sparse()) { + at::Tensor dst = at::zeros(tensor_lists[d][t].sizes(), tensor_lists[d][t].options().layout(at::kStrided)); + dst.add_(tensor_lists[d][t]); + tl.addresses[d][loc_tensor_info] = dst.data_ptr(); + } else { + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + } + } + loc_tensor_info++; + + int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size; + + for(int chunk = 0; chunk < chunks_this_tensor; chunk++) + { + // std::cout << chunks_this_tensor << std::endl; + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] && + chunk == chunks_this_tensor - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]); + bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); + if(tensors_full || blocks_full || last_chunk) + { + // using accscalar_t = acc_type; + multi_tensor_apply_kernel<<>>( + chunk_size, + noop_flag.DATA_PTR(), + tl, + callable, + args...); + + AT_CUDA_CHECK(cudaGetLastError()); + + // Reset. The control flow possibilities here make my brain hurt. + loc_block_info = 0; + if(chunk == chunks_this_tensor - 1) + { + // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl; + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } + else + { + // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl; + tl.sizes[0] = tl.sizes[loc_tensor_info-1]; + for(int d = 0; d < depth; d++) + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1]; + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } + } + } +} diff --git a/csrc/multi_tensor_axpby_kernel.cu b/csrc/multi_tensor_axpby_kernel.cu index 021df27d7..87f536bf9 100644 --- a/csrc/multi_tensor_axpby_kernel.cu +++ b/csrc/multi_tensor_axpby_kernel.cu @@ -10,7 +10,7 @@ #include "type_shim.h" #include "multi_tensor_apply.cuh" -#define BLOCK_SIZE 512 +#define BLOCK_SIZE 1024 #define ILP 4 template @@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda( // If build times suffer, think about where to put this dispatch, // and what logic should be moved out of multi_tensor_apply. - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda", - DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda", - DISPATCH_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda", multi_tensor_apply<3>( BLOCK_SIZE, chunk_size, diff --git a/csrc/multi_tensor_l2norm_kernel.cu b/csrc/multi_tensor_l2norm_kernel.cu index 2bf615e94..66189112f 100644 --- a/csrc/multi_tensor_l2norm_kernel.cu +++ b/csrc/multi_tensor_l2norm_kernel.cu @@ -9,7 +9,7 @@ #include #include "type_shim.h" -#include "multi_tensor_apply.cuh" +#include "multi_tensor_apply_base.cuh" #define BLOCK_SIZE 512 #define ILP 4 @@ -195,7 +195,11 @@ struct MaxNormFunctor }; -__global__ void cleanup( +__global__ void +#ifdef USE_ROCM +__launch_bounds__(1024) +#endif +cleanup( float* output, float* output_per_tensor, float* ret, @@ -232,7 +236,11 @@ __global__ void cleanup( } } -__global__ void cleanup_v2( +__global__ void +#ifdef USE_ROCM +__launch_bounds__(1024) +#endif +cleanup_v2( float* output, float* output_per_tensor, float* ret, @@ -323,7 +331,7 @@ std::tuple multi_tensor_l2norm_cuda( ret_per_tensor = at::empty({0}, float_options); } - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", multi_tensor_apply<1>( BLOCK_SIZE, chunk_size, @@ -392,7 +400,7 @@ void multi_tensor_norm_out_cuda( output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options); if (norm_type == 0) { - DISPATCH_FLOAT_AND_HALF( + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16( tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda", multi_tensor_apply<1>( BLOCK_SIZE, @@ -406,7 +414,7 @@ void multi_tensor_norm_out_cuda( max_chunks_per_tensor);) } else { - DISPATCH_FLOAT_AND_HALF( + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16( tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", multi_tensor_apply<1>( BLOCK_SIZE, diff --git a/csrc/multi_tensor_l2norm_kernel_mp.cu b/csrc/multi_tensor_l2norm_kernel_mp.cu index 987f76f51..d023c6d97 100644 --- a/csrc/multi_tensor_l2norm_kernel_mp.cu +++ b/csrc/multi_tensor_l2norm_kernel_mp.cu @@ -9,7 +9,7 @@ #include #include "type_shim.h" -#include "multi_tensor_apply.cuh" +#include "multi_tensor_apply_base.cuh" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/csrc/multi_tensor_l2norm_scale_kernel.cu b/csrc/multi_tensor_l2norm_scale_kernel.cu index f60e96090..f856a5202 100644 --- a/csrc/multi_tensor_l2norm_scale_kernel.cu +++ b/csrc/multi_tensor_l2norm_scale_kernel.cu @@ -9,7 +9,7 @@ #include #include "type_shim.h" -#include "multi_tensor_apply.cuh" +#include "multi_tensor_apply_base.cuh" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/csrc/multi_tensor_lamb.cu b/csrc/multi_tensor_lamb.cu index 3137fcd21..54a05a71c 100644 --- a/csrc/multi_tensor_lamb.cu +++ b/csrc/multi_tensor_lamb.cu @@ -10,7 +10,7 @@ #include "type_shim.h" #include "multi_tensor_apply.cuh" -#define BLOCK_SIZE 512 +#define BLOCK_SIZE 1024 #define ILP 4 template @@ -372,7 +372,7 @@ void multi_tensor_lamb_cuda( // We now in-place modify grad to store update before compute its norm // Generally this is not a issue since people modify grad in step() method all the time // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", multi_tensor_apply<4>( BLOCK_SIZE, chunk_size, @@ -395,7 +395,7 @@ void multi_tensor_lamb_cuda( std::vector> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2); - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", multi_tensor_apply<2>( BLOCK_SIZE, chunk_size, diff --git a/csrc/multi_tensor_lamb_mp.cu b/csrc/multi_tensor_lamb_mp.cu index b52ebd9ce..a213c1816 100644 --- a/csrc/multi_tensor_lamb_mp.cu +++ b/csrc/multi_tensor_lamb_mp.cu @@ -10,7 +10,7 @@ #include "type_shim.h" #include "multi_tensor_apply.cuh" -#define BLOCK_SIZE 512 +#define BLOCK_SIZE 1024 #define ILP 4 template diff --git a/csrc/multi_tensor_lamb_stage_1.cu b/csrc/multi_tensor_lamb_stage_1.cu index 6ad7649bc..1d5e398a3 100644 --- a/csrc/multi_tensor_lamb_stage_1.cu +++ b/csrc/multi_tensor_lamb_stage_1.cu @@ -10,7 +10,7 @@ #include "type_shim.h" #include "multi_tensor_apply.cuh" -#define BLOCK_SIZE 512 +#define BLOCK_SIZE 1024 #define ILP 4 // Step 1 computes the 'update' value of regular Adam optimizer. @@ -128,9 +128,9 @@ void multi_tensor_lamb_stage1_cuda( float next_step = float(step+1); float beta1_correction = 1.0f - std::pow(beta1, next_step); float beta2_correction = 1.0f - std::pow(beta2, next_step); - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", - DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1", - DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1", multi_tensor_apply<5>( BLOCK_SIZE, chunk_size, diff --git a/csrc/multi_tensor_lamb_stage_2.cu b/csrc/multi_tensor_lamb_stage_2.cu index 90970666c..e1999effd 100644 --- a/csrc/multi_tensor_lamb_stage_2.cu +++ b/csrc/multi_tensor_lamb_stage_2.cu @@ -10,7 +10,7 @@ #include "type_shim.h" #include "multi_tensor_apply.cuh" -#define BLOCK_SIZE 512 +#define BLOCK_SIZE 1024 #define ILP 4 using MATH_T = float; @@ -105,8 +105,8 @@ void multi_tensor_lamb_stage2_cuda( using namespace at; - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", - DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2", multi_tensor_apply<2>( BLOCK_SIZE, chunk_size, diff --git a/csrc/multi_tensor_lars.cu b/csrc/multi_tensor_lars.cu new file mode 100644 index 000000000..bc9bbee2f --- /dev/null +++ b/csrc/multi_tensor_lars.cu @@ -0,0 +1,354 @@ +#include +#include +#include +#include + +#include "type_shim.h" +#include "compat.h" +#include "multi_tensor_apply.cuh" + +#include +#include + +#define BLOCK_SIZE 512 +#define ILP 4 + +/** + * Perform fused SGD on multiple buffers + * N: number of tensors + * tl[0] : gradients + * tl[1] : weights + * tl[2] : momentum buffers + * tl[3] : fp16 weights (if appropriate) + * wd : weight_decay (scalar) + * momentum : momentum (scalar) + * dampening : momentum dampening (scalar) + * lr : learning rate (scalar) + * nesterov : enable nesterov (bool) + * first run : necessary for proper momentum handling & init + * wd_after_momentum : apply weight decay _after_ momentum instead of before + **/ + +template +struct LARSFunctor +{ + __device__ __forceinline__ void operator()( + int chunk_size, + volatile int* noop_gmem, + TensorListMetadata& tl, + float *grad_norms, + float *param_norms, + float lr, + float trust_coefficient, + float epsilon, + float weight_decay, + float momentum, + float dampening, + bool nesterov, + bool first_run, + bool wd_after_momentum, + float scale, + const bool is_skipped) { + + // Early exit if we don't need to do anything + if (*noop_gmem) return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + n -= chunk_idx * chunk_size; + //n = min(n, chunk_size); + + T_grad* grad_in = (T_grad*) tl.addresses[0][tensor_loc]; + grad_in += chunk_idx * chunk_size; + + T_weight* weight_in = (T_weight*) tl.addresses[1][tensor_loc]; + weight_in += chunk_idx * chunk_size; + + T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc]; + mom_in += chunk_idx*chunk_size; + + at::Half *model_weights_out = nullptr; + if(N == 4) + { + model_weights_out = (at::Half*)tl.addresses[3][tensor_loc]; + model_weights_out += chunk_idx*chunk_size; + } + + float scaled_lr; + if (is_skipped) { + scaled_lr = lr; + } + else { + int tensor_offset = tl.start_tensor_this_launch + tensor_loc; + float p_norm = param_norms[tensor_offset]; + float trust_ratio = 1.0; + float g_norm = grad_norms[tensor_offset]; + if (g_norm > 0.0f && p_norm > 0.0f) { + trust_ratio = trust_coefficient * p_norm / (g_norm + p_norm * weight_decay + epsilon); + } + scaled_lr = lr * trust_ratio; + } + + // Non-divergent exit condition for the __syncthreads + float incoming_grads[ILP]; + float incoming_weights[ILP]; + float incoming_moms[ILP]; + for(int i_start = 0; + i_start < n && i_start < chunk_size; + i_start += blockDim.x*ILP) + { + #pragma unroll + for(int ii = 0; ii < ILP; ii++) + { + incoming_grads[ii] = 0; + incoming_weights[ii] = 0; + incoming_moms[ii] = 0; + int i = i_start + threadIdx.x + ii*blockDim.x; + if(i < n && i < chunk_size) + { + incoming_grads[ii] = static_cast(grad_in[i]); + incoming_weights[ii] = static_cast(weight_in[i]); + incoming_moms[ii] = static_cast(mom_in[i]); + } + } + + // note for clarification to future michael: + // From a pure memory dependency perspective, there's likely no point unrolling + // the write loop, since writes just fire off once their LDGs arrive. + // Put another way, the STGs are dependent on the LDGs, but not on each other. + // There is still compute ILP benefit from unrolling the loop though. + #pragma unroll + for(int ii = 0; ii < ILP; ii++) + { + int i = i_start + threadIdx.x + ii*blockDim.x; + if(i < n && i < chunk_size) + { + // apply weight decay before momentum + incoming_grads[ii] += weight_decay * incoming_weights[ii]; + incoming_moms[ii] = incoming_moms[ii] * momentum - scaled_lr * incoming_grads[ii]; + + // adjust the weight and write out + if (nesterov) { + incoming_weights[ii] += incoming_moms[ii] * momentum - scaled_lr * incoming_grads[ii]; + } else { + incoming_weights[ii] += incoming_moms[ii]; + } + + weight_in[i] = static_cast(incoming_weights[ii]); + + // if necessary, write out an fp16 copy of the weights + if(N == 4) + model_weights_out[i] = static_cast(weight_in[i]); + + // also write out the new momentum + //if(momentum != 0.f) + mom_in[i] = static_cast(incoming_moms[ii]); + } + } + } + } +}; + +void multi_tensor_lars_cuda( + int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor grad_norms, + at::Tensor param_norms, + float lr, + float trust_coefficient, + float epsilon, + float weight_decay, + float momentum, + float dampening, + bool nesterov, + bool first_run, + bool wd_after_momentum, + float scale, + const bool is_skipped) +{ + auto num_tensors = tensor_lists.size(); + auto grad_type = tensor_lists[0][0].scalar_type(); + auto weight_type = tensor_lists[1][0].scalar_type(); + + if(num_tensors == 4) { + for(int i = 0; i < tensor_lists[3].size(); i++) { + TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, + "Additional output tensors should always be fp16."); + } + } + + TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors"); + + // We have 3 possibilities to handle here, in terms of + // grad_type, param_type, momentum_type, requires_fp16_copy + // 1. fp16, fp16, fp16, No + // 2. fp32, fp32, fp32, No + // 3. fp16, fp32, fp32, Yes + // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case + // 5. bfp16, bfp16, bfp16, No + // 6. bfp16, fp32, fp32, Yes + // It's easier to hardcode these possibilities than to use + // switches etc. to handle the cross-product of cases where + // we don't want the majority of them. + + // Case 1. fp16, fp16, fp16, No + if(grad_type == at::ScalarType::Half && + weight_type == at::ScalarType::Half && + num_tensors == 3) + { + multi_tensor_apply<3>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + LARSFunctor<3, at::Half, at::Half>(), + grad_norms.DATA_PTR(), + param_norms.DATA_PTR(), + lr, + trust_coefficient, + epsilon, + weight_decay, + momentum, + dampening, + nesterov, + first_run, + wd_after_momentum, + scale, + is_skipped); + } + // Case 2. fp32, fp32, fp32, No + else if(grad_type == at::ScalarType::Float && + weight_type == at::ScalarType::Float && + num_tensors == 3) + { + multi_tensor_apply<3>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + LARSFunctor<3, float, float>(), + grad_norms.DATA_PTR(), + param_norms.DATA_PTR(), + lr, + trust_coefficient, + epsilon, + weight_decay, + momentum, + dampening, + nesterov, + first_run, + wd_after_momentum, + scale, + is_skipped); + } + // Case 3. fp16, fp32, fp32, Yes + else if(grad_type == at::ScalarType::Half && + weight_type == at::ScalarType::Float && + num_tensors == 4) + { + multi_tensor_apply<4>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + LARSFunctor<4, at::Half, float>(), + grad_norms.DATA_PTR(), + param_norms.DATA_PTR(), + lr, + trust_coefficient, + epsilon, + weight_decay, + momentum, + dampening, + nesterov, + first_run, + wd_after_momentum, + scale, + is_skipped); + } + // Case 4. fp32, fp32, fp32, Yes + else if(grad_type == at::ScalarType::Float && + weight_type == at::ScalarType::Float && + num_tensors == 4) + { + multi_tensor_apply<4>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + LARSFunctor<4, float, float>(), + grad_norms.DATA_PTR(), + param_norms.DATA_PTR(), + lr, + trust_coefficient, + epsilon, + weight_decay, + momentum, + dampening, + nesterov, + first_run, + wd_after_momentum, + scale, + is_skipped); + } + // Case 5. bfp16, bfp16, bfp16, No + else if(grad_type == at::ScalarType::BFloat16 && + weight_type == at::ScalarType::BFloat16 && + num_tensors == 3) + { + multi_tensor_apply<3>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + LARSFunctor<3, at::BFloat16, at::BFloat16>(), + grad_norms.DATA_PTR(), + param_norms.DATA_PTR(), + lr, + trust_coefficient, + epsilon, + weight_decay, + momentum, + dampening, + nesterov, + first_run, + wd_after_momentum, + scale, + is_skipped); + } + // Case 6. bfp16, fp32, fp32, Yes + else if(grad_type == at::ScalarType::BFloat16 && + weight_type == at::ScalarType::Float && + num_tensors == 4) + { + multi_tensor_apply<4>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + LARSFunctor<4, at::BFloat16, float>(), + grad_norms.DATA_PTR(), + param_norms.DATA_PTR(), + lr, + trust_coefficient, + epsilon, + weight_decay, + momentum, + dampening, + nesterov, + first_run, + wd_after_momentum, + scale, + is_skipped); + } + else + { + AT_ERROR("multi_tensor_lars only supports some combinations of gradient & weight types. Given: ", + "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/csrc/multi_tensor_novograd.cu b/csrc/multi_tensor_novograd.cu index 2decc06b8..4da815d72 100644 --- a/csrc/multi_tensor_novograd.cu +++ b/csrc/multi_tensor_novograd.cu @@ -10,7 +10,7 @@ #include "type_shim.h" #include "multi_tensor_apply.cuh" -#define BLOCK_SIZE 512 +#define BLOCK_SIZE 1024 #define ILP 4 typedef enum{ @@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda( multi_tensor_norm_out_cuda(chunk_size, noop_flag, grad_list, grad_norms, beta2, (1.0f - beta2), norm_type); // Assume single type across p,g,m1,m2 now - DISPATCH_DOUBLE_FLOAT_AND_HALF( + DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16( tensor_lists[0][0].scalar_type(), 0, "novograd", multi_tensor_apply<3>( BLOCK_SIZE, diff --git a/csrc/multi_tensor_scale_kernel.cu b/csrc/multi_tensor_scale_kernel.cu index 629ee9420..5386f4df3 100644 --- a/csrc/multi_tensor_scale_kernel.cu +++ b/csrc/multi_tensor_scale_kernel.cu @@ -12,7 +12,7 @@ #include "type_shim.h" #include "multi_tensor_apply.cuh" -#define BLOCK_SIZE 512 +#define BLOCK_SIZE 1024 #define ILP 4 template @@ -121,8 +121,8 @@ void multi_tensor_scale_cuda( // If build times suffer, think about where to put this dispatch, // and what logic should be moved out of multi_tensor_apply. - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda", - DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda", + DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda", multi_tensor_apply<2>( BLOCK_SIZE, chunk_size, diff --git a/csrc/multi_tensor_sgd_kernel.cu b/csrc/multi_tensor_sgd_kernel.cu index 42a7406be..5d1f685ab 100644 --- a/csrc/multi_tensor_sgd_kernel.cu +++ b/csrc/multi_tensor_sgd_kernel.cu @@ -8,7 +8,7 @@ #include #include -#define BLOCK_SIZE 512 +#define BLOCK_SIZE 1024 #define ILP 4 /** @@ -168,6 +168,8 @@ void multi_tensor_sgd_cuda( // 2. fp32, fp32, fp32, No // 3. fp16, fp32, fp32, Yes // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case + // 5. bfp16, bfp16, bfp16, No + // 6. bfp16, fp32, fp32, Yes // It's easier to hardcode these possibilities than to use // switches etc. to handle the cross-product of cases where // we don't want the majority of them. @@ -270,6 +272,46 @@ void multi_tensor_sgd_cuda( wd_after_momentum, scale); } + // Case 5. bfp16, bfp16, bfp16, No + else if(grad_type == at::ScalarType::BFloat16 && + weight_type == at::ScalarType::BFloat16 && + num_tensors == 3) + { + multi_tensor_apply<3>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + SGDFunctor<3, at::BFloat16, at::BFloat16>(), + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale); + } + // Case 6. bfp16, fp32, fp32, Yes + else if(grad_type == at::ScalarType::BFloat16 && + weight_type == at::ScalarType::Float && + num_tensors == 4) + { + multi_tensor_apply<4>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + SGDFunctor<4, at::BFloat16, float>(), + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale); + } else { AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ", diff --git a/csrc/static_switch.h b/csrc/static_switch.h new file mode 100644 index 000000000..74bcf325d --- /dev/null +++ b/csrc/static_switch.h @@ -0,0 +1,25 @@ +// From +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() \ No newline at end of file diff --git a/csrc/type_shim.h b/csrc/type_shim.h index a805941c7..17f48eabc 100644 --- a/csrc/type_shim.h +++ b/csrc/type_shim.h @@ -1,6 +1,9 @@ #include #include "compat.h" + +#ifndef TYPE_SHIM +#define TYPE_SHIM // Forward/backward compatiblity hack around // https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 // pending more future-proof guidance from upstream. @@ -14,6 +17,43 @@ // //operator at::ScalarType(){ return payload.; }; // }; + +// hipify local to this source file until torch-hipify includes this mapping +#ifndef HIPBLAS_V2 +#define CUBLAS_COMPUTE_16F HIPBLAS_C_16F +#else +#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F +#endif + +// until we use hiblas v2 +// however hipblas v1 is still using its custom type +#ifndef HIPBLAS_V2 +#define HIPBLAS_COMPUTE_64F HIPBLAS_R_64F +#define HIPBLAS_COMPUTE_32F HIPBLAS_R_32F + +#define HIPBLASLT_COMPUTE_F64 HIPBLAS_R_64F +#define HIPBLASLT_COMPUTE_F32 HIPBLAS_R_32F + +#define HIP_R_16F HIPBLAS_R_16F +#define HIP_R_32F HIPBLAS_R_32F +#define HIP_R_64F HIPBLAS_R_64F +#define HIP_C_16F HIPBLAS_C_16F +#define HIP_C_32F HIPBLAS_C_32F +#define HIP_C_64F HIPBLAS_C_64F +#define HIP_R_8I HIPBLAS_R_8I +#define HIP_R_8U HIPBLAS_R_8U +#define HIP_R_32I HIPBLAS_R_32I +#define HIP_R_32U HIPBLAS_R_32U +#define HIP_C_8I HIPBLAS_C_8I +#define HIP_C_8U HIPBLAS_C_8U +#define HIP_C_32I HIPBLAS_C_32I +#define HIP_C_32U HIPBLAS_C_32U +#define HIP_R_16BF HIPBLAS_R_16B +#define HIP_C_16BF HIPBLAS_C_16B +#endif + + + #define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ switch(TYPE) \ { \ @@ -163,6 +203,66 @@ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } +// TODO: We might have come up with an optimal set of dispatch macros by +// changing the signature to have an integer suffix of number of types +// to dispatch for as defined in upstream (e.g AT_DISPATCH_FLOATING_TYPES_AND2) +// Refactor once all the extension ops are enabled. +#define DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + +#define DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ switch(TYPE) \ @@ -327,15 +427,16 @@ __device__ __forceinline__ T reduce_block_into_lanes { int tid = threadIdx.x + threadIdx.y*blockDim.x; int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32. + auto double_warp_size = warpSize * 2; - if(blockSize >= 64) + if(blockSize >= double_warp_size) { x[tid] = val; __syncthreads(); } #pragma unroll - for(int i = (blockSize >> 1); i >= 64; i >>= 1) + for(int i = (blockSize >> 1); i >= double_warp_size; i >>= 1) { if(tid < i) x[tid] = x[tid] + x[tid+i]; @@ -344,17 +445,22 @@ __device__ __forceinline__ T reduce_block_into_lanes T final; - if(tid < 32) + if(tid < warpSize) { - if(blockSize >= 64) - final = x[tid] + x[tid+32]; + if(blockSize >= double_warp_size) + final = x[tid] + x[tid + warpSize]; else final = val; // __SYNCWARP(); #pragma unroll - for(int i = 16; i >= lanes; i >>= 1) + for(int i = warpSize / 2; i >= lanes; i >>= 1) { +#ifdef USE_ROCM + final = final + __shfl_down(final, i); +#else final = final + __shfl_down_sync(0xffffffff, final, i); +#endif + } } if(share_result) @@ -377,15 +483,16 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op { int tid = threadIdx.x + threadIdx.y*blockDim.x; int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32. + auto double_warp_size = warpSize * 2; - if(blockSize >= 64) + if(blockSize >= double_warp_size) { x[tid] = val; __syncthreads(); } #pragma unroll - for(int i = (blockSize >> 1); i >= 64; i >>= 1) + for(int i = (blockSize >> 1); i >= double_warp_size; i >>= 1) { if(tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i])); @@ -394,17 +501,22 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op T final; - if(tid < 32) + if(tid < warpSize) { - if(blockSize >= 64) - final = fmaxf(fabsf(x[tid]), fabsf(x[tid+32])); + if(blockSize >= double_warp_size) + final = fmaxf(fabsf(x[tid]), fabsf(x[tid + warpSize])); else final = val; // __SYNCWARP(); #pragma unroll - for(int i = 16; i >= lanes; i >>= 1) + for(int i = 16; i >= lanes; i >>= 1) { +#ifdef USE_ROCM + final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i))); +#else final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); +#endif + } } if(share_result) @@ -417,3 +529,5 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op return final; } + +#endif // TYPE_SHIM diff --git a/csrc/welford.cu b/csrc/welford.cu index 374a3845b..fabee1999 100644 --- a/csrc/welford.cu +++ b/csrc/welford.cu @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -11,6 +12,11 @@ #include "type_shim.h" #include "compat.h" +#if defined USE_ROCM +#define SHFL_DOWN(mask,val,i) __shfl_down(val, i) +#else +#define SHFL_DOWN __shfl_down_sync +#endif __device__ __forceinline__ int lastpow2(int n) { @@ -39,15 +45,12 @@ __host__ __forceinline__ int h_last_pow2(unsigned int n) { return n - (n >> 1); } - -#define WARP_SIZE 32 - template __device__ __forceinline__ T warp_reduce_sum(T val) { #pragma unroll - for(int i = WARP_SIZE/2; i > 0; i >>= 1) - val = val + __shfl_down_sync(0xffffffff, val, i); + for(int i = C10_WARP_SIZE/2; i > 0; i >>= 1) + val = val + SHFL_DOWN(0xffffffff, val, i); return val; } @@ -56,25 +59,26 @@ __device__ __forceinline__ T reduce_block(T *x, T val) { int tid = threadIdx.y*blockDim.x + threadIdx.x; int blockSize = blockDim.x * blockDim.y; + int lane = tid % C10_WARP_SIZE; + int wid = tid / C10_WARP_SIZE; - if (blockSize > 32) { + if (blockSize > C10_WARP_SIZE) { val = warp_reduce_sum(val); - if (tid % WARP_SIZE == 0) - x[tid/WARP_SIZE] = val; + if (lane == 0) + x[wid] = val; __syncthreads(); - val = (tid < blockSize / WARP_SIZE? x[tid%WARP_SIZE] : T(0)); + val = (tid < blockSize / C10_WARP_SIZE? x[lane] : T(0)); } - if(tid/WARP_SIZE==0) val = warp_reduce_sum(val); + if(wid==0) val = warp_reduce_sum(val); return val; } #define ELEMENTS_PER_ITER 4 // enables concurrency within each thread to hide latency #define ELEMENTS_PER_THREAD 16 -#define OPTIMAL_TILE_W 32 #define MAX_H_BLOCK 128 #define MAX_BLOCK_SIZE 512 @@ -88,7 +92,7 @@ __host__ void flexible_launch_configs( dim3 &block, dim3 &grid, const bool coop_flag = false) { - int block_x = std::min(h_last_pow2(stride), OPTIMAL_TILE_W); + int block_x = std::min(h_last_pow2(stride), at::cuda::warp_size()); int block_y = std::min(h_last_pow2(div_ru(reduction , ELEMENTS_PER_THREAD)), MAX_BLOCK_SIZE / block_x); if (block_x * block_y != MAX_BLOCK_SIZE) { @@ -128,10 +132,10 @@ template __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num) { #pragma unroll - for(int i = WARP_SIZE/2; i > 0; i >>= 1) { - auto num_new = __shfl_down_sync(0xffffffff, num, i); - auto mean_new = __shfl_down_sync(0xffffffff, mean, i); - auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i); + for(int i = C10_WARP_SIZE/2; i > 0; i >>= 1) { + auto num_new = SHFL_DOWN(0xffffffff, num, i); + auto mean_new = SHFL_DOWN(0xffffffff, mean, i); + auto m2n_new = SHFL_DOWN(0xffffffff, m2n, i); welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new); } } @@ -146,10 +150,10 @@ __device__ void welford_reduce_mean_m2n( int block_size, int thread_id) { - int lane = thread_id % WARP_SIZE; - int wid = thread_id / WARP_SIZE; + int lane = thread_id % C10_WARP_SIZE; + int wid = thread_id / C10_WARP_SIZE; - if (block_size > 32) { + if (block_size > C10_WARP_SIZE) { warp_reduce_mean_m2n(mean, m2n, num); if (lane == 0) { x[wid*2] = mean; @@ -159,9 +163,9 @@ __device__ void welford_reduce_mean_m2n( __syncthreads(); if (wid == 0) { - mean = (thread_id < block_size / WARP_SIZE)? x[lane*2] : T(0); - m2n = (thread_id < block_size / WARP_SIZE)? x[lane*2+1] : T(0); - num = (thread_id < block_size / WARP_SIZE)? count[lane] : int(0); + mean = (thread_id < block_size / C10_WARP_SIZE)? x[lane*2] : T(0); + m2n = (thread_id < block_size / C10_WARP_SIZE)? x[lane*2+1] : T(0); + num = (thread_id < block_size / C10_WARP_SIZE)? count[lane] : int(0); } } @@ -256,6 +260,9 @@ __device__ __forceinline__ void merge_block_vertical(T& sum_dy, // welford kernel calculating mean/biased_variance/unbiased_variance template +#ifdef USE_ROCM +__launch_bounds__(MAX_BLOCK_SIZE) +#endif __global__ void welford_kernel( const scalar_t* __restrict__ input, outscalar_t* __restrict__ out_mean, @@ -282,8 +289,8 @@ __global__ void welford_kernel( } } - static __shared__ int s_mem[160]; - accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32]; + static __shared__ int s_mem[C10_WARP_SIZE]; + static __shared__ accscalar_t s_mem_ac[C10_WARP_SIZE*2]; welford_reduce_mean_m2n(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id); @@ -295,6 +302,9 @@ __global__ void welford_kernel( // elementwise BN kernel template +#ifdef USE_ROCM +__launch_bounds__(MAX_BLOCK_SIZE) +#endif __global__ void batchnorm_forward_kernel( const scalar_t* __restrict__ input, const accscalar_t* __restrict__ mean, @@ -322,6 +332,9 @@ __global__ void batchnorm_forward_kernel( // Breaking the grad_input to two step to support sync BN, which requires all // reduce of the intermediate results across processes. template +#ifdef USE_ROCM +__launch_bounds__(MAX_BLOCK_SIZE) +#endif __global__ void reduce_bn_kernel( const scalar_t* __restrict__ input, const scalar_t* __restrict__ grad_output, @@ -334,7 +347,7 @@ __global__ void reduce_bn_kernel( const int bs, const int fs, const int ss) { - static __shared__ int s_mem[64]; + static __shared__ int s_mem[C10_WARP_SIZE]; //int total_item_num = bs * ss; int thread_id = threadIdx.y*blockDim.x + threadIdx.x; @@ -386,6 +399,9 @@ __global__ void reduce_bn_kernel( // elementwise backward BN kernel template +#ifdef USE_ROCM +__launch_bounds__(MAX_BLOCK_SIZE) +#endif __global__ void batchnorm_backward_kernel( const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ input, @@ -425,6 +441,9 @@ template typename accscalar_t, typename outscalar_t, int PARALLEL_LOADS> +#ifdef USE_ROCM +__launch_bounds__(MAX_BLOCK_SIZE) +#endif __global__ void welford_kernel_c_last( const scalar_t* __restrict__ input, @@ -566,6 +585,9 @@ welford_kernel_c_last( // parallel welford kernel to further reduce mean / biased_var // into mean / unbiased_var / inv_std across multiple processes. template +#ifdef USE_ROCM +__launch_bounds__(MAX_BLOCK_SIZE) +#endif __global__ void welford_kernel_parallel( const scalar_t* __restrict__ mean, const scalar_t* __restrict__ var_biased, @@ -599,6 +621,9 @@ template < typename accscalar_t, typename layerscalar_t, int PARALLEL_LOADS> +#ifdef USE_ROCM +__launch_bounds__(MAX_BLOCK_SIZE) +#endif __global__ void batchnorm_forward_c_last_kernel( const scalar_t* __restrict__ input, const scalar_t* __restrict__ z, @@ -649,6 +674,9 @@ template < typename accscalar_t, typename layerscalar_t, int PARALLEL_LOADS> +#ifdef USE_ROCM +__launch_bounds__(MAX_BLOCK_SIZE) +#endif __global__ void relu_backward_c_last_kernel( const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ input, @@ -699,6 +727,9 @@ template typename accscalar_t, typename layerscalar_t, int PARALLEL_LOADS> +#ifdef USE_ROCM +__launch_bounds__(MAX_BLOCK_SIZE) +#endif __global__ void reduce_bn_c_last_kernel( const scalar_t* __restrict__ input, const scalar_t* __restrict__ grad_output, @@ -852,6 +883,9 @@ template < typename accscalar_t, typename layerscalar_t, int PARALLEL_LOADS> +#ifdef USE_ROCM +__launch_bounds__(MAX_BLOCK_SIZE) +#endif __global__ void batchnorm_backward_c_last_kernel( const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ input, @@ -912,7 +946,7 @@ std::vector welford_mean_var_CUDA(const at::Tensor input) { at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type)); at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type)); - int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / 32)); + int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / at::cuda::warp_size())); int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size))); const dim3 block(block_x, block_y); const dim3 grid(feature_size); @@ -948,7 +982,7 @@ at::Tensor batchnorm_forward_CUDA( auto space_size = get_tensor_spatial_size(input); - int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); + int block_x = max(at::cuda::warp_size(), min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4)); const dim3 block(block_x, block_y); int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x)); @@ -1021,7 +1055,7 @@ std::vector reduce_bn_CUDA( auto space_size = get_tensor_spatial_size(input); - int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ 32)); + int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ at::cuda::warp_size())); int block_x = max(1, min(MAX_BLOCK_SIZE/ block_y, h_last_pow2(space_size))); const dim3 block(block_x, block_y); const dim3 grid(feature_size); @@ -1088,7 +1122,7 @@ at::Tensor batchnorm_backward_CUDA( auto space_size = get_tensor_spatial_size(input); - int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); + int block_x = max(at::cuda::warp_size(), min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4)); const dim3 block(block_x, block_y); int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x)); diff --git a/op_builder/__init__.py b/op_builder/__init__.py new file mode 100644 index 000000000..726ec6f4d --- /dev/null +++ b/op_builder/__init__.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Portions of this code were adapted from DeepSpeed: +# https://github.com/microsoft/DeepSpeed +# Modified for ROCm Apex + +import sys +import os +import pkgutil +import importlib + +from .builder import get_default_compute_capabilities, OpBuilder + +__apex__ = True + +# List of all available op builders from apex op_builder +try: + import apex.op_builder # noqa: F401 # type: ignore + op_builder_dir = "apex.op_builder" +except ImportError: + op_builder_dir = "op_builder" + +__op_builders__ = [] + +this_module = sys.modules[__name__] + + +def builder_closure(member_name): + if op_builder_dir == "op_builder": + # during installation time cannot get builder due to torch not installed, + # return closure instead + def _builder(): + from apex.op_builder.all_ops import BuilderUtils + builder = BuilderUtils().create_op_builder(member_name) + return builder + + return _builder + else: + # during runtime, return op builder class directly + from apex.op_builder.all_ops import BuilderUtils + builder = BuilderUtils().get_op_builder(member_name) + return builder + +# this is for the import statement such as 'from apex.op_builder import FusedLayerNormBuilder' to work +# reflect builder names and add builder closure, such as 'apex.op_builder.FusedLayerNormBuilder()' creates op builder +for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(this_module.__file__)]): + if module_name != 'all_ops' and module_name != 'builder': + module = importlib.import_module(f".{module_name}", package=op_builder_dir) + for member_name in module.__dir__(): + if member_name.endswith('Builder') and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "CPUOpBuilder": + # assign builder name to variable with same name + # the following is equivalent to i.e. TransformerBuilder = "TransformerBuilder" + this_module.__dict__[member_name] = builder_closure(member_name) \ No newline at end of file diff --git a/op_builder/all_ops.py b/op_builder/all_ops.py new file mode 100644 index 000000000..e18dbdd71 --- /dev/null +++ b/op_builder/all_ops.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Portions of this code were adapted from DeepSpeed: +# https://github.com/microsoft/DeepSpeed +# Modified for ROCm Apex + +import os +import pkgutil +import importlib + +class BuilderUtils: + def op_builder_dir(self): + try: + # is op_builder from apex or a 3p version? this should only succeed if it's apex + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __apex__ + return "op_builder" + except ImportError: + return "apex.op_builder" + + # dict that holds class name <--> class type mapping i.e. + # 'AsyncIOBuilder': + # this dict will be filled at init stage + class_dict = None + + def _lazy_init_class_dict(self): + if self.class_dict is not None: + return + else: + self.class_dict = {} + # begin initialize for create_op_builder() + # put all valid class name <--> class type mapping into class_dict + op_builder_dir = self.op_builder_dir() + op_builder_module = importlib.import_module(op_builder_dir) + op_builder_absolute_path = os.path.dirname(op_builder_module.__file__) + for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]): + # avoid self references, + # skip sub_directories which contains ops for other backend(cpu, npu, etc.). + if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir( + os.path.join(op_builder_absolute_path, module_name)): + module = importlib.import_module("{}.{}".format(op_builder_dir, module_name)) + for member_name in module.__dir__(): + if member_name.endswith( + 'Builder' + ) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "CPUOpBuilder": # avoid abstract classes + if not member_name in self.class_dict: + self.class_dict[member_name] = getattr(module, member_name) + # end initialize for create_op_builder() + + # create an instance of op builder and return, name specified by class_name + def create_op_builder(self, class_name): + self._lazy_init_class_dict() + if class_name in self.class_dict: + return self.class_dict[class_name]() + else: + return None + + # return an op builder class, name specified by class_name + def get_op_builder(self, class_name): + self._lazy_init_class_dict() + if class_name in self.class_dict: + return self.class_dict[class_name] + else: + return None + +# List of all available ops + +# append all builder names into __op_builders__ +builder_utils = BuilderUtils() +op_builder_dir = builder_utils.op_builder_dir() +op_builder_module = importlib.import_module(op_builder_dir) +__op_builders__ = [] + +for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]): + # avoid self references + if module_name != 'all_ops' and module_name != 'builder': + module = importlib.import_module("{}.{}".format(op_builder_dir, module_name)) + for member_name in module.__dir__(): + if member_name.endswith('Builder'): + # append builder to __op_builders__ list + builder = builder_utils.create_op_builder(member_name) + __op_builders__.append(builder) + +ALL_OPS = {op.name: op for op in __op_builders__ if op is not None} \ No newline at end of file diff --git a/op_builder/amp_C.py b/op_builder/amp_C.py new file mode 100644 index 000000000..41f029fcb --- /dev/null +++ b/op_builder/amp_C.py @@ -0,0 +1,45 @@ +from .builder import CUDAOpBuilder + +import sys + + +class AmpCBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_AMP_C' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "amp_C" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['csrc/amp_C_frontend.cpp', + 'csrc/multi_tensor_sgd_kernel.cu', + 'csrc/multi_tensor_scale_kernel.cu', + 'csrc/multi_tensor_axpby_kernel.cu', + 'csrc/multi_tensor_l2norm_kernel.cu', + 'csrc/multi_tensor_l2norm_kernel_mp.cu', + 'csrc/multi_tensor_l2norm_scale_kernel.cu', + 'csrc/multi_tensor_lamb_stage_1.cu', + 'csrc/multi_tensor_lamb_stage_2.cu', + 'csrc/multi_tensor_adam.cu', + 'csrc/multi_tensor_adagrad.cu', + 'csrc/multi_tensor_novograd.cu', + 'csrc/multi_tensor_lars.cu', + 'csrc/multi_tensor_lamb.cu', + 'csrc/multi_tensor_lamb_mp.cu'] + + def include_paths(self): + return ['csrc/'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags += ['-lineinfo', '--use_fast_math'] + return nvcc_flags \ No newline at end of file diff --git a/op_builder/apex_C.py b/op_builder/apex_C.py new file mode 100644 index 000000000..b02526e77 --- /dev/null +++ b/op_builder/apex_C.py @@ -0,0 +1,25 @@ +from .builder import CPUOpBuilder + +import sys + + +class ApexCBuilder(CPUOpBuilder): + BUILD_VAR = 'APEX_BUILD_APEX_C' + INCLUDE_FLAG = "APEX_BUILD_CPP_OPS" + NAME = "apex_C" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ["csrc/flatten_unflatten.cpp"] + + def include_paths(self): + return ['csrc/' ] + + def libraries_args(self): + args = super().libraries_args() + return args \ No newline at end of file diff --git a/op_builder/bnp.py b/op_builder/bnp.py new file mode 100644 index 000000000..f7fbe1abd --- /dev/null +++ b/op_builder/bnp.py @@ -0,0 +1,33 @@ +from .builder import CUDAOpBuilder + +import sys + + +class BnpBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_BNP' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "bnp" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['contrib/csrc/groupbn/batch_norm.cu', + 'contrib/csrc/groupbn/ipc.cu', + 'contrib/csrc/groupbn/interface.cpp', + 'contrib/csrc/groupbn/batch_norm_add_relu.cu'] + + def include_paths(self): + return ['contrib/csrc', 'csrc'] + + def cxx_args(self): + return self.version_dependent_macros() + + def nvcc_args(self): + return ['-DCUDA_HAS_FP16=1', + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__'] + self.version_dependent_macros() \ No newline at end of file diff --git a/op_builder/builder.py b/op_builder/builder.py new file mode 100644 index 000000000..20553bd58 --- /dev/null +++ b/op_builder/builder.py @@ -0,0 +1,929 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Portions of this code were adapted from DeepSpeed: +# https://github.com/microsoft/DeepSpeed +# Modified for ROCm Apex + +import os +import re +import sys +import time +import importlib +from pathlib import Path +import subprocess +import shlex +import shutil +import tempfile +import distutils.ccompiler +import distutils.log +import distutils.sysconfig +from distutils.errors import CompileError, LinkError +from abc import ABC, abstractmethod +from typing import List + +YELLOW = '\033[93m' +END = '\033[0m' +WARNING = f"{YELLOW} [WARNING] {END}" + +DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions" +DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0" + +try: + import torch +except ImportError: + print(f"{WARNING} unable to import torch, please install it if you want to pre-compile any apex ops.") +else: + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + + +class MissingCUDAException(Exception): + pass + + +class CUDAMismatchException(Exception): + pass + + +def installed_cuda_version(name=""): + import torch.utils.cpp_extension + cuda_home = torch.utils.cpp_extension.CUDA_HOME + if cuda_home is None: + raise MissingCUDAException("CUDA_HOME does not exist, unable to compile CUDA op(s)") + # Ensure there is not a cuda version mismatch between torch and nvcc compiler + output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True) + output_split = output.split() + release_idx = output_split.index("release") + release = output_split[release_idx + 1].replace(',', '').split(".") + # Ignore patch versions, only look at major + minor + cuda_major, cuda_minor = release[:2] + return int(cuda_major), int(cuda_minor) + + +def get_default_compute_capabilities(): + compute_caps = DEFAULT_COMPUTE_CAPABILITIES + # Update compute capability according to: https://en.wikipedia.org/wiki/CUDA#GPUs_supported + import torch.utils.cpp_extension + if torch.utils.cpp_extension.CUDA_HOME is not None: + if installed_cuda_version()[0] == 11: + if installed_cuda_version()[1] >= 0: + compute_caps += ";8.0" + if installed_cuda_version()[1] >= 1: + compute_caps += ";8.6" + if installed_cuda_version()[1] >= 8: + compute_caps += ";9.0" + elif installed_cuda_version()[0] == 12: + compute_caps += ";8.0;8.6;9.0" + if installed_cuda_version()[1] >= 8: + compute_caps += ";10.0;12.0" + return compute_caps + + +# list compatible minor CUDA versions - so that for example pytorch built with cuda-11.0 can be used +# to build apex and system-wide installed cuda 11.2 +cuda_minor_mismatch_ok = { + 10: ["10.0", "10.1", "10.2"], + 11: ["11.0", "11.1", "11.2", "11.3", "11.4", "11.5", "11.6", "11.7", "11.8"], + 12: ["12.0", "12.1", "12.2", "12.3", "12.4", "12.5", "12.6", + "12.8"], # There does not appear to be a CUDA Toolkit 12.7 +} + + +def assert_no_cuda_mismatch(name=""): + cuda_major, cuda_minor = installed_cuda_version(name) + sys_cuda_version = f'{cuda_major}.{cuda_minor}' + torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) + # This is a show-stopping error, should probably not proceed past this + if sys_cuda_version != torch_cuda_version: + if (cuda_major in cuda_minor_mismatch_ok and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major] + and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major]): + print(f"Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda} " + "but since the APIs are compatible, accepting this combination") + return True + elif os.getenv("APEX_SKIP_CUDA_CHECK", "0") == "1": + print( + f"{WARNING} Apex Op Builder: Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda}." + "Detected `APEX_SKIP_CUDA_CHECK=1`: Allowing this combination of CUDA, but it may result in unexpected behavior." + ) + return True + raise CUDAMismatchException( + f">- Apex Op Builder: Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda}, unable to compile " + "cuda/cpp extensions without a matching cuda version.") + return True + + +class OpBuilder(ABC): + _rocm_version = None + _rocm_gpu_arch = None + _rocm_wavefront_size = None + _is_rocm_pytorch = None + _is_sycl_enabled = None + _loaded_ops = {} + + def __init__(self, name): + self.name = name + self.jit_mode = False + self.build_for_cpu = False + self.enable_bf16 = False + self.error_log = None + + @abstractmethod + def absolute_name(self): + ''' + Returns absolute build path for cases where the op is pre-installed, e.g., apex.ops.adam.cpu_adam + will be installed as something like: apex/ops/adam/cpu_adam.so + ''' + pass + + @abstractmethod + def sources(self): + ''' + Returns list of source files for your op, relative to root of apex package + ''' + pass + + def hipify_extension(self): + pass + + def sycl_extension(self): + pass + + @staticmethod + def validate_torch_version(torch_info): + install_torch_version = torch_info['version'] + current_torch_version = ".".join(torch.__version__.split('.')[:2]) + if install_torch_version != current_torch_version: + raise RuntimeError("PyTorch version mismatch! apex ops were compiled and installed " + "with a different version than what is being used at runtime. " + f"Please re-install apex or switch torch versions. " + f"Install torch version={install_torch_version}, " + f"Runtime torch version={current_torch_version}") + + @staticmethod + def validate_torch_op_version(torch_info): + if not OpBuilder.is_rocm_pytorch(): + current_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) + install_cuda_version = torch_info['cuda_version'] + if install_cuda_version != current_cuda_version: + raise RuntimeError("CUDA version mismatch! apex ops were compiled and installed " + "with a different version than what is being used at runtime. " + f"Please re-install apex or switch torch versions. " + f"Install CUDA version={install_cuda_version}, " + f"Runtime CUDA version={current_cuda_version}") + else: + current_hip_version = ".".join(torch.version.hip.split('.')[:2]) + install_hip_version = torch_info['hip_version'] + if install_hip_version != current_hip_version: + raise RuntimeError("HIP version mismatch! apex ops were compiled and installed " + "with a different version than what is being used at runtime. " + f"Please re-install apex or switch torch versions. " + f"Install HIP version={install_hip_version}, " + f"Runtime HIP version={current_hip_version}") + + @staticmethod + def is_rocm_pytorch(): + if OpBuilder._is_rocm_pytorch is not None: + return OpBuilder._is_rocm_pytorch + + _is_rocm_pytorch = False + try: + import torch + except ImportError: + pass + else: + if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5): + _is_rocm_pytorch = hasattr(torch.version, 'hip') and torch.version.hip is not None + if _is_rocm_pytorch: + from torch.utils.cpp_extension import ROCM_HOME + _is_rocm_pytorch = ROCM_HOME is not None + OpBuilder._is_rocm_pytorch = _is_rocm_pytorch + return OpBuilder._is_rocm_pytorch + + @staticmethod + def is_sycl_enabled(): + if OpBuilder._is_sycl_enabled is not None: + return OpBuilder._is_sycl_enabled + + _is_sycl_enabled = False + try: + result = subprocess.run(["c2s", "--version"], capture_output=True) + except: + pass + else: + _is_sycl_enabled = True + + OpBuilder._is_sycl_enabled = _is_sycl_enabled + return OpBuilder._is_sycl_enabled + + @staticmethod + def installed_rocm_version(): + if OpBuilder._rocm_version: + return OpBuilder._rocm_version + + ROCM_MAJOR = '0' + ROCM_MINOR = '0' + ROCM_VERSION_DEV_RAW = "" + if OpBuilder.is_rocm_pytorch(): + from torch.utils.cpp_extension import ROCM_HOME + rocm_ver_file = Path(ROCM_HOME).joinpath(".info/version") + if rocm_ver_file.is_file(): + with open(rocm_ver_file, 'r') as file: + ROCM_VERSION_DEV_RAW = file.read() + elif "rocm" in torch.__version__: + ROCM_VERSION_DEV_RAW = torch.__version__.split("rocm")[1] + if ROCM_VERSION_DEV_RAW != "": + ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0] + ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1] + else: + # Look in /usr/include/rocm-version.h + rocm_ver_file = Path("/usr/include/rocm_version.h") + if rocm_ver_file.is_file(): + with open(rocm_ver_file, 'r') as file: + for ln in file.readlines(): + if "#define ROCM_VERSION_MAJOR" in ln: + ROCM_MAJOR = re.findall(r'\S+', ln)[2] + elif "#define ROCM_VERSION_MINOR" in ln: + ROCM_MINOR = re.findall(r'\S+', ln)[2] + if ROCM_MAJOR == '0': + assert False, "Could not detect ROCm version" + + OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR)) + return OpBuilder._rocm_version + + @staticmethod + def get_rocm_gpu_arch(): + if OpBuilder._rocm_gpu_arch: + return OpBuilder._rocm_gpu_arch + rocm_info = Path("/opt/rocm/bin/rocminfo") + if (not rocm_info.is_file()): + rocm_info = Path("rocminfo") + rocm_gpu_arch_cmd = str(rocm_info) + " | grep -o -m 1 'gfx.*'" + try: + result = subprocess.check_output(rocm_gpu_arch_cmd, shell=True) + rocm_gpu_arch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + rocm_gpu_arch = "" + OpBuilder._rocm_gpu_arch = rocm_gpu_arch + return OpBuilder._rocm_gpu_arch + + @staticmethod + def get_rocm_wavefront_size(): + if OpBuilder._rocm_wavefront_size: + return OpBuilder._rocm_wavefront_size + + rocm_info = Path("/opt/rocm/bin/rocminfo") + if (not rocm_info.is_file()): + rocm_info = Path("rocminfo") + rocm_wavefront_size_cmd = str( + rocm_info) + " | grep -Eo -m1 'Wavefront Size:[[:space:]]+[0-9]+' | grep -Eo '[0-9]+'" + try: + result = subprocess.check_output(rocm_wavefront_size_cmd, shell=True) + rocm_wavefront_size = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + rocm_wavefront_size = "32" + OpBuilder._rocm_wavefront_size = rocm_wavefront_size + return OpBuilder._rocm_wavefront_size + + def include_paths(self): + ''' + Returns list of include paths, relative to root of apex package + ''' + return [] + + def nvcc_args(self): + ''' + Returns optional list of compiler flags to forward to nvcc when building CUDA sources + ''' + return [] + + def cxx_args(self): + ''' + Returns optional list of compiler flags to forward to the build + ''' + return [] + + def is_compatible(self, verbose=False): + ''' + Check if all non-python dependencies are satisfied to build this op + ''' + return True + + def extra_ldflags(self): + return [] + + def has_function(self, funcname, libraries, library_dirs=None, verbose=False): + ''' + Test for existence of a function within a tuple of libraries. + + This is used as a smoke test to check whether a certain library is available. + As a test, this creates a simple C program that calls the specified function, + and then distutils is used to compile that program and link it with the specified libraries. + Returns True if both the compile and link are successful, False otherwise. + ''' + tempdir = None # we create a temporary directory to hold various files + filestderr = None # handle to open file to which we redirect stderr + oldstderr = None # file descriptor for stderr + try: + # Echo compile and link commands that are used. + if verbose: + distutils.log.set_verbosity(1) + + # Create a compiler object. + compiler = distutils.ccompiler.new_compiler(verbose=verbose) + + # Configure compiler and linker to build according to Python install. + distutils.sysconfig.customize_compiler(compiler) + + # Create a temporary directory to hold test files. + tempdir = tempfile.mkdtemp() + + # Define a simple C program that calls the function in question + prog = "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" % (funcname, funcname) + + # Write the test program to a file. + filename = os.path.join(tempdir, 'test.c') + with open(filename, 'w') as f: + f.write(prog) + + # Redirect stderr file descriptor to a file to silence compile/link warnings. + if not verbose: + filestderr = open(os.path.join(tempdir, 'stderr.txt'), 'w') + oldstderr = os.dup(sys.stderr.fileno()) + os.dup2(filestderr.fileno(), sys.stderr.fileno()) + + # Workaround for behavior in distutils.ccompiler.CCompiler.object_filenames() + # Otherwise, a local directory will be used instead of tempdir + drive, driveless_filename = os.path.splitdrive(filename) + root_dir = driveless_filename[0] if os.path.isabs(driveless_filename) else '' + output_dir = os.path.join(drive, root_dir) + + # Attempt to compile the C program into an object file. + cflags = shlex.split(os.environ.get('CFLAGS', "")) + objs = compiler.compile([filename], output_dir=output_dir, extra_preargs=self.strip_empty_entries(cflags)) + + # Attempt to link the object file into an executable. + # Be sure to tack on any libraries that have been specified. + ldflags = shlex.split(os.environ.get('LDFLAGS', "")) + compiler.link_executable(objs, + os.path.join(tempdir, 'a.out'), + extra_preargs=self.strip_empty_entries(ldflags), + libraries=libraries, + library_dirs=library_dirs) + + # Compile and link succeeded + return True + + except CompileError: + return False + + except LinkError: + return False + + except: + return False + + finally: + # Restore stderr file descriptor and close the stderr redirect file. + if oldstderr is not None: + os.dup2(oldstderr, sys.stderr.fileno()) + if filestderr is not None: + filestderr.close() + + # Delete the temporary directory holding the test program and stderr files. + if tempdir is not None: + shutil.rmtree(tempdir) + + def strip_empty_entries(self, args): + ''' + Drop any empty strings from the list of compile and link flags + ''' + return [x for x in args if len(x) > 0] + + def cpu_arch(self): + try: + from cpuinfo import get_cpu_info + except ImportError as e: + cpu_info = self._backup_cpuinfo() + if cpu_info is None: + return "-march=native" + + try: + cpu_info = get_cpu_info() + except Exception as e: + self.warning(f"{self.name} attempted to use py-cpuinfo but failed (exception type: {type(e)}, {e}), " + "falling back to lscpu to get this information.") + cpu_info = self._backup_cpuinfo() + if cpu_info is None: + return "-march=native" + + if cpu_info['arch'].startswith('PPC_'): + # gcc does not provide -march on PowerPC, use -mcpu instead + return '-mcpu=native' + return '-march=native' + + def get_cuda_compile_flag(self): + try: + if not self.is_rocm_pytorch(): + assert_no_cuda_mismatch(self.name) + return "-D__ENABLE_CUDA__" + except MissingCUDAException: + print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " + "only cpu ops can be compiled!") + return '-D__DISABLE_CUDA__' + return '-D__DISABLE_CUDA__' + + def _backup_cpuinfo(self): + # Construct cpu_info dict from lscpu that is similar to what py-cpuinfo provides + if not self.command_exists('lscpu'): + self.warning(f"{self.name} attempted to query 'lscpu' after failing to use py-cpuinfo " + "to detect the CPU architecture. 'lscpu' does not appear to exist on " + "your system, will fall back to use -march=native and non-vectorized execution.") + return None + result = subprocess.check_output(['lscpu']) + result = result.decode('utf-8').strip().lower() + + cpu_info = {} + cpu_info['arch'] = None + cpu_info['flags'] = "" + if 'genuineintel' in result or 'authenticamd' in result: + cpu_info['arch'] = 'X86_64' + if 'avx512' in result: + cpu_info['flags'] += 'avx512,' + elif 'avx512f' in result: + cpu_info['flags'] += 'avx512f,' + if 'avx2' in result: + cpu_info['flags'] += 'avx2' + elif 'ppc64le' in result: + cpu_info['arch'] = "PPC_" + + return cpu_info + + def simd_width(self): + try: + from cpuinfo import get_cpu_info + except ImportError as e: + cpu_info = self._backup_cpuinfo() + if cpu_info is None: + return '-D__SCALAR__' + + try: + cpu_info = get_cpu_info() + except Exception as e: + self.warning(f"{self.name} attempted to use py-cpuinfo but failed (exception type: {type(e)}, {e}), " + "falling back to lscpu to get this information.") + cpu_info = self._backup_cpuinfo() + if cpu_info is None: + return '-D__SCALAR__' + + if cpu_info['arch'] == 'X86_64': + if 'avx512' in cpu_info['flags'] or 'avx512f' in cpu_info['flags']: + return '-D__AVX512__' + elif 'avx2' in cpu_info['flags']: + return '-D__AVX256__' + return '-D__SCALAR__' + + def command_exists(self, cmd): + if '|' in cmd: + cmds = cmd.split("|") + else: + cmds = [cmd] + valid = False + for cmd in cmds: + safe_cmd = ["bash", "-c", f"type {cmd}"] + result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE) + valid = valid or result.wait() == 0 + + if not valid and len(cmds) > 1: + print(f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!") + elif not valid and len(cmds) == 1: + print(f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!") + return valid + + def warning(self, msg): + self.error_log = f"{msg}" + print(f"{WARNING} {msg}") + + def apex_src_path(self, code_path): + if os.path.isabs(code_path): + return code_path + else: + return os.path.join(Path(__file__).parent.parent.absolute(), code_path) + + def builder(self): + from torch.utils.cpp_extension import CppExtension + include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())] + return CppExtension(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=include_dirs, + extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())}, + extra_link_args=self.strip_empty_entries(self.extra_ldflags())) + + def load(self, verbose=True): + if self.name in __class__._loaded_ops: + return __class__._loaded_ops[self.name] + + from apex.git_version_info import installed_ops, torch_info + if installed_ops.get(self.name, False): + # Ensure the op we're about to load was compiled with the same + # torch/cuda versions we are currently using at runtime. + self.validate_torch_version(torch_info) + if torch.cuda.is_available() and isinstance(self, CUDAOpBuilder): + self.validate_torch_op_version(torch_info) + + op_module = importlib.import_module(self.absolute_name()) + __class__._loaded_ops[self.name] = op_module + return op_module + else: + return self.jit_load(verbose) + + def jit_load(self, verbose=True): + if not self.is_compatible(verbose): + raise RuntimeError( + f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}" + ) + try: + import ninja # noqa: F401 # type: ignore + except ImportError: + raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.") + + if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch(): + self.build_for_cpu = not torch.cuda.is_available() + + self.jit_mode = True + from torch.utils.cpp_extension import load + + start_build = time.time() + sources = [os.path.abspath(self.apex_src_path(path)) for path in self.sources()] + extra_include_paths = [os.path.abspath(self.apex_src_path(path)) for path in self.include_paths()] + + # Torch will try and apply whatever CCs are in the arch list at compile time, + # we have already set the intended targets ourselves we know that will be + # needed at runtime. This prevents CC collisions such as multiple __half + # implementations. Stash arch list to reset after build. + torch_arch_list = None + if "TORCH_CUDA_ARCH_LIST" in os.environ: + torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") + os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + nvcc_args = self.strip_empty_entries(self.nvcc_args()) + cxx_args = self.strip_empty_entries(self.cxx_args()) + + cxx_args.append("-UC10_USE_GLOG") + nvcc_args.append("-UC10_USE_GLOG") + if isinstance(self, CUDAOpBuilder): + if not self.build_for_cpu and self.enable_bf16: + cxx_args.append("-DBF16_AVAILABLE") + nvcc_args.append("-DBF16_AVAILABLE") + nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__") + nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__") + nvcc_args.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + + if self.is_rocm_pytorch(): + cxx_args.append("-D__HIP_PLATFORM_AMD__=1") + cxx_args.append("-DUSE_ROCM") + os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() + cxx_args.append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + + op_module = load(name=self.name, + sources=self.strip_empty_entries(sources), + extra_include_paths=self.strip_empty_entries(extra_include_paths), + extra_cflags=cxx_args, + extra_cuda_cflags=nvcc_args, + extra_ldflags=self.strip_empty_entries(self.extra_ldflags()), + with_cuda=True if (isinstance(self, CUDAOpBuilder) and not self.build_for_cpu) else None, + verbose=verbose) + + build_duration = time.time() - start_build + if verbose: + print(f"Time to load {self.name} op: {build_duration} seconds") + + # Reset arch list so we are not silently removing it for other possible use cases + if torch_arch_list: + os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list + + __class__._loaded_ops[self.name] = op_module + + return op_module + + +class CUDAOpBuilder(OpBuilder): + + def compute_capability_args(self, cross_compile_archs=None): + """ + Returns nvcc compute capability compile flags. + + 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`. + 2. If neither is set default compute capabilities will be used + 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX + + Format: + + - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples: + + TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ... + TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ... + + - `cross_compile_archs` uses ; separator. + + """ + ccs = [] + if self.jit_mode: + # Compile for underlying architectures since we know those at runtime + for i in range(torch.cuda.device_count()): + CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i) + cc = f"{CC_MAJOR}.{CC_MINOR}" + if cc not in ccs: + ccs.append(cc) + ccs = sorted(ccs) + ccs[-1] += '+PTX' + else: + # Cross-compile mode, compile for various architectures + # env override takes priority + cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None) + if cross_compile_archs_env is not None: + if cross_compile_archs is not None: + print( + f"{WARNING} env var TORCH_CUDA_ARCH_LIST={cross_compile_archs_env} overrides cross_compile_archs={cross_compile_archs}" + ) + cross_compile_archs = cross_compile_archs_env.replace(' ', ';') + else: + if cross_compile_archs is None: + cross_compile_archs = get_default_compute_capabilities() + ccs = cross_compile_archs.split(';') + + ccs = self.filter_ccs(ccs) + if len(ccs) == 0: + raise RuntimeError( + f"Unable to load {self.name} op due to no compute capabilities remaining after filtering") + + args = [] + self.enable_bf16 = True + for cc in ccs: + num = cc[0] + cc[1].split('+')[0] + args.append(f'-gencode=arch=compute_{num},code=sm_{num}') + if cc[1].endswith('+PTX'): + args.append(f'-gencode=arch=compute_{num},code=compute_{num}') + + if int(cc[0]) <= 7: + self.enable_bf16 = False + + return args + + def filter_ccs(self, ccs: List[str]): + """ + Prune any compute capabilities that are not compatible with the builder. Should log + which CCs have been pruned. + """ + return [cc.split('.') for cc in ccs] + + def version_dependent_macros(self): + # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 + version_ge_1_1 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): + version_ge_1_1 = ['-DVERSION_GE_1_1'] + version_ge_1_3 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): + version_ge_1_3 = ['-DVERSION_GE_1_3'] + version_ge_1_5 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): + version_ge_1_5 = ['-DVERSION_GE_1_5'] + + version_dependent_macro_args = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 + if self.is_rocm_pytorch() and (self.torch_version()[0] >= 6): + version_dependent_macro_args += ["-DHIPBLAS_V2"] + + return version_dependent_macro_args + + def is_compatible(self, verbose=False): + return super().is_compatible(verbose) + + def builder(self): + try: + if not self.is_rocm_pytorch(): + assert_no_cuda_mismatch(self.name) + self.build_for_cpu = False + except MissingCUDAException: + self.build_for_cpu = True + + if self.build_for_cpu: + from torch.utils.cpp_extension import CppExtension as ExtensionBuilder + else: + from torch.utils.cpp_extension import CUDAExtension as ExtensionBuilder + include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())] + compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} if self.build_for_cpu else \ + {'cxx': self.strip_empty_entries(self.cxx_args()), \ + 'nvcc': self.strip_empty_entries(self.nvcc_args())} + + if not self.build_for_cpu and self.enable_bf16: + compile_args['cxx'].append("-DBF16_AVAILABLE") + compile_args['nvcc'].append("-DBF16_AVAILABLE") + + if self.is_rocm_pytorch(): + compile_args['cxx'].append("-D__HIP_PLATFORM_AMD__=1") + #cxx compiler args are required to compile cpp files + compile_args['cxx'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + #nvcc compiler args are required to compile hip files + compile_args['nvcc'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + if self.get_rocm_gpu_arch(): + os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() + + cuda_ext = ExtensionBuilder(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=include_dirs, + libraries=self.strip_empty_entries(self.libraries_args()), + extra_compile_args=compile_args, + extra_link_args=self.strip_empty_entries(self.extra_ldflags())) + + if self.is_rocm_pytorch(): + # hip converts paths to absolute, this converts back to relative + sources = cuda_ext.sources + curr_file = Path(__file__).parent.parent # ds root + for i in range(len(sources)): + src = Path(sources[i]) + if src.is_absolute(): + sources[i] = str(src.relative_to(curr_file)) + else: + sources[i] = str(src) + cuda_ext.sources = sources + return cuda_ext + + def hipify_extension(self): + if self.is_rocm_pytorch(): + from torch.utils.hipify import hipify_python + hipify_python.hipify( + project_directory=os.getcwd(), + output_directory=os.getcwd(), + header_include_dirs=self.include_paths(), + includes=[os.path.join(os.getcwd(), '*')], + extra_files=[os.path.abspath(s) for s in self.sources()], + show_detailed=True, + is_pytorch_extension=True, + hipify_extra_files_only=True, + ) + + def cxx_args(self): + if sys.platform == "win32": + return ['-O2'] + else: + return ['-O3', '-std=c++17', '-g', '-Wno-reorder'] + + def nvcc_args(self): + if self.build_for_cpu: + return [] + args = ['-O3'] + if self.is_rocm_pytorch(): + ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version() + args += [ + '-std=c++17', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__', + '-U__HIP_NO_HALF2_OPERATORS__', + '-DUSE_ROCM', + '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, + '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR + ] + else: + try: + nvcc_threads = int(os.getenv("APEX_NVCC_THREADS", "")) + if nvcc_threads <= 0: + raise ValueError("") + except ValueError: + nvcc_threads = min(os.cpu_count(), 8) + + cuda_major, cuda_minor = installed_cuda_version() + if cuda_major > 10: + if cuda_major == 12 and cuda_minor >= 5: + std_lib = '-std=c++20' + else: + std_lib = '-std=c++17' + else: + std_lib = '-std=c++14' + args += [ + '-allow-unsupported-compiler' if sys.platform == "win32" else '', '--use_fast_math', std_lib, + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + f'--threads={nvcc_threads}' + ] + if os.environ.get('APEX_DEBUG_CUDA_BUILD', '0') == '1': + args.append('--ptxas-options=-v') + args += self.compute_capability_args() + return args + + def libraries_args(self): + if self.build_for_cpu: + return [] + + if sys.platform == "win32": + return ['cublas', 'curand'] + else: + return [] + + def backward_pass_guard_args(self): + torch_dir = torch.__path__[0] + context_file = os.path.join(torch_dir, "include", "ATen", "Context.h") + if os.path.exists(context_file): + lines = open(context_file, 'r').readlines() + found_Backward_Pass_Guard = False + found_ROCmBackward_Pass_Guard = False + for line in lines: + if "BackwardPassGuard" in line: + # BackwardPassGuard has been renamed to ROCmBackwardPassGuard + # https://github.com/pytorch/pytorch/pull/71881/commits/4b82f5a67a35406ffb5691c69e6b4c9086316a43 + if "ROCmBackwardPassGuard" in line: + found_ROCmBackward_Pass_Guard = True + else: + found_Backward_Pass_Guard = True + break + backward_pass_guard_args = [] + if found_Backward_Pass_Guard: + backward_pass_guard_args += ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=BackwardPassGuard'] + if found_ROCmBackward_Pass_Guard: + backward_pass_guard_args += ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=ROCmBackwardPassGuard'] + return backward_pass_guard_args + + def aten_atomic_args(self): + torch_dir = torch.__path__[0] + if os.path.exists(os.path.join(torch_dir, "include", "ATen", "Atomic.cuh")): + return ['-DATEN_ATOMIC_HEADER'] + else: + return [] + + def generator_args(self): + generator_flag = [] + torch_dir = torch.__path__[0] + if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + return generator_flag + + def nvcc_threads_args(self): + cuda_major, cuda_minor = installed_cuda_version() + if cuda_major >= 11 and cuda_minor >= 2: + return ["--threads", "4"] + return [] + + def nccl_args(self): + nccl_library = ["-lnccl"] + if self.is_rocm_pytorch(): + nccl_library = ["-lrccl"] + return nccl_library + + def nccl_version(self): + return torch.cuda.nccl.version()[0:2] + + def torch_version(self): + return (TORCH_MAJOR, TORCH_MINOR) + + def is_supported(self): + return super().is_supported() + +class CPUOpBuilder(CUDAOpBuilder): + + def get_cuda_lib64_path(self): + import torch + if not self.is_rocm_pytorch(): + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64") + if not os.path.exists(CUDA_LIB64): + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib") + else: + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib") + return CUDA_LIB64 + + def extra_ldflags(self): + if self.build_for_cpu: + return ['-fopenmp'] + + if not self.is_rocm_pytorch(): + ld_flags = ['-lcurand'] + if not self.build_for_cpu: + ld_flags.append(f'-L{self.get_cuda_lib64_path()}') + return ld_flags + + return [] + + def cxx_args(self): + args = [] + if not self.build_for_cpu: + CUDA_LIB64 = self.get_cuda_lib64_path() + + args += super().cxx_args() + args += [ + f'-L{CUDA_LIB64}', + '-lcudart', + '-lcublas', + '-g', + ] + + CPU_ARCH = self.cpu_arch() + SIMD_WIDTH = self.simd_width() + CUDA_ENABLE = self.get_cuda_compile_flag() + args += [ + CPU_ARCH, + '-fopenmp', + SIMD_WIDTH, + CUDA_ENABLE, + ] + + return args diff --git a/op_builder/distributed_adam.py b/op_builder/distributed_adam.py new file mode 100644 index 000000000..ef453bee9 --- /dev/null +++ b/op_builder/distributed_adam.py @@ -0,0 +1,33 @@ +from .builder import CUDAOpBuilder + +import sys + + +class DistributedAdamBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_DISTRIBUTED_ADAM' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "distributed_adam_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp', + 'contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'] + + def include_paths(self): + return ['contrib/csrc/', + 'csrc'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags += ['--use_fast_math'] + return nvcc_flags \ No newline at end of file diff --git a/op_builder/distributed_lamb.py b/op_builder/distributed_lamb.py new file mode 100644 index 000000000..74d77d129 --- /dev/null +++ b/op_builder/distributed_lamb.py @@ -0,0 +1,33 @@ +from .builder import CUDAOpBuilder + +import sys + + +class DistributedLambBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_DISTRIBUTED_LAMB' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "distributed_lamb_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp', + 'contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu'] + + def include_paths(self): + return ['contrib/csrc/', + 'csrc'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags += ['--use_fast_math'] + return nvcc_flags \ No newline at end of file diff --git a/op_builder/fast_multihead_attn.py b/op_builder/fast_multihead_attn.py new file mode 100644 index 000000000..0f2f8b52f --- /dev/null +++ b/op_builder/fast_multihead_attn.py @@ -0,0 +1,50 @@ +from .builder import CUDAOpBuilder + +import sys + + +class FastMultiheadAttnBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_FAST_MULTIHEAD_ATTN' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "fast_multihead_attn" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['contrib/csrc/multihead_attn/multihead_attn_frontend.cpp', + 'contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu', + "contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu", + "contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu", + "contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu", + "contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu", + "contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu", + "contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu", + "contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu"] + + def include_paths(self): + return ['csrc/', + 'contrib/csrc/', + 'contrib/csrc/multihead_attn'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + self.generator_args() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + self.generator_args() + if not self.is_rocm_pytorch(): + nvcc_flags += ['-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '--expt-relaxed-constexpr', + '--expt-extended-lambda', + '--use_fast_math'] + self.compute_capability_args() + else: + nvcc_flags += ['-I/opt/rocm/include/hiprand', + '-I/opt/rocm/include/rocrand', + '-U__HIP_NO_HALF_OPERATORS__', + '-U__HIP_NO_HALF_CONVERSIONS__'] + self.backward_pass_guard_args() + return nvcc_flags \ No newline at end of file diff --git a/op_builder/focal_loss.py b/op_builder/focal_loss.py new file mode 100644 index 000000000..98a21330a --- /dev/null +++ b/op_builder/focal_loss.py @@ -0,0 +1,33 @@ +from .builder import CUDAOpBuilder + +import sys + + +class FocalLossBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_FOCAL_LOSS' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "focal_loss_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['contrib/csrc/focal_loss/focal_loss_cuda.cpp', + 'contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu'] + + def include_paths(self): + return ['contrib/csrc/' ] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + if self.is_rocm_pytorch(): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + else: + nvcc_flags = ['-O3', '--ftz=false', '--use_fast_math'] + return nvcc_flags \ No newline at end of file diff --git a/op_builder/fused_adam.py b/op_builder/fused_adam.py new file mode 100644 index 000000000..f335368d8 --- /dev/null +++ b/op_builder/fused_adam.py @@ -0,0 +1,33 @@ +from .builder import CUDAOpBuilder + +import sys + + +class FusedAdamBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_FUSED_ADAM' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "fused_adam_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['contrib/csrc/optimizers/fused_adam_cuda.cpp', + 'contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'] + + def include_paths(self): + return ['contrib/csrc/', + 'csrc'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags += ['--use_fast_math'] + return nvcc_flags \ No newline at end of file diff --git a/op_builder/fused_bias_swiglu.py b/op_builder/fused_bias_swiglu.py new file mode 100644 index 000000000..4a7d13881 --- /dev/null +++ b/op_builder/fused_bias_swiglu.py @@ -0,0 +1,57 @@ +from .builder import CUDAOpBuilder +import sys +import os + +class FusedBiasSwiGLUBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_FUSED_BIAS_SWIGLU' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "fused_bias_swiglu" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return [ + "csrc/megatron/fused_bias_swiglu.cpp", + "csrc/megatron/fused_bias_swiglu_cuda.cu" + ] + + def include_paths(self): + return ['csrc'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = [ + '-O3', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__' + ] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags.extend( + [ + '--expt-relaxed-constexpr', + '--expt-extended-lambda' + ]) + else: + # Handle ROCm arch flags + amdgpu_targets = os.environ.get('PYTORCH_ROCM_ARCH', '') + if not amdgpu_targets: + print("Warning: PYTORCH_ROCM_ARCH environment variable is empty.") + print("Using default architecture. Set this variable for specific GPU targets.") + print("Example: export PYTORCH_ROCM_ARCH=gfx906") + amdgpu_targets = "gfx906" + try: + for amdgpu_target in amdgpu_targets.split(';'): + if amdgpu_target: + nvcc_flags += [f'--offload-arch={amdgpu_target}'] + except Exception as e: + print(f"Warning: Error processing PYTORCH_ROCM_ARCH: {e}") + print("Falling back to default architecture gfx906") + nvcc_flags += ['--offload-arch=gfx906'] + return nvcc_flags \ No newline at end of file diff --git a/op_builder/fused_conv_bias_relu.py b/op_builder/fused_conv_bias_relu.py new file mode 100644 index 000000000..997cfb32d --- /dev/null +++ b/op_builder/fused_conv_bias_relu.py @@ -0,0 +1,36 @@ +from .builder import CUDAOpBuilder +import sys + + +class FusedConvBiasReluBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_FUSED_CONV_BIAS_RELU' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "fused_conv_bias_relu" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + if self.is_rocm_pytorch(): + return ["contrib/csrc/conv_bias_relu/conv_bias_relu_rocm.cpp"] + else: + return ["contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"] + + def include_paths(self): + paths = ['contrib/csrc/'] + if not self.is_rocm_pytorch(): + paths.append("apex/contrib/csrc/cudnn-frontend/include") + return paths + + def cxx_args(self): + args = super().cxx_args() + return args + self.generator_args() + self.version_dependent_macros() + + def libraries_args(self): + if self.is_rocm_pytorch(): + return super().libraries_args() + ['MIOpen'] + else: + return super().libraries_args() \ No newline at end of file diff --git a/op_builder/fused_dense.py b/op_builder/fused_dense.py new file mode 100644 index 000000000..4d40eef6d --- /dev/null +++ b/op_builder/fused_dense.py @@ -0,0 +1,28 @@ +from .builder import CUDAOpBuilder + +import sys + + +class FusedDenseBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_FUSED_DENSE' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "fused_dense_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['csrc/fused_dense_base.cpp', 'csrc/fused_dense_cuda.cu'] + + def include_paths(self): + return ['csrc'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + return ['-O3'] + self.version_dependent_macros() \ No newline at end of file diff --git a/op_builder/fused_index_mul_2d.py b/op_builder/fused_index_mul_2d.py new file mode 100644 index 000000000..d04564e15 --- /dev/null +++ b/op_builder/fused_index_mul_2d.py @@ -0,0 +1,34 @@ +from .builder import CUDAOpBuilder + +import sys + + +class FusedIndexMul2dBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_FUSED_INDEX_MUL_2D' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "fused_index_mul_2d" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp', + 'contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu'] + + def include_paths(self): + return ['contrib/csrc/'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags += ['--use_fast_math', '--ftz=false'] + else: + nvcc_flags += self.aten_atomic_args() + return nvcc_flags \ No newline at end of file diff --git a/op_builder/fused_lamb.py b/op_builder/fused_lamb.py new file mode 100644 index 000000000..02a0b6fe7 --- /dev/null +++ b/op_builder/fused_lamb.py @@ -0,0 +1,34 @@ +from .builder import CUDAOpBuilder + +import sys + + +class FusedLambBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_FUSED_LAMB' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "fused_lamb_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['contrib/csrc/optimizers/fused_lamb_cuda.cpp', + 'contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu', + 'csrc/multi_tensor_l2norm_kernel.cu'] + + def include_paths(self): + return ['contrib/csrc/', + 'csrc'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags += ['--use_fast_math'] + return nvcc_flags \ No newline at end of file diff --git a/op_builder/fused_layer_norm.py b/op_builder/fused_layer_norm.py new file mode 100644 index 000000000..66130f17b --- /dev/null +++ b/op_builder/fused_layer_norm.py @@ -0,0 +1,31 @@ +from .builder import CUDAOpBuilder + +import sys + + +class FusedLayerNormBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_FUSED_LAYER_NORM' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "fused_layer_norm_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['csrc/layer_norm_cuda.cpp', 'csrc/layer_norm_cuda_kernel.cu'] + + def include_paths(self): + return ['csrc'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags.extend(['--use_fast_math', '-maxrregcount=50']) + return nvcc_flags \ No newline at end of file diff --git a/op_builder/fused_rope.py b/op_builder/fused_rope.py new file mode 100644 index 000000000..c87f14b84 --- /dev/null +++ b/op_builder/fused_rope.py @@ -0,0 +1,40 @@ +from .builder import CUDAOpBuilder + +import sys + + +class FusedRopeBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_FUSED_ROPE' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "fused_rotary_positional_embedding" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ["csrc/megatron/fused_rotary_positional_embedding.cpp", + "csrc/megatron/fused_rotary_positional_embedding_cuda.cu"] + + def include_paths(self): + return ['csrc', 'csrc/megatron'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = [ + '-O3', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__' + ] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags.extend( + [ + '--expt-relaxed-constexpr', + '--expt-extended-lambda' + ]) + return nvcc_flags \ No newline at end of file diff --git a/op_builder/fused_weight_gradient_mlp.py b/op_builder/fused_weight_gradient_mlp.py new file mode 100644 index 000000000..b6d595385 --- /dev/null +++ b/op_builder/fused_weight_gradient_mlp.py @@ -0,0 +1,42 @@ +from .builder import CUDAOpBuilder + +class FusedWeightGradientMlpCudaBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_FUSED_WEIGHT_GRADIENT_MLP' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "fused_weight_gradient_mlp_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return [ + "csrc/megatron/fused_weight_gradient_dense.cpp", + "csrc/megatron/fused_weight_gradient_dense_cuda.cu", + "csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu", + ] + + def include_paths(self): + # Both csrc and csrc/megatron are included in the original extension + return ['csrc', 'csrc/megatron'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = [ + '-O3', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__' + ] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags.extend( + [ + '--expt-relaxed-constexpr', + '--expt-extended-lambda', + "--use_fast_math" + ]) + self.compute_capability_args() + return nvcc_flags \ No newline at end of file diff --git a/op_builder/generic_scaled_masked_softmax_cuda.py b/op_builder/generic_scaled_masked_softmax_cuda.py new file mode 100644 index 000000000..a0fb2d5fc --- /dev/null +++ b/op_builder/generic_scaled_masked_softmax_cuda.py @@ -0,0 +1,39 @@ +from .builder import CUDAOpBuilder + +class GenericScaledMaskedSoftmaxCudaBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_GENERIC_SCALED_MASKED_SOFTMAX_CUDA' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "generic_scaled_masked_softmax_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return [ + "csrc/megatron/generic_scaled_masked_softmax_cpu.cpp", + "csrc/megatron/generic_scaled_masked_softmax_cuda.cu" + ] + + def include_paths(self): + return ['csrc'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = [ + '-O3', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__' + ] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags.extend( + [ + '--expt-relaxed-constexpr', + '--expt-extended-lambda' + ]) + return nvcc_flags \ No newline at end of file diff --git a/op_builder/mlp.py b/op_builder/mlp.py new file mode 100644 index 000000000..c6a177721 --- /dev/null +++ b/op_builder/mlp.py @@ -0,0 +1,32 @@ +from .builder import CUDAOpBuilder + +import sys + + +class MlpBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_MLP' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "mlp_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['csrc/mlp.cpp', + 'csrc/mlp_cuda.cu'] + + def include_paths(self): + return ['csrc'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if self.is_rocm_pytorch(): + nvcc_flags.extend(self.backward_pass_guard_args()) + return nvcc_flags \ No newline at end of file diff --git a/op_builder/nccl_allocator.py b/op_builder/nccl_allocator.py new file mode 100644 index 000000000..320e76476 --- /dev/null +++ b/op_builder/nccl_allocator.py @@ -0,0 +1,36 @@ +from .builder import CUDAOpBuilder + +import sys + + +class NCCLAllocatorBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_NCCL_ALLOCATOR' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "_apex_nccl_allocator" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ["contrib/csrc/nccl_allocator/NCCLAllocator.cpp"] + + def include_paths(self): + return ['contrib/csrc/'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + self.generator_args() + + def nvcc_args(self): + return self.nccl_args() + + def is_compatible(self, verbose=False): + torch_version = self.torch_version() + if torch_version >= (2, 6): + available_nccl_version = self.nccl_version() + if available_nccl_version >= (2, 19): + return True + return False \ No newline at end of file diff --git a/op_builder/nccl_p2p.py b/op_builder/nccl_p2p.py new file mode 100644 index 000000000..37772572e --- /dev/null +++ b/op_builder/nccl_p2p.py @@ -0,0 +1,26 @@ +from .builder import CUDAOpBuilder + +import sys + + +class NCCLP2PBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_NCCL_P2P' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "nccl_p2p_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ["contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu", + "contrib/csrc/nccl_p2p/nccl_p2p.cpp"] + + def include_paths(self): + return ['contrib/csrc/'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + self.generator_args() \ No newline at end of file diff --git a/op_builder/peer_memory.py b/op_builder/peer_memory.py new file mode 100644 index 000000000..c869f0be6 --- /dev/null +++ b/op_builder/peer_memory.py @@ -0,0 +1,26 @@ +from .builder import CUDAOpBuilder + +import sys + + +class PeerMemoryBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_PEER_MEMORY' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "peer_memory_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ["contrib/csrc/peer_memory/peer_memory_cuda.cu", + "contrib/csrc/peer_memory/peer_memory.cpp"] + + def include_paths(self): + return ['contrib/csrc/'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + self.generator_args() \ No newline at end of file diff --git a/op_builder/scaled_masked_softmax_cuda.py b/op_builder/scaled_masked_softmax_cuda.py new file mode 100644 index 000000000..1013ef8d2 --- /dev/null +++ b/op_builder/scaled_masked_softmax_cuda.py @@ -0,0 +1,40 @@ +from .builder import CUDAOpBuilder + +class ScaledMaskedSoftmaxCudaBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_SCALED_MASKED_SOFTMAX_CUDA' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "scaled_masked_softmax_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return [ + "csrc/megatron/scaled_masked_softmax_cpu.cpp", + "csrc/megatron/scaled_masked_softmax_cuda.cu" + ] + + def include_paths(self): + # Both csrc and csrc/megatron are included in the original extension + return ['csrc', 'csrc/megatron'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = [ + '-O3', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__' + ] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags.extend( + [ + '--expt-relaxed-constexpr', + '--expt-extended-lambda' + ]) + return nvcc_flags \ No newline at end of file diff --git a/op_builder/scaled_softmax_cuda.py b/op_builder/scaled_softmax_cuda.py new file mode 100644 index 000000000..f29543963 --- /dev/null +++ b/op_builder/scaled_softmax_cuda.py @@ -0,0 +1,41 @@ +from .builder import CUDAOpBuilder + +import sys + +class ScaledSoftmaxCudaBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_SCALED_SOFTMAX_CUDA' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "scaled_softmax_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return [ + "csrc/megatron/scaled_softmax_cpu.cpp", + "csrc/megatron/scaled_softmax_cuda.cu" + ] + + def include_paths(self): + return ['csrc'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = [ + '-O3', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__' + ] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags.extend( + [ + '--expt-relaxed-constexpr', + '--expt-extended-lambda' + ]) + return nvcc_flags \ No newline at end of file diff --git a/op_builder/scaled_upper_triang_masked_softmax_cuda.py b/op_builder/scaled_upper_triang_masked_softmax_cuda.py new file mode 100644 index 000000000..3c2273ad9 --- /dev/null +++ b/op_builder/scaled_upper_triang_masked_softmax_cuda.py @@ -0,0 +1,39 @@ +from .builder import CUDAOpBuilder + +class ScaledUpperTriangMaskedSoftmaxCudaBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_SCALED_UPPER_TRIANG_MASKED_SOFTMAX_CUDA' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "scaled_upper_triang_masked_softmax_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return [ + "csrc/megatron/scaled_upper_triang_masked_softmax_cpu.cpp", + "csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu" + ] + + def include_paths(self): + return ['csrc'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = [ + '-O3', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__' + ] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags.extend( + [ + '--expt-relaxed-constexpr', + '--expt-extended-lambda' + ]) + return nvcc_flags \ No newline at end of file diff --git a/op_builder/syncbn.py b/op_builder/syncbn.py new file mode 100644 index 000000000..251c33e01 --- /dev/null +++ b/op_builder/syncbn.py @@ -0,0 +1,28 @@ +from .builder import CUDAOpBuilder + +import sys + + +class SyncBnBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_SYNCBN' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "syncbn" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['csrc/syncbn.cpp', 'csrc/welford.cu'] + + def include_paths(self): + return ['csrc'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + return ['-O3'] + self.version_dependent_macros() \ No newline at end of file diff --git a/op_builder/transducer_joint.py b/op_builder/transducer_joint.py new file mode 100644 index 000000000..c17f60f7b --- /dev/null +++ b/op_builder/transducer_joint.py @@ -0,0 +1,33 @@ +from .builder import CUDAOpBuilder +import sys + + +class TransducerJointBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_TRANSDUCER_JOINT' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "transducer_joint_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ["contrib/csrc/transducer/transducer_joint.cpp", + "contrib/csrc/transducer/transducer_joint_kernel.cu"] + + def include_paths(self): + return ['contrib/csrc/', + #it uses philox.cuh from contrib/csrc/multihead_attn + 'contrib/csrc/multihead_attn'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + self.generator_args() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + self.generator_args() + if not self.is_rocm_pytorch(): + nvcc_flags += self.nvcc_threads_args() + return nvcc_flags \ No newline at end of file diff --git a/op_builder/transducer_loss.py b/op_builder/transducer_loss.py new file mode 100644 index 000000000..53ae4eaac --- /dev/null +++ b/op_builder/transducer_loss.py @@ -0,0 +1,31 @@ +from .builder import CUDAOpBuilder +import sys + + +class TransducerLossBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_TRANSDUCER_LOSS' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "transducer_loss_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ["contrib/csrc/transducer/transducer_loss.cpp", + "contrib/csrc/transducer/transducer_loss_kernel.cu"] + + def include_paths(self): + return ['contrib/csrc/' ] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags += self.nvcc_threads_args() + return nvcc_flags \ No newline at end of file diff --git a/op_builder/xentropy.py b/op_builder/xentropy.py new file mode 100644 index 000000000..84f3ddf12 --- /dev/null +++ b/op_builder/xentropy.py @@ -0,0 +1,29 @@ +from .builder import CUDAOpBuilder + +import sys + + +class XentropyBuilder(CUDAOpBuilder): + BUILD_VAR = 'APEX_BUILD_XENTROPY' + INCLUDE_FLAG = "APEX_BUILD_CUDA_OPS" + NAME = "xentropy_cuda" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'apex.{self.NAME}' + + def sources(self): + return ['contrib/csrc/xentropy/interface.cpp', + 'contrib/csrc/xentropy/xentropy_kernel.cu'] + + def include_paths(self): + return ['csrc', 'contrib/csrc/' ] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + return ['-O3'] + self.version_dependent_macros() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..f29f03dd1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[build-system] +requires = [ + "setuptools", + "wheel", +] +build-backend = "setuptools.build_meta" + diff --git a/requirements.txt b/requirements.txt index fd202d9b7..d527b4498 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,12 @@ cxxfilt>=0.2.0 tqdm>=4.28.1 -numpy>=1.15.3 +numpy PyYAML>=5.1 pytest>=3.5.1 packaging>=14.0 +matplotlib>=3.8 +pandas>=2.2.2 +py-cpuinfo +build +ninja +wheel \ No newline at end of file diff --git a/scripts/clean.py b/scripts/clean.py new file mode 100644 index 000000000..be7e69798 --- /dev/null +++ b/scripts/clean.py @@ -0,0 +1,16 @@ +import torch.utils.cpp_extension +import shutil +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from op_builder.all_ops import ALL_OPS + +torch_ext_directory = torch.utils.cpp_extension._get_build_directory("", False) + +install_ops = dict.fromkeys(ALL_OPS.keys(), False) +for op_name, builder in ALL_OPS.items(): + path = os.path.join(torch_ext_directory, op_name) + if os.path.exists(path): + print ("removing torch extension", op_name, "at", torch_ext_directory) + shutil.rmtree(path) \ No newline at end of file diff --git a/scripts/jit_module.py b/scripts/jit_module.py new file mode 100644 index 000000000..7c4ef48c4 --- /dev/null +++ b/scripts/jit_module.py @@ -0,0 +1,242 @@ +""" +Script to test JIT module for Apex. +""" + +import sys +import os + +class JitModule: + + def __init__(self): + self.op_builder_folder = "op_builder" + self.compatability_folder = "compatibility" + + def get_module_name(self, builder_file_name): + #open builder file and read the NAME attribute + with open(os.path.join(self.op_builder_folder, f"{builder_file_name}.py"), "r") as f: + contents = f.read() + for line in contents.split("\n"): + if "NAME = " in line: + return line.split("NAME = ")[1].strip()[1:-1] + return None + + + def create_loader_class_name(self, module_name): + parts = module_name.split("_") + new_name = "" + for part in parts: + new_name += part.capitalize() + return f"_{new_name}Module" + + + def create_builder_class_name(self, module_name): + parts = module_name.split("_") + new_name = "" + for part in parts: + new_name += part.capitalize() + return f"{new_name}Builder" + + + def create_build_var(self, module_name): + return f"APEX_BUILD_{module_name.upper()}" + + + def check_if_builder_module_exists(self, module_name): + if os.path.exists(os.path.join(self.op_builder_folder, f"{module_name}.py")): + return True + else: + return False + + def check_if_loader_module_exists(self, module_name): + if os.path.exists(os.path.join(self.compatability_folder, f"{module_name}.py")): + return True + else: + return False + + + def findBuilderClassName(self, builder_name): + #read file contents of op_builder/builder_name.py + with open(os.path.join(self.op_builder_folder, f"{builder_name}.py"), "r") as f: + contents = f.read() + #find the class name that inherits from CPUOpBuilder or CUDAOpBuilder + for line in contents.split("\n"): + if "class" in line: + return line.split("class")[1].split("(")[0].strip() + return None + + + def create_loader(self, builder_name): + module_name = self.get_module_name(builder_name) or builder_name + #check if a loader module in compatability folder + is_loader_exists = self.check_if_loader_module_exists(module_name) + if is_loader_exists: + print(f"Loader module {module_name} exists") + return + + #create loader class name to use in loader module + loader_class_name = self.create_loader_class_name(module_name) + + #find builder class name to use in the loader + builder_class_name = self.findBuilderClassName(builder_name) + + #create a loader module in compatability folder + with open(os.path.join(self.compatability_folder, f"{module_name}.py"), "w") as f: + f.write(f"import sys\n") + f.write(f"import importlib\n") + f.write(f"\n") + f.write(f"class {loader_class_name}:\n") + f.write(f" def __init__(self):\n") + f.write(f" self._loaded_module = None\n") + f.write(f" self._loading = False\n") + f.write(f"\n") + f.write(f" def _load_module(self):\n") + f.write(f" if self._loaded_module is None and not self._loading:\n") + f.write(f" self._loading = True\n") + f.write(f" try:\n") + f.write(f" apex_op_builder = importlib.import_module('apex.op_builder')\n") + f.write(f" builder = getattr(apex_op_builder, '{builder_class_name}')\n") + f.write(f" self._loaded_module = builder().load()\n") + f.write(f" except Exception as e:\n") + f.write(f" self._loading = False\n") + f.write(f" raise ImportError('Failed to load " + builder_name + " :' + str(e))\n") + f.write(f" finally:\n") + f.write(f" self._loading = False\n") + f.write(f" return self._loaded_module\n") + f.write(f"\n") + f.write(f" def __getattr__(self, name):\n") + f.write(f" if name.startswith('_'):\n") + f.write(f" raise AttributeError(f'module {module_name} has no attribute ' + name)\n") + f.write(f" return getattr(self._load_module(), name)\n") #dynamic loading of the module + f.write(f"\n") + f.write(f" def __dir__(self):\n") + f.write(f" try:\n") + f.write(f" return dir(self._load_module())\n") + f.write(f" except:\n") + f.write(f" return []\n") + f.write(f"\n") + f.write(f" def __repr__(self):\n") + f.write(f" return ''\n") + f.write(f"\n") + f.write(f"sys.modules[__name__] = {loader_class_name}()\n") + + print(f"Loader module {module_name} created") + + + def create_builder(self, module_name): + #Interactively prompt for builder details and create the builder module. + if_cuda_module = input("Is this a CUDA module? (Y/n) ").strip() or "y" + sources = input("Enter the sources (comma separated). Press Enter to skip ").strip() + + + if if_cuda_module == "y": + class_name = "CUDAOpBuilder" + include_flag = "APEX_BUILD_CUDA_OPS" + else: + class_name = "CPUOpBuilder" + include_flag = "APEX_BUILD_CPU_OPS" + + builder_class_name = self.create_builder_class_name(module_name) + build_var = self.create_build_var(module_name) + + if len(sources) == 0: + sources_list = [] + sources_list_string = "[]" + else: + sources_list = sources.split(",") + sources_list_string = "[" + ",".join(["'" + source.strip() + "'" for source in sources_list]) + "]" + print(f"sources_list_string: {sources_list_string}") + + include_paths = [] + for source in sources_list: + if "csrc" in source and "csrc" not in include_paths: + include_paths.append("csrc") + elif "contrib/csrc" in source and "contrib/csrc" not in include_paths: + include_paths.append("contrib/csrc") + include_paths_string = "[" + ",".join(["'" + path.strip() + "'" for path in include_paths]) + "]" + + with open(os.path.join(self.op_builder_folder, f"{module_name}.py"), "w") as f: + if if_cuda_module == "y": + f.write(f"from .builder import CUDAOpBuilder\n") + else: + f.write(f"from .builder import CPUOpBuilder\n") + f.write(f"\n") + f.write(f"class {builder_class_name}({class_name}):\n") + f.write(f" # Required. The environment variable to indicate prebuilding the module when installing apex e.g. APEX_BUILD_FUSED_BIAS_SWIGLU for fused_bias_swiglu\n") + f.write(f" BUILD_VAR = \"{build_var}\"\n") + f.write(f" # Required. Either APEX_BUILD_CUDA_OPS or APEX_BUILD_CPU_OPS to indicate whether the module will be built for gpu or cpu\n") + f.write(f" INCLUDE_FLAG = \"{include_flag}\"\n") + f.write(f" # Required. Name of module e.g. fused_bias_swiglu\n") + f.write(f" NAME = \"{module_name}\"\n") + f.write(f"\n") + f.write(f" def __init__(self):\n") + f.write(f" super().__init__(name=self.NAME)\n") + f.write(f"\n") + f.write(f" # Required to override. Return the namespace where the module will be installed.\n") + f.write(f" def absolute_name(self):\n") + f.write(f" return f'apex.{{self.NAME}}'\n") + f.write(f"\n") + f.write(f" # Required to override. Return the list of source files to be compiled\n") + f.write(f" # Please mention the full path of the source files\n") + f.write(f" # e.g. ['csrc/fused_dense_base.cpp', 'csrc/fused_dense_cuda.cu']\n") + f.write(f" def sources(self):\n") + f.write(f" return {sources_list_string}\n") + f.write(f"\n") + f.write(f" # Required to override. Return the list of include directories\n") + f.write(f" # Please mention the full path of the include directories\n") + f.write(f" # e.g. ['csrc', 'contrib/csrc']\n") + f.write(f" def include_paths(self):\n") + f.write(f" return {include_paths_string}\n") + f.write(f"\n") + f.write(f" # Optional. Return a list of extra compiler flags for the C++ compiler when building C++ sources (e.g. optimization level, preprocessor macros).\n") + f.write(f" def cxx_args(self):\n") + f.write(f" return super().cxx_args() + self.generator_args() + self.version_dependent_macros()\n") + f.write(f"\n") + f.write(f" # Optional. Return a list of extra compiler flags for nvcc when building CUDA sources (e.g. -O3, architecture flags, preprocessor macros).\n") + f.write(f" def nvcc_args(self):\n") + f.write(f" return super().nvcc_args() + ['-O3'] + self.version_dependent_macros()\n") + f.write(f"\n") + f.write(f" # Optional. Return True if this module can be installed and loaded given the environment (e.g. minimum torch version supported).\n") + f.write(f" def is_compatible(self, verbose=False):\n") + f.write(f" return True\n") + f.write(f"\n") + f.write(f" # Optional. Return list of libraries to compile against e.g. MIOpen.\n") + f.write(f" def libraries_args(self):\n") + f.write(f" return super().libraries_args()\n") + + print(f"Builder module {module_name} created") + + + def add_jit_module(self, builder_name): + #check if builder module exists + is_builder_exists = self.check_if_builder_module_exists(builder_name) + if not is_builder_exists: + self.create_builder(builder_name) + else: + print(f"Builder module {builder_name} already exists") + + #get module name from builder name + module_name = self.get_module_name(builder_name) + if module_name is None: + print(f"Module name for builder {builder_name} not found") + return + + #if the loader module does not exist, create it + if not self.check_if_loader_module_exists(builder_name): + self.create_loader(builder_name) + + +def main(): + jit_module = JitModule() + if len(sys.argv) > 1: + module_name = sys.argv[1] + else: + module_name = input("What is the name of the module? ").strip() + if not module_name: + print("No module name provided.") + sys.exit(1) + success = jit_module.add_jit_module(module_name) + if success: + print("JIT module added") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/setup.py b/setup.py index 2b96214ed..febfe94a9 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,29 @@ -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME -from setuptools import setup, find_packages -import subprocess - import sys import warnings import os +import glob +from packaging.version import parse, Version + +from setuptools import setup, find_packages, Distribution +import subprocess + +import torch +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, + CUDAExtension, + CUDA_HOME, + ROCM_HOME, + load, + ) + +import typing +import shlex + +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) + +from op_builder.all_ops import ALL_OPS +import shutil # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) @@ -18,59 +36,52 @@ def get_cuda_bare_metal_version(cuda_dir): release = output[release_idx].split(".") bare_metal_major = release[0] bare_metal_minor = release[1][0] - return raw_output, bare_metal_major, bare_metal_minor - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) - torch_binary_major = torch.version.cuda.split(".")[0] - torch_binary_minor = torch.version.cuda.split(".")[1] - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) +def get_rocm_bare_metal_version(rocm_dir): + raw_output = subprocess.check_output([rocm_dir + "/bin/hipcc", "--version"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("version:") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + return raw_output, bare_metal_major, bare_metal_minor -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) +def get_apex_version(): + cwd = os.path.dirname(os.path.abspath(__file__)) + apex_version_file = os.path.join(cwd, "version.txt") + if os.path.exists(apex_version_file): + with open(apex_version_file) as f: + apex_version = f.read().strip() + else: + raise RuntimeError("version.txt file is missing") + if os.getenv("BUILD_VERSION"): + apex_version = os.getenv("BUILD_VERSION") + if os.getenv("DESIRED_CUDA"): + apex_version += "+" + os.getenv("DESIRED_CUDA") + if os.getenv("APEX_COMMIT"): + apex_version += ".git"+os.getenv("APEX_COMMIT")[:8] + return apex_version -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: - return nvcc_extra_args + ["--threads", "4"] - return nvcc_extra_args +print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) +print("\n\ntorch.version.hip = {}\n\n".format(torch.version.hip)) +ROCM_MAJOR = int(torch.version.hip.split('.')[0]) +ROCM_MINOR = int(torch.version.hip.split('.')[1]) -def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool: - cudnn_available = torch.backends.cudnn.is_available() - cudnn_version = torch.backends.cudnn.version() if cudnn_available else None - if not (cudnn_available and (cudnn_version >= required_cudnn_version)): - warnings.warn( - f"Skip `{global_option}` as it requires cuDNN {required_cudnn_version} or later, " - f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}" - ) - return False - return True +def check_if_rocm_pytorch(): + is_rocm_pytorch = False + if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5): + is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False + return is_rocm_pytorch +IS_ROCM_PYTORCH = check_if_rocm_pytorch() -if not torch.cuda.is_available(): +if not torch.cuda.is_available() and not IS_ROCM_PYTORCH: # https://github.com/NVIDIA/apex/issues/486 # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). @@ -91,612 +102,234 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" else: os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) +elif not torch.cuda.is_available() and IS_ROCM_PYTORCH: + print('\nWarning: Torch did not find available GPUs on this system.\n', + 'If your intention is to cross-compile, this is not an error.\n' + 'By default, Apex will cross-compile for the same gfx targets\n' + 'used by default in ROCm PyTorch\n') if TORCH_MAJOR == 0 and TORCH_MINOR < 4: raise RuntimeError( "Apex requires Pytorch 0.4 or newer.\nThe latest stable release can be obtained from https://pytorch.org/" ) -cmdclass = {} -ext_modules = [] - +# cmdclass = {} extras = {} -if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv: - if TORCH_MAJOR == 0: - raise RuntimeError( - "--cpp_ext requires Pytorch 1.0 or later, " "found torch.__version__ = {}".format(torch.__version__) - ) +if not IS_ROCM_PYTORCH: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +else: + _, bare_metal_version, bare_metal_minor = get_rocm_bare_metal_version(ROCM_HOME) -if "--cpp_ext" in sys.argv: - sys.argv.remove("--cpp_ext") - ext_modules.append(CppExtension("apex_C", ["csrc/flatten_unflatten.cpp"])) - - -# Set up macros for forward/backward compatibility hack around -# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e -# and -# https://github.com/NVIDIA/apex/issues/456 -# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac -version_ge_1_1 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): - version_ge_1_1 = ["-DVERSION_GE_1_1"] -version_ge_1_3 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): - version_ge_1_3 = ["-DVERSION_GE_1_3"] -version_ge_1_5 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): - version_ge_1_5 = ["-DVERSION_GE_1_5"] -version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 - -if "--distributed_adam" in sys.argv: - sys.argv.remove("--distributed_adam") - raise_if_cuda_home_none("--distributed_adam") - ext_modules.append( - CUDAExtension( - name="distributed_adam_cuda", - sources=[ - "apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp", - "apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu", - ], - include_dirs=[os.path.join(this_dir, "csrc")], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads(["-O3", "--use_fast_math"] + version_dependent_macros), - }, - ) - ) -if "--distributed_lamb" in sys.argv: - sys.argv.remove("--distributed_lamb") - raise_if_cuda_home_none("--distributed_lamb") - ext_modules.append( - CUDAExtension( - name="distributed_lamb_cuda", - sources=[ - "apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp", - "apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu", - ], - include_dirs=[os.path.join(this_dir, "csrc")], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads(["-O3", "--use_fast_math"] + version_dependent_macros), - }, - ) - ) +# ***************************** Op builder ********************** -if "--cuda_ext" in sys.argv: - sys.argv.remove("--cuda_ext") - raise_if_cuda_home_none("--cuda_ext") - check_cuda_torch_binary_vs_bare_metal(CUDA_HOME) - - ext_modules.append( - CUDAExtension( - name="amp_C", - sources=[ - "csrc/amp_C_frontend.cpp", - "csrc/multi_tensor_sgd_kernel.cu", - "csrc/multi_tensor_scale_kernel.cu", - "csrc/multi_tensor_axpby_kernel.cu", - "csrc/multi_tensor_l2norm_kernel.cu", - "csrc/multi_tensor_l2norm_kernel_mp.cu", - "csrc/multi_tensor_l2norm_scale_kernel.cu", - "csrc/multi_tensor_lamb_stage_1.cu", - "csrc/multi_tensor_lamb_stage_2.cu", - "csrc/multi_tensor_adam.cu", - "csrc/multi_tensor_adagrad.cu", - "csrc/multi_tensor_novograd.cu", - "csrc/multi_tensor_lamb.cu", - "csrc/multi_tensor_lamb_mp.cu", - ], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads( - [ - "-lineinfo", - "-O3", - # '--resource-usage', - "--use_fast_math", - ] - + version_dependent_macros - ), - }, - ) - ) - ext_modules.append( - CUDAExtension( - name="syncbn", - sources=["csrc/syncbn.cpp", "csrc/welford.cu"], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros), - }, - ) - ) +def get_env_if_set(key, default: typing.Any = ""): + """ + Returns an environment variable if it is set and not "", + otherwise returns a default value. In contrast, the fallback + parameter of os.environ.get() is skipped if the variable is set to "". + """ + return os.environ.get(key, None) or default - ext_modules.append( - CUDAExtension( - name="fused_layer_norm_cuda", - sources=["csrc/layer_norm_cuda.cpp", "csrc/layer_norm_cuda_kernel.cu"], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads(["-maxrregcount=50", "-O3", "--use_fast_math"] + version_dependent_macros), - }, - ) - ) +def command_exists(cmd): + if sys.platform == "win32": + safe_cmd = shlex.split(f'{cmd}') + result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE) + return result.wait() == 1 + else: + safe_cmd = shlex.split(f"bash -c type {cmd}") + result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE) + return result.wait() == 0 + +BUILD_OP_DEFAULT = 0 +BUILD_CPP_OPS = int(get_env_if_set('APEX_BUILD_CPP_OPS', BUILD_OP_DEFAULT)) +BUILD_CUDA_OPS = int(get_env_if_set('APEX_BUILD_CUDA_OPS', BUILD_OP_DEFAULT)) +build_flags = { + "APEX_BUILD_CPP_OPS" : BUILD_CPP_OPS, + "APEX_BUILD_CUDA_OPS" : BUILD_CUDA_OPS, + } + +if BUILD_CPP_OPS or BUILD_CUDA_OPS: + if TORCH_MAJOR == 0: + raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, " + "found torch.__version__ = {}".format(torch.__version__) + ) - ext_modules.append( - CUDAExtension( - name="mlp_cuda", - sources=["csrc/mlp.cpp", "csrc/mlp_cuda.cu"], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros), - }, - ) - ) - ext_modules.append( - CUDAExtension( - name="fused_dense_cuda", - sources=["csrc/fused_dense.cpp", "csrc/fused_dense_cuda.cu"], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros), - }, - ) - ) +def is_env_set(key): + """ + Checks if an environment variable is set and not "". + """ + return bool(os.environ.get(key, None)) - ext_modules.append( - CUDAExtension( - name="scaled_upper_triang_masked_softmax_cuda", - sources=[ - "csrc/megatron/scaled_upper_triang_masked_softmax.cpp", - "csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu", - ], - include_dirs=[os.path.join(this_dir, "csrc")], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - ] - + version_dependent_macros - ), - }, - ) - ) +def get_op_build_env_name(op_name): + assert hasattr(ALL_OPS[op_name], 'BUILD_VAR'), \ + f"{op_name} is missing BUILD_VAR field" + return ALL_OPS[op_name].BUILD_VAR - ext_modules.append( - CUDAExtension( - name="scaled_masked_softmax_cuda", - sources=["csrc/megatron/scaled_masked_softmax.cpp", "csrc/megatron/scaled_masked_softmax_cuda.cu"], - include_dirs=[os.path.join(this_dir, "csrc")], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - ] - + version_dependent_macros - ), - }, - ) - ) - # Check, if CUDA11 is installed for compute capability 8.0 - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11: - cc_flag = [] - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if int(bare_metal_minor) > 0: - cc_flag.append("-gencode") - cc_flag.append("arch=compute_86,code=sm_86") - ext_modules.append( - CUDAExtension( - name="fused_weight_gradient_mlp_cuda", - include_dirs=[os.path.join(this_dir, "csrc")], - sources=[ - "csrc/megatron/fused_weight_gradient_dense.cpp", - "csrc/megatron/fused_weight_gradient_dense_cuda.cu", - "csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu", - ], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-gencode", - "arch=compute_70,code=sm_70", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - + version_dependent_macros - + cc_flag - ), - }, - ) - ) - -if "--permutation_search" in sys.argv: - sys.argv.remove("--permutation_search") - - if CUDA_HOME is None: - raise RuntimeError("--permutation_search was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - else: - cc_flag = ['-Xcompiler', '-fPIC', '-shared'] - ext_modules.append( - CUDAExtension(name='permutation_search_cuda', - sources=['apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu'], - include_dirs=[os.path.join(this_dir, 'apex', 'contrib', 'sparsity', 'permutation_search_kernels', 'CUDA_kernels')], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc':['-O3'] + version_dependent_macros + cc_flag})) - -if "--bnp" in sys.argv: - sys.argv.remove("--bnp") - raise_if_cuda_home_none("--bnp") - ext_modules.append( - CUDAExtension( - name="bnp", - sources=[ - "apex/contrib/csrc/groupbn/batch_norm.cu", - "apex/contrib/csrc/groupbn/ipc.cu", - "apex/contrib/csrc/groupbn/interface.cpp", - "apex/contrib/csrc/groupbn/batch_norm_add_relu.cu", - ], - include_dirs=[os.path.join(this_dir, "csrc")], - extra_compile_args={ - "cxx": [] + version_dependent_macros, - "nvcc": append_nvcc_threads( - [ - "-DCUDA_HAS_FP16=1", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - ] - + version_dependent_macros - ), - }, - ) - ) +def op_build_enabled(op_name): + env_var = get_op_build_env_name(op_name) + return int(get_env_if_set(env_var, BUILD_OP_DEFAULT)) -if "--xentropy" in sys.argv: - sys.argv.remove("--xentropy") - raise_if_cuda_home_none("--xentropy") - ext_modules.append( - CUDAExtension( - name="xentropy_cuda", - sources=["apex/contrib/csrc/xentropy/interface.cpp", "apex/contrib/csrc/xentropy/xentropy_kernel.cu"], - include_dirs=[os.path.join(this_dir, "csrc")], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros), - }, - ) - ) +def is_op_build_included(op_name): + #check if operation has BUILD_FLAG defined + assert hasattr(ALL_OPS[op_name], 'INCLUDE_FLAG'), \ + f"{op_name} is missing INCLUDE_FLAG field" + include_flag = ALL_OPS[op_name].INCLUDE_FLAG + return get_env_if_set(include_flag, False) -if "--focal_loss" in sys.argv: - sys.argv.remove("--focal_loss") - raise_if_cuda_home_none("--focal_loss") - ext_modules.append( - CUDAExtension( - name='focal_loss_cuda', - sources=[ - 'apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp', - 'apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu', - ], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args={ - 'cxx': ['-O3'] + version_dependent_macros, - 'nvcc':['-O3', '--use_fast_math', '--ftz=false'] + version_dependent_macros, - }, - ) - ) +ext_modules = [] +install_ops = dict.fromkeys(ALL_OPS.keys(), False) + +for op_name, builder in ALL_OPS.items(): + op_compatible = builder.is_compatible() + build_enabled = op_build_enabled(op_name) or is_op_build_included(op_name) + + # If op is requested but not available, throw an error. + if build_enabled and not op_compatible: + env_var = get_op_build_env_name(op_name) + builder.warning(f"Skip pre-compile of incompatible {op_name}; One can disable {op_name} with {env_var}=0") + continue + + # If op is compatible but install is not build enabled (JIT mode). + if IS_ROCM_PYTORCH and op_compatible and not build_enabled: + builder.hipify_extension() + + # If op build enabled, add builder to extensions. + # Also check if corresponding flags are checked + if build_enabled and op_compatible: + install_ops[op_name] = True + ext_modules.append(builder.builder()) + +print(f'Install Ops={install_ops}') + +# Write out version/git info. +git_hash_cmd = shlex.split("bash -c \"git rev-parse --short HEAD\"") +git_branch_cmd = shlex.split("bash -c \"git rev-parse --abbrev-ref HEAD\"") +if command_exists('git') and not is_env_set('APEX_BUILD_STRING'): + try: + result = subprocess.check_output(git_hash_cmd) + git_hash = result.decode('utf-8').strip() + result = subprocess.check_output(git_branch_cmd) + git_branch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + git_hash = "unknown" + git_branch = "unknown" +else: + git_hash = "unknown" + git_branch = "unknown" + +# Parse the apex version string from version.txt. +version_str = get_apex_version() +version_str += f'+{git_hash}' + +torch_version = ".".join([str(TORCH_MAJOR), str(TORCH_MINOR)]) +bf16_support = False +# Set cuda_version to 0.0 if cpu-only. +cuda_version = "0.0" +nccl_version = "0.0" +# Set hip_version to 0.0 if cpu-only. +hip_version = "0.0" +if torch.version.cuda is not None: + cuda_version = ".".join(torch.version.cuda.split('.')[:2]) + if sys.platform != "win32": + if isinstance(torch.cuda.nccl.version(), int): + # This will break if minor version > 9. + nccl_version = ".".join(str(torch.cuda.nccl.version())[:2]) + else: + nccl_version = ".".join(map(str, torch.cuda.nccl.version()[:2])) + if hasattr(torch.cuda, 'is_bf16_supported') and torch.cuda.is_available(): + bf16_support = torch.cuda.is_bf16_supported() +if hasattr(torch.version, 'hip') and torch.version.hip is not None: + hip_version = ".".join(torch.version.hip.split('.')[:2]) +torch_info = { + "version": torch_version, + "bf16_support": bf16_support, + "cuda_version": cuda_version, + "nccl_version": nccl_version, + "hip_version": hip_version +} + +print(f"version={version_str}, git_hash={git_hash}, git_branch={git_branch}") +with open('apex/git_version_info_installed.py', 'w') as fd: + fd.write(f"version='{version_str}'\n") + fd.write(f"git_hash='{git_hash}'\n") + fd.write(f"git_branch='{git_branch}'\n") + fd.write(f"installed_ops={install_ops}\n") + fd.write(f"build_flags={build_flags}\n") + fd.write(f"torch_info={torch_info}\n") -if "--deprecated_fused_adam" in sys.argv: - sys.argv.remove("--deprecated_fused_adam") - raise_if_cuda_home_none("--deprecated_fused_adam") - ext_modules.append( - CUDAExtension( - name="fused_adam_cuda", - sources=[ - "apex/contrib/csrc/optimizers/fused_adam_cuda.cpp", - "apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu", - ], - include_dirs=[os.path.join(this_dir, "csrc")], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads(["-O3", "--use_fast_math"] + version_dependent_macros), - }, - ) - ) +if "--cpp_ext" in sys.argv: + sys.argv.remove("--cpp_ext") -if "--deprecated_fused_lamb" in sys.argv: - sys.argv.remove("--deprecated_fused_lamb") - raise_if_cuda_home_none("--deprecated_fused_lamb") - ext_modules.append( - CUDAExtension( - name="fused_lamb_cuda", - sources=[ - "apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp", - "apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu", - "csrc/multi_tensor_l2norm_kernel.cu", - ], - include_dirs=[os.path.join(this_dir, "csrc")], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads(["-O3", "--use_fast_math"] + version_dependent_macros), - }, - ) - ) +if "--cuda_ext" in sys.argv: + sys.argv.remove("--cuda_ext") -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -if "--fast_layer_norm" in sys.argv: - sys.argv.remove("--fast_layer_norm") - raise_if_cuda_home_none("--fast_layer_norm") - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11: - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - - ext_modules.append( - CUDAExtension( - name="fast_layer_norm", - sources=[ - "apex/contrib/csrc/layer_norm/ln_api.cpp", - "apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu", - "apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu", - ], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-gencode", - "arch=compute_70,code=sm_70", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "-I./apex/contrib/csrc/layer_norm/", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - + version_dependent_macros - + generator_flag - + cc_flag - ), - }, - include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")], - ) - ) +with open('requirements.txt') as f: + required = f.read().splitlines() -if "--fmha" in sys.argv: - sys.argv.remove("--fmha") - raise_if_cuda_home_none("--fmha") - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) < 11: - raise RuntimeError("--fmha only supported on SM80") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - - ext_modules.append( - CUDAExtension( - name="fmhalib", - sources=[ - "apex/contrib/csrc/fmha/fmha_api.cpp", - "apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu", - "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu", - "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu", - "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu", - "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu", - "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu", - "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu", - "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu", - "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu", - ], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - + version_dependent_macros - + generator_flag - + cc_flag - ), - }, - include_dirs=[ - os.path.join(this_dir, "apex/contrib/csrc"), - os.path.join(this_dir, "apex/contrib/csrc/fmha/src"), - ], - ) - ) +# Find python files in compatibility folder +compatibility_dir = os.path.join(this_dir, 'compatibility') +py_modules = [] +if os.path.exists(compatibility_dir): + for file in os.listdir(compatibility_dir): + if file.endswith('.py') and file != '__init__.py': + module_name = f"{file[:-3]}" + py_modules.append(module_name) -if "--fast_multihead_attn" in sys.argv: - sys.argv.remove("--fast_multihead_attn") - raise_if_cuda_home_none("--fast_multihead_attn") - - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11: - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if int(bare_metal_minor) > 0: - cc_flag.append("-gencode") - cc_flag.append("arch=compute_86,code=sm_86") - - subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) - ext_modules.append( - CUDAExtension( - name="fast_multihead_attn", - sources=[ - "apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp", - "apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu", - "apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu", - "apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu", - "apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu", - "apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu", - "apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu", - "apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu", - "apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu", - ], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-gencode", - "arch=compute_70,code=sm_70", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - + version_dependent_macros - + generator_flag - + cc_flag - ), - }, - include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")], - ) - ) + #copy outside temporarily + src_file = os.path.join(compatibility_dir, file) + dst_file = os.path.join(this_dir, file) + shutil.copy2(src_file, dst_file) +else: + print("Warning: compatibility folder not found") -if "--transducer" in sys.argv: - sys.argv.remove("--transducer") - raise_if_cuda_home_none("--transducer") - ext_modules.append( - CUDAExtension( - name="transducer_joint_cuda", - sources=[ - "apex/contrib/csrc/transducer/transducer_joint.cpp", - "apex/contrib/csrc/transducer/transducer_joint_kernel.cu", - ], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros + generator_flag, - "nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros + generator_flag), - }, - include_dirs=[os.path.join(this_dir, "csrc"), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")], - ) - ) - ext_modules.append( - CUDAExtension( - name="transducer_loss_cuda", - sources=[ - "apex/contrib/csrc/transducer/transducer_loss.cpp", - "apex/contrib/csrc/transducer/transducer_loss_kernel.cu", - ], - include_dirs=[os.path.join(this_dir, "csrc")], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros), - }, - ) - ) +class BinaryDistribution(Distribution): + """Force wheel to be platform-specific even without ext_modules.""" + def has_ext_modules(self): + return True -# note (mkozuki): Now `--fast_bottleneck` option (i.e. apex/contrib/bottleneck) depends on `--peer_memory` and `--nccl_p2p`. -if "--fast_bottleneck" in sys.argv: - sys.argv.remove("--fast_bottleneck") - raise_if_cuda_home_none("--fast_bottleneck") - if check_cudnn_version_and_warn("--fast_bottleneck", 8400): - subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"]) - ext_modules.append( - CUDAExtension( - name="fast_bottleneck", - sources=["apex/contrib/csrc/bottleneck/bottleneck.cpp"], - include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")], - extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag}, - ) - ) - -if "--peer_memory" in sys.argv: - sys.argv.remove("--peer_memory") - raise_if_cuda_home_none("--peer_memory") - ext_modules.append( - CUDAExtension( - name="peer_memory_cuda", - sources=[ - "apex/contrib/csrc/peer_memory/peer_memory_cuda.cu", - "apex/contrib/csrc/peer_memory/peer_memory.cpp", - ], - extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag}, - ) - ) -if "--nccl_p2p" in sys.argv: - sys.argv.remove("--nccl_p2p") - raise_if_cuda_home_none("--nccl_p2p") - ext_modules.append( - CUDAExtension( - name="nccl_p2p_cuda", - sources=[ - "apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu", - "apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp", - ], - extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag}, - ) - ) +# Resolve symlinks for packaging - auto-detect symlinks in apex folder +def resolve_symlinks_in_dir(base_dir): + """Find and resolve all symlink directories inside a directory.""" + symbolic_link_folders = [] + for entry in os.listdir(base_dir): + entry_path = os.path.join(base_dir, entry) + if os.path.islink(entry_path) and os.path.isdir(os.path.realpath(entry_path)): + target = os.path.realpath(entry_path) + symbolic_link_folders.append([entry_path, target]) -if "--fused_conv_bias_relu" in sys.argv: - sys.argv.remove("--fused_conv_bias_relu") - raise_if_cuda_home_none("--fused_conv_bias_relu") - if check_cudnn_version_and_warn("--fused_conv_bias_relu", 8400): - subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"]) - ext_modules.append( - CUDAExtension( - name="fused_conv_bias_relu", - sources=["apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"], - include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")], - extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag}, - ) - ) + print(f"Symbolic link folders: {symbolic_link_folders}") + for entry_path, target in symbolic_link_folders: + print(f"Resolving symlink {entry_path} -> {target}") + os.unlink(entry_path) + shutil.copytree(target, entry_path) + +resolve_symlinks_in_dir(os.path.join(this_dir, 'apex')) setup( name="apex", - version="0.1", + version=get_apex_version(), packages=find_packages( - exclude=("build", "csrc", "include", "tests", "dist", "docs", "tests", "examples", "apex.egg-info",) + exclude=("build", "include", "tests", "dist", "docs", "tests", "examples", "apex.egg-info", "op_builder", "compatibility") ), description="PyTorch Extensions written by NVIDIA", ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, + cmdclass={'build_ext': BuildExtension} if ext_modules else {}, extras_require=extras, + install_requires=required, + include_package_data=True, + py_modules=py_modules, + distclass=BinaryDistribution ) + +#delete the temporarily copied compatibility files +for py_module in py_modules: + path = dst_file = os.path.join(this_dir, py_module + ".py") + if os.path.exists(path): + os.remove(path) \ No newline at end of file diff --git a/tests/L0/run_amp/test_add_param_group.py b/tests/L0/run_amp/test_add_param_group.py index d3e90c433..62f775349 100644 --- a/tests/L0/run_amp/test_add_param_group.py +++ b/tests/L0/run_amp/test_add_param_group.py @@ -11,14 +11,14 @@ from torch.nn import Parameter from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT + ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT, common_reset class MyModel(torch.nn.Module): - def __init__(self, unique): + def __init__(self, unique, dtype=torch.float16): super(MyModel, self).__init__() self.weight0 = Parameter(unique + torch.arange(2, device='cuda', dtype=torch.float32)) - self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16)) + self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=dtype)) @staticmethod def ops(input, weight0, weight1): @@ -37,7 +37,7 @@ def setUp(self): common_init(self) def tearDown(self): - pass + common_reset(self) def zero_grad(self, models, optimizer, how_to_zero): if how_to_zero == "none": @@ -51,11 +51,15 @@ def zero_grad(self, models, optimizer, how_to_zero): optimizer.zero_grad() def test_add_param_group(self): - for opt_level in ("O0", "O1", "O2", "O3"): + for opt_level in ("O0", "O1", "O2", "O3", "O4", "O5"): for zero_before_add in (True, False): for try_accumulation in (True, False): - model0 = MyModel(1) - model1 = MyModel(2) + if opt_level in {"O4", "O5"}: + model0 = MyModel(1, torch.bfloat16) + model1 = MyModel(2, torch.bfloat16) + else: + model0 = MyModel(1) + model1 = MyModel(2) optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], momentum=0.125) @@ -89,8 +93,12 @@ def test_add_param_group(self): [param.data.clone() for param in model1.parameters()] for how_to_zero in "none", "model", "optimizer": - model0 = MyModel(1) - model1 = MyModel(2) + if opt_level in {"O4", "O5"}: + model0 = MyModel(1, torch.bfloat16) + model1 = MyModel(2, torch.bfloat16) + else: + model0 = MyModel(1) + model1 = MyModel(2) optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], momentum=0.125) @@ -139,10 +147,15 @@ def test_add_param_group(self): [param.data.clone() for param in model1.parameters()] for reference, final in zip(reference_params, final_params): + # TODO: remove the conversion once allclose supports bfloat16 type. + if final.dtype == torch.bfloat16: + final = final.float() self.assertTrue(torch.allclose(reference.to(final.dtype), final), "opt_level = {}, how_to_zero = {}, zero_before_add = {}".format( opt_level, how_to_zero, zero_before_add)) + if opt_level != "O0": + _amp_state.handle._deactivate() if __name__ == '__main__': unittest.main() diff --git a/tests/L0/run_amp/test_basic_casts.py b/tests/L0/run_amp/test_basic_casts.py index 5d4d81d1a..7ec254e42 100644 --- a/tests/L0/run_amp/test_basic_casts.py +++ b/tests/L0/run_amp/test_basic_casts.py @@ -9,7 +9,9 @@ import torch.nn.functional as F from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT + ALWAYS_HALF, ALWAYS_BFLOAT16, ALWAYS_FLOAT, MATCH_INPUT, common_reset + +from apex.testing.common_utils import skipIfRocm def run_layer_test(test_case, fns, expected, input_shape, test_backward=True): for fn, typ in it.product(fns, expected.keys()): @@ -20,124 +22,242 @@ def run_layer_test(test_case, fns, expected, input_shape, test_backward=True): y.float().sum().backward() test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ]) -class TestBasicCasts(unittest.TestCase): - def setUp(self): - self.handle = amp.init(enabled=True) - common_init(self) - - def tearDown(self): - self.handle._deactivate() - - def test_linear_is_half(self): +class _TestBasicCasts(unittest.TestCase): + def _test_linear(self, expected): m = nn.Linear(self.h, self.h) f = ft.partial(F.linear, weight=m.weight, bias=m.bias) - run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h)) + run_layer_test(self, [m, f], expected, (self.b, self.h)) - def test_conv2d_is_half(self): + def _test_conv2d(self, expected): m = nn.Conv2d(self.c, self.c, self.k) f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias) - run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h)) + run_layer_test(self, [m, f], expected, (self.b, self.c, self.h, self.h)) - def test_softmax_is_float(self): + def _test_softmax(self, expected): m = nn.Softmax(dim=1) f = ft.partial(F.softmax, dim=1) - run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h)) + run_layer_test(self, [m, f], expected, (self.b, self.h)) - def test_group_norm_is_float(self): + def _test_group_norm(self, expected): m = nn.GroupNorm(num_groups=4, num_channels=self.c) - run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h)) + run_layer_test(self, [m], expected, (self.b, self.c, self.h, self.h)) - def test_mse_loss_is_float(self): + def _test_mse_loss(self, expected): shape = (self.b, self.h) target = torch.randn(shape) mod = nn.MSELoss() m = lambda x: mod(x, target) f = ft.partial(F.mse_loss, target=target) - run_layer_test(self, [m], ALWAYS_FLOAT, shape) + run_layer_test(self, [m], expected, shape) - def test_relu_is_match(self): - run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h)) + def _test_relu(self, expected): + run_layer_test(self, [nn.ReLU(), F.relu], expected, (self.b, self.h)) - def test_batch_norm_is_match(self): + def _test_batch_norm(self, expected): m = nn.BatchNorm2d(num_features=self.c) f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var, weight=m.weight, bias=m.bias, training=True) - run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h)) + run_layer_test(self, [m], expected, (self.b, self.c, self.h, self.h)) # Test forward-only for BN inference m.eval() f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var, weight=m.weight, bias=m.bias, training=False) - run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h), + run_layer_test(self, [m, f], expected, (self.b, self.c, self.h, self.h), test_backward=False) +class TestBasicCastsHalf(_TestBasicCasts): + def setUp(self): + self.handle = amp.init(enabled=True, patch_type=torch.half) + common_init(self) + + def tearDown(self): + self.handle._deactivate() + common_reset(self) + + def test_linear_is_half(self): + self._test_linear(ALWAYS_HALF) + + def test_conv2d_is_half(self): + self._test_conv2d(ALWAYS_HALF) + + def test_softmax_is_float(self): + self._test_softmax(ALWAYS_FLOAT) + + def test_group_norm_is_float(self): + self._test_group_norm(ALWAYS_FLOAT) + + def test_mse_loss_is_float(self): + self._test_mse_loss(ALWAYS_FLOAT) + + def test_relu_is_match(self): + self._test_relu(MATCH_INPUT) + + def test_batch_norm_is_match(self): + self._test_batch_norm(MATCH_INPUT) + +class TestBasicCastsBFloat16(_TestBasicCasts): + def setUp(self): + self.handle = amp.init(enabled=True, patch_type=torch.bfloat16) + common_init(self) + + def tearDown(self): + self.handle._deactivate() + common_reset(self) + + @skipIfRocm + def test_linear_is_bfloat16(self): + self._test_linear(ALWAYS_BFLOAT16) + + @skipIfRocm + def test_conv2d_is_bfloat16(self): + self._test_conv2d(ALWAYS_BFLOAT16) + + def test_softmax_is_float(self): + self._test_softmax(ALWAYS_FLOAT) + + def test_group_norm_is_float(self): + self._test_group_norm(ALWAYS_FLOAT) + + def test_mse_loss_is_float(self): + self._test_mse_loss(ALWAYS_FLOAT) + + def test_relu_is_match(self): + self._test_relu(MATCH_INPUT) + + def test_batch_norm_is_match(self): + self._test_batch_norm(MATCH_INPUT) + class TestBannedMethods(unittest.TestCase): def setUp(self): - self.handle = amp.init(enabled=True) + self.handle = amp.init(enabled=True, patch_type=torch.half) common_init(self) def tearDown(self): self.handle._deactivate() + common_reset(self) - def bce_common(self, assertion): + def bce_common(self, assertion, dtype=torch.half): shape = (self.b, self.h) target = torch.rand(shape) mod = nn.BCELoss() m = lambda x: mod(x, target) f = ft.partial(F.binary_cross_entropy, target=target) for fn in [m, f]: - x = torch.rand(shape, dtype=torch.half) + x = torch.rand(shape, dtype=dtype) assertion(fn, x) def test_bce_raises_by_default(self): assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x) - self.bce_common(assertion) + self.bce_common(assertion, dtype=torch.half) + + # handle with bfloat16 as patch_type + self.handle._deactivate() + self.handle = amp.init(enabled=True, patch_type=torch.bfloat16) + self.bce_common(assertion, dtype=torch.bfloat16) def test_bce_is_float_with_allow_banned(self): self.handle._deactivate() - self.handle = amp.init(enabled=True, allow_banned=True) + self.handle = amp.init(enabled=True, allow_banned=True, patch_type=torch.half) assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT) - self.bce_common(assertion) + self.bce_common(assertion, dtype=torch.half) -class TestTensorCasts(unittest.TestCase): - def setUp(self): - self.handle = amp.init(enabled=True) - common_init(self) - - def tearDown(self): + # handle with bfloat16 as patch_type self.handle._deactivate() + self.handle = amp.init(enabled=True, allow_banned=True, patch_type=torch.bfloat16) + self.bce_common(assertion, dtype=torch.bfloat16) - def test_matmul_method_is_half(self): +class _TestTensorCasts(unittest.TestCase): + def _test_matmul_method(self, expected): other = torch.randn(self.h, self.h) lhs = lambda x: x.matmul(other) rhs = lambda x: other.matmul(x) - run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h)) + run_layer_test(self, [lhs, rhs], expected, (self.h, self.h)) - def test_matmul_op_is_half(self): + def _test_matmul_op(self, expected): other = torch.randn(self.h, self.h) lhs = lambda x: x @ other rhs = lambda x: other @ x - run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h)) + run_layer_test(self, [lhs, rhs], expected, (self.h, self.h)) - def test_pow_method_is_float(self): + def _test_pow_method(self, expected): fn = lambda x: x.pow(2.) - run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) + run_layer_test(self, [fn], expected, (self.b, self.h)) - def test_pow_op_is_float(self): + def _test_pow_op(self, expected): fn = lambda x: x ** 2. - run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) + run_layer_test(self, [fn], expected, (self.b, self.h)) - def test_cpu_is_float(self): + def _test_cpu(self, expected): fn = lambda x: x.cpu() + run_layer_test(self, [fn], expected, (self.b, self.h)) + + def _test_sum(self, expected): + fn = lambda x: x.sum() + run_layer_test(self, [fn], expected, (self.b, self.h)) + + # TODO: maybe more tests on disabled casting? + +class TestTensorCastsHalf(_TestTensorCasts): + def setUp(self): + self.handle = amp.init(enabled=True, patch_type=torch.half) + common_init(self) + + def tearDown(self): + self.handle._deactivate() + common_reset(self) + + def test_matmul_method_is_half(self): + self._test_matmul_method(ALWAYS_HALF) + + def test_matmul_op_is_half(self): + self._test_matmul_op(ALWAYS_HALF) + + def test_pow_method_is_float(self): + self._test_pow_method(ALWAYS_FLOAT) + + def test_pow_op_is_float(self): + self._test_pow_op(ALWAYS_FLOAT) + + def test_cpu_is_float(self): always_cpu_float = {torch.float: 'torch.FloatTensor', torch.half: 'torch.FloatTensor'} - run_layer_test(self, [fn], always_cpu_float, (self.b, self.h)) + self._test_cpu(always_cpu_float) def test_sum_is_float(self): - fn = lambda x: x.sum() - run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) + self._test_sum(ALWAYS_FLOAT) + +class TestTensorCastsBFloat16(_TestTensorCasts): + def setUp(self): + self.handle = amp.init(enabled=True, patch_type=torch.bfloat16) + common_init(self) + + def tearDown(self): + self.handle._deactivate() + common_reset(self) + + @skipIfRocm + def test_matmul_method_is_bfloat16(self): + self._test_matmul_method(ALWAYS_BFLOAT16) + + @skipIfRocm + def test_matmul_op_is_bfloat16(self): + self._test_matmul_op(ALWAYS_BFLOAT16) + + def test_pow_method_is_float(self): + self._test_pow_method(ALWAYS_FLOAT) + + def test_pow_op_is_float(self): + self._test_pow_op(ALWAYS_FLOAT) + + def test_cpu_is_float(self): + always_cpu_float = {torch.float: 'torch.FloatTensor', + torch.bfloat16: 'torch.FloatTensor'} + self._test_cpu(always_cpu_float) + + def test_sum_is_float(self): + self._test_sum(ALWAYS_FLOAT) - # TODO: maybe more tests on disabled casting? if __name__ == '__main__': unittest.main() diff --git a/tests/L0/run_amp/test_cache.py b/tests/L0/run_amp/test_cache.py index b58d2665f..c5b33ade0 100644 --- a/tests/L0/run_amp/test_cache.py +++ b/tests/L0/run_amp/test_cache.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT + ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT, common_reset def get_reference_grad(i, w, ops): # Creating new tensors ensures, among other things, that the new tensors are not in the cache. @@ -65,14 +65,14 @@ def setUp(self): common_init(self) def tearDown(self): - pass + common_reset(self) - def train_eval_train_test(self, module, t): + def train_eval_train_test(self, module, t, opt_level): model = module(t).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=1.0) _amp_state.allow_incoming_model_not_fp32 = True - model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) + model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level, verbosity=0) _amp_state.allow_incoming_model_not_fp32 = False def training_step(): @@ -93,6 +93,8 @@ def training_step(): # but I'm keeping this in case we want different tolerances for fp16 and fp32 checks. if model.weight.grad.type() == "torch.cuda.HalfTensor": self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad)) + elif model.weight.grad.type() == "torch.cuda.BFloat16Tensor": + self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad)) elif model.weight.grad.type() == "torch.cuda.FloatTensor": self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad)) else: @@ -115,22 +117,41 @@ def training_step(): # I could easily have these as a set of for loops in a single test, # instead of going for granularity. def test_whitelist_module_fp16_weight(self): - self.train_eval_train_test(WhitelistModule, torch.float16) + self.train_eval_train_test(WhitelistModule, torch.float16, "O1") def test_whitelist_module_fp32_weight(self): - self.train_eval_train_test(WhitelistModule, torch.float32) + self.train_eval_train_test(WhitelistModule, torch.float32, "O1") def test_blacklist_module_fp16_weight(self): - self.train_eval_train_test(BlacklistModule, torch.float16) + self.train_eval_train_test(BlacklistModule, torch.float16, "O1") def test_blacklist_module_fp32_weight(self): - self.train_eval_train_test(BlacklistModule, torch.float32) + self.train_eval_train_test(BlacklistModule, torch.float32, "O1") def test_promote_module_fp16_weight(self): - self.train_eval_train_test(PromoteModule, torch.float16) + self.train_eval_train_test(PromoteModule, torch.float16, "O1") + + def test_promote_module_fp32_weight(self): + self.train_eval_train_test(PromoteModule, torch.float32, "O1") + + # opt_level = O4 + def test_whitelist_module_bfp16_weight(self): + self.train_eval_train_test(WhitelistModule, torch.bfloat16, "O4") + + def test_whitelist_module_fp32_weight(self): + self.train_eval_train_test(WhitelistModule, torch.float32, "O4") + + def test_blacklist_module_bfp16_weight(self): + self.train_eval_train_test(BlacklistModule, torch.bfloat16, "O4") + + def test_blacklist_module_fp32_weight(self): + self.train_eval_train_test(BlacklistModule, torch.float32, "O4") + + def test_promote_module_bfp16_weight(self): + self.train_eval_train_test(PromoteModule, torch.bfloat16, "O4") def test_promote_module_fp32_weight(self): - self.train_eval_train_test(PromoteModule, torch.float32) + self.train_eval_train_test(PromoteModule, torch.float32, "O4") if __name__ == '__main__': diff --git a/tests/L0/run_amp/test_checkpointing.py b/tests/L0/run_amp/test_checkpointing.py index 921985cd7..ff7ee884d 100644 --- a/tests/L0/run_amp/test_checkpointing.py +++ b/tests/L0/run_amp/test_checkpointing.py @@ -7,9 +7,9 @@ from apex import amp - from utils import common_init, FLOAT - +from apex.testing.common_utils import skipFlakyTest +from apex.amp import _amp_state class MyModel(torch.nn.Module): def __init__(self): @@ -28,7 +28,7 @@ def forward(self, x): class TestCheckpointing(unittest.TestCase): def setUp(self): self.initial_lr = 1e-3 - self.test_opt_levels = ("O0", "O1", "O2", "O3") + self.test_opt_levels = ("O0", "O1", "O2", "O3", "O4", "O5") def seed(self): torch.manual_seed(2809) @@ -44,7 +44,7 @@ def check_state_dict_fp32(self, state_dict): 'Parameter in state_dict not FLOAT') def train_step(self, model, optimizer, data, loss_ids): - optimizer.zero_grad() + optimizer.zero_grad() output = model(data) @@ -69,7 +69,7 @@ def compare_models(self, modelA, modelB, test_setup=''): msg='Parameters in state_dices not equal.' + 'key: {}\nparam: {}\nrestored: {}\ndiff: {} for {}'.format( key, paramA, paramB, paramA - paramB, test_setup)) - + def test_restoring(self): nb_epochs = 10 nb_epochs_restore = nb_epochs // 2 @@ -102,12 +102,12 @@ def test_restoring(self): if opt_level == res_opt_level: # train for nb_epochs and restore after nb_epochs_restore for epoch in range(nb_epochs): - + x = torch.randn(16, 3, 24, 24, device='cuda') output = self.train_step( model, optimizer, x, range(num_losses)) # Initialize model one step before comparing. - # Otherwise the batchnorm layers will be updated + # Otherwise the batchnorm layers will be updated # additionally in restore_model if epoch == (nb_epochs_restore - 1): # Load model and optimizer @@ -126,6 +126,7 @@ def test_restoring(self): lr=self.initial_lr) if amp_before_load: + _amp_state.handle._deactivate() restore_model, restore_optimizer = amp.initialize( restore_model, restore_optimizer, @@ -139,6 +140,7 @@ def test_restoring(self): # amp.load_state_dict(checkpoint['amp']) if not amp_before_load: + _amp_state.handle._deactivate() restore_model, restore_optimizer = amp.initialize( restore_model, restore_optimizer, @@ -156,11 +158,14 @@ def test_restoring(self): torch.allclose(output.float(), restore_output.float()), 'Output of reference and restored models differ for ' + test_setup) self.compare_models(model, restore_model, test_setup) + _amp_state.handle._deactivate() # if opt_level != res_opt_level else: # skip tests for different opt_levels + _amp_state.handle._deactivate() continue + @skipFlakyTest def test_loss_scale_decrease(self): num_losses = 3 nb_decrease_loss_scales = [0, 1, 2] @@ -170,10 +175,10 @@ def test_loss_scale_decrease(self): nb_decrease_loss_scales_tmp = list(nb_decrease_loss_scales) model = MyModel().to('cuda') - + optimizer = optim.SGD(model.parameters(), lr=self.initial_lr) - + model, optimizer = amp.initialize( model, optimizer, opt_level=opt_level, num_losses=num_losses, verbosity=0) @@ -181,26 +186,26 @@ def test_loss_scale_decrease(self): if amp._amp_state.opt_properties.loss_scale != 'dynamic': #print('Static loss scale set. Skipping opt_level.') continue - + # force to skip some updates to decrease the loss_scale initial_loss_scales = [] for idx in range(num_losses): initial_loss_scales.append( amp._amp_state.loss_scalers[idx].loss_scale()) - + for _ in range(len(nb_decrease_loss_scales)): x = torch.randn(16, 3, 24, 24, device='cuda') for idx in range(num_losses): while nb_decrease_loss_scales_tmp[idx] > 0: optimizer.zero_grad() output = model(x * 2**17) - loss = output.mean() - + loss = output.mean() + with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss: scaled_loss.backward(retain_graph=True) optimizer.step() nb_decrease_loss_scales_tmp[idx] -= 1 - + # Check loss scales afterwards updated_loss_scales = [] for idx in range(num_losses): @@ -220,6 +225,11 @@ def test_loss_scale_decrease(self): self.assertEqual(scaler['loss_scale'], init_ls / 2**factor) unskipped_target = 0 self.assertEqual(scaler['unskipped'], unskipped_target) + + if opt_level != "O0": + _amp_state.handle._deactivate() + + def test_state_dict(self): for opt_level in self.test_opt_levels: @@ -228,7 +238,7 @@ def test_state_dict(self): continue model = MyModel().to('cuda') - optimizer = optim.Adam(model.parameters(), lr=1e-3) + optimizer = optim.Adam(model.parameters(), lr=1e-3, capturable=True) model, optimizer = amp.initialize( model, optimizer, opt_level=opt_level, verbosity=0) @@ -236,12 +246,13 @@ def test_state_dict(self): state_dict = model.state_dict() for key in state_dict: self.assertFalse('Half' in state_dict[key].type()) + self.assertFalse('BFloat16' in state_dict[key].type()) # Check, if model is still trainable # Create dummy data data = torch.randn(10, 3, 4, 4, device='cuda') target = torch.randn(10, 6, 4, 4, device='cuda') - + # Get initnial loss optimizer.zero_grad() output = model(data) @@ -262,6 +273,10 @@ def test_state_dict(self): self.assertTrue(loss.item() < last_loss) last_loss = loss.item() + if opt_level != "O0": + _amp_state.handle._deactivate() + + if __name__=='__main__': unittest.main() - + diff --git a/tests/L0/run_amp/test_fused_sgd.py b/tests/L0/run_amp/test_fused_sgd.py index 7f592128d..480cd1132 100644 --- a/tests/L0/run_amp/test_fused_sgd.py +++ b/tests/L0/run_amp/test_fused_sgd.py @@ -13,7 +13,6 @@ from utils import common_init, HALF, FLOAT,\ ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT - try: import amp_C disabled = False @@ -181,7 +180,7 @@ def test_2models2losses1optimizer(self): self.assertTrue(torch.allclose(model, reference)) self.assertTrue(torch.allclose(model, master.to(model.dtype))) - if opt_level == "O1": + if opt_level != "O0": _amp_state.handle._deactivate() @unittest.skipIf(disabled, "amp_C is unavailable") @@ -342,7 +341,7 @@ def test_3models2losses1optimizer(self): self.assertTrue(torch.allclose(model, reference)) self.assertTrue(torch.allclose(model, master.to(model.dtype))) - if opt_level == "O1": + if opt_level != "O0": _amp_state.handle._deactivate() @unittest.skipIf(disabled, "amp_C is unavailable") @@ -537,7 +536,7 @@ def what_got_skipped(which_iter, which_backward): self.assertTrue(torch.allclose(model, reference)) self.assertTrue(torch.allclose(model, master.to(model.dtype))) - if opt_level == "O1": + if opt_level != "O0": _amp_state.handle._deactivate() @unittest.skipIf(disabled, "amp_C is unavailable") @@ -787,7 +786,7 @@ def what_got_skipped(which_iter, which_backward, which_model): self.assertTrue(torch.allclose(model, reference)) self.assertTrue(torch.allclose(model, master.to(model.dtype))) - if opt_level == "O1": + if opt_level != "O0": _amp_state.handle._deactivate() if __name__ == '__main__': diff --git a/tests/L0/run_amp/test_larc.py b/tests/L0/run_amp/test_larc.py index f4f3e838f..9dddfd93c 100644 --- a/tests/L0/run_amp/test_larc.py +++ b/tests/L0/run_amp/test_larc.py @@ -6,7 +6,8 @@ from apex import amp from apex.parallel.LARC import LARC -from utils import common_init +from utils import common_init, common_reset +from apex.amp import _amp_state class MyModel(torch.nn.Module): @@ -26,7 +27,7 @@ def setUp(self): common_init(self) def tearDown(self): - pass + common_reset(self) def test_larc_mixed_precision(self): for opt_level in ["O0", "O1", "O2", "O3"]: @@ -48,6 +49,10 @@ def test_larc_mixed_precision(self): scaled_loss.backward() optimizer.step() + if opt_level != "O0": + _amp_state.handle._deactivate() + + if __name__ == "__main__": unittest.main() diff --git a/tests/L0/run_amp/test_multi_tensor_axpby.py b/tests/L0/run_amp/test_multi_tensor_axpby.py index 0b439bb8d..4921378a2 100644 --- a/tests/L0/run_amp/test_multi_tensor_axpby.py +++ b/tests/L0/run_amp/test_multi_tensor_axpby.py @@ -10,7 +10,7 @@ from math import floor from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT + ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT, common_reset try: import amp_C @@ -35,11 +35,11 @@ def setUp(self): self.b = 8.0 self.xval = 4.0 self.yval = 16.0 - self.overflow_buf = torch.cuda.IntTensor(1).zero_() + self.overflow_buf = torch.tensor(1, dtype=torch.int, device='cuda').zero_() self.ref = torch.full((1,), 136.0, device="cuda", dtype=torch.float32) def tearDown(self): - pass + common_reset(self) # The tensor creation here is written for convenience, not speed. def axpby(self, sizea, sizeb, applier, repeat_tensors, @@ -69,7 +69,10 @@ def to_fmt(t, tp): applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b, -1) - self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]), + # TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16 + if out_type == torch.bfloat16: + out_list = [out.float() for out in out_list] + self.assertTrue(all([torch.allclose(out, self.ref.to(out.dtype)) for out in out_list]), msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors, x_type, y_type, out_type, inplace)) self.assertTrue(self.overflow_buf.item() == 0, @@ -119,9 +122,9 @@ def test_fuzz(self): for sizea, sizeb in input_size_pairs: for applier in appliers: for repeat in repeat_tensors: - for x_type in (torch.float32, torch.float16): - for y_type in (torch.float32, torch.float16): - for out_type in (torch.float32, torch.float16): + for x_type in (torch.float32, torch.float16, torch.bfloat16): + for y_type in (torch.float32, torch.float16, torch.bfloat16): + for out_type in (torch.float32, torch.float16, torch.bfloat16): for inplace in (True, False): if inplace is True and (y_type is not out_type): continue diff --git a/tests/L0/run_amp/test_multi_tensor_l2norm.py b/tests/L0/run_amp/test_multi_tensor_l2norm.py index ed3cbd195..bb28e52d2 100644 --- a/tests/L0/run_amp/test_multi_tensor_l2norm.py +++ b/tests/L0/run_amp/test_multi_tensor_l2norm.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT + ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT, common_reset try: import amp_C @@ -26,10 +26,10 @@ class TestMultiTensorL2Norm(unittest.TestCase): def setUp(self): common_init(self) self.val = 4.0 - self.overflow_buf = torch.cuda.IntTensor(1).zero_() + self.overflow_buf = torch.tensor(1, dtype=torch.int, device='cuda').zero_() def tearDown(self): - pass + common_reset(self) # The tensor creation here is written for convenience, not speed. def l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type, per_tensor): @@ -67,7 +67,7 @@ def test_fuzz(self): (33333, 555), (555, 33333)) appliers = ( - MultiTensorApply(2048*32), + MultiTensorApply(2048*32), MultiTensorApply(333), MultiTensorApply(33333)) repeat_tensors = ( diff --git a/tests/L0/run_amp/test_multi_tensor_scale.py b/tests/L0/run_amp/test_multi_tensor_scale.py index 22da2490c..f97109c9e 100644 --- a/tests/L0/run_amp/test_multi_tensor_scale.py +++ b/tests/L0/run_amp/test_multi_tensor_scale.py @@ -9,11 +9,11 @@ import torch.nn.functional as F from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT + ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT, common_reset try: import amp_C - from amp_C import multi_tensor_scale + from amp_C import multi_tensor_scale from apex.multi_tensor_apply import MultiTensorApply disabled = False except ImportError as err: @@ -26,11 +26,11 @@ class TestMultiTensorScale(unittest.TestCase): def setUp(self): common_init(self) self.scale = 4.0 - self.overflow_buf = torch.cuda.IntTensor(1).zero_() + self.overflow_buf = torch.tensor(1, dtype=torch.int, device='cuda').zero_() self.ref = torch.cuda.FloatTensor([1.0]) def tearDown(self): - pass + common_reset(self) # The tensor creation here is written for convenience, not speed. def downscale(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, inplace=False): @@ -49,9 +49,12 @@ def downscale(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, in applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale) - self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list])) + # TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16 + if out_type == torch.bfloat16: + out_list = [out.float() for out in out_list] + self.assertTrue(all([torch.allclose(out, self.ref.to(out.dtype)) for out in out_list])) self.assertTrue(self.overflow_buf.item() == 0) - + def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False): self.overflow_buf.zero_() a = torch.cuda.FloatTensor(sizea).fill_(self.scale) @@ -79,7 +82,7 @@ def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, # @unittest.skipIf(disabled, "amp_C is unavailable") # def test_fp16_to_fp16(self): # self.downscale(self.fp16, self.fp16, self.fp16_ref) - # + # # @unittest.skipIf(disabled, "amp_C is unavailable") # def test_fp32_to_fp16(self): # self.downscale(self.fp32, self.fp16, self.fp16_ref) @@ -96,7 +99,7 @@ def test_fuzz(self): (33333, 555), (555, 33333)) appliers = ( - MultiTensorApply(2048*32), + MultiTensorApply(2048*32), MultiTensorApply(333), MultiTensorApply(33333)) repeat_tensors = ( @@ -106,8 +109,8 @@ def test_fuzz(self): for sizea, sizeb in input_size_pairs: for applier in appliers: for repeat in repeat_tensors: - for in_type in (torch.float32, torch.float16): - for out_type in (torch.float32, torch.float16): + for in_type in (torch.float32, torch.float16, torch.bfloat16): + for out_type in (torch.float32, torch.float16, torch.bfloat16): for inplace in (True, False): if inplace is True and (out_type is not in_type): continue diff --git a/tests/L0/run_amp/test_multiple_models_optimizers_losses.py b/tests/L0/run_amp/test_multiple_models_optimizers_losses.py index 068c84537..78a144a7d 100644 --- a/tests/L0/run_amp/test_multiple_models_optimizers_losses.py +++ b/tests/L0/run_amp/test_multiple_models_optimizers_losses.py @@ -11,7 +11,7 @@ from torch.nn import Parameter from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT + ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT, common_reset class MyModel(torch.nn.Module): def __init__(self, unique): @@ -40,7 +40,7 @@ def setUp(self): common_init(self) def tearDown(self): - pass + common_reset(self) def test_2models2losses1optimizer(self): model0 = MyModel(1) @@ -164,7 +164,7 @@ def test_2models2losses1optimizer(self): self.assertTrue(torch.allclose(model, reference)) self.assertTrue(torch.allclose(model, master.to(model.dtype))) - if opt_level == "O1": + if opt_level != "O0": _amp_state.handle._deactivate() def test_3models2losses1optimizer(self): @@ -320,7 +320,7 @@ def test_3models2losses1optimizer(self): self.assertTrue(torch.allclose(model, reference)) self.assertTrue(torch.allclose(model, master.to(model.dtype))) - if opt_level == "O1": + if opt_level != "O0": _amp_state.handle._deactivate() def test_2models2losses2optimizers(self): @@ -510,7 +510,7 @@ def what_got_skipped(which_iter, which_backward): self.assertTrue(torch.allclose(model, reference)) self.assertTrue(torch.allclose(model, master.to(model.dtype))) - if opt_level == "O1": + if opt_level != "O0": _amp_state.handle._deactivate() def test_3models2losses2optimizers(self): @@ -755,7 +755,7 @@ def what_got_skipped(which_iter, which_backward, which_model): self.assertTrue(torch.allclose(model, reference)) self.assertTrue(torch.allclose(model, master.to(model.dtype))) - if opt_level == "O1": + if opt_level != "O0": _amp_state.handle._deactivate() if __name__ == '__main__': diff --git a/tests/L0/run_amp/test_promotion.py b/tests/L0/run_amp/test_promotion.py index f5ef30c12..9e308574c 100644 --- a/tests/L0/run_amp/test_promotion.py +++ b/tests/L0/run_amp/test_promotion.py @@ -7,18 +7,18 @@ from torch import nn import torch.nn.functional as F -from utils import common_init, HALF, FLOAT, DTYPES +from utils import common_init, HALF, FLOAT, DTYPES, DTYPES2, MATCH_INPUT, common_reset -class TestPromotion(unittest.TestCase): - def setUp(self): - self.handle = amp.init(enabled=True) - common_init(self) - - def tearDown(self): - self.handle._deactivate() - - def run_binary_promote_test(self, fns, input_shape, x_inplace=False): - type_pairs = it.product(DTYPES, DTYPES) +class _TestPromotion(unittest.TestCase): + def run_binary_promote_test(self, fns, input_shape, lp_type, x_inplace=False): + if lp_type == torch.half: + dtypes = DTYPES + elif lp_type == torch.bfloat16: + dtypes = DTYPES2 + else: + raise RuntimeError("Creating test class with invalid low_precision type. \ + Supported types are torch.half and torch.bfloat16") + type_pairs = it.product(dtypes, dtypes) for fn, (xtype, ytype) in it.product(fns, type_pairs): x = torch.randn(input_shape, dtype=xtype).requires_grad_() x_leaf = x @@ -35,41 +35,80 @@ def run_binary_promote_test(self, fns, input_shape, x_inplace=False): if xtype == torch.float or ytype == torch.float: self.assertEqual(out.type(), FLOAT) else: - self.assertEqual(out.type(), HALF) + self.assertEqual(out.type(), MATCH_INPUT[lp_type]) out.float().sum().backward() self.assertEqual(x_leaf.grad.dtype, xtype) + def _test_cat_matches_widest(self, lp_type): + shape = self.b + ys = [torch.randn(shape, dtype=lp_type) for _ in range(5)] + x_float = torch.randn(shape) + out = torch.cat(ys + [x_float]) + self.assertEqual(out.type(), FLOAT) + x_lp = torch.randn(shape, dtype=lp_type) + out = torch.cat(ys + [x_lp]) + self.assertEqual(out.type(), MATCH_INPUT[lp_type]) + + def _test_inplace_exp_is_error_for_lp(self, lp_type): + xs = torch.randn(self.b) + xs.exp_() + self.assertEqual(xs.type(), FLOAT) + xs = torch.randn(self.b, dtype=lp_type) + with self.assertRaises(NotImplementedError): + xs.exp_() + +class TestPromotionHalf(_TestPromotion): + def setUp(self): + self.handle = amp.init(enabled=True, patch_type=torch.half) + common_init(self) + + def tearDown(self): + self.handle._deactivate() + common_reset(self) + def test_atan2_matches_widest(self): fns = [lambda x, y : torch.atan2(x, y), lambda x, y : x.atan2(y)] - self.run_binary_promote_test(fns, (self.b,)) + self.run_binary_promote_test(fns, (self.b,), torch.half) def test_mul_matches_widest(self): fns = [lambda x, y : torch.mul(x, y), lambda x, y: x.mul(y)] - self.run_binary_promote_test(fns, (self.b,)) + self.run_binary_promote_test(fns, (self.b,), torch.half) def test_cat_matches_widest(self): - shape = self.b - ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)] - x_float = torch.randn(shape) - out = torch.cat(ys + [x_float]) - self.assertEqual(out.type(), FLOAT) - x_half = torch.randn(shape, dtype=torch.half) - out = torch.cat(ys + [x_half]) - self.assertEqual(out.type(), HALF) + self._test_cat_matches_widest(torch.half) def test_inplace_exp_is_error_for_half(self): - xs = torch.randn(self.b) - xs.exp_() - self.assertEqual(xs.type(), FLOAT) - xs = torch.randn(self.b, dtype=torch.half) - with self.assertRaises(NotImplementedError): - xs.exp_() + self._test_inplace_exp_is_error_for_lp(torch.half) + + def test_inplace_add_matches_self(self): + fn = lambda x, y: x.add_(y) + self.run_binary_promote_test([fn], (self.b,), torch.half, x_inplace=True) + +class TestPromotionBFloat16(_TestPromotion): + def setUp(self): + self.handle = amp.init(enabled=True, patch_type=torch.bfloat16) + common_init(self) + + def tearDown(self): + self.handle._deactivate() + common_reset(self) + + def test_mul_matches_widest(self): + fns = [lambda x, y : torch.mul(x, y), + lambda x, y: x.mul(y)] + self.run_binary_promote_test(fns, (self.b,), torch.bfloat16) + + def test_cat_matches_widest(self): + self._test_cat_matches_widest(torch.bfloat16) + + def test_inplace_exp_is_error_for_bfloat16(self): + self._test_inplace_exp_is_error_for_lp(torch.bfloat16) def test_inplace_add_matches_self(self): fn = lambda x, y: x.add_(y) - self.run_binary_promote_test([fn], (self.b,), x_inplace=True) + self.run_binary_promote_test([fn], (self.b,), torch.bfloat16, x_inplace=True) if __name__ == '__main__': unittest.main() diff --git a/tests/L0/run_amp/test_rnn.py b/tests/L0/run_amp/test_rnn.py index c49a5f003..02fb301d3 100644 --- a/tests/L0/run_amp/test_rnn.py +++ b/tests/L0/run_amp/test_rnn.py @@ -5,7 +5,8 @@ import torch from torch import nn -from utils import common_init, HALF +from utils import common_init, HALF, common_reset +from apex.testing.common_utils import skipIfRocm class TestRnnCells(unittest.TestCase): def setUp(self): @@ -14,6 +15,7 @@ def setUp(self): def tearDown(self): self.handle._deactivate() + common_reset(self) def run_cell_test(self, cell, state_tuple=False): shape = (self.b, self.h) @@ -38,7 +40,7 @@ def run_cell_test(self, cell, state_tuple=False): outputs[-1].float().sum().backward() for i, x in enumerate(xs): self.assertEqual(x.grad.dtype, x.dtype) - + def test_rnn_cell_is_half(self): cell = nn.RNNCell(self.h, self.h) self.run_cell_test(cell) @@ -58,6 +60,7 @@ def setUp(self): def tearDown(self): self.handle._deactivate() + common_reset(self) def run_rnn_test(self, rnn, layers, bidir, state_tuple=False): for typ in [torch.float, torch.half]: @@ -73,6 +76,7 @@ def run_rnn_test(self, rnn, layers, bidir, state_tuple=False): output[-1, :, :].float().sum().backward() self.assertEqual(x.grad.dtype, x.dtype) + @skipIfRocm def test_rnn_is_half(self): configs = [(1, False), (2, False), (2, True)] for layers, bidir in configs: @@ -80,6 +84,7 @@ def test_rnn_is_half(self): nonlinearity='relu', bidirectional=bidir) self.run_rnn_test(rnn, layers, bidir) + @skipIfRocm def test_gru_is_half(self): configs = [(1, False), (2, False), (2, True)] for layers, bidir in configs: @@ -87,6 +92,7 @@ def test_gru_is_half(self): bidirectional=bidir) self.run_rnn_test(rnn, layers, bidir) + @skipIfRocm def test_lstm_is_half(self): configs = [(1, False), (2, False), (2, True)] for layers, bidir in configs: @@ -94,6 +100,7 @@ def test_lstm_is_half(self): bidirectional=bidir) self.run_rnn_test(rnn, layers, bidir, state_tuple=True) + @skipIfRocm def test_rnn_packed_sequence(self): num_layers = 2 rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers) diff --git a/tests/L0/run_amp/utils.py b/tests/L0/run_amp/utils.py index 7aa20c369..781e03336 100644 --- a/tests/L0/run_amp/utils.py +++ b/tests/L0/run_amp/utils.py @@ -2,15 +2,21 @@ HALF = 'torch.cuda.HalfTensor' FLOAT = 'torch.cuda.FloatTensor' +BFLOAT16 = 'torch.cuda.BFloat16Tensor' DTYPES = [torch.half, torch.float] +DTYPES2 = [torch.bfloat16, torch.float] + ALWAYS_HALF = {torch.float: HALF, torch.half: HALF} +ALWAYS_BFLOAT16 = {torch.bfloat16: BFLOAT16, + torch.float: BFLOAT16} ALWAYS_FLOAT = {torch.float: FLOAT, torch.half: FLOAT} MATCH_INPUT = {torch.float: FLOAT, - torch.half: HALF} + torch.half: HALF, + torch.bfloat16: BFLOAT16} def common_init(test_case): test_case.h = 64 @@ -18,4 +24,9 @@ def common_init(test_case): test_case.c = 16 test_case.k = 3 test_case.t = 10 - torch.set_default_tensor_type(torch.cuda.FloatTensor) + torch.set_default_device('cuda') + torch.set_default_dtype(torch.float) + + +def common_reset(test_case): + torch.set_default_device('cpu') diff --git a/tests/L0/run_fp16util/test_fp16util.py b/tests/L0/run_fp16util/test_fp16util.py index eecddbc01..b6cba9824 100644 --- a/tests/L0/run_fp16util/test_fp16util.py +++ b/tests/L0/run_fp16util/test_fp16util.py @@ -73,3 +73,6 @@ def test_output_is_half(self): out_tensor = self.fp16_model(self.in_tensor) assert out_tensor.dtype == torch.half + +if __name__ == '__main__': + unittest.main() diff --git a/apex/contrib/test/fused_dense/test_fused_dense.py b/tests/L0/run_fused_dense/test_fused_dense.py similarity index 64% rename from apex/contrib/test/fused_dense/test_fused_dense.py rename to tests/L0/run_fused_dense/test_fused_dense.py index 301ebf6b5..6490f703c 100644 --- a/apex/contrib/test/fused_dense/test_fused_dense.py +++ b/tests/L0/run_fused_dense/test_fused_dense.py @@ -8,14 +8,14 @@ class FusedDenseTest(unittest.TestCase): def setUp(self, seed=0): torch.manual_seed(seed) - #torch.cuda.manual_seed_all(seed) + # torch.cuda.manual_seed_all(seed) self.seq_length = 512 self.sequences = 3 self.hidden_dim = 1024 self.ref_inputs = torch.randn(self.sequences*self.seq_length, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).int().half().requires_grad_(True) + dtype=torch.float16, device=torch.device("cuda")).half().requires_grad_(True) self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True) self.dense = fused_dense.FusedDense(1024, 3072) @@ -32,13 +32,12 @@ def test_fused_dense(self) : dx_ref = torch.matmul(dy, self.dense.weight.clone()) db_ref = dy.sum(0, False) - - self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(y_ref, y_tst, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(dw_ref, self.dense.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(dx_ref, self.tst_inputs.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(db_ref, self.dense.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) + self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)) + self.assertTrue(torch.allclose(y_ref, y_tst, atol=1e-3, rtol=1e-3, equal_nan=True)) + self.assertTrue(torch.allclose(dw_ref, self.dense.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) + self.assertTrue(torch.allclose(dx_ref, self.tst_inputs.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) + self.assertTrue(torch.allclose(db_ref, self.dense.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/L0/run_fused_dense/test_gelu.py b/tests/L0/run_fused_dense/test_gelu.py new file mode 100644 index 000000000..913fec7ab --- /dev/null +++ b/tests/L0/run_fused_dense/test_gelu.py @@ -0,0 +1,42 @@ +from apex import fused_dense +import torch +import torch.nn.functional as F +import unittest + + +class FusedDenseGeluDenseTest(unittest.TestCase): + + def test_fused_dense_gelu_dense(self) : + seed = 0 + torch.manual_seed(seed) + batch_size = 4 + in_features = 3 + intermediate_features = 3 + out_features = 2 + + #tst_dtype = torch.float8_e4m3 + # tst_dtype = torch.float8_e5m2 + tst_dtype = torch.float16 + + I = torch.randn(batch_size, in_features, dtype=tst_dtype, device='cuda').requires_grad_(True) + + denseGelu = fused_dense.FusedDenseGeluDense(in_features, intermediate_features, out_features) + denseGelu.to(dtype=tst_dtype) + denseGelu.cuda() + + #get weight and bias from the denseGelu module + W1 = denseGelu.weight1 + b1 = denseGelu.bias1 + W2 = denseGelu.weight2 + b2 = denseGelu.bias2 + + y_tst = denseGelu(I.clone().detach().requires_grad_(True)) + + C1 = torch.matmul(I, W1.t())+b1 + gelu_output = F.gelu(C1) + y_ref = torch.matmul(gelu_output, W2.t())+b2 + torch.testing.assert_close(y_ref, y_tst, atol=1e-3, rtol=1e-3, equal_nan=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py b/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py index 2150366fd..61b64849a 100644 --- a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py +++ b/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py @@ -1,252 +1,234 @@ -import itertools -import unittest - import torch +from apex.normalization import FusedLayerNorm +from apex.normalization import FusedRMSNorm +from apex.normalization import MixedFusedLayerNorm +from apex.normalization import MixedFusedRMSNorm -import apex - - -class TestFusedLayerNorm(unittest.TestCase): - dtype = torch.float - elementwise_affine = False - normalized_shape = [32, 16] - rtol, atol = None, None - fwd_thresholds = dict(rtol=None, atol=None) - bwd_thresholds = dict(rtol=None, atol=None) - mixed_fused = False - - def setUp(self): - # bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one - if not self.mixed_fused: - self.module_cpu_ = apex.normalization.FusedLayerNorm( - normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).cpu() - self.module_cuda_ = apex.normalization.FusedLayerNorm( - normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).to(device="cuda", dtype=self.dtype) - else: - assert self.elementwise_affine - self.module_cpu_ = apex.normalization.MixedFusedLayerNorm( - normalized_shape=self.normalized_shape).cpu() - self.module_cuda_ = apex.normalization.MixedFusedLayerNorm( - normalized_shape=self.normalized_shape).to(device="cuda", dtype=self.dtype) +from torch.testing._internal import common_utils +from torch.testing._internal.common_device_type import instantiate_device_type_tests + +from itertools import product + +def _prep_inputs(batch_size, normalized_shape, dtype): + shape = (batch_size, *normalized_shape) + fused = torch.randn(shape).cuda().requires_grad_(True) + with torch.no_grad(): + native = fused.clone().to(dtype).requires_grad_(True) + return native, fused + +autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,) + +class TestFusedLayerNorm(common_utils.TestCase): + + def _test_fused_layer_norm( + self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, + fwd_thresholds=dict(rtol=None, atol=None), bwd_thresholds=dict(rtol=None, atol=None) + ): + normalized_shape = [32, 16] + + if not mixed_fused: + module_cpu_ = FusedLayerNorm( + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient + ).cpu() + module_cuda_ = FusedLayerNorm( + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient + ).to(device="cuda", dtype=dtype) + else: + assert elementwise_affine + module_cpu_ = MixedFusedLayerNorm( + normalized_shape=normalized_shape, memory_efficient=memory_efficient + ).cpu() + module_cuda_ = MixedFusedLayerNorm( + normalized_shape=normalized_shape, memory_efficient=memory_efficient + ).to(device="cuda", dtype=dtype) - def _check_same_output(self, batch_size, contiguous): torch.cuda.manual_seed(42) if contiguous: - input_shape = [batch_size] + self.normalized_shape + input_shape = [batch_size] + normalized_shape input_ = torch.randn(input_shape, device="cpu").requires_grad_(True) - input_cuda_ = input_.to(device="cuda", dtype=self.dtype).detach().requires_grad_(True) + input_cuda_ = input_.to(device="cuda", dtype=dtype).detach().requires_grad_(True) self.assertTrue(input_.is_contiguous()) self.assertTrue(input_cuda_.is_contiguous()) else: - input_shape = [batch_size] + self.normalized_shape - input_shape = [batch_size * 3] + [self.normalized_shape[0] * 5, self.normalized_shape[1] * 3] + input_shape = [batch_size] + normalized_shape + input_shape = [batch_size * 3] + [normalized_shape[0] * 5, normalized_shape[1] * 3] input_src_ = torch.randn(input_shape, device="cpu") input_ = input_src_[::3, ::5, ::3].detach().requires_grad_(True) - input_cuda_ = input_src_.to(device="cuda", dtype=self.dtype)[::3, ::5, ::3].detach().requires_grad_(True) + input_cuda_ = input_src_.to(device="cuda", dtype=dtype)[::3, ::5, ::3].detach().requires_grad_(True) # make sure that tensors are NOT contiguous. self.assertFalse(input_.is_contiguous()) self.assertFalse(input_cuda_.is_contiguous()) - out_cpu_ = self.module_cpu_(input_) + out_cpu_ = module_cpu_(input_) gO = torch.rand_like(out_cpu_) out_cpu_.backward(gO) - out_cuda_ = self.module_cuda_(input_cuda_) - gO = gO.to(device="cuda", dtype=self.dtype) + out_cuda_ = module_cuda_(input_cuda_) + + gO = gO.to(device="cuda", dtype=dtype) out_cuda_.backward(gO) self.assertFalse(out_cpu_.is_cuda) self.assertTrue(out_cuda_.is_cuda) - # TODO (mkozuki): `torch.testing.assert_allclose` is deprecated. - # Use `torch.testing.assert_close`. - # See https://github.com/pytorch/pytorch/issues/61844 - torch.testing.assert_allclose( - out_cpu_.to(device="cuda", dtype=self.dtype), out_cuda_, **self.fwd_thresholds) - torch.testing.assert_allclose( - input_.grad.to(device="cuda", dtype=self.dtype), input_cuda_.grad, **self.bwd_thresholds) - - def _test_same_output(self, batch_size): - for contiguous in (True, False): - with self.subTest(contiguous=contiguous): - self._check_same_output(batch_size, contiguous) - - def test_layer_norm(self): - self._test_same_output(16) - - def test_large_batch(self): - self._test_same_output(65536) - - -class TestFusedRMSNorm(unittest.TestCase): - dtype = torch.float - elementwise_affine = False - normalized_shape = [32, 16] - rtol, atol = None, None - fwd_thresholds = dict(rtol=None, atol=None) - bwd_thresholds = dict(rtol=None, atol=None) - mixed_fused = False - - def setUp(self): - # bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one - if not self.mixed_fused: - self.module_cpu_ = apex.normalization.FusedRMSNorm( - normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).cpu() - self.module_cuda_ = apex.normalization.FusedRMSNorm( - normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).to(device="cuda", dtype=self.dtype) + torch.testing.assert_close( + out_cpu_.to(device="cuda", dtype=dtype), out_cuda_, **fwd_thresholds) + torch.testing.assert_close( + input_.grad.to(device="cuda", dtype=dtype), input_cuda_.grad, **bwd_thresholds) + + def _test_fused_rms_norm( + self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, + fwd_thresholds=dict(rtol=None, atol=None), bwd_thresholds=dict(rtol=None, atol=None) + ): + + normalized_shape = [32, 16] + + if not mixed_fused: + module_cpu_ = FusedRMSNorm( + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient + ).cpu() + module_cuda_ = FusedRMSNorm( + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient + ).to(device="cuda", dtype=dtype) else: - assert self.elementwise_affine - self.module_cpu_ = apex.normalization.MixedFusedRMSNorm( - normalized_shape=self.normalized_shape).cpu() - self.module_cuda_ = apex.normalization.MixedFusedRMSNorm( - normalized_shape=self.normalized_shape).to(device="cuda", dtype=self.dtype) + assert elementwise_affine + module_cpu_ = MixedFusedRMSNorm( + normalized_shape=normalized_shape).cpu() + module_cuda_ = MixedFusedRMSNorm( + normalized_shape=normalized_shape).to(device="cuda", dtype=dtype) - def _check_same_output(self, batch_size, contiguous): torch.cuda.manual_seed(42) if contiguous: - input_shape = [batch_size] + self.normalized_shape + input_shape = [batch_size] + normalized_shape input_ = torch.randn(input_shape, device="cpu").requires_grad_(True) - input_cuda_ = input_.to(device="cuda", dtype=self.dtype).detach().requires_grad_(True) + input_cuda_ = input_.to(device="cuda", dtype=dtype).detach().requires_grad_(True) self.assertTrue(input_.is_contiguous()) self.assertTrue(input_cuda_.is_contiguous()) else: - input_shape = [batch_size] + self.normalized_shape - input_shape = [batch_size * 3] + [self.normalized_shape[0] * 5, self.normalized_shape[1] * 3] + input_shape = [batch_size] + normalized_shape + input_shape = [batch_size * 3] + [normalized_shape[0] * 5, normalized_shape[1] * 3] input_src_ = torch.randn(input_shape, device="cpu") input_ = input_src_[::3, ::5, ::3].detach().requires_grad_(True) - input_cuda_ = input_src_.to(device="cuda", dtype=self.dtype)[::3, ::5, ::3].detach().requires_grad_(True) + input_cuda_ = input_src_.to(device="cuda", dtype=dtype)[::3, ::5, ::3].detach().requires_grad_(True) # make sure that tensors are NOT contiguous. self.assertFalse(input_.is_contiguous()) self.assertFalse(input_cuda_.is_contiguous()) - out_cpu_ = self.module_cpu_(input_) + out_cpu_ = module_cpu_(input_) gO = torch.rand_like(out_cpu_) out_cpu_.backward(gO) - out_cuda_ = self.module_cuda_(input_cuda_) - # TODO (mkozuki): `torch.testing.assert_allclose` is deprecated. - # Use `torch.testing.assert_close`. - # See https://github.com/pytorch/pytorch/issues/61844 - torch.testing.assert_allclose( - out_cpu_.to(device="cuda", dtype=self.dtype), out_cuda_.clone().detach(), **self.fwd_thresholds) - gO = gO.to(device="cuda", dtype=self.dtype) + out_cuda_ = module_cuda_(input_cuda_) + + torch.testing.assert_close( + out_cpu_.to(device="cuda", dtype=dtype), out_cuda_.clone().detach(), **fwd_thresholds) + gO = gO.to(device="cuda", dtype=dtype) out_cuda_.backward(gO) self.assertFalse(out_cpu_.is_cuda) self.assertTrue(out_cuda_.is_cuda) - torch.testing.assert_allclose( - input_.grad.to(device="cuda", dtype=self.dtype), input_cuda_.grad, **self.bwd_thresholds) - if self.elementwise_affine: - torch.testing.assert_allclose(self.module_cpu_.weight.grad.to(device="cuda", dtype=self.dtype), - self.module_cuda_.weight.grad, **self.bwd_thresholds) - - def _test_same_output(self, batch_size): - for contiguous in (True, False): - with self.subTest(contiguous=contiguous): - self._check_same_output(batch_size, contiguous) - - def test_layer_norm(self): - self._test_same_output(16) - - def test_large_batch(self): - self._test_same_output(65536) - - -class TestFusedLayerNormElemWise(TestFusedLayerNorm): - elementwise_affine = True - -class TestMixedFusedLayerNormElemWise(TestFusedLayerNorm): - elementwise_affine = True - mixed_fused = True - -class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise): - dtype = torch.half - - def test_large_batch(self): - self.skipTest("Skip to save time") - -class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise): - dtype = torch.bfloat16 - # NOTE (mkozuki): [BFloat16 Layer Norm flakiness] - # Use thresholds larger than those used in pytorch, see - # https://github.com/pytorch/pytorch/blob/72274e2a2fd55019ec860e1743dbdc5b0c5a5624/torch/testing/_asserts.py#L26 - fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) - bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) - - def test_large_batch(self): - self.skipTest("Skip to save time") - - -class TestFusedRMSNormElemWise(TestFusedRMSNorm): - bwd_thresholds = dict(rtol=2e-3, atol=2e-4) - elementwise_affine = True - -class TestMixedFusedRMSNormElemWise(TestFusedRMSNorm): - bwd_thresholds = dict(rtol=2e-3, atol=2e-4) - elementwise_affine = True - mixed_fused = True - -class TestFusedRMSNormElemWiseHalf(TestFusedRMSNormElemWise): - dtype = torch.half - bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) - - def test_large_batch(self): - self.skipTest("Skip to save time") - - -class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise): - dtype = torch.bfloat16 - # NOTE (mkozuki): [BFloat16 Layer Norm flakiness] - # Use thresholds larger than those used in pytorch, see - # https://github.com/pytorch/pytorch/blob/72274e2a2fd55019ec860e1743dbdc5b0c5a5624/torch/testing/_asserts.py#L26 - fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) - bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) - - def test_large_batch(self): - self.skipTest("Skip to save time") - - -def _prep_layers(normalized_shape, elementwise_affine, dtype): - native = torch.nn.LayerNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine - ).to(device="cuda", dtype=dtype) - fused = apex.normalization.FusedLayerNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine - ).cuda() - return native, fused - - -def _prep_rms_layers(normalized_shape, elementwise_affine, dtype): - native = apex.normalization.FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine + torch.testing.assert_close( + input_.grad.to(device="cuda", dtype=dtype), input_cuda_.grad, **bwd_thresholds) + if elementwise_affine: + torch.testing.assert_close(module_cpu_.weight.grad.to(device="cuda", dtype=dtype), + module_cuda_.weight.grad, **bwd_thresholds) + + # layer norm tests + @common_utils.parametrize( + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (False,), (False,), (torch.float,), (True, False))) ) - fused = apex.normalization.FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine - ).cuda() - return native, fused - - -def _prep_inputs(batch_size, normalized_shape, dtype): - shape = (batch_size, *normalized_shape) - fused = torch.randn(shape).cuda().requires_grad_(True) - with torch.no_grad(): - native = fused.clone().to(dtype).requires_grad_(True) - return native, fused - + def test_layer_norm_regular(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) + + @common_utils.parametrize( + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (True,), (False,), (torch.float,), (True, False))) + ) + def test_layer_norm_elemwise(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) -autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,) + @common_utils.parametrize( + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (True,), (True,), (torch.float,), (True, False))) + ) + def test_layer_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) + + @common_utils.parametrize( + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16,), (True, False), (True,), (False,), (torch.half,), (True, False))) + ) + def test_layer_norm_half(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, + fwd_thresholds=dict(rtol=1e-3, atol=1e-3), bwd_thresholds=dict(rtol=1e-3, atol=1e-3)) + + @common_utils.parametrize( + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16,), (True, False), (True,), (False,), (torch.bfloat16,), (True, False))) + ) + def test_layer_norm_bfloat16(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, + fwd_thresholds=dict(rtol=1.6e-2, atol=3e-4), bwd_thresholds=dict(rtol=1.6e-2, atol=3e-3)) + + # rms norm tests + @common_utils.parametrize( + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (False,), (False,), (torch.float,), (True, False))) + ) + def test_rms_norm_regular(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) -class TestAutocastFusedLayerNorm(unittest.TestCase): - bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) - bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) + @common_utils.parametrize( + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (True,), (False,), (torch.float,), (True, False))) + ) + def test_rms_norm_elemwise(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, + bwd_thresholds=dict(rtol=2e-3, atol=2e-4)) - def setUp(self): - self.batch_size = 16 - self.normalized_shape = [32, 16] + @common_utils.parametrize( + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (True,), (True,), (torch.float,), (True, False))) + ) + def test_rms_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, + bwd_thresholds=dict(rtol=2e-3, atol=2e-4)) + + @common_utils.parametrize( + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16,), (True, False), (True,), (False,), (torch.half,), (True, False))) + ) + def test_rms_norm_half(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, + bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)) + + @common_utils.parametrize( + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16,), (True, False), (True,), (False,), (torch.bfloat16,), (True, False))) + ) + def test_rms_norm_bfloat16(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, + fwd_thresholds=dict(rtol=1.6e-2, atol=3e-4), bwd_thresholds=dict(rtol=1.6e-2, atol=3e-2)) - def _run_test(self, dtype, elementwise_affine): - native, fused = _prep_layers(self.normalized_shape, elementwise_affine, dtype) - native_x, fused_x = _prep_inputs(self.batch_size, self.normalized_shape, dtype) + @common_utils.parametrize( + "dtype, elementwise_affine, memory_efficient", + list(product(autocast_dtypes, (True, False), (True, False))) + ) + def test_autocast_fused_layer_norm(self, dtype, elementwise_affine, memory_efficient): + bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) + bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) + batch_size = 16 + normalized_shape = [32, 16] + native = torch.nn.LayerNorm( + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine + ).to(device="cuda", dtype=dtype) + fused = FusedLayerNorm( + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient + ).cuda() + native_x, fused_x = _prep_inputs(batch_size, normalized_shape, dtype) expected = native(native_x) - with torch.cuda.amp.autocast(dtype=dtype): + with torch.amp.autocast('cuda', dtype=dtype): actual = fused(fused_x) - tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedLayerNorm.bf16_fwd_thresholds - torch.testing.assert_allclose(actual, expected, **tols) + tols = {'rtol': None, 'atol': None} if dtype == torch.half else bf16_fwd_thresholds + # original tests used torch.testing.assert_allclose, which disables dtype checking by default. + # link to issue here: https://github.com/pytorch/pytorch/issues/61844 + torch.testing.assert_close(actual, expected, **tols, check_dtype=False) g_native = torch.rand_like(expected) with torch.no_grad(): @@ -254,31 +236,35 @@ def _run_test(self, dtype, elementwise_affine): expected.backward(g_native) actual.backward(g_fused) - tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedLayerNorm.bf16_bwd_thresholds - torch.testing.assert_allclose(native_x.grad, fused_x.grad, **tols) - - def test_autocast(self): - for (dtype, elementwise_affine) in itertools.product(autocast_dtypes, (True, False)): - with self.subTest(f"{dtype}-{elementwise_affine}"): - self._run_test(dtype, elementwise_affine) - -class TestAutocastFusedRMSNorm(unittest.TestCase): - bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) - bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) - - def setUp(self): - self.batch_size = 16 - self.normalized_shape = [32, 16] - - def _run_test(self, dtype, elementwise_affine): - native, fused = _prep_rms_layers(self.normalized_shape, elementwise_affine, dtype) - native_x, fused_x = _prep_inputs(self.batch_size, self.normalized_shape, dtype) + if dtype != torch.half: + tols = bf16_bwd_thresholds + elif memory_efficient: + tols = {'rtol': 1e-3, 'atol': 1e-4} + else: + tols = {'rtol': None, 'atol': None} + torch.testing.assert_close(native_x.grad, fused_x.grad, **tols, check_dtype=False) + @common_utils.parametrize( + "dtype, elementwise_affine, memory_efficient", + list(product(autocast_dtypes, (True, False), (True, False))) + ) + def test_autocast_fused_rms_norm(self, dtype, elementwise_affine, memory_efficient): + bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) + bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) + batch_size = 16 + normalized_shape = [32, 16] + native = FusedRMSNorm( + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient, + ).to(dtype=dtype) + fused = FusedRMSNorm( + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient, + ).cuda() + native_x, fused_x = _prep_inputs(batch_size, normalized_shape, dtype) expected = native(native_x.cpu()) - with torch.cuda.amp.autocast(dtype=dtype): + with torch.amp.autocast('cuda', dtype=dtype): actual = fused(fused_x) - tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedRMSNorm.bf16_fwd_thresholds - torch.testing.assert_allclose(actual, expected.detach().clone().cuda(), **tols) + tols = {'rtol': None, 'atol': None} if dtype == torch.half else bf16_fwd_thresholds + torch.testing.assert_close(actual, expected.detach().clone().cuda(), **tols, check_dtype=False) g_native = torch.rand_like(expected) with torch.no_grad(): @@ -286,10 +272,100 @@ def _run_test(self, dtype, elementwise_affine): expected.backward(g_native) actual.backward(g_fused) - tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedRMSNorm.bf16_bwd_thresholds - torch.testing.assert_allclose(native_x.grad.cuda(), fused_x.grad, **tols) - - def test_autocast(self): - for (dtype, elementwise_affine) in itertools.product(autocast_dtypes, (True, False)): - with self.subTest(f"{dtype}-{elementwise_affine}"): - self._run_test(dtype, elementwise_affine) + tols = {'rtol': 1e-3, 'atol': 1e-3} if dtype == torch.half else bf16_bwd_thresholds + torch.testing.assert_close(native_x.grad.cuda(), fused_x.grad, **tols, check_dtype=False) + + def _verify_export(self, fused, fused_x): + # check that export() is working + import io + f = io.BytesIO() + torch.onnx.export(fused, (fused_x,), f, + input_names=['x_in'], + opset_version=18, + ) + # Load the ONNX model + import onnx + model_onnx = onnx.load_from_string(f.getvalue()) + # Get string representation + onnx_str = onnx.helper.printable_graph(model_onnx.graph) + + assert 'x_in' in onnx_str + assert 'ReduceMean' in onnx_str or 'LayerNormalization' in onnx_str + + def test_rms_export(self): + batch_size = 16 + normalized_shape = [32, 16] + fused = FusedRMSNorm( + normalized_shape=normalized_shape, elementwise_affine=True + ).cuda() + fused_m = MixedFusedRMSNorm( + normalized_shape=normalized_shape + ).cuda() + native_x, fused_x = _prep_inputs(batch_size, normalized_shape, torch.float32) + self._verify_export(fused, fused_x) + self._verify_export(fused_m, fused_x) + + def test_layer_norm_export(self): + batch_size = 16 + normalized_shape = [32, 16] + fused = FusedLayerNorm( + normalized_shape=normalized_shape, elementwise_affine=True + ).cuda() + fused_m = MixedFusedLayerNorm( + normalized_shape=normalized_shape + ).cuda() + native_x, fused_x = _prep_inputs(batch_size, normalized_shape, torch.float32) + self._verify_export(fused, fused_x) + self._verify_export(fused_m, fused_x) + + @common_utils.parametrize("elementwise_affine", (True, False)) + def test_compile_fused_layer_norm(self, elementwise_affine): + batch_size = 16 + normalized_shape = [32, 16] + eager_mod = FusedLayerNorm( + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine + ).cuda() + compiled_mod = torch.compile(fullgraph=True)(eager_mod) + input_shape = [batch_size] + normalized_shape + eager_x = torch.randn(input_shape, device="cuda").requires_grad_(True) + compiled_x = eager_x.detach().clone().requires_grad_(True) + + expected = eager_mod(eager_x) + actual = compiled_mod(compiled_x) + torch.testing.assert_close(actual, expected.detach()) + + g_eager = torch.rand_like(expected) + with torch.no_grad(): + g_compiled = g_eager.detach().clone() + expected.backward(g_eager) + actual.backward(g_compiled) + + torch.testing.assert_close(eager_x.grad, compiled_x.grad) + + @common_utils.parametrize("elementwise_affine", (True, False)) + def test_compile_fused_rms_norm(self, elementwise_affine): + batch_size = 16 + normalized_shape = [32, 16] + eager_mod = FusedRMSNorm( + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine + ).cuda() + compiled_mod = torch.compile(fullgraph=True)(eager_mod) + input_shape = [batch_size] + normalized_shape + eager_x = torch.randn(input_shape, device="cuda").requires_grad_(True) + compiled_x = eager_x.detach().clone().requires_grad_(True) + + expected = eager_mod(eager_x) + actual = compiled_mod(compiled_x) + torch.testing.assert_close(actual, expected.detach()) + + g_eager = torch.rand_like(expected) + with torch.no_grad(): + g_compiled = g_eager.detach().clone() + expected.backward(g_eager) + actual.backward(g_compiled) + + torch.testing.assert_close(eager_x.grad, compiled_x.grad) + +instantiate_device_type_tests(TestFusedLayerNorm, globals(), only_for=("cuda",)) +if __name__ == "__main__": + common_utils.run_tests() \ No newline at end of file diff --git a/tests/L0/run_mlp/test_mlp.py b/tests/L0/run_mlp/test_mlp.py index 9ccda566d..09ebddee1 100644 --- a/tests/L0/run_mlp/test_mlp.py +++ b/tests/L0/run_mlp/test_mlp.py @@ -7,6 +7,7 @@ from torch import nn from apex.mlp import MLP +from apex.testing.common_utils import skipFlakyTest batch_size = 1024 mlp_sizes = [480, 1024, 1024, 512, 256, 1] @@ -17,6 +18,7 @@ class TestMLP(unittest.TestCase): def test_creation(self): MLP(mlp_sizes) + @skipFlakyTest def test_numeric(self): mlp = MLP(mlp_sizes).cuda() @@ -37,7 +39,7 @@ def test_numeric(self): np.testing.assert_allclose( mlp_out.detach().cpu().numpy(), ref_out.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) + atol=1e-5, rtol=1e-5) # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out mlp_out.mean().mul(10.).backward() @@ -45,12 +47,13 @@ def test_numeric(self): np.testing.assert_allclose( test_input.grad.detach().cpu().numpy(), ref_input.grad.detach().cpu().numpy(), - atol=0, rtol=1e-5) + atol=1e-5, rtol=1e-5) np.testing.assert_allclose( mlp.biases[0].grad.detach().cpu().numpy(), ref_mlp[0].bias.grad.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) + atol=1e-5, rtol=1e-5) + @skipFlakyTest def test_no_bias(self): for use_activation in ['none', 'relu', 'sigmoid']: mlp = MLP(mlp_sizes, bias=False, activation=use_activation).cuda() @@ -74,7 +77,7 @@ def test_no_bias(self): np.testing.assert_allclose( mlp_out.detach().cpu().numpy(), ref_out.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) + atol=1e-5, rtol=1e-5) # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out mlp_out.mean().mul(10.).backward() @@ -82,12 +85,13 @@ def test_no_bias(self): np.testing.assert_allclose( test_input.grad.detach().cpu().numpy(), ref_input.grad.detach().cpu().numpy(), - atol=0, rtol=100) + atol=1e-5, rtol=100) np.testing.assert_allclose( mlp.weights[0].grad.detach().cpu().numpy(), ref_mlp[0].weight.grad.detach().cpu().numpy(), - atol=1e-7, rtol=100) + atol=1e-5, rtol=100) + @skipFlakyTest def test_with_bias(self): for use_activation in ['none', 'relu', 'sigmoid']: mlp = MLP(mlp_sizes, bias=True, activation=use_activation).cuda() @@ -112,7 +116,7 @@ def test_with_bias(self): np.testing.assert_allclose( mlp_out.detach().cpu().numpy(), ref_out.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) + atol=1e-5, rtol=1e-5) # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out mlp_out.mean().mul(10.).backward() @@ -120,16 +124,17 @@ def test_with_bias(self): np.testing.assert_allclose( test_input.grad.detach().cpu().numpy(), ref_input.grad.detach().cpu().numpy(), - atol=0, rtol=1) + atol=1e-5, rtol=1) np.testing.assert_allclose( mlp.weights[0].grad.detach().cpu().numpy(), ref_mlp[0].weight.grad.detach().cpu().numpy(), - atol=1e-7, rtol=1) + atol=1e-5, rtol=1) np.testing.assert_allclose( mlp.biases[0].grad.detach().cpu().numpy(), ref_mlp[0].bias.grad.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) + atol=1e-5, rtol=1e-5) + @skipFlakyTest def test_no_grad(self): mlp = MLP(mlp_sizes).cuda() @@ -150,7 +155,7 @@ def test_no_grad(self): np.testing.assert_allclose( mlp_out.detach().cpu().numpy(), ref_out.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) + atol=1e-5, rtol=1e-5) # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out mlp_out.mean().mul(10.).backward() @@ -158,8 +163,7 @@ def test_no_grad(self): np.testing.assert_allclose( mlp.weights[0].grad.detach().cpu().numpy(), ref_mlp[0].weight.grad.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) - + atol=1e-5, rtol=1e-5) def test_performance_half(self): mlp = MLP(mlp_sizes).cuda().half() @@ -190,7 +194,7 @@ def test_performance_half(self): mlp.zero_grad() test_loss.backward() - torch.cuda.profiler.start() + #torch.cuda.profiler.start() torch.cuda.synchronize() start_time = time() for _ in range(num_iters): @@ -212,7 +216,7 @@ def test_performance_half(self): torch.cuda.synchronize() stop_time = time() print(F"C++ MLP time {(stop_time - start_time) * 1000. / num_iters:.4f} ms") - torch.cuda.profiler.stop() + #torch.cuda.profiler.stop() if __name__ == '__main__': unittest.main() diff --git a/tests/L0/run_optimizers/test_adam.py b/tests/L0/run_optimizers/test_adam.py new file mode 100644 index 000000000..9fd00cbea --- /dev/null +++ b/tests/L0/run_optimizers/test_adam.py @@ -0,0 +1,254 @@ +import copy +import math +import random +import unittest + +import torch +import torch.nn.functional as F +from torch import nn +from torch.testing._internal.common_device_type import largeTensorTest + +try: + import apex +except ImportError as e: + HAS_APEX = False +else: + HAS_APEX = True + + +class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(2) + self.fc1 = nn.Linear(256, 120) + self.relu3 = nn.ReLU() + self.fc2 = nn.Linear(120, 84) + self.relu4 = nn.ReLU() + self.fc3 = nn.Linear(84, 10) + self.relu5 = nn.ReLU() + + def forward(self, x): + y = self.conv1(x) + y = self.relu1(y) + y = self.pool1(y) + y = self.conv2(y) + y = self.relu2(y) + y = self.pool2(y) + y = y.reshape(y.shape[0], -1) + y = self.fc1(y) + y = self.relu3(y) + y = self.fc2(y) + y = self.relu4(y) + y = self.fc3(y) + y = self.relu5(y) + return y + + +@unittest.skipIf(not HAS_APEX, "`apex` is not found.") +class AdamTest(unittest.TestCase): + def setUp(self, seed=0): + super().setUp() + torch.manual_seed(seed) + + self.model = Model().cuda() + self.model_ = Model().cuda() + self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) + + self.lr = 0.00001 + params = [p for p in self.model.parameters() if p.requires_grad] + self.optimizer = torch.optim.Adam(params, lr=self.lr) + + def testGradScaler(self): + params_ = [p for p in self.model_.parameters() if p.requires_grad] + optimizer_ = apex.optimizers.FusedAdam(params_, lr=self.lr, capturable=False) + scaler = torch.amp.GradScaler('cuda', enabled=True) + scaler_ = torch.amp.GradScaler('cuda', enabled=True) + + for i in range(100): + x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last) + x_ = x.clone() + gt = torch.rand([32, 10]).cuda() + gt_ = gt.clone() + + # Reference + with torch.amp.autocast('cuda', enabled=True): + y = self.model(x) + loss = ((gt - y) ** 2).mean() + + scaler.scale(loss).backward() + scaler.step(self.optimizer) + scaler.update() + + # DUT + with torch.amp.autocast('cuda', enabled=True): + y = self.model_(x) + loss_ = ((gt_ - y) ** 2).mean() + + scaler_.scale(loss_).backward() + scaler_.step(optimizer_) + scaler_.update() + + for module in zip(self.model.modules(), self.model_.modules()): + m = module[0] + m_ = module[1] + if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear): + torch.testing.assert_close(m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True) + torch.testing.assert_close(m.weight.grad, m_.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True) + + # Init for next iteration + self.optimizer.zero_grad() + optimizer_.zero_grad() + + self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) + + def testGradScalerCapturable(self): + params_ = [p for p in self.model_.parameters() if p.requires_grad] + optimizer_ = apex.optimizers.FusedAdam(params_, lr=self.lr, capturable=True) + scaler = torch.amp.GradScaler('cuda', enabled=True) + scaler_ = torch.amp.GradScaler('cuda', enabled=True) + + for i in range(100): + x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last) + x_ = x.clone() + gt = torch.rand([32, 10]).cuda() + gt_ = gt.clone() + + # Reference + with torch.amp.autocast('cuda', enabled=True): + y = self.model(x) + loss = ((gt - y) ** 2).mean() + + scaler.scale(loss).backward() + scaler.step(self.optimizer) + scaler.update() + + # DUT + with torch.amp.autocast('cuda', enabled=True): + y = self.model_(x) + loss_ = ((gt_ - y) ** 2).mean() + + scaler_.scale(loss_).backward() + scaler_.step(optimizer_) + scaler_.update() + + for module in zip(self.model.modules(), self.model_.modules()): + m = module[0] + m_ = module[1] + if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear): + torch.testing.assert_close(m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True) + torch.testing.assert_close(m.weight.grad, m_.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True) + + # Init for next iteration + self.optimizer.zero_grad() + optimizer_.zero_grad() + + self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) + + def testGradScalerCapturableMaster(self): + # Cast conv layers to FP16 + for m in self.model_.modules(): + if m.__class__ in [torch.nn.Conv2d]: + m.half() + params_ = [p for p in self.model_.parameters() if p.requires_grad] + optimizer_ = apex.optimizers.FusedAdam(params_, lr=self.lr, capturable=True, master_weights=True) + scaler = torch.amp.GradScaler('cuda', enabled=True) + scaler_ = torch.amp.GradScaler('cuda', enabled=True) + + for i in range(100): + x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last) + x_ = x.clone() + gt = torch.rand([32, 10]).cuda() + gt_ = gt.clone() + + # Reference + with torch.amp.autocast('cuda', enabled=True): + y = self.model(x) + loss = ((gt - y) ** 2).mean() + + scaler.scale(loss).backward() + scaler.step(self.optimizer) + scaler.update() + + # DUT + with torch.amp.autocast('cuda', enabled=True): + y = self.model_(x) + loss_ = ((gt_ - y) ** 2).mean() + + scaler_.scale(loss_).backward() + scaler_.step(optimizer_) + scaler_.update() + + for module in zip(self.model.modules(), self.model_.modules()): + m = module[0] + m_ = module[1] + if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear): + torch.testing.assert_close(m.weight, m_.weight.float(), atol=1e-3, rtol=1e-3, equal_nan=True) + torch.testing.assert_close(m.weight.grad, m_.weight.grad.float(), atol=1e-3, rtol=1e-3, equal_nan=True) + + # Init for next iteration + self.optimizer.zero_grad() + optimizer_.zero_grad() + + self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) + + def testNative(self): + params_ = [p for p in self.model_.parameters() if p.requires_grad] + optimizer_ = apex.optimizers.FusedAdam(params_, lr=self.lr, capturable=False) + + for i in range(100): + x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last) + x_ = x.clone() + gt = torch.rand([32, 10]).cuda() + gt_ = gt.clone() + + # Reference + y = self.model(x) + loss = ((gt - y) ** 2).mean() + + loss.backward() + self.optimizer.step() + + # DUT + y = self.model_(x) + loss_ = ((gt_ - y) ** 2).mean() + + loss_.backward() + optimizer_.step() + + for module in zip(self.model.modules(), self.model_.modules()): + m = module[0] + m_ = module[1] + if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear): + torch.testing.assert_close(m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True) + torch.testing.assert_close(m.weight.grad, m_.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True) + + # Init for next iteration + self.optimizer.zero_grad() + optimizer_.zero_grad() + + self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) + + @largeTensorTest('60GB', 'cuda') + def testLargeTensor(self): + t = torch.zeros(2359332864, dtype=torch.half, device='cuda') + t2 = torch.zeros(2359332864, dtype=torch.half, device='cuda') + grad = torch.randn_like(t) + t.grad = grad + t2.grad = grad + params = [t] + params2 = [t2] + optimizer = apex.optimizers.FusedAdam(params, lr=self.lr) + optimizer.step() + optimizer2 = torch.optim.Adam(params2, lr=self.lr) + torch.testing.assert_close(t, t2) + torch.cuda.synchronize() + + +if __name__ == '__main__': + unittest.main() + diff --git a/tests/L0/run_optimizers/test_fused_optimizer.py b/tests/L0/run_optimizers/test_fused_optimizer.py index e4c86ef9f..3a969d3a2 100644 --- a/tests/L0/run_optimizers/test_fused_optimizer.py +++ b/tests/L0/run_optimizers/test_fused_optimizer.py @@ -6,6 +6,8 @@ import apex +from apex.testing.common_utils import skipIfRocm + class TestFusedOptimizer(unittest.TestCase): def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7): @@ -91,6 +93,8 @@ class TestFusedAdam(TestFusedOptimizer): def setUp(self): super().setUp() self.options = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, + 'weight_decay': 0, 'amsgrad': False, "capturable": True} + self.tst_options = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay': 0, 'amsgrad': False} self.ref_optim = torch.optim.Adam self.fused_optim = apex.optimizers.FusedAdam @@ -100,9 +104,11 @@ def test_float(self): # NOTE(mkozuki): Current threshold values look too small for BFloat16. # TODO(mkozuki): Refactor `TestFusedOptimizer` + @unittest.skip("NaN issue observed on ROCm as of 12/1/2021. The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/63") def test_half(self): self.gen_single_type_test(param_type=torch.float16, skip_assert=True) + @skipIfRocm def test_bfloat16(self): self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True) @@ -175,11 +181,14 @@ def test_fp16_output(self): def test_adam_option(self): nelem = 1 adam_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, + 'weight_decay':0, 'amsgrad':False, 'capturable':True} + + adam_option_tst = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':0, 'amsgrad':False} tensor = torch.rand(nelem, dtype=torch.float, device='cuda') ref_param, tst_param, ref_optim, tst_optim = \ - self.gen_param_optim([tensor], adam_option) + self.gen_param_optim([tensor], adam_option, adam_option_tst) for i in range(self.iters): self.gen_grad(ref_param, tst_param) diff --git a/tests/L0/run_optimizers/test_fused_optimizer_channels_last.py b/tests/L0/run_optimizers/test_fused_optimizer_channels_last.py new file mode 100644 index 000000000..7db329bce --- /dev/null +++ b/tests/L0/run_optimizers/test_fused_optimizer_channels_last.py @@ -0,0 +1,112 @@ +from itertools import product +import random +import unittest + +import torch + +import apex + +# NHWC +class TestFusedOptimizerChannelsLast(unittest.TestCase): + def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7): + self.max_abs_diff = max_abs_diff + self.max_rel_diff = max_rel_diff + self.iters = iters + torch.manual_seed(9876) + + def tearDown(self): + pass + + def gen_param_optim(self, tensors, options, device, tst_options=None): + + # Adding this to make backward compatible with existing tests. Just in + # case "tst_options" are not provided, it gets a copy of options + # which contains the parameters for the reference optimizer + if tst_options == None: + tst_options = options + + ref_param = [] + tst_param = [] + for tensor in tensors: + input = tensor.clone().contiguous(memory_format=torch.channels_last).to(device) # channels_last + ref_input = tensor.clone().contiguous().to(device) + + self.assertTrue(input.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_input.is_contiguous(memory_format=torch.contiguous_format)) + + tst_param.append(torch.nn.Parameter(input)) + ref_param.append(torch.nn.Parameter(ref_input)) + + ref_optim = self.ref_optim(ref_param, **options) + tst_optim = self.fused_optim(tst_param, **tst_options) + return (ref_param, tst_param, ref_optim, tst_optim) + + def gen_grad(self, ref_param, tst_param): + for p_ref, p_tst in zip(ref_param, tst_param): + p_ref.grad = torch.rand_like(p_ref) + p_tst.grad = p_ref.grad.clone() #### p_tst is =torch.channels_last but p_tst.grad is torch.contiguous_format + + self.assertTrue(p_tst.grad.is_contiguous(memory_format=torch.contiguous_format)) + self.assertTrue(p_ref.grad.is_contiguous(memory_format=torch.contiguous_format)) + + + def get_max_diff(self, ref_param, tst_param): + max_abs_diff = max_rel_diff = 0 + for p_ref, p_tst in zip(ref_param, tst_param): + self.assertTrue(p_ref.is_contiguous(memory_format=torch.contiguous_format)) + self.assertTrue(p_tst.is_contiguous(memory_format=torch.channels_last)) + max_abs_diff_p = (p_ref - p_tst).abs().max().item() + max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item() + + if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p + if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p + + return max_abs_diff, max_rel_diff + + def gen_single_type_test(self, param_type=torch.float, device='cuda', *, skip_assert: bool = False): + # nelem = 278011 + + # Some ref and test optimizers may require different set of options. + # This is a quick workaround to add that functionality while making + # minimum changes in existing code. + # If there is no "tst_options" field provided, safe to initialize + # the test optimizer with the parameters of reference optimizer. + if not hasattr(self, 'tst_options'): + self.tst_options = self.options + + tensor = torch.rand([3,4,2,3], dtype=param_type, device=device) + ref_param, tst_param, ref_optim, tst_optim = \ + self.gen_param_optim([tensor], self.options, device, self.tst_options) + + for i in range(self.iters): + self.gen_grad(ref_param, tst_param) + ref_optim.step() + tst_optim.step() + if skip_assert: + return + max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) + self.assertLessEqual(max_abs_diff, self.max_abs_diff) + self.assertLessEqual(max_rel_diff, self.max_rel_diff) + +class TestFusedSGDChannelLast(TestFusedOptimizerChannelsLast): + def __init__(self, *args, **kwargs): + super(TestFusedSGDChannelLast, self).__init__(*args, **kwargs) + self.options = {"lr": .25, "momentum": .125} + self.ref_optim = torch.optim.SGD + self.fused_optim = apex.optimizers.FusedSGD + + def test_float(self): + self.gen_single_type_test(param_type=torch.float) + + def test_half(self): + self.gen_single_type_test(param_type=torch.float16) + + @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") + def test_multi_device(self): + devices = ("cuda:0", "cuda:1") + for current_dev, tensor_dev in product(devices, devices): + with torch.cuda.device(current_dev): + self.gen_single_type_test(param_type=torch.float, device=tensor_dev) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/L0/run_optimizers/test_lamb.py b/tests/L0/run_optimizers/test_lamb.py index 4900fe5af..c6ef9aa95 100644 --- a/tests/L0/run_optimizers/test_lamb.py +++ b/tests/L0/run_optimizers/test_lamb.py @@ -285,6 +285,7 @@ def test_float(self): def test_half(self): self.gen_single_type_test(param_type=torch.float16) + @unittest.skip("Skipped the test since it failed the accuracy test on the PyTorch as of 8/1/2022. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/83") @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") def test_multi_device(self): devices = ("cuda:0", "cuda:1") diff --git a/tests/L0/run_rocm.sh b/tests/L0/run_rocm.sh new file mode 100755 index 000000000..32405e7ab --- /dev/null +++ b/tests/L0/run_rocm.sh @@ -0,0 +1,2 @@ +#!/bin/bash +APEX_TEST_WITH_ROCM=1 APEX_SKIP_FLAKY_TEST=1 python run_test.py diff --git a/tests/L0/run_test.py b/tests/L0/run_test.py index 32db6f564..ed84fe956 100644 --- a/tests/L0/run_test.py +++ b/tests/L0/run_test.py @@ -13,20 +13,30 @@ import unittest import sys +from apex.testing.common_utils import TEST_WITH_ROCM +from apex.testing.common_utils import SKIP_FLAKY_TEST TEST_ROOT = os.path.dirname(os.path.abspath(__file__)) + +#the tests that are allowed TEST_DIRS = [ "run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_mlp", - "run_transformer", + "run_fused_dense", + "run_transformer", ] + +#the tests that are run by default DEFAULT_TEST_DIRS = [ + "run_amp", + "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_mlp", + "run_fused_dense", "run_transformer", ] @@ -56,9 +66,7 @@ def main(args): suite = unittest.TestLoader().discover(test_dir) print("\nExecuting tests from " + test_dir) - result = runner.run(suite) - if not result.wasSuccessful(): errcode = 1 @@ -68,3 +76,4 @@ def main(args): if __name__ == '__main__': args = parse_args() main(args) + diff --git a/tests/L0/run_transformer/test_fused_bias_swiglu.py b/tests/L0/run_transformer/test_fused_bias_swiglu.py new file mode 100644 index 000000000..e7c2e4793 --- /dev/null +++ b/tests/L0/run_transformer/test_fused_bias_swiglu.py @@ -0,0 +1,55 @@ +import torch +import fused_bias_swiglu +from torch.testing._internal import common_utils +import torch.nn.functional as F + + +class TestFusedBiasSwiGLU(common_utils.TestCase): + + def swiglu(self, y): + y_1, y_2 = torch.chunk(y, 2, -1) + return F.silu(y_1) * y_2 + + def bias_swiglu(self, y, bias): + y = y + bias + return self.swiglu(y) + + def swiglu_back(self, g, y): + y_1, y_2 = torch.chunk(y, 2, -1) + return torch.cat( + (g * torch.sigmoid(y_1) * (1 + y_1 * (1 - torch.sigmoid(y_1))) * y_2, g * F.silu(y_1)), -1 + ) + + def bias_swiglu_back(self, g, y, bias): + y = y + bias + return self.swiglu_back(g, y) + + def test_fused_bias_swiglu(self): + # Inputs + batch_size, hidden_dim = 16, 512 + dtypes = [torch.float32, torch.float64, torch.float16] + + for dtype in dtypes: + print(f"Testing with data type: {dtype}") + input = torch.randn(batch_size, hidden_dim, device="cuda", dtype=dtype) + bias = torch.randn(hidden_dim, device="cuda", dtype=dtype) + + try: + actual = fused_bias_swiglu.forward(input, bias) + expected = self.bias_swiglu(input, bias) + + self.assertEqual(actual, expected, atol=1e-3, rtol=1e-3) + + grad_output = torch.randn(batch_size, hidden_dim // 2, device="cuda", dtype=dtype) # Output gradient + actual_grad = fused_bias_swiglu.backward(grad_output, input, bias) + expected_grad = self.bias_swiglu_back(grad_output, input, bias) + self.assertEqual(actual_grad, expected_grad, atol=1e-3, rtol=1e-3) + + print(f"Test succeeded for data type: {dtype}") + except AssertionError as e: + print(f"Test failed for data type: {dtype}") + print(e) + + +if __name__ == "__main__": + common_utils.run_tests() \ No newline at end of file diff --git a/tests/L0/run_transformer/test_fused_rope.py b/tests/L0/run_transformer/test_fused_rope.py new file mode 100644 index 000000000..f578867a4 --- /dev/null +++ b/tests/L0/run_transformer/test_fused_rope.py @@ -0,0 +1,351 @@ +"""Test for fused RoPE functions. + +Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py +""" # NOQA + +import itertools + +import torch +from torch.testing._internal import common_utils +from apex.transformer.functional import ( + fused_apply_rotary_pos_emb, + fused_apply_rotary_pos_emb_cached, + fused_apply_rotary_pos_emb_thd, + fused_apply_rotary_pos_emb_2d, +) +from apex.transformer.functional.fused_rope import AITER_ROPE_BACKEND +ERROR_TOLERANCE=1e-3 +if AITER_ROPE_BACKEND: + ERROR_TOLERANCE=1e-2 + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Change sign so the last dimension becomes [-odd, +even] + + Args: + x (Tensor): Input tensor + + Returns: + Tensor: Tensor rotated half + """ + + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +# Copied from Megatron-Core for testing. +# https://github.com/NVIDIA/Megatron-LM/blob/5f2877d85cb26e47ce6dcdae4b80adf376abf4e8/megatron/core/models/common/embeddings/rotary_pos_embedding.py#L139 +def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + """Apply rotary positional embedding to input tensor T. + + check https://kexue.fm/archives/8265 for detailed formulas + + Args: + t (Tensor): Input tensor T is of shape [seq_length, ... , dim] + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] + + Returns: + Tensor: The input tensor after applying RoPE + """ + rot_dim = freqs.shape[-1] + + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + + t = (t * cos_) + (_rotate_half(t) * sin_) + return torch.cat((t, t_pass), dim=-1) + + +def apply_rotary_pos_emb_thd( + t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + """A baseline implementation of applying RoPE for `thd` format. + + Args: + t (Tensor): Input tensor T is of shape [t, h, d] + cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, + with shape [b + 1] and dtype torch.int32. + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + + Returns: + Tensor: Shape [t, h, d]. The input tensor after applying RoPE. + """ + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return torch.cat( + [ + apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) + for x in torch.split(t, seqlens) + ] + ).squeeze(1) + + +def apply_rotary_pos_emb_2d(q, img_h, img_w, cos_h, sin_h, cos_w, sin_w): + q = q.view(q.shape[0], img_h, img_w, q.shape[2], q.shape[3]) + q1, q2 = q.chunk(2, dim=-1) + cos_h = cos_h[:, :img_h].unsqueeze(2) # [1, H, 1, 1, D//2] + sin_h = sin_h[:, :img_h].unsqueeze(2) # [1, H, 1, 1, D//2] + q1 = (q1 * cos_h) + (_rotate_half(q1) * sin_h) + cos_w = cos_w[:, :img_w].unsqueeze(1) # [1, 1, W, 1, D//2] + sin_w = sin_w[:, :img_w].unsqueeze(1) # [1, 1, W, 1, D//2] + q2 = (q2 * cos_w) + (_rotate_half(q2) * sin_w) + return torch.cat([q1, q2], dim=-1).view(q.shape[0], -1, q.shape[3], q.shape[4]) + + +class TestFusedRoPE(common_utils.TestCase): + def setUp(self): + super().setUp() + self.batch_size = 2 + self.head_num = 64 + self.seq_length = [2048, 4096] + self.hidden_size = [128, 256] + self.rotary_percent = [0.5, 1.0] + self.dtype = [torch.float32, torch.bfloat16, torch.float16] + self.transpose = [None, (0, 1), (2, 3)] + self.transpose_output_memory = [False, True] + self.loss_func = [self._overlapping_grad, self._non_overlapping_grad] + self.cached = [False, True] + self.device = torch.cuda.current_device() + # for 2D RoPE + self.img_h = [32, 64] + self.img_w = [32, 64] + + def tearDown(self) -> None: + torch.cuda.empty_cache() + super().tearDown() + + def _overlapping_grad(self, output) -> torch.Tensor: + return output.sum() * 2 + + def _non_overlapping_grad(self, output) -> torch.Tensor: + t = torch.ones_like(output) + return torch.sum(output * t) + + def test_forward_backward(self): + for ( + dtype, + seq_length, + hidden_size, + rotary_percent, + transpose, + transpose_output_memory, + loss_func, + cached, + ) in itertools.product( + self.dtype, + self.seq_length, + self.hidden_size, + self.rotary_percent, + self.transpose, + self.transpose_output_memory, + self.loss_func, + self.cached, + ): + t = torch.rand( + (seq_length, self.batch_size, self.head_num, hidden_size), + dtype=dtype, + device=self.device, + ) + if transpose: + t = t.transpose(*transpose).contiguous().transpose(*transpose) + t.requires_grad = True + + emb = torch.rand( + (seq_length, 1, 1, int(hidden_size * rotary_percent)), + dtype=torch.float32, + device=self.device, + ) + + # unfused + output_unfused = apply_rotary_pos_emb(t, emb) + loss_unfused = loss_func(output_unfused) + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() + t.grad = None + + # fused + if cached: + cos, sin = emb.cos(), emb.sin() + output_fused = fused_apply_rotary_pos_emb_cached( + t, cos, sin, transpose_output_memory=transpose_output_memory + ) + else: + output_fused = fused_apply_rotary_pos_emb( + t, emb, transpose_output_memory=transpose_output_memory + ) + loss_fused = loss_func(output_fused) + loss_fused.backward() + grad_fused = t.grad.detach().clone() + t.grad = None + + self.assertEqual( + output_unfused, + output_fused, + msg=f"{dtype=}, {seq_length=}, {hidden_size=}, {rotary_percent=}, " + f"{transpose=}, {transpose_output_memory=}, loss_func={loss_func.__name__}", + atol=ERROR_TOLERANCE, + rtol=ERROR_TOLERANCE, + ) + self.assertEqual( + grad_unfused, + grad_fused, + msg=f"{dtype=}, {seq_length=}, {hidden_size=}, {rotary_percent=}, " + f"{transpose=}, {transpose_output_memory=}, loss_func={loss_func.__name__}", + atol=ERROR_TOLERANCE, + rtol=ERROR_TOLERANCE, + ) + assert ( + output_fused.transpose(0, 1).is_contiguous() is transpose_output_memory + ) + + def test_thd_forward_backward(self): + cu_seqlens = torch.tensor( + [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048], + dtype=torch.int32, + device=self.device, + ) + for ( + dtype, + hidden_size, + rotary_percent, + transpose, + loss_func, + ) in itertools.product( + self.dtype, + self.hidden_size, + self.rotary_percent, + [None, [1, 2]], + self.loss_func, + ): + t = torch.rand( + (cu_seqlens[-1], self.head_num, hidden_size), + dtype=dtype, + device=self.device, + ) + if transpose: + t = t.transpose(*transpose).contiguous().transpose(*transpose) + t.requires_grad = True + + emb = torch.rand( + (cu_seqlens[-1], 1, 1, int(hidden_size * rotary_percent)), + dtype=torch.float32, + device=self.device, + ) + + # unfused + output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb) + loss_unfused = loss_func(output_unfused) + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() + t.grad = None + + # fused + output_fused = fused_apply_rotary_pos_emb_thd( + t, + cu_seqlens, + emb, + ) + loss_fused = loss_func(output_fused) + loss_fused.backward() + grad_fused = t.grad.detach().clone() + t.grad = None + + self.assertEqual( + output_unfused, + output_fused, + msg=f"{dtype=}, {cu_seqlens=}, {hidden_size=}, {rotary_percent=}, " + f"{transpose=}, loss_func={loss_func.__name__}", + atol=ERROR_TOLERANCE, + rtol=ERROR_TOLERANCE, + ) + self.assertEqual( + grad_unfused, + grad_fused, + msg=f"{dtype=}, {cu_seqlens=}, {hidden_size=}, {rotary_percent=}, " + f"{transpose=}, loss_func={loss_func.__name__}", + atol=ERROR_TOLERANCE, + rtol=ERROR_TOLERANCE, + ) + + def test_2d_forward_backward(self): + for ( + dtype, + img_h, + img_w, + hidden_size, + transpose, + loss_func, + margin, + ) in itertools.product( + self.dtype, + self.img_h, + self.img_w, + self.hidden_size, + self.transpose, + self.loss_func, + [0, 3], + ): + t = torch.rand( + (self.batch_size, img_h * img_w, self.head_num, hidden_size), + dtype=dtype, + device=self.device, + ) + if transpose: + t = t.transpose(*transpose).contiguous().transpose(*transpose) + t.requires_grad = True + + emb_h = torch.rand( + (1, img_h + margin, 1, hidden_size // 2), + dtype=torch.float32, + device=self.device, + ) + cos_h, sin_h = emb_h.cos().to(dtype), emb_h.sin().to(dtype) + + emb_w = torch.rand( + (1, img_w + margin, 1, hidden_size // 2), + dtype=torch.float32, + device=self.device, + ) + cos_w, sin_w = emb_w.cos().to(dtype), emb_w.sin().to(dtype) + + # unfused + output_unfused = apply_rotary_pos_emb_2d( + t, img_h, img_w, cos_h, sin_h, cos_w, sin_w + ) + loss_unfused = loss_func(output_unfused) + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() + t.grad = None + + # fused + output_fused = fused_apply_rotary_pos_emb_2d( + t, img_h, img_w, cos_h, sin_h, cos_w, sin_w + ) + loss_fused = loss_func(output_fused) + loss_fused.backward() + grad_fused = t.grad.detach().clone() + t.grad = None + + self.assertEqual( + output_unfused, + output_fused, + msg=f"{dtype=}, {img_h=}, {img_w=}, {hidden_size=}, " + f"{transpose=}, loss_func={loss_func.__name__}", + atol=ERROR_TOLERANCE, + rtol=ERROR_TOLERANCE, + ) + self.assertEqual( + grad_unfused, + grad_fused, + msg=f"{dtype=}, {img_h=}, {img_w=}, {hidden_size=}, " + f"{transpose=}, loss_func={loss_func.__name__}", + atol=ERROR_TOLERANCE, + rtol=ERROR_TOLERANCE, + ) + + +if __name__ == "__main__": + common_utils.run_tests() \ No newline at end of file diff --git a/tests/L0/run_transformer/test_layers.py b/tests/L0/run_transformer/test_layers.py index b3b2eb2fc..9f3066907 100644 --- a/tests/L0/run_transformer/test_layers.py +++ b/tests/L0/run_transformer/test_layers.py @@ -398,6 +398,8 @@ def _row_parallel_linear_test_impl( chunks=tensor_model_parallel_world_size, dim=0, )[parallel_state.get_tensor_model_parallel_rank()], + atol=1e-4, + rtol=1e-3 ) parallel_state.destroy_model_parallel() diff --git a/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py b/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py index a409c40f2..1b044b845 100644 --- a/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py +++ b/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py @@ -212,7 +212,8 @@ def _forward_backward_test_impl( params = list(model_module.parameters()) rank = params[0].get_device() offset = pipeline_model_parallel_world_size - param_id = rank // data_parallel_size + vm_id * offset + param_id = parallel_state.get_pipeline_model_parallel_rank() + vm_id * pipeline_model_parallel_world_size + # param_id = rank // data_parallel_size + vm_id * offset target_params = target_model[param_id] self.assertEqual(params[0].cpu(), target_params[0]) diff --git a/tests/distributed/amp_master_params/amp_master_params.py b/tests/distributed/amp_master_params/amp_master_params.py index 4af5092f7..4b3a80498 100644 --- a/tests/distributed/amp_master_params/amp_master_params.py +++ b/tests/distributed/amp_master_params/amp_master_params.py @@ -9,6 +9,7 @@ # FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied # automatically by torch.distributed.launch. parser.add_argument("--local_rank", default=0, type=int) +parser.add_argument("--opt_level", default="O2", type=str) args = parser.parse_args() # FOR DISTRIBUTED: If we are running under torch.distributed.launch, @@ -42,7 +43,7 @@ model = torch.nn.Linear(D_in, D_out).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) -model, optimizer = amp.initialize(model, optimizer, opt_level="O2") +model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) if args.distributed: # FOR DISTRIBUTED: After amp.initialize, wrap the model with diff --git a/tests/distributed/amp_master_params/compare.py b/tests/distributed/amp_master_params/compare.py index e5cbf20c1..b8047752a 100644 --- a/tests/distributed/amp_master_params/compare.py +++ b/tests/distributed/amp_master_params/compare.py @@ -14,6 +14,9 @@ model_params_rank1, master_params_rank0, master_params_rank1): + # converting model params to float is a hack since allclose doesn't support bfloat16 yet. + model_rank0 = model_rank0.float() + model_rank1 = model_rank1.float() assert torch.allclose(model_rank0, model_rank1), "Model param mismatch" assert torch.allclose(master_rank0, master_rank1), "Master param mismatch" # Some debugging/investigation assistance code: @@ -23,6 +26,6 @@ # print(maxval.item(), maxind.item(), offending_val_half.item(), offending_val_float.item(), # offending_val_float.half().item()) # rtol needs to be > 2^-11 because of denormals... - assert torch.allclose(model_rank0, master_rank0.half(), rtol=.005), "Model-master mismatch" + assert torch.allclose(model_rank0, master_rank0, rtol=.005), "Model-master mismatch" print("OK: Model and master params match across ranks.") diff --git a/tests/distributed/run_rocm_distributed.sh b/tests/distributed/run_rocm_distributed.sh new file mode 100755 index 000000000..322137bbd --- /dev/null +++ b/tests/distributed/run_rocm_distributed.sh @@ -0,0 +1,49 @@ +#!/bin/bash -x +set -e + +# To run the test on 2 gpus +export WORLD_SIZE=2 + +torchrun=`dirname \`which python\``/torchrun + +# Test with opt_level="O2" +echo "running opt_level O2" +# python -m torch.distributed.launch --nproc_per_node=2 amp_master_params/amp_master_params.py --opt_level "O2" +python $torchrun --nproc_per_node=2 amp_master_params/amp_master_params.py --opt_level "O2" +python amp_master_params/compare.py + +# delete the model files +echo -e "O2 test completed. Deleting model files\n" +rm rank0model.pth +rm rank1model.pth +rm rank0master.pth +rm rank1master.pth + + +# Test with opt_level="O5" +#echo "running opt_level O5" +#python -m torch.distributed.launch --nproc_per_node=2 amp_master_params/amp_master_params.py --opt_level "O5" +#python amp_master_params/compare.py + +## delete the model files +#echo "O5 test completed. Deleting model files" +#rm rank0model.pth +#rm rank1model.pth +#rm rank0master.pth +#rm rank1master.pth + +## Run the Sync BN Tests. +echo "Running syncbn tests" +python -m torch.distributed.launch --nproc_per_node=2 synced_batchnorm/two_gpu_unit_test.py +python -m torch.distributed.launch --nproc_per_node=2 synced_batchnorm/two_gpu_unit_test.py --fp16 +python -m torch.distributed.launch --nproc_per_node=2 synced_batchnorm/two_gpu_test_different_batch_size.py --apex +echo "Running syncbn python only tests" +python synced_batchnorm/python_single_gpu_unit_test.py +echo "Running syncbn batchnorm1d tests" +python synced_batchnorm/test_batchnorm1d.py +#beware, you need a system with at least 4 gpus to test group_size&1 | tee ../../$LOG_FILE +cd ../../ + +cd apex/contrib/test +PYTHONUNBUFFERED=1 python run_rocm_extensions.py 2>&1 | tee ../../../$LOG_FILE2 +cd ../../../ + +torchrun --nproc_per_node 8 apex/contrib/peer_memory/peer_halo_exchange_module_tests.py 2>&1 | tee -a $LOG_FILE2 + +cd tests/distributed/synced_batchnorm +sh unit_test.sh 2>&1 | tee -a ../../../$LOG_FILE2 +cd ../../../ + +#explicitly load the builder and build the remaining extensions +python tests/jit_build/load_extra_extensions.py 2>&1 | tee $LOG_FILE + +FAILED_TESTS=$(python tests/jit_build/count_failed_unit_tests.py $LOG_FILE) +FAILED_TESTS2=$(python tests/jit_build/count_failed_unit_tests.py $LOG_FILE2) +BUILT_SO_COUNT=$(python tests/jit_build/count_built_so.py) +TORCH_EXTENSIONS_COUNT=$(python tests/jit_build/count_torch_extensions.py) + +echo "Failed L0 tests = $FAILED_TESTS" +echo "Failed contrib tests = $FAILED_TESTS2" +echo ".so count = $BUILT_SO_COUNT" +echo "JIT torch extensions count = $TORCH_EXTENSIONS_COUNT" + +echo "$FAILED_TESTS $FAILED_TESTS2 $BUILT_SO_COUNT $TORCH_EXTENSIONS_COUNT" \ No newline at end of file diff --git a/tests/jit_build/scripts/run.sh b/tests/jit_build/scripts/run.sh new file mode 100644 index 000000000..aeb41fadd --- /dev/null +++ b/tests/jit_build/scripts/run.sh @@ -0,0 +1,25 @@ +#parse the arguments +JIT_CONDITION="$2" + +echo $(pwd) + +WORKSPACE_DIR=/myworkspace +mkdir -p $WORKSPACE_DIR + +cd $WORKSPACE_DIR +git clone https://github.com/rocm/apex.git --recursive +cd apex +git checkout Refactor_build +git submodule update --init --recursive + +sh tests/jit_build/build.sh "condition" $JIT_CONDITION + +# Capture the output from run_tests.sh +TEST_RESULTS=$(sh tests/jit_build/run_tests.sh "condition" $JIT_CONDITION | tail -1) + +# Parse the returned values +read FAILED_TESTS FAILED_TESTS2 BUILT_SO_COUNT TORCH_EXTENSIONS_COUNT <<< "$TEST_RESULTS" + +MULTIPLE_RESULTS_FILE="../results_jit_unit_test.csv" +#echo "condition,failed unit tests" > "$MULTIPLE_RESULTS_FILE" +echo "$JIT_CONDITION,$FAILED_TESTS,$FAILED_TESTS2,$BUILT_SO_COUNT,$TORCH_EXTENSIONS_COUNT" >> "$MULTIPLE_RESULTS_FILE" \ No newline at end of file diff --git a/tests/test_extension_import.py b/tests/test_extension_import.py new file mode 100644 index 000000000..e5fc8ebfd --- /dev/null +++ b/tests/test_extension_import.py @@ -0,0 +1,248 @@ +import unittest +import os +import subprocess +import sys +import site +import ast +from apex.op_builder.all_ops import ALL_OPS + + +class TestExtensionImport(unittest.TestCase): + + def __init__(self, *args, **kwargs): + super(TestExtensionImport, self).__init__(*args, **kwargs) + + self.jit_info_file = "apex/git_version_info_installed.py" + + #find the absolute path of this file + current_file_path = os.path.abspath(__file__) + + #get the absolute path of the parent folder of this file + #tests folder + parent_folder_path = os.path.dirname(current_file_path) + #apex folder + parent_folder_path = os.path.dirname(parent_folder_path) + self.parent_folder_path = parent_folder_path + + def is_jit_modules_mode(self): + """ + This method checks if the file git_version_info_installed.py exists + """ + jit_file_path = os.path.join(site.getsitepackages()[0], self.jit_info_file) + #print ("jit_file_path", jit_file_path) + mode = os.path.exists(jit_file_path) + print ("jit_mode", mode) + return mode + + def get_extensions_list_from_setup(self): + """ + This method reads setup.py and gets the list of extensions from the setup.py file + """ + + #get setup.py file contents + setup_path = os.path.join(self.parent_folder_path, "setup.py") + + #read setup_path contents + with open(setup_path, 'r') as f: + setup_contents = f.readlines() + + #print ("length", len(setup_contents)) + #get the list of extensions from setup.py + extensions = [] + line_index = 0 + found = 0 + while line_index < len(setup_contents): + line = setup_contents[line_index] + if "CUDAExtension" in line: + found += 1 + if found == 1: + continue + #print ("extension", line, line_index) + + if "name"in line: + name_line = line.strip() + else: + #get the next line + line_index += 1 + name_line = setup_contents[line_index].strip() + + #extract the name part + if "name" in name_line: + if "'" in name_line: + name = name_line[name_line.find("name") + 6 : name_line.rfind("'")] + else: + name = name_line[name_line.find("name") + 6 : name_line.rfind('"')] + extensions.append(name) + + line_index += 1 + + return extensions + + + def get_jit_modules(self): + """ + This method reads the jit file and extracts installed_ops dictionary + """ + jit_info_path = os.path.join(site.getsitepackages()[0], self.jit_info_file) + with open(jit_info_path, 'r') as f: + lines = f.readlines() + for line in lines: + if "installed_ops" in line: + ops_list = line[len("installed_ops") + 1 : ] + ops_list = ast.literal_eval(ops_list) + #print ("op_list", ops_list) + return list(ops_list.keys()) + return {} + + def get_environment(self): + """ + This method retrieves the environment for testing import + otherwise get ImportError: libc10.so: cannot open shared object file: No such file or directory + """ + # Get current environment and ensure CUDA/PyTorch libraries are available + env = os.environ.copy() + + ld_library_path = env.get('LD_LIBRARY_PATH', '') + extra_paths = [] + + # Add PyTorch library path + try: + import torch + torch_lib_path = os.path.join(os.path.dirname(torch.__file__), 'lib') + if os.path.exists(torch_lib_path): + extra_paths.append(torch_lib_path) + except ImportError: + pass + + # Add ROCm library path if present + rocm_path = os.environ.get('ROCM_PATH', '/opt/rocm') + rocm_lib = os.path.join(rocm_path, 'lib') + if os.path.exists(rocm_lib): + extra_paths.append(rocm_lib) + + # Add common CUDA library paths (only those that exist) + for path in ['/usr/local/cuda/lib64', '/usr/local/cuda/lib', + '/opt/conda/lib', '/usr/lib/x86_64-linux-gnu']: + if os.path.isdir(path): + extra_paths.append(path) + + # Update LD_LIBRARY_PATH + if ld_library_path: + env['LD_LIBRARY_PATH'] = ':'.join(extra_paths) + ':' + ld_library_path + else: + env['LD_LIBRARY_PATH'] = ':'.join(extra_paths) + return env + + + def check_extension_import(self, extension_name, env): + """ + Check if an extension can be imported successfully using subprocess + Returns True if import successful, False if ImportError occurs + """ + try: + + # Run Python subprocess to test the import + result = subprocess.run([ + sys.executable, '-c', + 'import ' + extension_name + ], capture_output=True, text=True, timeout=30, env=env) + print ("result.stdout", result.stdout, result.stderr) + # Check if subprocess completed successfully + if result.returncode != 0 and "Error" in result.stderr: + return False, result.stderr + else: + return True, "" + + except subprocess.TimeoutExpired: + print(f"Import test timed out for {extension_name}") + return False, "Timeout" + except Exception as e: + print(f"Error testing import for {extension_name}: {e}") + return False, str(e) + + def check_jit_extension_import(self, extension_name, env): + all_ops = dict.fromkeys(ALL_OPS.keys(), False) + #get the builder for that extension + builder = ALL_OPS[extension_name] + builder_name = type(builder).__name__ + #print ("----builder_name-----", builder_name) + + #increase timeout + timeout = 60 * 60 + try: + # Run Python subprocess to test the import + result = subprocess.run([ + sys.executable, '-c', + 'from apex.op_builder import ' + builder_name + + '\n' + builder_name + "().load()" + ], capture_output=True, text=True, timeout=timeout, env=env) + print ("result.stdout", result.stdout, result.stderr) + # Check if subprocess completed successfully + if result.returncode != 0 and "Error" in result.stderr: + return False, result.stderr + else: + return True, "" + + except subprocess.TimeoutExpired: + print(f"Import test timed out for {extension_name}") + return False, "Timeout" + except Exception as e: + print(f"Error testing import for {extension_name}: {e}") + return False, str(e) + + + def test_extensions_import(self): + #check the extensions mode + jit_mode = self.is_jit_modules_mode() + + if not jit_mode: + #get the list of extensions from setup.py + extensions = self.get_extensions_list_from_setup() + else: + extensions = self.get_jit_modules() + + #get environment + env = self.get_environment() + + #import all the extensions + results = [] + for extension in extensions: + print ("checking extension", extension) + with self.subTest(extension=extension): + if not jit_mode: + success, error_message = self.check_extension_import(extension, env) + else: + success, error_message = self.check_jit_extension_import(extension, env) + #self.assertTrue(success, f"Failed to import extension: {extension}") + results.append((extension, success, error_message)) + + # Sort results by success status (True first, then False) + sorted_results = sorted(results, key=lambda x: (not x[1], x[0])) + + #save results to a extension_import_results.txt file + results_file_path = os.path.join(self.parent_folder_path, "extension_import_results.csv") + with open(results_file_path, 'w') as f: + f.write("Extension,Success,Error Message\n") + for extension, success, error_message in results: + f.write(f"{extension},{success},{error_message}\n") + + #print the results as a table + print("\nExtension Import Results:") + print("-" * 60) + print(f"{'Extension':<30} {'Success':<10} {'Error Message':<20}") + print("-" * 60) + for extension, success, error_message in sorted_results: + error_display = error_message[:17] + "..." if len(error_message) > 20 else error_message + print(f"{extension:<30} {success:<10} {error_display:<20}") + print("-" * 60) + + # Fail the test if any extensions failed to import + failed_extensions = [ext for ext, success, _ in results if not success] + self.assertEqual( + len(failed_extensions), 0, + f"{len(failed_extensions)} extension(s) failed to import: {', '.join(failed_extensions)}" + ) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/third_party/aiter b/third_party/aiter new file mode 160000 index 000000000..56824f8f2 --- /dev/null +++ b/third_party/aiter @@ -0,0 +1 @@ +Subproject commit 56824f8f221584862216bea0ac738c232f538e4c diff --git a/version.txt b/version.txt new file mode 100644 index 000000000..1cac385c6 --- /dev/null +++ b/version.txt @@ -0,0 +1 @@ +1.11.0