Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions hopper/epilogue_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ struct CollectiveEpilogueFwd {
static constexpr bool PackGQA = PackGQA_;
static constexpr bool Split = Split_;
static constexpr bool Use_smem = !(Split && !Varlen);
static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA;
static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA && sizeof(Element) <= 2;

static_assert(ArchTag::kMinComputeCapability >= 80);
static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);
static_assert(sizeof(Element) <= 2);
static_assert(sizeof(Element) <= 4);

static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{});
Expand All @@ -48,7 +48,7 @@ struct CollectiveEpilogueFwd {
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;

// These are for storing the output tensor without TMA (e.g., for setting output to zero)
static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element);
static constexpr int kGmemElemsPerStore = kBlockM >= 32 ? sizeof(cute::uint128_t) / sizeof(Element) : sizeof(cute::uint64_t) / sizeof(Element);
static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore");
// We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements
// in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times
Expand All @@ -57,23 +57,28 @@ struct CollectiveEpilogueFwd {
static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore;
// If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp
static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0);
static_assert(!PackGQA || cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0);
static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow");
using GmemTileCopyAtomO = std::conditional_t<
kBlockM >= 32,
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>,
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<64>, Element>>;
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
make_tiled_copy(GmemTileCopyAtomO{},
GmemLayoutAtom{},
Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); // Val layout, 8 or 16 vals per store
Layout<Shape<_1, Int<kGmemElemsPerStore>>>{}));

using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>());
using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{})));
static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
static constexpr int kSwizzle = sizeof(Element) == 4 ? 2 : (kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)));
static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 3 : (sizeof(Element) == 2 ? 3 : 4);
static constexpr int kSwizzleShift = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleShift>{},
Layout<Shape<_8, Int<kBlockKGmem>>,
Stride<Int<kBlockKGmem>, _1>>{}));
using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{})));
Expand Down Expand Up @@ -238,8 +243,15 @@ struct CollectiveEpilogueFwd {
// If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion.
// Otherwise we can permute after conversion.
if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); }
Tensor tOrO_out = make_tensor_like<Element>(tOrO);
flash::convert_type_out(tOrO, tOrO_out);
auto tOrO_out = [&] {
if constexpr (cute::is_same_v<Element, float>) {
return tOrO;
} else {
Tensor out = make_tensor_like<Element>(tOrO);
flash::convert_type_out(tOrO, out);
return out;
}
}();
if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); }

// Make sure all WGs have finished reading V
Expand Down
32 changes: 21 additions & 11 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql
int64_t num_splits,
std::optional<bool> pack_gqa_,
int64_t sm_margin,
std::optional<const at::Tensor> learnable_sink_
std::optional<at::Tensor> learnable_sink_
) {

auto dprops = at::cuda::getCurrentDeviceProperties();
Expand Down Expand Up @@ -857,11 +857,19 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql
TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment));

auto opts = q.options();
auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
auto default_out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
auto out_type = out_.has_value() ? out_.value().scalar_type() : default_out_type;
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
if (q_type == at::ScalarType::BFloat16) {
TORCH_CHECK(
out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Float,
"For BF16 input, output must have dtype BF16 or FP32"
);
} else {
TORCH_CHECK(out_type == default_out_type, "Output tensor must match the selected output dtype");
}
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
if (!is_varlen_q) {
Expand Down Expand Up @@ -912,6 +920,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql
attention_chunk,
softcap,
sm_margin);
params.is_fp32 = out_type == at::ScalarType::Float;
params.total_q = total_q;
params.total_k = total_k;
params.b_k = batch_size_k;
Expand Down Expand Up @@ -1115,7 +1124,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql
out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type));
softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));
}
params.is_fp32 = false;
params.is_fp32 = out_type == at::ScalarType::Float;
params.oaccum_ptr = out_accum.data_ptr();
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_split_stride = out_accum.stride(0);
Expand Down Expand Up @@ -1181,10 +1190,8 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
if (params.num_splits > 1) {
if (out_type == at::ScalarType::BFloat16) {
// Since we want output in BF16. Otherwise fwd_combine will output to FP16
params.is_bf16 = true;
}
params.is_fp32 = out_type == at::ScalarType::Float;
params.is_bf16 = out_type == at::ScalarType::BFloat16;
// Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1
// and seqlen = total_q, and don't need to dispatch to Varlen there.
// However, with dynamic split, each row needs to know which batch it belongs to
Expand Down Expand Up @@ -1300,7 +1307,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
double softcap,
bool deterministic,
int64_t sm_margin,
std::optional<const at::Tensor> learnable_sink_,
std::optional<at::Tensor> learnable_sink_,
std::optional<at::Tensor> dsink_
) {

Expand Down Expand Up @@ -1749,7 +1756,8 @@ TORCH_LIBRARY(flash_attn_3, m) {
"Tensor? scheduler_metadata = None,"
"int num_splits = 0,"
"bool? pack_gqa = None,"
"int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
"int sm_margin = 0,"
"Tensor? learnable_sink = None) -> (Tensor(out!), Tensor, Tensor, Tensor)");
m.def("bwd("
"Tensor dout,"
"Tensor q,"
Expand All @@ -1772,7 +1780,9 @@ TORCH_LIBRARY(flash_attn_3, m) {
"int window_size_right = -1,"
"float softcap = 0.0,"
"bool deterministic = False,"
"int sm_margin = 0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
"int sm_margin = 0,"
"Tensor? learnable_sink = None,"
"Tensor? dsink = None) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("fwd_combine("
"Tensor out_partial,"
"Tensor lse_partial,"
Expand Down
Empty file added hopper/flash_attn_3/__init__.py
Empty file.
Loading