Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c4a4ee5
add build for --fast_layer_norm
jeffdaily Sep 1, 2022
c1e3a72
use HIP bfloat16 header
jeffdaily Sep 6, 2022
09d7be8
missing <functional> header
jeffdaily Sep 6, 2022
fc13459
warp size considerations, TODOs
jeffdaily Sep 6, 2022
1228e3c
work around some compiler errors
jeffdaily Sep 6, 2022
f2399d6
it finally compiles
jeffdaily Sep 7, 2022
b30d3ba
Manually hipify cudaLaunchCooperativeKernel in --fast_layer_norm exte…
Sep 9, 2022
103e128
Update setup.py for --fast_layer_norm extension
Sep 9, 2022
f2908b7
Merge remote-tracking branch 'origin/master' into fastlayernorm
Sep 9, 2022
d0e7073
Merge remote-tracking branch 'origin/master' into fastlayernorm
aspanday Feb 20, 2023
637f2e1
CTAS_PER_ROW > 1 requires reduction that seems to be making the accur…
aspanday Mar 14, 2023
03002d0
retaining the previous CTAS_PER_ROW values for fwd pass.
aspanday Mar 14, 2023
a4e4d78
Merge branch 'fastlayernorm' of https://github.com/ROCmSoftwarePlatfo…
aspanday Mar 14, 2023
bdd5cdb
removing checks for is_cuda. This interferes with the unittest. May n…
aspanday May 3, 2023
b78abf9
adding appropriate assert messages. This reduces #warnings as well as…
aspanday May 3, 2023
f47022d
adding appropriate assert messages. This reduces #warnings as well as…
aspanday May 3, 2023
12cb79f
adding appropriate assert messages. This reduces #warnings as well as…
aspanday May 3, 2023
f973057
adding appropriate assert messages. This reduces #warnings as well as…
aspanday May 3, 2023
37feda6
adding hipFuncSetAttrbute when IS_ROCM is True
aspanday May 3, 2023
5777035
Updating REGISTER_BWD_KERNEL macros for various hidden_sizes to get r…
aspanday May 3, 2023
6d58ce1
Added option to pass normalized_shape for fast_layer_norm same as fus…
aspanday May 3, 2023
ae9de3e
Updated fast_layer_norm test to call fused_layer_norm test when hidde…
aspanday May 3, 2023
fb79a52
Merge branch 'master' into fastlayernorm
hubertlu-tw May 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 287 additions & 0 deletions apex/contrib/csrc/layer_norm/hip_bfloat162.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
/**
* MIT License
*
* Copyright (c) 2019 - 2021 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/

/*!\file
* \brief hip_bfloat162.h provides struct for hip_bfloat162 typedef
*/

#ifndef _hip_bfloat162_H_
#define _hip_bfloat162_H_

#if __cplusplus < 201103L || !defined(__HIPCC__)

// If this is a C compiler, C++ compiler below C++11, or a host-only compiler, we only
// include a minimal definition of hip_bfloat162

#include <stdint.h>
/*! \brief Struct to represent a 16 bit brain floating point number. */
typedef struct
{
uint16_t data;
uint16_t data2;
} hip_bfloat162;

#else // __cplusplus < 201103L || !defined(__HIPCC__)

#include <cmath>
#include <cstddef>
#include <cstdint>
#include <hip/hip_runtime.h>
#include <hip/hip_bfloat16.h>
#include <ostream>
#include <type_traits>

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wshadow"
struct hip_bfloat162
{
uint16_t data;
uint16_t data2;

enum truncate_t
{
truncate
};

__host__ __device__ hip_bfloat162() = default;

// round upper 16 bits of IEEE float to convert to bfloat16
explicit __host__ __device__ hip_bfloat162(float f)
: data(float_to_bfloat16(f))
{
}

explicit __host__ __device__ hip_bfloat162(float f, truncate_t)
: data(truncate_float_to_bfloat16(f))
{
}

// zero extend lower 16 bits of bfloat16 to convert to IEEE float
__host__ __device__ operator float() const
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(data) << 16};
return u.fp32;
}

