From c4a4ee5cf5ae06519945b0471125f5ac7ad98fd4 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Thu, 1 Sep 2022 18:20:09 +0000 Subject: [PATCH 01/19] add build for --fast_layer_norm --- setup.py | 53 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/setup.py b/setup.py index 03ae2b7bc..591145d98 100644 --- a/setup.py +++ b/setup.py @@ -362,23 +362,42 @@ def check_if_rocm_pytorch(): if "--fast_layer_norm" in sys.argv: sys.argv.remove("--fast_layer_norm") - raise_if_cuda_home_none("--fast_layer_norm") - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11: - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - - if CUDA_HOME is None and not IS_ROCM_PYTORCH: - raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - else: - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_80,code=sm_80') + #raise_if_cuda_home_none("--fast_layer_norm") + ## Check, if CUDA11 is installed for compute capability 8.0 + #cc_flag = [] + #_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) + #if int(bare_metal_major) >= 11: + # cc_flag.append("-gencode") + # cc_flag.append("arch=compute_80,code=sm_80") + + #if CUDA_HOME is None and not IS_ROCM_PYTORCH: + # raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") + #else: + # # Check, if CUDA11 is installed for compute capability 8.0 + # cc_flag = [] + # _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) + # if int(bare_metal_major) >= 11: + # cc_flag.append('-gencode') + # cc_flag.append('arch=compute_80,code=sm_80') + + nvcc_args_fast_layer_norm = ['-maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros + hipcc_args_fast_layer_norm = ['-O3'] + version_dependent_macros + print ("INFO: Building fast layernorm extension.") + ext_modules.append( + CUDAExtension(name='fast_layer_norm_cuda', + sources=[#'apex/contrib/csrc/layer_norm/ln.h', + 'apex/contrib/csrc/layer_norm/ln_api.cpp', + #'apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh', + 'apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu', + 'apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu', + #'apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh', + #'apex/contrib/csrc/layer_norm/ln_kernel_traits.h', + #'apex/contrib/csrc/layer_norm/ln_utils.cuh' + ], + include_dirs=[os.path.join(this_dir, 'csrc'), + os.path.join(this_dir, 'apex/contrib/csrc/layer_norm')], + extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, + 'nvcc': nvcc_args_fast_layer_norm if not IS_ROCM_PYTORCH else hipcc_args_fast_layer_norm})) if "--fmha" in sys.argv: sys.argv.remove("--fmha") From c1e3a7218498138647d32204eb5b12b31a9d1abf Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Tue, 6 Sep 2022 15:29:30 +0000 Subject: [PATCH 02/19] use HIP bfloat16 header --- apex/contrib/csrc/layer_norm/ln.h | 8 ++++++++ apex/contrib/csrc/layer_norm/ln_utils.cuh | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/apex/contrib/csrc/layer_norm/ln.h b/apex/contrib/csrc/layer_norm/ln.h index 07392a192..1b1c81a61 100644 --- a/apex/contrib/csrc/layer_norm/ln.h +++ b/apex/contrib/csrc/layer_norm/ln.h @@ -2,7 +2,11 @@ #include #include +#ifdef USE_ROCM +#include +#else #include +#endif namespace layer_norm { @@ -121,7 +125,11 @@ extern BwdRegistry BWD_FUNCS; using fp32 = float; using fp16 = half; +#ifdef USE_ROCM +using bf16 = hip_bfloat16; +#else using bf16 = nv_bfloat16; +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/layer_norm/ln_utils.cuh b/apex/contrib/csrc/layer_norm/ln_utils.cuh index e18d36de7..0183716d8 100644 --- a/apex/contrib/csrc/layer_norm/ln_utils.cuh +++ b/apex/contrib/csrc/layer_norm/ln_utils.cuh @@ -2,7 +2,11 @@ #include +#ifdef USE_ROCM +#include +#else #include +#endif #include #include "ln.h" From 09d7be8c192d4d13e6ec90ee72af70f0f0575a96 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Tue, 6 Sep 2022 15:34:13 +0000 Subject: [PATCH 03/19] missing header --- apex/contrib/csrc/layer_norm/ln.h | 1 + 1 file changed, 1 insertion(+) diff --git a/apex/contrib/csrc/layer_norm/ln.h b/apex/contrib/csrc/layer_norm/ln.h index 1b1c81a61..415a9acd4 100644 --- a/apex/contrib/csrc/layer_norm/ln.h +++ b/apex/contrib/csrc/layer_norm/ln.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #ifdef USE_ROCM From fc1345940c954a6d653ce850a325bcf7df4a116c Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Tue, 6 Sep 2022 21:21:13 +0000 Subject: [PATCH 04/19] warp size considerations, TODOs --- apex/contrib/csrc/layer_norm/hip_bfloat162.h | 287 ++++++++++++++++++ .../layer_norm/ln_bwd_semi_cuda_kernel.cu | 10 +- apex/contrib/csrc/layer_norm/ln_utils.cuh | 59 ++++ 3 files changed, 355 insertions(+), 1 deletion(-) create mode 100644 apex/contrib/csrc/layer_norm/hip_bfloat162.h diff --git a/apex/contrib/csrc/layer_norm/hip_bfloat162.h b/apex/contrib/csrc/layer_norm/hip_bfloat162.h new file mode 100644 index 000000000..2d1d79f8b --- /dev/null +++ b/apex/contrib/csrc/layer_norm/hip_bfloat162.h @@ -0,0 +1,287 @@ +/** + * MIT License + * + * Copyright (c) 2019 - 2021 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/*!\file + * \brief hip_bfloat162.h provides struct for hip_bfloat162 typedef + */ + +#ifndef _hip_bfloat162_H_ +#define _hip_bfloat162_H_ + +#if __cplusplus < 201103L || !defined(__HIPCC__) + +// If this is a C compiler, C++ compiler below C++11, or a host-only compiler, we only +// include a minimal definition of hip_bfloat162 + +#include +/*! \brief Struct to represent a 16 bit brain floating point number. */ +typedef struct +{ + uint16_t data; + uint16_t data2; +} hip_bfloat162; + +#else // __cplusplus < 201103L || !defined(__HIPCC__) + +#include +#include +#include +#include +#include +#include +#include + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wshadow" +struct hip_bfloat162 +{ + uint16_t data; + uint16_t data2; + + enum truncate_t + { + truncate + }; + + __host__ __device__ hip_bfloat162() = default; + + // round upper 16 bits of IEEE float to convert to bfloat16 + explicit __host__ __device__ hip_bfloat162(float f) + : data(float_to_bfloat16(f)) + { + } + + explicit __host__ __device__ hip_bfloat162(float f, truncate_t) + : data(truncate_float_to_bfloat16(f)) + { + } + + // zero extend lower 16 bits of bfloat16 to convert to IEEE float + __host__ __device__ operator float() const + { + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(data) << 16}; + return u.fp32; + } + + static __host__ __device__ hip_bfloat162 round_to_bfloat16(float f) + { + hip_bfloat162 output; + output.data = float_to_bfloat16(f); + return output; + } + + static __host__ __device__ hip_bfloat162 round_to_bfloat16(float f, truncate_t) + { + hip_bfloat162 output; + output.data = truncate_float_to_bfloat16(f); + return output; + } + +private: + static __host__ __device__ uint16_t float_to_bfloat16(float f) + { + union + { + float fp32; + uint32_t int32; + } u = {f}; + if(~u.int32 & 0x7f800000) + { + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even + } + else if(u.int32 & 0xffff) + { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + u.int32 |= 0x10000; // Preserve signaling NaN + } + return uint16_t(u.int32 >> 16); + } + + // Truncate instead of rounding, preserving SNaN + static __host__ __device__ uint16_t truncate_float_to_bfloat16(float f) + { + union + { + float fp32; + uint32_t int32; + } u = {f}; + return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff)); + } +}; +#pragma clang diagnostic pop + +typedef struct +{ + uint16_t data; + uint16_t data2; +} hip_bfloat162_public; + +static_assert(std::is_standard_layout{}, + "hip_bfloat162 is not a standard layout type, and thus is " + "incompatible with C."); + +static_assert(std::is_trivial{}, + "hip_bfloat162 is not a trivial type, and thus is " + "incompatible with C."); + +static_assert(sizeof(hip_bfloat162) == sizeof(hip_bfloat162_public) + && offsetof(hip_bfloat162, data) == offsetof(hip_bfloat162_public, data), + "internal hip_bfloat162 does not match public hip_bfloat162"); + +inline std::ostream& operator<<(std::ostream& os, const hip_bfloat162& bf16) +{ + return os << float(bf16); +} +inline __host__ __device__ hip_bfloat162 operator+(hip_bfloat162 a) +{ + return a; +} +inline __host__ __device__ hip_bfloat162 operator-(hip_bfloat162 a) +{ + a.data ^= 0x8000; + return a; +} +inline __host__ __device__ hip_bfloat162 operator+(hip_bfloat162 a, hip_bfloat162 b) +{ + return hip_bfloat162(float(a) + float(b)); +} +inline __host__ __device__ hip_bfloat162 operator-(hip_bfloat162 a, hip_bfloat162 b) +{ + return hip_bfloat162(float(a) - float(b)); +} +inline __host__ __device__ hip_bfloat162 operator*(hip_bfloat162 a, hip_bfloat162 b) +{ + return hip_bfloat162(float(a) * float(b)); +} +inline __host__ __device__ hip_bfloat162 operator/(hip_bfloat162 a, hip_bfloat162 b) +{ + return hip_bfloat162(float(a) / float(b)); +} +inline __host__ __device__ bool operator<(hip_bfloat162 a, hip_bfloat162 b) +{ + return float(a) < float(b); +} +inline __host__ __device__ bool operator==(hip_bfloat162 a, hip_bfloat162 b) +{ + return float(a) == float(b); +} +inline __host__ __device__ bool operator>(hip_bfloat162 a, hip_bfloat162 b) +{ + return b < a; +} +inline __host__ __device__ bool operator<=(hip_bfloat162 a, hip_bfloat162 b) +{ + return !(a > b); +} +inline __host__ __device__ bool operator!=(hip_bfloat162 a, hip_bfloat162 b) +{ + return !(a == b); +} +inline __host__ __device__ bool operator>=(hip_bfloat162 a, hip_bfloat162 b) +{ + return !(a < b); +} +inline __host__ __device__ hip_bfloat162& operator+=(hip_bfloat162& a, hip_bfloat162 b) +{ + return a = a + b; +} +inline __host__ __device__ hip_bfloat162& operator-=(hip_bfloat162& a, hip_bfloat162 b) +{ + return a = a - b; +} +inline __host__ __device__ hip_bfloat162& operator*=(hip_bfloat162& a, hip_bfloat162 b) +{ + return a = a * b; +} +inline __host__ __device__ hip_bfloat162& operator/=(hip_bfloat162& a, hip_bfloat162 b) +{ + return a = a / b; +} +inline __host__ __device__ hip_bfloat162& operator++(hip_bfloat162& a) +{ + return a += hip_bfloat162(1.0f); +} +inline __host__ __device__ hip_bfloat162& operator--(hip_bfloat162& a) +{ + return a -= hip_bfloat162(1.0f); +} +inline __host__ __device__ hip_bfloat162 operator++(hip_bfloat162& a, int) +{ + hip_bfloat162 orig = a; + ++a; + return orig; +} +inline __host__ __device__ hip_bfloat162 operator--(hip_bfloat162& a, int) +{ + hip_bfloat162 orig = a; + --a; + return orig; +} + +namespace std +{ + constexpr __host__ __device__ bool isinf(hip_bfloat162 a) + { + return !(~a.data & 0x7f80) && !(a.data & 0x7f); + } + constexpr __host__ __device__ bool isnan(hip_bfloat162 a) + { + return !(~a.data & 0x7f80) && +(a.data & 0x7f); + } + constexpr __host__ __device__ bool iszero(hip_bfloat162 a) + { + return !(a.data & 0x7fff); + } +} + +#endif // __cplusplus < 201103L || !defined(__HIPCC__) + +#endif // _hip_bfloat162_H_ diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu index 3893d4e0c..b9d5b5422 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -35,7 +35,7 @@ void launch_(LaunchParams &launch_params, const bool configure_params if( configure_params ) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + cudaError_t status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; launch_params.barrier_size = 0; @@ -52,7 +52,11 @@ void launch_(LaunchParams &launch_params, const bool configure_params } if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { + // hipify missing cudaFuncSetAttribute, cudaFuncAttributeMaxDynamicSharedMemorySize +#ifdef USE_ROCM +#else CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); +#endif } auto stream = launch_params.stream; auto ctas_per_col = launch_params.params.ctas_per_col; @@ -60,10 +64,14 @@ void launch_(LaunchParams &launch_params, const bool configure_params if( Kernel_traits::CTAS_PER_ROW == 1 ) { kernel<<>>(launch_params.params); } else { +#ifdef USE_ROCM + assert(0 && "hipLaunchCooperativeKernel TODO"); +#else dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = (void *)&launch_params.params; cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); +#endif } using Kernel_traits_f = layer_norm::Kernel_traits_finalize +#include "hip_bfloat162.h" #else #include #endif @@ -13,7 +14,11 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef USE_ROCM +constexpr uint32_t THREADS_PER_WARP = warpSize; +#else constexpr uint32_t THREADS_PER_WARP = 32; +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -92,7 +97,11 @@ struct Sum { template inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){ +#ifdef USE_ROCM + return __shfl_xor(x, idx); +#else return __shfl_xor_sync(uint32_t(-1), x, idx); +#endif } template<> @@ -102,7 +111,11 @@ inline __device__ float2 warp_shuffle_xor(const float2 & x, uint32_t idx template inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){ +#ifdef USE_ROCM + return __shfl_down(x, idx); +#else return __shfl_down_sync(uint32_t(-1), x, idx); +#endif } template<> @@ -192,10 +205,17 @@ struct TypeToVec2 { using Type = half2; }; +#ifdef USE_ROCM +template<> +struct TypeToVec2 { + using Type = hip_bfloat162; +}; +#else template<> struct TypeToVec2 { using Type = nv_bfloat162; }; +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -245,6 +265,15 @@ struct Converter{ } }; +#ifdef USE_ROCM +template<> +struct Converter{ + static inline __device__ hip_bfloat162 convert(const float2 &x) { + hip_bfloat162 raw; + return raw; + } +}; +#else template<> struct Converter{ static inline __device__ nv_bfloat162 convert(const float2 &x) { @@ -262,6 +291,7 @@ struct Converter{ #endif } }; +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -335,10 +365,19 @@ struct InterCTASync { } inline __device__ void spin_wait_(int *barrier, int step, int expected) { +#ifdef USE_ROCM + atomicAdd(barrier, step); + for( int found = -1; found != expected; ) { + found = *barrier; + } +#else + // reduction operation on global and shared memory + // reduction with operand b and value in a, store result at location a, overwriting original value asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); for( int found = -1; found != expected; ) { asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); } +#endif } inline __device__ void sync(){ @@ -403,7 +442,11 @@ struct Reducer : public Reducer { workspace[bidn_] = data; } inter_cta_.sync(); +#ifdef USE_ROCM + static_assert(CTAS_PER_ROW <= warpSize); +#else static_assert(CTAS_PER_ROW <= 32); +#endif T total = Zeros::get(); if(this->lane_ < CTAS_PER_ROW){ total = workspace[this->lane_]; @@ -429,7 +472,11 @@ struct Reducer { enum { SMEM_BYTES = 0 }; enum { WORKSPACE_BYTES_PER_GROUP = 0 }; +#ifdef USE_ROCM + enum { THREADS_PER_WARP = 64 }; +#else enum { THREADS_PER_WARP = 32 }; +#endif template inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) @@ -477,7 +524,11 @@ struct Reducer : public Reducer { enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; enum { WORKSPACE_BYTES_PER_GROUP = 0 }; +#ifdef USE_ROCM + enum { THREADS_PER_WARP = 64 }; +#else enum { THREADS_PER_WARP = 32 }; +#endif template inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) @@ -606,7 +657,11 @@ struct Stats { T m2 = Zeros::get(); // Assume CTA group size in N less than 32, such that we can finalize with a single warp. +#ifdef USE_ROCM + static_assert(CTAS_PER_ROW <= warpSize); +#else static_assert(CTAS_PER_ROW <= 32); +#endif // Every warp does the final reduction locally. if( lane_ < CTAS_PER_ROW ) { @@ -671,7 +726,11 @@ struct Stats { T m2 = Zeros::get(); // Assume that there are less than 32 warps, such that we can finalize with a single warp +#ifdef USE_ROCM + static_assert(WARPS_N <= warpSize); +#else static_assert(WARPS_N <= 32); +#endif if(lane < WARPS_N){ stats_t result = smem[lane]; n = N * THREADS_PER_WARP; From 1228e3c6394d424c6876c6dac0a42e5e194a61b3 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Tue, 6 Sep 2022 21:57:13 +0000 Subject: [PATCH 05/19] work around some compiler errors --- apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh | 3 ++- .../csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu | 3 ++- apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu | 11 ++++++++++- setup.py | 2 +- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh index 8595f5ed4..00a5fa153 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh +++ b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh @@ -18,6 +18,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; using compute_t = typename Ktraits::compute_t; + using input_t = typename Ktraits::input_t; using index_t = typename Ktraits::index_t; using Ivec = typename Ktraits::Ivec; using Ovec = typename Ktraits::Ovec; @@ -119,7 +120,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { compute_t dy_tmp = dy[it * NUM_ELTS + jt]; compute_t y_tmp = y[it * NUM_ELTS + jt]; compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); - dx[it].data.elt[jt] = dx_tmp; + dx[it].data.elt[jt] = input_t(dx_tmp); } dx[it].store_to(params.dx, idx); idx += Ktraits::VEC_COLS_PER_LDG; diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu index b9d5b5422..8849907f3 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -54,8 +54,9 @@ void launch_(LaunchParams &launch_params, const bool configure_params if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { // hipify missing cudaFuncSetAttribute, cudaFuncAttributeMaxDynamicSharedMemorySize #ifdef USE_ROCM + CHECK_CUDA(hipFuncSetAttribute((const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); #else - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); + CHECK_CUDA(cudaFuncSetAttribute((const void *)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); #endif } auto stream = launch_params.stream; diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu index 62ff78ee3..bc45b6da8 100644 --- a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu @@ -51,7 +51,12 @@ void launch_(LaunchParams &launch_params, const bool configure_params } if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); + // hipify missing cudaFuncSetAttribute, cudaFuncAttributeMaxDynamicSharedMemorySize +#ifdef USE_ROCM + CHECK_CUDA(hipFuncSetAttribute((const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); +#else + CHECK_CUDA(cudaFuncSetAttribute((const void *)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); +#endif } auto stream = launch_params.stream; auto ctas_per_col = launch_params.params.ctas_per_col; @@ -59,10 +64,14 @@ void launch_(LaunchParams &launch_params, const bool configure_params if( Kernel_traits::CTAS_PER_ROW == 1 ) { kernel<<>>(launch_params.params); } else { +#ifdef USE_ROCM + assert(0 && "hipLaunchCooperativeKernel TODO"); +#else dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = (void *)&launch_params.params; cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); +#endif } } diff --git a/setup.py b/setup.py index 591145d98..f9be45d37 100644 --- a/setup.py +++ b/setup.py @@ -381,7 +381,7 @@ def check_if_rocm_pytorch(): # cc_flag.append('arch=compute_80,code=sm_80') nvcc_args_fast_layer_norm = ['-maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros - hipcc_args_fast_layer_norm = ['-O3'] + version_dependent_macros + hipcc_args_fast_layer_norm = ['-O3', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros print ("INFO: Building fast layernorm extension.") ext_modules.append( CUDAExtension(name='fast_layer_norm_cuda', From f2399d697f8fa6d41b31f294433941d4f4d00f2e Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 7 Sep 2022 01:06:10 +0000 Subject: [PATCH 06/19] it finally compiles --- .../layer_norm/ln_bwd_semi_cuda_kernel.cu | 124 +++++++++--------- .../csrc/layer_norm/ln_fwd_cuda_kernel.cu | 52 ++++---- .../csrc/layer_norm/ln_kernel_traits.h | 12 +- apex/contrib/csrc/layer_norm/ln_utils.cuh | 5 + 4 files changed, 107 insertions(+), 86 deletions(-) diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu index 8849907f3..75363cbd4 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -88,13 +88,19 @@ void launch_(LaunchParams &launch_params, const bool configure_params kernel_f<<>>(launch_params.params); } +#ifdef USE_ROCM +constexpr bool is_rocm = true; +#else +constexpr bool is_rocm = false; +#endif + // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); @@ -103,11 +109,11 @@ REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); @@ -115,23 +121,23 @@ REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 2 : 4, 4); +REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 2 : 4, 4); +REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 2 : 4, 4); +REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 2 : 4, 4); +REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); @@ -140,15 +146,15 @@ REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); @@ -158,34 +164,34 @@ REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 2 : 4, 4); +REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 2 : 4, 4); +REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); @@ -193,35 +199,35 @@ REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4); -REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4); -REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4); -REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4); -REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, is_rocm ? 2 : 4, 4); +REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, is_rocm ? 2 : 4, 4); +REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, is_rocm ? 4 : 8, 4); REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); @@ -230,15 +236,15 @@ REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu index bc45b6da8..bc7f5797b 100644 --- a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu @@ -34,7 +34,7 @@ void launch_(LaunchParams &launch_params, const bool configure_params if( configure_params ) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + cudaError_t status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; launch_params.barrier_size = 0; @@ -76,11 +76,17 @@ void launch_(LaunchParams &launch_params, const bool configure_params } +#ifdef USE_ROCM +constexpr bool is_rocm = true; +#else +constexpr bool is_rocm = false; +#endif + REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, is_rocm ? 8 : 16); REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, is_rocm ? 8 : 16); REGISTER_FWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); @@ -102,21 +108,21 @@ REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, is_rocm ? 8 : 16); REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, is_rocm ? 8 : 16); REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16); REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16); REGISTER_FWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); -REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 2 : 4); REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 2 : 4); REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4); REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); @@ -126,9 +132,9 @@ REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16); REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16); REGISTER_FWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); @@ -156,21 +162,21 @@ REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); +REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 2 : 4); REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); +REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 2 : 4); REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4); REGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 8 : 16); REGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 8 : 16); REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 4 : 8); REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 4 : 8); REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); @@ -178,11 +184,11 @@ REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, is_rocm ? 8 : 16); +REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 8 : 16); REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16); +REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 8 : 16); +REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16); REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); @@ -197,9 +203,9 @@ REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 4 : 8); REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 4 : 8); REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); diff --git a/apex/contrib/csrc/layer_norm/ln_kernel_traits.h b/apex/contrib/csrc/layer_norm/ln_kernel_traits.h index ed745c5ee..68e00e787 100644 --- a/apex/contrib/csrc/layer_norm/ln_kernel_traits.h +++ b/apex/contrib/csrc/layer_norm/ln_kernel_traits.h @@ -22,7 +22,11 @@ struct Kernel_traits_base { enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; +#ifdef USE_ROCM + enum { THREADS_PER_WARP = 64 }; +#else enum { THREADS_PER_WARP = 32 }; +#endif }; @@ -84,10 +88,10 @@ template< typename output_t_, typename compute_t_, typename index_t_, - uint32_t HIDDEN_SIZE_, - uint32_t CTAS_PER_ROW_, - uint32_t WARPS_M_, - uint32_t WARPS_N_, + uint32_t HIDDEN_SIZE_, + uint32_t CTAS_PER_ROW_, + uint32_t WARPS_M_, + uint32_t WARPS_N_, uint32_t BYTES_PER_LDG_ = 16, typename Base = Kernel_traits_base< HIDDEN_SIZE_, diff --git a/apex/contrib/csrc/layer_norm/ln_utils.cuh b/apex/contrib/csrc/layer_norm/ln_utils.cuh index 5fdfb342f..e41fabd21 100644 --- a/apex/contrib/csrc/layer_norm/ln_utils.cuh +++ b/apex/contrib/csrc/layer_norm/ln_utils.cuh @@ -608,8 +608,13 @@ inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_ac m2_a = m2_ab; } // Intra-warp broadcast (only lane 0 has valid stats). +#ifdef USE_ROCM + m_a = __shfl(m_a, 0); + m2_a = __shfl(m2_a, 0); +#else m_a = __shfl_sync(uint32_t(-1), m_a, 0); m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// From b30d3bad98a7c319612ee2a2839b151eafc17bff Mon Sep 17 00:00:00 2001 From: hubertlu-tw Date: Fri, 9 Sep 2022 22:05:53 +0000 Subject: [PATCH 07/19] Manually hipify cudaLaunchCooperativeKernel in --fast_layer_norm extension --- apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu | 8 ++++---- apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu index 75363cbd4..ad668f4f3 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -65,12 +65,12 @@ void launch_(LaunchParams &launch_params, const bool configure_params if( Kernel_traits::CTAS_PER_ROW == 1 ) { kernel<<>>(launch_params.params); } else { -#ifdef USE_ROCM - assert(0 && "hipLaunchCooperativeKernel TODO"); -#else - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = (void *)&launch_params.params; +#ifdef USE_ROCM + hipLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); +#else cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); #endif } diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu index bc7f5797b..9966b4673 100644 --- a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu @@ -64,12 +64,12 @@ void launch_(LaunchParams &launch_params, const bool configure_params if( Kernel_traits::CTAS_PER_ROW == 1 ) { kernel<<>>(launch_params.params); } else { -#ifdef USE_ROCM - assert(0 && "hipLaunchCooperativeKernel TODO"); -#else - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = (void *)&launch_params.params; +#ifdef USE_ROCM + hipLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); +#else cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); #endif } From 103e128768f1715249d7b546f1728a0677ac5484 Mon Sep 17 00:00:00 2001 From: hubertlu-tw Date: Fri, 9 Sep 2022 22:06:28 +0000 Subject: [PATCH 08/19] Update setup.py for --fast_layer_norm extension --- setup.py | 90 +++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 63 insertions(+), 27 deletions(-) diff --git a/setup.py b/setup.py index f9be45d37..440ce68c9 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,53 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_major, bare_metal_minor +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) + torch_binary_major = torch.version.cuda.split(".")[0] + torch_binary_minor = torch.version.cuda.split(".")[1] + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pytorch binaries. " + "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + + +def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool: + cudnn_available = torch.backends.cudnn.is_available() + cudnn_version = torch.backends.cudnn.version() if cudnn_available else None + if not (cudnn_available and (cudnn_version >= required_cudnn_version)): + warnings.warn( + f"Skip `{global_option}` as it requires cuDNN {required_cudnn_version} or later, " + f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}" + ) + return False + + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) @@ -360,39 +407,28 @@ def check_if_rocm_pytorch(): if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): generator_flag = ["-DOLD_GENERATOR_PATH"] -if "--fast_layer_norm" in sys.argv: - sys.argv.remove("--fast_layer_norm") - #raise_if_cuda_home_none("--fast_layer_norm") - ## Check, if CUDA11 is installed for compute capability 8.0 - #cc_flag = [] - #_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) - #if int(bare_metal_major) >= 11: - # cc_flag.append("-gencode") - # cc_flag.append("arch=compute_80,code=sm_80") - - #if CUDA_HOME is None and not IS_ROCM_PYTORCH: - # raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - #else: - # # Check, if CUDA11 is installed for compute capability 8.0 - # cc_flag = [] - # _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) - # if int(bare_metal_major) >= 11: - # cc_flag.append('-gencode') - # cc_flag.append('arch=compute_80,code=sm_80') +if "--fast_layer_norm" in sys.argv or "--cuda_ext" in sys.argv: + if "--fast_layer_norm" in sys.argv: + sys.argv.remove("--fast_layer_norm") + + if not IS_ROCM_PYTORCH: + raise_if_cuda_home_none("--fast_layer_norm") + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + if not IS_ROCM_PYTORCH: + _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11: + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,code=sm_80") nvcc_args_fast_layer_norm = ['-maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros hipcc_args_fast_layer_norm = ['-O3', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros - print ("INFO: Building fast layernorm extension.") + ext_modules.append( - CUDAExtension(name='fast_layer_norm_cuda', - sources=[#'apex/contrib/csrc/layer_norm/ln.h', - 'apex/contrib/csrc/layer_norm/ln_api.cpp', - #'apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh', + CUDAExtension(name='fast_layer_norm', + sources=['apex/contrib/csrc/layer_norm/ln_api.cpp', 'apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu', 'apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu', - #'apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh', - #'apex/contrib/csrc/layer_norm/ln_kernel_traits.h', - #'apex/contrib/csrc/layer_norm/ln_utils.cuh' ], include_dirs=[os.path.join(this_dir, 'csrc'), os.path.join(this_dir, 'apex/contrib/csrc/layer_norm')], From 637f2e198513775cd2776a2d7b34c43d70d60ed0 Mon Sep 17 00:00:00 2001 From: aspanday Date: Tue, 14 Mar 2023 21:13:02 +0000 Subject: [PATCH 09/19] CTAS_PER_ROW > 1 requires reduction that seems to be making the accuracy tests to fail in fwd and bwd pass. Making CTAS_PER_ROW=1 for all cases doesnt seem to be affecting the performance and is allowing all fwd pass tests to pass. In addition bwd pass tests fail for hidden_size >= 8192. With CTAS_PER_ROW=1 for all cases in bwd, tests fail when hidden_size >= 16K. This commit allows fastlayernorm capabale to be used for inference cases. Additional debugging required for fastlayernorm bwd. --- .../layer_norm/ln_bwd_semi_cuda_kernel.cu | 289 ++++++++++++------ .../csrc/layer_norm/ln_fwd_cuda_kernel.cu | 167 +++++----- 2 files changed, 276 insertions(+), 180 deletions(-) diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu index ad668f4f3..57356ba5b 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -65,7 +65,7 @@ void launch_(LaunchParams &launch_params, const bool configure_params if( Kernel_traits::CTAS_PER_ROW == 1 ) { kernel<<>>(launch_params.params); } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = (void *)&launch_params.params; #ifdef USE_ROCM @@ -157,99 +157,194 @@ REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); - -REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); - -REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 2 : 4, 4); -REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 2 : 4, 4); -REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); - -REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); - -REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, is_rocm ? 2 : 4, 4); -REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, is_rocm ? 2 : 4, 4); -REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, is_rocm ? 4 : 8, 4); - -REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); - +REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); + +REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); + +REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 2 : 4, 4); +REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 2 : 4, 4); +REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); + +REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); + +REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); + +REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 1, 1, 8, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 1, 1, 8, is_rocm ? 2 : 4, 4); +REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 1, 1, 8, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 1, 1, 8, is_rocm ? 2 : 4, 4); +REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 1, 1, 8, is_rocm ? 4 : 8, 4); + +REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); + +REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); + +REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); + +REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); + +// REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, is_rocm ? 4 : 8, 4); +// REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, is_rocm ? 4 : 8, 4); +// REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); + +// REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); +// REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); +// REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); + +// REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); +// REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 2 : 4, 4); +// REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); +// REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 2 : 4, 4); +// REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); + +// REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); +// REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); +// REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); + +// REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +// REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, is_rocm ? 4 : 8, 4); +// REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, is_rocm ? 2 : 4, 4); +// REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, is_rocm ? 4 : 8, 4); +// REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, is_rocm ? 2 : 4, 4); +// REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, is_rocm ? 4 : 8, 4); + +// REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +// REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +// REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, is_rocm ? 8 : 16, 4); +// REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); + +// REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); \ No newline at end of file diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu index 9966b4673..8a836a627 100644 --- a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu @@ -149,92 +149,93 @@ REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 2 : 4); -REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 2 : 4); -REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4); - -REGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 8 : 16); -REGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 8 : 16); -REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8); - -REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 4 : 8); -REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 4 : 8); -REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8); - -REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, is_rocm ? 8 : 16); -REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 8 : 16); -REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 8 : 16); -REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16); - -REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 4 : 8); -REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 4 : 8); -REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4); - -REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4); - -REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16); -REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); -REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); -REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16); -REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); -REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16); +REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 2 : 4); +REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 2 : 4); +REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 1, 1, 4, 4); + +REGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16); +REGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16); +REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 1, 1, 4, 8); + +REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 1, 1, 4, 8); +REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8); +REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 1, 1, 4, 8); +REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 8); +REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 1, 1, 4, 8); + +REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 1, 1, 4, is_rocm ? 8 : 16); +REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16); +REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16); +REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16); + +REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8); +REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 8); +REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 1, 1, 4, 4); + +REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 1, 1, 4, 4); +REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 1, 1, 4, 4); + +REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 1, 1, 4, 16); From 03002d0207509eda1bdd291bd3a1f40164669631 Mon Sep 17 00:00:00 2001 From: aspanday Date: Tue, 14 Mar 2023 21:15:53 +0000 Subject: [PATCH 10/19] retaining the previous CTAS_PER_ROW values for fwd pass. --- .../csrc/layer_norm/ln_fwd_cuda_kernel.cu | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu index 8a836a627..92ef5f823 100644 --- a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu @@ -239,3 +239,91 @@ REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 2 : 4); +// REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 2 : 4); +// REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4); + +// REGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 8 : 16); +// REGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 8 : 16); +// REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8); + +// REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); +// REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 4 : 8); +// REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8); +// REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 4 : 8); +// REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8); + +// REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, is_rocm ? 8 : 16); +// REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 8 : 16); +// REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 8 : 16); +// REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16); + +// REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 4 : 8); +// REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 4 : 8); +// REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4); + +// REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4); + +// REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16); \ No newline at end of file From bdd5cdbf7c2ebcdc34925305e068aee12083eecc Mon Sep 17 00:00:00 2001 From: aspanday Date: Wed, 3 May 2023 14:17:38 +0000 Subject: [PATCH 11/19] removing checks for is_cuda. This interferes with the unittest. May need to revert this change later if needed. --- apex/contrib/csrc/layer_norm/ln_api.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/apex/contrib/csrc/layer_norm/ln_api.cpp b/apex/contrib/csrc/layer_norm/ln_api.cpp index 30e4a5fec..90c2245b5 100644 --- a/apex/contrib/csrc/layer_norm/ln_api.cpp +++ b/apex/contrib/csrc/layer_norm/ln_api.cpp @@ -89,9 +89,9 @@ std::vector ln_fwd(const at::Tensor &x, // BxSxhidden_size TORCH_CHECK(beta.scalar_type() == wtype); - TORCH_CHECK(x.is_cuda()) - TORCH_CHECK(gamma.is_cuda()) - TORCH_CHECK(beta.is_cuda()) + //TORCH_CHECK(x.is_cuda()) + //TORCH_CHECK(gamma.is_cuda()) + //TORCH_CHECK(beta.is_cuda()) TORCH_CHECK(x.is_contiguous()); auto sizes = x.sizes(); @@ -170,11 +170,11 @@ std::vector ln_bwd(const at::Tensor &dz, // BxSxhidden_size TORCH_CHECK(mu.dtype() == ctype); TORCH_CHECK(rsigma.dtype() == ctype); - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(dz.is_cuda()); - TORCH_CHECK(mu.is_cuda()); - TORCH_CHECK(rsigma.is_cuda()); - TORCH_CHECK(gamma.is_cuda()); + //TORCH_CHECK(x.is_cuda()); + //TORCH_CHECK(dz.is_cuda()); + //TORCH_CHECK(mu.is_cuda()); + //TORCH_CHECK(rsigma.is_cuda()); + //TORCH_CHECK(gamma.is_cuda()); TORCH_CHECK(x.is_contiguous()); TORCH_CHECK(dz.is_contiguous()); From b78abf9496f7fbd7dfdb3ce27dc24889245b28bd Mon Sep 17 00:00:00 2001 From: aspanday Date: Wed, 3 May 2023 14:20:46 +0000 Subject: [PATCH 12/19] adding appropriate assert messages. This reduces #warnings as well as helps with decoding impact of kernel parameters. --- apex/contrib/csrc/layer_norm/ln_utils.cuh | 28 +++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/apex/contrib/csrc/layer_norm/ln_utils.cuh b/apex/contrib/csrc/layer_norm/ln_utils.cuh index e41fabd21..609f2693b 100644 --- a/apex/contrib/csrc/layer_norm/ln_utils.cuh +++ b/apex/contrib/csrc/layer_norm/ln_utils.cuh @@ -15,7 +15,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef USE_ROCM -constexpr uint32_t THREADS_PER_WARP = warpSize; +constexpr uint32_t THREADS_PER_WARP = 64; #else constexpr uint32_t THREADS_PER_WARP = 32; #endif @@ -151,43 +151,43 @@ struct BytesToType {}; template<> struct BytesToType<64> { using Type = uint16; - static_assert(sizeof(Type) == 64); + static_assert(sizeof(Type) == 64, "is sizeof(uint16) = 64?"); }; template<> struct BytesToType<32> { using Type = uint8; - static_assert(sizeof(Type) == 32); + static_assert(sizeof(Type) == 32, "is sizeof(uint8) = 32?"); }; template<> struct BytesToType<16> { using Type = uint4; - static_assert(sizeof(Type) == 16); + static_assert(sizeof(Type) == 16, "is sizeof(uint4) = 16?"); }; template<> struct BytesToType<8> { using Type = uint64_t; - static_assert(sizeof(Type) == 8); + static_assert(sizeof(Type) == 8, "is sizeof(uint64_t) = 8?"); }; template<> struct BytesToType<4> { using Type = uint32_t; - static_assert(sizeof(Type) == 4); + static_assert(sizeof(Type) == 4, "is sizeof(uint32_t) = 4?"); }; template<> struct BytesToType<2> { using Type = uint16_t; - static_assert(sizeof(Type) == 2); + static_assert(sizeof(Type) == 2, "is sizeof(uint16_t) = 2?"); }; template<> struct BytesToType<1> { using Type = uint8_t; - static_assert(sizeof(Type) == 1); + static_assert(sizeof(Type) == 1, "is sizeof(uint8_t) = 8?"); }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -443,9 +443,9 @@ struct Reducer : public Reducer { } inter_cta_.sync(); #ifdef USE_ROCM - static_assert(CTAS_PER_ROW <= warpSize); + static_assert(CTAS_PER_ROW <= warpSize, "CTAS_PER_ROW <= warpSize."); #else - static_assert(CTAS_PER_ROW <= 32); + static_assert(CTAS_PER_ROW <= 32, "CTAS_PER_ROW <= 32."); #endif T total = Zeros::get(); if(this->lane_ < CTAS_PER_ROW){ @@ -663,9 +663,9 @@ struct Stats { // Assume CTA group size in N less than 32, such that we can finalize with a single warp. #ifdef USE_ROCM - static_assert(CTAS_PER_ROW <= warpSize); + static_assert(CTAS_PER_ROW <= warpSize, "CTAS_PER_ROW <= warpSize."); #else - static_assert(CTAS_PER_ROW <= 32); + static_assert(CTAS_PER_ROW <= 32, "CTAS_PER_ROW <= 32."); #endif // Every warp does the final reduction locally. @@ -732,9 +732,9 @@ struct Stats { // Assume that there are less than 32 warps, such that we can finalize with a single warp #ifdef USE_ROCM - static_assert(WARPS_N <= warpSize); + static_assert(WARPS_N <= warpSize, "CTAS_PER_ROW <= warpSize."); #else - static_assert(WARPS_N <= 32); + static_assert(WARPS_N <= 32, "CTAS_PER_ROW <= 32."); #endif if(lane < WARPS_N){ stats_t result = smem[lane]; From f47022daabe07aeded57d5a11f5817a21a13865a Mon Sep 17 00:00:00 2001 From: aspanday Date: Wed, 3 May 2023 14:21:37 +0000 Subject: [PATCH 13/19] adding appropriate assert messages. This reduces #warnings as well as helps with decoding impact of kernel parameters. --- apex/contrib/csrc/layer_norm/ln_kernel_traits.h | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/apex/contrib/csrc/layer_norm/ln_kernel_traits.h b/apex/contrib/csrc/layer_norm/ln_kernel_traits.h index 68e00e787..bdd8d32ca 100644 --- a/apex/contrib/csrc/layer_norm/ln_kernel_traits.h +++ b/apex/contrib/csrc/layer_norm/ln_kernel_traits.h @@ -51,7 +51,7 @@ template< > struct Kernel_traits_finalize : public Base { enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; - static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP); + static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP, "ROWS_PER_CTA (WARP_M) needs to be less than warpsize."); // Bytes per global load from the input. enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; // Number of elements fetched by a global load. @@ -62,7 +62,7 @@ struct Kernel_traits_finalize : public Base { static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!"); // The total number of BYTES_PER_LDG-wide words in a hidden vector. enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; - static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); + static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_), "HIDDEN_SIZE needs to be a multiple of elements per ldg (BYTES_PER_LDG / sizeof(ctype)) in REGISTER_BWD_LAUNCHER."); // Shared memory size to transpose the CTA result. enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; @@ -75,7 +75,7 @@ struct Kernel_traits_finalize : public Base { using Reducer = layer_norm::Reducer; // Condition for the whole CTA to participate in syncthreads. - static_assert(COLS % Base::THREADS_PER_WARP == 0); + static_assert(COLS % Base::THREADS_PER_WARP == 0, "HIDDEN_SIZE needs to be a multiple of BYTES_PER_LDG * warpSize in REGISTER_BWD_LAUNCHER"); enum { CTAS = COLS / Base::THREADS_PER_WARP }; }; @@ -127,7 +127,7 @@ struct Kernel_traits : public Base { enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; - static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); + static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); using reduce_t = typename layer_norm::TypeToVec2::Type; using Reducer = layer_norm::Reducer; @@ -139,18 +139,19 @@ struct Kernel_traits : public Base { using Ovec = layer_norm::Vec; using Wvec = layer_norm::Vec; using Cvec = layer_norm::Vec; + //using Tvec = layer_norm::Vec; enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; // Assume that each thread can handle the same number of elements in the output and weights as in the input. - static_assert(sizeof(input_t) >= sizeof(output_t)); - static_assert(sizeof(input_t) >= sizeof(weight_t)); + static_assert(sizeof(input_t) >= sizeof(output_t), "input_t size should be >= output_t size."); + static_assert(sizeof(input_t) >= sizeof(weight_t), "input_t size should be >= weight_t size."); // The number of columns fetched per load from input: one per thread. enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; // The total number of vectorized loads/stores per hidden vector. enum { VEC_COLS = COLS / ELTS_PER_LDG }; // The number of loads per thread for the input. enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; - static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); + static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS, "LDGS needs to be a whole number. Check ln_kernel_trails.h for LDGS definition and update kernel params in REGISTER_BWD_LAUNCHER."); //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); using Stats = layer_norm::Stats; From 12cb79f51ba4a34d5cf98ca1717cf75b9f905e2f Mon Sep 17 00:00:00 2001 From: aspanday Date: Wed, 3 May 2023 14:33:17 +0000 Subject: [PATCH 14/19] adding appropriate assert messages. This reduces #warnings as well as helps with decoding impact of kernel parameters. --- .../csrc/layer_norm/ln_bwd_kernels.cuh | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh index 00a5fa153..f7177323d 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh +++ b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh @@ -2,6 +2,8 @@ namespace layer_norm { +//__device__ uint64_t llvm_amdgcn_s_memrealtime() __asm("llvm.amdgcn.s.memrealtime"); + template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(layer_norm::BwdParams params) { @@ -41,13 +43,20 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); + static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW, "LDGS needs to be a whole number. See ln_kernel_traits.h and update kernel parameters in REGISTER_BWD_LAUNCHER."); Cvec dzy_sum[LDGS]; Cvec dz_sum[LDGS]; - + //uint64_t timer_start = 0; + //uint64_t timer_stop = 0; + //Tvec time = 0; + //timer_start = llvm_amdgcn_s_memrealtime(); memset(dzy_sum, 0, sizeof(dzy_sum)); memset(dz_sum, 0, sizeof(dz_sum)); + //timer_stop - llvm_amdgcn_s_memrealtime(); + + //time[0].data.elt[0] = int64(timer_start - timer_stop); + //time.store_to(params.time, idx); compute_t * smem_wgrad = reinterpret_cast(smem_); char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; @@ -143,7 +152,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { // Assumption: blockSize divides hidden size. enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; - static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); + static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "NUM_RES needs to be a whole number. See ln_kernel_traits.h and update kernel parameters in REGISTER_BWD_LAUNCHER."); idx = warp_m * Ktraits::VEC_COLS + tid_r; #pragma unroll @@ -254,7 +263,7 @@ void ln_bwd_finalize_kernel(BwdParams params) void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; - // More than one iter iff ROWS_PER_CTA < 32. + // More than one iter iff ROWS_PER_CTA < warpSize (IS_ROCM ---> =64, else =32). for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) { const int read_row = lane; const int read_col = w ^ read_row; @@ -273,12 +282,21 @@ void ln_bwd_finalize_kernel(BwdParams params) #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { compute_t b_i = dbeta_local.data.elt[it]; - compute_t g_i = dgamma_local.data.elt[it]; b_i = reducer.allreduce(b_i, sum); - g_i = reducer.allreduce(g_i, sum); + dbeta_local.data.elt[it] = b_i; + compute_t g_i = dgamma_local.data.elt[it]; + g_i = reducer.allreduce(g_i, sum); dgamma_local.data.elt[it] = g_i; - dbeta_local.data.elt[it] = b_i; + + + // compute_t b_i = dbeta_local.data.elt[it]; + // compute_t g_i = dgamma_local.data.elt[it]; + // b_i = reducer.allreduce(b_i, sum); + // g_i = reducer.allreduce(g_i, sum); + + // dgamma_local.data.elt[it] = g_i; + // dbeta_local.data.elt[it] = b_i; } // Leader stores the result at the current column. From f97305788e22f058df422a6f29927d0c5d75889c Mon Sep 17 00:00:00 2001 From: aspanday Date: Wed, 3 May 2023 14:33:44 +0000 Subject: [PATCH 15/19] adding appropriate assert messages. This reduces #warnings as well as helps with decoding impact of kernel parameters. --- .../csrc/layer_norm/ln_bwd_kernels.cuh | 28 ++++--------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh index f7177323d..9a35ce4e0 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh +++ b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh @@ -2,8 +2,6 @@ namespace layer_norm { -//__device__ uint64_t llvm_amdgcn_s_memrealtime() __asm("llvm.amdgcn.s.memrealtime"); - template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(layer_norm::BwdParams params) { @@ -47,16 +45,9 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { Cvec dzy_sum[LDGS]; Cvec dz_sum[LDGS]; - //uint64_t timer_start = 0; - //uint64_t timer_stop = 0; - //Tvec time = 0; - //timer_start = llvm_amdgcn_s_memrealtime(); + memset(dzy_sum, 0, sizeof(dzy_sum)); memset(dz_sum, 0, sizeof(dz_sum)); - //timer_stop - llvm_amdgcn_s_memrealtime(); - - //time[0].data.elt[0] = int64(timer_start - timer_stop); - //time.store_to(params.time, idx); compute_t * smem_wgrad = reinterpret_cast(smem_); char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; @@ -263,7 +254,7 @@ void ln_bwd_finalize_kernel(BwdParams params) void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; - // More than one iter iff ROWS_PER_CTA < warpSize (IS_ROCM ---> =64, else =32). + // More than one iter iff ROWS_PER_CTA < 32. for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) { const int read_row = lane; const int read_col = w ^ read_row; @@ -282,21 +273,12 @@ void ln_bwd_finalize_kernel(BwdParams params) #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { compute_t b_i = dbeta_local.data.elt[it]; - b_i = reducer.allreduce(b_i, sum); - dbeta_local.data.elt[it] = b_i; - compute_t g_i = dgamma_local.data.elt[it]; + b_i = reducer.allreduce(b_i, sum); g_i = reducer.allreduce(g_i, sum); - dgamma_local.data.elt[it] = g_i; - - - // compute_t b_i = dbeta_local.data.elt[it]; - // compute_t g_i = dgamma_local.data.elt[it]; - // b_i = reducer.allreduce(b_i, sum); - // g_i = reducer.allreduce(g_i, sum); - // dgamma_local.data.elt[it] = g_i; - // dbeta_local.data.elt[it] = b_i; + dgamma_local.data.elt[it] = g_i; + dbeta_local.data.elt[it] = b_i; } // Leader stores the result at the current column. From 37feda6581a3232b90a559b45e7dc0fd2e37dc48 Mon Sep 17 00:00:00 2001 From: aspanday Date: Wed, 3 May 2023 14:36:51 +0000 Subject: [PATCH 16/19] adding hipFuncSetAttrbute when IS_ROCM is True --- apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu index 92ef5f823..6efb8e17d 100644 --- a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu @@ -53,9 +53,9 @@ void launch_(LaunchParams &launch_params, const bool configure_params if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { // hipify missing cudaFuncSetAttribute, cudaFuncAttributeMaxDynamicSharedMemorySize #ifdef USE_ROCM - CHECK_CUDA(hipFuncSetAttribute((const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); + CHECK_CUDA(hipFuncSetAttribute((const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); #else - CHECK_CUDA(cudaFuncSetAttribute((const void *)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); + CHECK_CUDA(cudaFuncSetAttribute((const void *)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); #endif } auto stream = launch_params.stream; @@ -64,7 +64,7 @@ void launch_(LaunchParams &launch_params, const bool configure_params if( Kernel_traits::CTAS_PER_ROW == 1 ) { kernel<<>>(launch_params.params); } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = (void *)&launch_params.params; #ifdef USE_ROCM @@ -82,6 +82,12 @@ constexpr bool is_rocm = true; constexpr bool is_rocm = false; #endif +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG +REGISTER_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, 1, 4, 2, 4); +REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp32, 1, 4, 2, is_rocm ? 4 : 4); +REGISTER_FWD_LAUNCHER( 512, fp16, fp32, fp16, fp32, 1, 4, 2, 4); +REGISTER_FWD_LAUNCHER( 512, bf16, bf16, bf16, fp32, 1, 4, 2, is_rocm ? 4 : 4); +REGISTER_FWD_LAUNCHER( 512, bf16, fp32, bf16, fp32, 1, 4, 2, 4); REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, is_rocm ? 8 : 16); From 57770358879a7b4cefde5d4c6f00a2557694398f Mon Sep 17 00:00:00 2001 From: aspanday Date: Wed, 3 May 2023 14:37:56 +0000 Subject: [PATCH 17/19] Updating REGISTER_BWD_KERNEL macros for various hidden_sizes to get reported performances. --- .../layer_norm/ln_bwd_semi_cuda_kernel.cu | 227 ++++++------------ 1 file changed, 68 insertions(+), 159 deletions(-) diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu index 57356ba5b..abd78a57f 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -35,8 +35,13 @@ void launch_(LaunchParams &launch_params, const bool configure_params if( configure_params ) { int ctas_per_sm; +#ifdef USE_ROCM + hipError_t status_ = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); +#else cudaError_t status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); +#endif launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; launch_params.barrier_size = 0; launch_params.workspace_bytes = 0; @@ -97,29 +102,29 @@ constexpr bool is_rocm = false; // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL -REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); - -REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 3, 8, 8); +REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 3, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 3, 8, 8); +REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 3, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 3, 8, 8); + +REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 4, 4, 4); +REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 4, 4, 4); +REGISTER_BWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 4, 8, 4); + +REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 4, 3, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 4, 3, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 4, 3, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 4, 3, is_rocm ? 4 : 8, 4); +REGISTER_BWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 4, 3, is_rocm ? 8 : 16, 4); + +REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 4, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 4, 4, 4, 4); +REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 4, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 4, 4, 4, 4); +REGISTER_BWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 4, 4, 8, 4); REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 2 : 4, 4); @@ -127,11 +132,11 @@ REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4 REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 2 : 4, 4); REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 16, 4); +REGISTER_BWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 2 : 4, 4); @@ -139,11 +144,11 @@ REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4 REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 2 : 4, 4); REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 8, 4, 4); +REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 8, 4, 4); +REGISTER_BWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); @@ -151,17 +156,17 @@ REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, is_rocm ? 4 : 16, 4); +REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, is_rocm ? 4 : 16, 4); +REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); -REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 8, 4, 4); +REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); +REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 8, 4, 4); +REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); @@ -169,11 +174,11 @@ REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 1, 1, 16, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 1, 1, 16, is_rocm ? 4 : 16, 4); +REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 1, 1, 16, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 1, 1, 16, is_rocm ? 4 : 16, 4); +REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 1, 1, 16, 8, 4); REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); @@ -193,11 +198,11 @@ REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8, REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 2 : 4, 4); REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); -REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 1, 1, 16, 8, 4); +REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 1, 1, 16, 4, 4); +REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 1, 1, 16, 8, 4); +REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 1, 1, 16, 4, 4); +REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 1, 1, 16, 8, 4); REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 4 : 8, 4); @@ -211,11 +216,11 @@ REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); -REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 1, 1, 8, is_rocm ? 4 : 16, 4); +REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); +REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 1, 1, 8, is_rocm ? 4 : 16, 4); +REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 1, 1, 4, is_rocm ? 8 : 16, 4); @@ -247,104 +252,8 @@ REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 1, 1, 8, is_rocm ? 8 : 16, 4); REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); - -// REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -// REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -// REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -// REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, is_rocm ? 4 : 8, 4); -// REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, is_rocm ? 4 : 8, 4); -// REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); - -// REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -// REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); -// REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -// REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); -// REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); - -// REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); -// REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 2 : 4, 4); -// REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); -// REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 2 : 4, 4); -// REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); - -// REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -// REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); -// REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 4 : 8, 4); -// REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); - -// REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -// REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -// REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); - -// REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, is_rocm ? 4 : 8, 4); -// REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, is_rocm ? 2 : 4, 4); -// REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, is_rocm ? 4 : 8, 4); -// REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, is_rocm ? 2 : 4, 4); -// REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, is_rocm ? 4 : 8, 4); - -// REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -// REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -// REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, is_rocm ? 8 : 16, 4); -// REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); - -// REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); \ No newline at end of file +REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); \ No newline at end of file From 6d58ce1f3f71dc47390b10d3ca4eeb7c034cd8df Mon Sep 17 00:00:00 2001 From: aspanday Date: Wed, 3 May 2023 14:46:27 +0000 Subject: [PATCH 18/19] Added option to pass normalized_shape for fast_layer_norm same as fused_layer_norm. Updated the framework for fast_layer_norm to be same as fused_layer_norm. Updated variables gamma and beta to unique variables gamma_ and beta_ to address some runtime errors. Added patch to call fused_layer_norm bwd when hidden_size > 12K since the tests fail for these hidden_sizes. This is currently only a patch and needs to be debugged. Potential for even better performance for fast_layer_norm. NOTE that fast_layer_norm fwd is still called for hidden_sizes > 12K. --- apex/contrib/layer_norm/layer_norm.py | 58 +++++++++++++++++++-------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/apex/contrib/layer_norm/layer_norm.py b/apex/contrib/layer_norm/layer_norm.py index 8a8d26d43..8af044579 100644 --- a/apex/contrib/layer_norm/layer_norm.py +++ b/apex/contrib/layer_norm/layer_norm.py @@ -1,48 +1,69 @@ +import importlib +import numbers + import torch from torch.nn import init +from torch.nn import functional as F from apex._autocast_utils import _cast_if_autocast_enabled import fast_layer_norm +import fused_layer_norm_cuda + +#global fused_layer_norm_cuda +#fused_layer_norm_cuda = None class FastLayerNormFN(torch.autograd.Function): @staticmethod - def forward(ctx, x, gamma, beta, epsilon): + def forward(ctx, x, gamma, beta, normalized_shape, epsilon): x = x.contiguous() - gamma = gamma.contiguous() - beta = beta.contiguous() + gamma_ = gamma.contiguous() + beta_ = beta.contiguous() hidden_size = gamma.numel() xmat = x.view((-1, hidden_size)) ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon) - ctx.save_for_backward(x, gamma, mu, rsigma) + ctx.normalized_shape = normalized_shape + ctx.eps = epsilon + ctx.save_for_backward(x, gamma_, beta_, mu, rsigma) return ymat.view(x.shape) @staticmethod def backward(ctx, dy): # assert dy.is_contiguous() - dy = dy.contiguous() # this happens! - x, gamma, mu, rsigma = ctx.saved_tensors - + dy_ = dy.contiguous() # this happens! + x, gamma, beta, mu, rsigma= ctx.saved_tensors + #x, gamma, mu, rsigma= ctx.saved_tensors hidden_size = gamma.numel() xmat = x.view((-1, hidden_size)) - dymat = dy.view(xmat.shape) - dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma) - dx = dxmat.view(x.shape) - return dx, dgamma, dbeta, None + dymat = dy_.view(xmat.shape) + dxmat = dgamma = dbeta = None + if hidden_size > 12288: + #dxmat = fused_layer_norm_cuda.backward(dy_, mu, rsigma, xmat, ctx.normalized_shape, ctx.eps) + dxmat, dgamma, dbeta = fused_layer_norm_cuda.backward_affine( + dymat, mu, rsigma, xmat, ctx.normalized_shape, gamma, beta, ctx.eps + ) + dx = dxmat.view(x.shape) + else: + dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma) + dx = dxmat.view(x.shape) + return dx, dgamma, dbeta, None, None -def _fast_layer_norm(x, weight, bias, epsilon): - args = _cast_if_autocast_enabled(x, weight, bias, epsilon) +def _fast_layer_norm(x, weight, bias, normalized_shape, epsilon): + args = _cast_if_autocast_enabled(x, weight, bias, normalized_shape, epsilon) with torch.cuda.amp.autocast(enabled=False): return FastLayerNormFN.apply(*args) class FastLayerNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5): + def __init__(self, normalized_shape, eps=1e-5): super().__init__() self.epsilon = eps - self.weight = torch.nn.Parameter(torch.Tensor(hidden_size)) - self.bias = torch.nn.Parameter(torch.Tensor(hidden_size)) + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.weight = torch.nn.Parameter(torch.Tensor(*normalized_shape)) + self.bias = torch.nn.Parameter(torch.Tensor(*normalized_shape)) self.reset_parameters() def reset_parameters(self): @@ -50,4 +71,7 @@ def reset_parameters(self): init.zeros_(self.bias) def forward(self, x): - return _fast_layer_norm(x, self.weight, self.bias, self.epsilon) + if not x.is_cuda: + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.epsilon) + else: + return _fast_layer_norm(x, self.weight, self.bias, self.normalized_shape, self.epsilon) From ae9de3ebd01ef93e1d2a2af46be974a3e6c3c6b7 Mon Sep 17 00:00:00 2001 From: aspanday Date: Wed, 3 May 2023 14:48:48 +0000 Subject: [PATCH 19/19] Updated fast_layer_norm test to call fused_layer_norm test when hidden_sizes > 12K since fast_layer_norm patch calls fused_layer_norm for bwd when hideen_Sizes > 12K --- .../test/layer_norm/test_fast_layer_norm.py | 101 +++++++++++++++++- 1 file changed, 96 insertions(+), 5 deletions(-) diff --git a/apex/contrib/test/layer_norm/test_fast_layer_norm.py b/apex/contrib/test/layer_norm/test_fast_layer_norm.py index 089ec35d2..11a309bf9 100644 --- a/apex/contrib/test/layer_norm/test_fast_layer_norm.py +++ b/apex/contrib/test/layer_norm/test_fast_layer_norm.py @@ -6,6 +6,7 @@ import torch import fast_layer_norm as fln +import apex from apex.contrib.layer_norm.layer_norm import FastLayerNorm @@ -191,6 +192,89 @@ def check_err(x, relerr): for x, re in zip([z, mu, rs, dx, dg, db], [re_z, re_mu, re_rs, re_dx, re_dg, re_db]) ] +def test_withfused(S, B, hidden_size, itype=fp32, wtype=fp32, ctype=fp32): + + seed = 1243 + time = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + fwd_thresholds_fp16 = dict(rtol=1e-3, atol=1e-3) + bwd_thresholds_fp16 = dict(rtol=1e-3, atol=1e-3) + fwd_thresholds_fp32 = dict(rtol=1e-6, atol=1e-6) + bwd_thresholds_fp32 = dict(rtol=1e-6, atol=1e-6) + fwd_thresholds_bf16 = dict(rtol=1e-4, atol=1e-4) + bwd_thresholds_bf16 = dict(rtol=1e-4, atol=1e-4) + + otype = wtype + print("========================================================") + print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype}") + print("--------------------------------------------------------") + + x = torch.randn(S * B, hidden_size, dtype=itype, device=device) + gamma = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2 + beta = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2 + epsilon = 1e-5 + + x.requires_grad = True + gamma.requires_grad = True + beta.requires_grad = True + + mu_ref = x.mean(1, dtype=ctype, keepdim=True) + v = torch.square(x - mu_ref).mean(1, dtype=ctype, keepdim=True) + rs_ref = torch.rsqrt(v + epsilon) + y_ref = rs_ref * (x.to(ctype) - mu_ref) + z_ref = (gamma.unsqueeze(0) * (y_ref).to(otype) + beta.unsqueeze(0)).to(otype) + + mu_ref = mu_ref.flatten() + rs_ref = rs_ref.flatten() + + dz = torch.randn_like(z_ref) + + # z_ref.backward(dz) + # dx_ref = x.grad + # dgamma_ref = gamma.grad + # dbeta_ref = beta.grad + dx_ref, dg_ref, db_ref = backward_(dz, x, mu_ref, rs_ref, gamma) + + torch.cuda.manual_seed(seed) + x = torch.randn(S * B, hidden_size, dtype=itype, device=device) + x_cpu = x.detach().requires_grad_(True) + gamma_cpu = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2 + beta_cpu = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2 + x_cuda = x.to(device="cuda", dtype=itype).detach().requires_grad_(True) + gamma_cuda = gamma_cpu.to(device="cuda", dtype=wtype).detach().requires_grad_(True) + beta_cuda = beta_cpu.to(device="cuda", dtype=wtype).detach().requires_grad_(True) + x_cpu_ = x_cpu.contiguous() + x_cuda_ = x_cuda.contiguous() + + module_cpu_ = apex.normalization.MixedFusedLayerNorm(normalized_shape=[hidden_size]).to(device="cuda", dtype=itype) + module_cuda_ = FastLayerNorm(normalized_shape=[hidden_size]).to(device="cuda", dtype=itype) + + out_cpu_ = module_cpu_(x_cpu_) + gO = torch.rand_like(out_cpu_) + out_cpu_.backward(gO) + + # x_ = x_.to(device="cuda", dtype=itype) + out_cuda_ = module_cuda_(x_cuda_) + gO = gO.to(device="cuda", dtype=itype) + out_cuda_.backward(gO) + + if itype == fp16: + print(itype, "out test : ", + torch.testing.assert_close(out_cpu_.to(device="cuda", dtype=itype), out_cuda_, **fwd_thresholds_fp16)) + print(itype, "grad test : ", + torch.testing.assert_close(x_cpu_.grad.to(device="cuda", dtype=itype), x_cuda_.grad, **bwd_thresholds_fp16)) + elif itype == bf16: + print(itype, "out test : ", + torch.testing.assert_close(out_cpu_.to(device="cuda", dtype=itype), out_cuda_, **fwd_thresholds_bf16)) + print(itype, "grad test : ", + torch.testing.assert_close(x_cpu_.grad.to(device="cuda", dtype=itype), x_cuda_.grad, **bwd_thresholds_bf16)) + else: + print(itype, "out test : ", + torch.testing.assert_close(out_cpu_.to(device="cuda", dtype=itype), out_cuda_, **fwd_thresholds_fp32)) + print(itype, "grad test : ", + torch.testing.assert_close(x_cpu_.grad.to(device="cuda", dtype=itype), x_cuda_.grad, **bwd_thresholds_fp32)) + class TestFastLayerNorm(unittest.TestCase): def assertAll(self, l): @@ -232,11 +316,18 @@ def test_all_configs(self): for h in hidden_sizes: with self.subTest(f"hidden_size={h}"): - self.assertAll(test_(256, 2, h, fp32, fp32)) - self.assertAll(test_(256, 2, h, fp16, fp16)) - self.assertAll(test_(256, 2, h, fp32, fp16)) - self.assertAll(test_(256, 2, h, bf16, bf16)) - self.assertAll(test_(256, 2, h, fp32, bf16)) + if h > 12288: + self.assertAll(test_withfused(256, 2, h, fp32, fp32)) + self.assertAll(test_withfused(256, 2, h, fp16, fp16)) + self.assertAll(test_withfused(256, 2, h, fp32, fp16)) + # self.assertAll(test_withfused(256, 2, h, bf16, bf16)) + # self.assertAll(test_withfused(256, 2, h, fp32, bf16)) + else: + self.assertAll(test_(256, 2, h, fp32, fp32)) + self.assertAll(test_(256, 2, h, fp16, fp16)) + self.assertAll(test_(256, 2, h, fp32, fp16)) + # self.assertAll(test_(256, 2, h, bf16, bf16)) + # self.assertAll(test_(256, 2, h, fp32, bf16)) def test_run_benchmark(self): for (S, B, hidden_size, runs) in (