static __host__ __device__ hip_bfloat162 round_to_bfloat16(float f)
{
hip_bfloat162 output;
output.data = float_to_bfloat16(f);
return output;
}

static __host__ __device__ hip_bfloat162 round_to_bfloat16(float f, truncate_t)
{
hip_bfloat162 output;
output.data = truncate_float_to_bfloat16(f);
return output;
}

private:
static __host__ __device__ uint16_t float_to_bfloat16(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
if(~u.int32 & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
}
else if(u.int32 & 0xffff)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16);
}

// Truncate instead of rounding, preserving SNaN
static __host__ __device__ uint16_t truncate_float_to_bfloat16(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
}
};
#pragma clang diagnostic pop

typedef struct
{
uint16_t data;
uint16_t data2;
} hip_bfloat162_public;

static_assert(std::is_standard_layout<hip_bfloat162>{},
"hip_bfloat162 is not a standard layout type, and thus is "
"incompatible with C.");

static_assert(std::is_trivial<hip_bfloat162>{},
"hip_bfloat162 is not a trivial type, and thus is "
"incompatible with C.");

static_assert(sizeof(hip_bfloat162) == sizeof(hip_bfloat162_public)
&& offsetof(hip_bfloat162, data) == offsetof(hip_bfloat162_public, data),
"internal hip_bfloat162 does not match public hip_bfloat162");

inline std::ostream& operator<<(std::ostream& os, const hip_bfloat162& bf16)
{
return os << float(bf16);
}
inline __host__ __device__ hip_bfloat162 operator+(hip_bfloat162 a)
{
return a;
}
inline __host__ __device__ hip_bfloat162 operator-(hip_bfloat162 a)
{
a.data ^= 0x8000;
return a;
}
inline __host__ __device__ hip_bfloat162 operator+(hip_bfloat162 a, hip_bfloat162 b)
{
return hip_bfloat162(float(a) + float(b));
}
inline __host__ __device__ hip_bfloat162 operator-(hip_bfloat162 a, hip_bfloat162 b)
{
return hip_bfloat162(float(a) - float(b));
}
inline __host__ __device__ hip_bfloat162 operator*(hip_bfloat162 a, hip_bfloat162 b)
{
return hip_bfloat162(float(a) * float(b));
}
inline __host__ __device__ hip_bfloat162 operator/(hip_bfloat162 a, hip_bfloat162 b)
{
return hip_bfloat162(float(a) / float(b));
}
inline __host__ __device__ bool operator<(hip_bfloat162 a, hip_bfloat162 b)
{
return float(a) < float(b);
}
inline __host__ __device__ bool operator==(hip_bfloat162 a, hip_bfloat162 b)
{
return float(a) == float(b);
}
inline __host__ __device__ bool operator>(hip_bfloat162 a, hip_bfloat162 b)
{
return b < a;
}
inline __host__ __device__ bool operator<=(hip_bfloat162 a, hip_bfloat162 b)
{
return !(a > b);
}
inline __host__ __device__ bool operator!=(hip_bfloat162 a, hip_bfloat162 b)
{
return !(a == b);
}
inline __host__ __device__ bool operator>=(hip_bfloat162 a, hip_bfloat162 b)
{
return !(a < b);
}
inline __host__ __device__ hip_bfloat162& operator+=(hip_bfloat162& a, hip_bfloat162 b)
{
return a = a + b;
}
inline __host__ __device__ hip_bfloat162& operator-=(hip_bfloat162& a, hip_bfloat162 b)
{
return a = a - b;
}
inline __host__ __device__ hip_bfloat162& operator*=(hip_bfloat162& a, hip_bfloat162 b)
{
return a = a * b;
}
inline __host__ __device__ hip_bfloat162& operator/=(hip_bfloat162& a, hip_bfloat162 b)
{
return a = a / b;
}
inline __host__ __device__ hip_bfloat162& operator++(hip_bfloat162& a)
{
return a += hip_bfloat162(1.0f);
}
inline __host__ __device__ hip_bfloat162& operator--(hip_bfloat162& a)
{
return a -= hip_bfloat162(1.0f);
}
inline __host__ __device__ hip_bfloat162 operator++(hip_bfloat162& a, int)
{
hip_bfloat162 orig = a;
++a;
return orig;
}
inline __host__ __device__ hip_bfloat162 operator--(hip_bfloat162& a, int)
{
hip_bfloat162 orig = a;
--a;
return orig;
}

namespace std
{
constexpr __host__ __device__ bool isinf(hip_bfloat162 a)
{
return !(~a.data & 0x7f80) && !(a.data & 0x7f);
}
constexpr __host__ __device__ bool isnan(hip_bfloat162 a)
{
return !(~a.data & 0x7f80) && +(a.data & 0x7f);
}
constexpr __host__ __device__ bool iszero(hip_bfloat162 a)
{
return !(a.data & 0x7fff);
}
}

#endif // __cplusplus < 201103L || !defined(__HIPCC__)

#endif // _hip_bfloat162_H_
9 changes: 9 additions & 0 deletions apex/contrib/csrc/layer_norm/ln.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#pragma once

#include <functional>
#include <unordered_map>
#include <cuda_fp16.h>
#ifdef USE_ROCM
#include <hip/hip_bfloat16.h>
#else
#include <cuda_bf16.h>
#endif

namespace layer_norm {

Expand Down Expand Up @@ -121,7 +126,11 @@ extern BwdRegistry BWD_FUNCS;

using fp32 = float;
using fp16 = half;
#ifdef USE_ROCM
using bf16 = hip_bfloat16;
#else
using bf16 = nv_bfloat16;
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down
16 changes: 8 additions & 8 deletions apex/contrib/csrc/layer_norm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size

TORCH_CHECK(beta.scalar_type() == wtype);

TORCH_CHECK(x.is_cuda())
TORCH_CHECK(gamma.is_cuda())
TORCH_CHECK(beta.is_cuda())
//TORCH_CHECK(x.is_cuda())
//TORCH_CHECK(gamma.is_cuda())
//TORCH_CHECK(beta.is_cuda())

TORCH_CHECK(x.is_contiguous());
auto sizes = x.sizes();
Expand Down Expand Up @@ -170,11 +170,11 @@ std::vector<at::Tensor> ln_bwd(const at::Tensor &dz, // BxSxhidden_size
TORCH_CHECK(mu.dtype() == ctype);
TORCH_CHECK(rsigma.dtype() == ctype);

TORCH_CHECK(x.is_cuda());
TORCH_CHECK(dz.is_cuda());
TORCH_CHECK(mu.is_cuda());
TORCH_CHECK(rsigma.is_cuda());
TORCH_CHECK(gamma.is_cuda());
//TORCH_CHECK(x.is_cuda());
//TORCH_CHECK(dz.is_cuda());
//TORCH_CHECK(mu.is_cuda());
//TORCH_CHECK(rsigma.is_cuda());
//TORCH_CHECK(gamma.is_cuda());

TORCH_CHECK(x.is_contiguous());
TORCH_CHECK(dz.is_contiguous());
Expand Down
7 changes: 4 additions & 3 deletions apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };

using compute_t = typename Ktraits::compute_t;
using input_t = typename Ktraits::input_t;
using index_t = typename Ktraits::index_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
Expand All @@ -40,7 +41,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;

static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW, "LDGS needs to be a whole number. See ln_kernel_traits.h and update kernel parameters in REGISTER_BWD_LAUNCHER.");

Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
Expand Down Expand Up @@ -119,7 +120,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
dx[it].data.elt[jt] = dx_tmp;
dx[it].data.elt[jt] = input_t(dx_tmp);
}
dx[it].store_to(params.dx, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
Expand All @@ -142,7 +143,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {

// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "NUM_RES needs to be a whole number. See ln_kernel_traits.h and update kernel parameters in REGISTER_BWD_LAUNCHER.");

idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
Expand Down
Loading