From f50dd4ac084c6a9b88bb690ca76362c965a79235 Mon Sep 17 00:00:00 2001 From: Harsh Chauhan Date: Thu, 18 Jun 2026 18:30:15 +0530 Subject: [PATCH 1/3] inital softmax gpu kernel and tests added --- core/inc/SOFIE/ROperator_Softmax.hxx | 106 ++++++++++++++ .../TestCustomModelsFromONNXForAlpakaCuda.cxx | 138 ++++++++++++++++++ 2 files changed, 244 insertions(+) diff --git a/core/inc/SOFIE/ROperator_Softmax.hxx b/core/inc/SOFIE/ROperator_Softmax.hxx index 5626c0f..e5ddd1c 100644 --- a/core/inc/SOFIE/ROperator_Softmax.hxx +++ b/core/inc/SOFIE/ROperator_Softmax.hxx @@ -185,6 +185,112 @@ public: } return out.str(); } + + std::string Generate_GPU_Kernel_ALPAKA(std::string opName) override { + if (fShape.empty()) + throw std::runtime_error("SOFIE Softmax called to Generate_GPU_Kernel_ALPAKA without being initialized first"); + + opName = "op_" + opName; + std::string kname = "SoftmaxKernel_" + opName; + + size_t size = fShape.size(); + size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis; + + std::string axis_size = fShape[axis].GetVal();// per-row reduction length + std::string inner_stride = UTILITY::ComputeStrideFromShape(fShape)[axis].GetVal();// stride along the axis + + //one thread per row, serial 3-pass softmax (max, exp+sum, normalize). + std::string op; + op = "\n//------ SOFTMAX_KERNEL_ALPAKA\n"; + op += SP + "struct " + kname + " {\n"; + op += SP + SP + "template\n"; + op += SP + SP + "ALPAKA_FN_ACC void operator()(\n"; + op += SP + SP + SP + "TAcc const& acc,\n"; + op += SP + SP + SP + "T const* __restrict__ X,\n"; + op += SP + SP + SP + "T* __restrict__ Y,\n"; + op += SP + SP + SP + "std::size_t const numRows) const {\n\n"; + + op += SP + SP + SP + "auto const gid = alpaka::getIdx(acc)[0];\n"; + op += SP + SP + SP + "auto const grid_extent = alpaka::getWorkDiv(acc)[0];\n\n"; + + op += SP + SP + SP + "std::size_t const axis_size = " + axis_size + ";\n"; + op += SP + SP + SP + "std::size_t const inner_stride = " + inner_stride + ";\n"; + op += SP + SP + SP + "std::size_t const row_block = axis_size * inner_stride;\n\n"; + + op += SP + SP + SP + "for (std::size_t r = gid; r < numRows; r += grid_extent) {\n"; + op += SP + SP + SP + SP + "std::size_t const row_base = (r / inner_stride) * row_block + (r % inner_stride);\n\n"; + + op += SP + SP + SP + SP + "// pass 1: max\n"; + op += SP + SP + SP + SP + "T vmax = X[row_base];\n"; + op += SP + SP + SP + SP + "for (std::size_t l = 1; l < axis_size; ++l) {\n"; + op += SP + SP + SP + SP + SP + "T v = X[row_base + l * inner_stride];\n"; + op += SP + SP + SP + SP + SP + "if (v > vmax) vmax = v;\n"; + op += SP + SP + SP + SP + "}\n\n"; + + op += SP + SP + SP + SP + "// pass 2: exp(x - max), sum\n"; + op += SP + SP + SP + SP + "T sum = static_cast(0);\n"; + op += SP + SP + SP + SP + "for (std::size_t l = 0; l < axis_size; ++l) {\n"; + op += SP + SP + SP + SP + SP + "std::size_t const idx = row_base + l * inner_stride;\n"; + op += SP + SP + SP + SP + SP + "T e = alpaka::math::exp(acc, X[idx] - vmax);\n"; + op += SP + SP + SP + SP + SP + "Y[idx] = e;\n"; + op += SP + SP + SP + SP + SP + "sum += e;\n"; + op += SP + SP + SP + SP + "}\n\n"; + + op += SP + SP + SP + SP + "// pass 3: normalize\n"; + op += SP + SP + SP + SP + "T inv = static_cast(1) / sum;\n"; + op += SP + SP + SP + SP + "for (std::size_t l = 0; l < axis_size; ++l) {\n"; + op += SP + SP + SP + SP + SP + "std::size_t const idx = row_base + l * inner_stride;\n"; + op += SP + SP + SP + SP + SP + "Y[idx] *= inv;\n"; + if (fLogSoftmax) + op += SP + SP + SP + SP + SP + "Y[idx] = alpaka::math::log(acc, Y[idx]);\n"; + op += SP + SP + SP + SP + "}\n"; + + op += SP + SP + SP + "}\n";// row loop end + op += SP + SP + "}\n";// operator() end + op += SP + "};\n"; + return op; + } + + std::string Generate_GPU_Kernel_Definitions_ALPAKA(std::string opName) override { + opName = "op_" + opName; + std::string kname = "SoftmaxKernel_" + opName; + return SP + kname + " softmaxKernel_" + opName + ";\n"; + } + + std::string Generate_GPU_ALPAKA(std::string opName) override { + if (fShape.empty()) + throw std::runtime_error("SOFIE Softmax called to Generate_GPU_ALPAKA without being initialized first"); + + opName = "op_" + opName; + std::string kname = "softmaxKernel_" + opName; + + size_t size = fShape.size(); + size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis; + std::string axis_size = fShape[axis].GetVal(); + std::string length_str = ConvertDimShapeToLength(fShape); + + std::string num_rows; + if (IsInteger(length_str) && IsInteger(axis_size)) + num_rows = std::to_string(std::stoul(length_str) / std::stoul(axis_size)); + else + num_rows = "(" + length_str + ") / (" + axis_size + ")"; + + std::stringstream out; + out << "\n//------ SOFTMAX_GPU_ALPAKA\n"; + out << SP << "auto const elementsPerThread_" << opName << " = Vec::all(static_cast(1));\n"; + out << SP << "auto const elementsPerGrid_" << opName << " = Vec::all(Idx{" << num_rows << "});\n"; + out << SP << "auto const workDiv_" << opName << " = sofie_workdiv(elementsPerGrid_" << opName << ");\n"; + out << SP << "alpaka::exec(queue, workDiv_" << opName + << ", " << kname + << ", alpaka::getPtrNative(deviceBuf_" << fNX << ")" + << ", alpaka::getPtrNative(deviceBuf_" << fNY << ")" + << ", static_cast(" << num_rows << "));\n"; + return out.str(); + } + + std::vector GetStdLibs() override { + return { std::string("cmath") }; + } }; } // namespace SOFIE diff --git a/test/TestCustomModelsFromONNXForAlpakaCuda.cxx b/test/TestCustomModelsFromONNXForAlpakaCuda.cxx index fccacbe..993eb67 100644 --- a/test/TestCustomModelsFromONNXForAlpakaCuda.cxx +++ b/test/TestCustomModelsFromONNXForAlpakaCuda.cxx @@ -176,6 +176,15 @@ #include "Clip_FromONNX_GPU_ALPAKA.hxx" #include "Not_FromONNX_GPU_ALPAKA.hxx" +#include "Softmax1d_FromONNX_GPU_ALPAKA.hxx" +#include "input_models/references/Softmax1d.ref.hxx" +#include "Softmax2d_FromONNX_GPU_ALPAKA.hxx" +#include "input_models/references/Softmax2d.ref.hxx" +#include "Softmax3d_FromONNX_GPU_ALPAKA.hxx" +#include "input_models/references/Softmax3d.ref.hxx" +#include "Softmax4d_FromONNX_GPU_ALPAKA.hxx" +#include "input_models/references/Softmax4d.ref.hxx" + #include "GNN_model_FromONNX_GPU_ALPAKA.hxx" #include @@ -3161,3 +3170,132 @@ TEST_F(SofieAlpakaTest, Logic_BitwiseNot) for (std::size_t i = 0; i < N; ++i) EXPECT_EQ(res[i], ref[i]) << " index=" << i; } + +TEST_F(SofieAlpakaTest, Softmax1d) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + std::vector input_vec({-1.f, 0.f, 1.f}); + const Idx N = static_cast(input_vec.size()); + + auto input_h = alpaka::allocBuf(host, Ext1D::all(N)); + float* input_ptr = reinterpret_cast(alpaka::getPtrNative(input_h)); + for (Idx i = 0; i < N; ++i) input_ptr[i] = input_vec[i]; + + auto input_d = alpaka::allocBuf(device, Ext1D::all(N)); + alpaka::memcpy(queue, input_d, input_h); + alpaka::wait(queue); + + auto result_h = alpaka::allocBuf(host, Ext1D::all(N)); + { + SOFIE_Softmax1d::Session session; + auto result = session.infer(input_d); + alpaka::wait(queue); + cudaDeviceSynchronize(); + alpaka::memcpy(queue, result_h, result); + alpaka::wait(queue); + } + + float* res = reinterpret_cast(alpaka::getPtrNative(result_h)); + float* ref = Softmax1d_ExpectedOutput::output; + for (Idx i = 0; i < N; ++i) + EXPECT_LE(std::abs(res[i] - ref[i]), TOLERANCE) << " index=" << i; +} + +TEST_F(SofieAlpakaTest, Softmax2d) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + std::vector input_vec({-1.f, 0.f, 1.f}); + const Idx N = static_cast(input_vec.size()); + + auto input_h = alpaka::allocBuf(host, Ext1D::all(N)); + float* input_ptr = reinterpret_cast(alpaka::getPtrNative(input_h)); + for (Idx i = 0; i < N; ++i) input_ptr[i] = input_vec[i]; + + auto input_d = alpaka::allocBuf(device, Ext1D::all(N)); + alpaka::memcpy(queue, input_d, input_h); + alpaka::wait(queue); + + auto result_h = alpaka::allocBuf(host, Ext1D::all(N)); + { + SOFIE_Softmax2d::Session session; + auto result = session.infer(input_d); + alpaka::wait(queue); + cudaDeviceSynchronize(); + alpaka::memcpy(queue, result_h, result); + alpaka::wait(queue); + } + + float* res = reinterpret_cast(alpaka::getPtrNative(result_h)); + float* ref = Softmax2d_ExpectedOutput::output; + for (Idx i = 0; i < N; ++i) + EXPECT_LE(std::abs(res[i] - ref[i]), TOLERANCE) << " index=" << i; +} + +TEST_F(SofieAlpakaTest, Softmax3d) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + std::vector input_vec({ + -0.8939f, -0.3674f, 0.1763f, 1.5804f, -0.4687f, 1.2253f, -1.3488f, -0.1000f, + -0.1262f, 0.4962f, 1.0870f, 0.6905f, -0.3451f, -1.6981f, -0.4688f, 0.4468f, + -0.5479f, 0.0650f, 1.0446f, -1.6249f, -0.7190f, -1.7520f, 3.7753f, -1.4939f}); + const Idx N = static_cast(input_vec.size()); + + auto input_h = alpaka::allocBuf(host, Ext1D::all(N)); + float* input_ptr = reinterpret_cast(alpaka::getPtrNative(input_h)); + for (Idx i = 0; i < N; ++i) input_ptr[i] = input_vec[i]; + + auto input_d = alpaka::allocBuf(device, Ext1D::all(N)); + alpaka::memcpy(queue, input_d, input_h); + alpaka::wait(queue); + + auto result_h = alpaka::allocBuf(host, Ext1D::all(N)); + { + SOFIE_Softmax3d::Session session; + auto result = session.infer(input_d); + alpaka::wait(queue); + cudaDeviceSynchronize(); + alpaka::memcpy(queue, result_h, result); + alpaka::wait(queue); + } + + float* res = reinterpret_cast(alpaka::getPtrNative(result_h)); + float* ref = Softmax3d_ExpectedOutput::output; + for (Idx i = 0; i < N; ++i) + EXPECT_LE(std::abs(res[i] - ref[i]), TOLERANCE) << " index=" << i; +} + +TEST_F(SofieAlpakaTest, Softmax4d) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + std::vector input_vec({ + -0.5869f, -1.4272f, -0.1546f, 0.0096f, 0.1706f, 0.0388f, -0.3484f, -0.7829f, + 1.1138f, -0.5644f, -0.6264f, -1.1890f, 1.6741f, -0.7130f, 0.9592f, 1.7477f, + -0.4775f, 1.3407f, -0.3882f, -0.4560f, 1.0385f, -0.1669f, 0.5540f, -1.0790f, + -0.6153f, -0.6274f, -1.2304f, -0.6757f, 1.0178f, -0.2379f, -0.7912f, -0.0165f, + -0.5423f, 0.1459f, 1.3585f, -0.5005f, -0.2187f, -1.8181f, -0.6642f, 0.0287f, + -1.9103f, 0.7984f, -0.7860f, 1.5134f, 1.3873f, -0.6462f, -0.6354f, -0.1335f}); + const Idx N = static_cast(input_vec.size()); + + auto input_h = alpaka::allocBuf(host, Ext1D::all(N)); + float* input_ptr = reinterpret_cast(alpaka::getPtrNative(input_h)); + for (Idx i = 0; i < N; ++i) input_ptr[i] = input_vec[i]; + + auto input_d = alpaka::allocBuf(device, Ext1D::all(N)); + alpaka::memcpy(queue, input_d, input_h); + alpaka::wait(queue); + + auto result_h = alpaka::allocBuf(host, Ext1D::all(N)); + { + SOFIE_Softmax4d::Session session; + auto result = session.infer(input_d); + alpaka::wait(queue); + cudaDeviceSynchronize(); + alpaka::memcpy(queue, result_h, result); + alpaka::wait(queue); + } + + float* res = reinterpret_cast(alpaka::getPtrNative(result_h)); + float* ref = Softmax4d_ExpectedOutput::output; + for (Idx i = 0; i < N; ++i) + EXPECT_LE(std::abs(res[i] - ref[i]), TOLERANCE) << " index=" << i; +} From 2331e160d0cbbdd7c443d6ceebc17c95dc85488a Mon Sep 17 00:00:00 2001 From: Harsh Chauhan Date: Thu, 18 Jun 2026 23:28:13 +0530 Subject: [PATCH 2/3] block per row reduction --- core/inc/SOFIE/ROperator_Softmax.hxx | 112 ++++++++++++++++++--------- 1 file changed, 77 insertions(+), 35 deletions(-) diff --git a/core/inc/SOFIE/ROperator_Softmax.hxx b/core/inc/SOFIE/ROperator_Softmax.hxx index e5ddd1c..be65c86 100644 --- a/core/inc/SOFIE/ROperator_Softmax.hxx +++ b/core/inc/SOFIE/ROperator_Softmax.hxx @@ -186,6 +186,22 @@ public: return out.str(); } + // threads per row for the block-per-row kernel, picked from the (static) row + // length: next power of 2 >= axis_size, clamped to [32, 1024]. + // Dynamic axis falls back to 256. Kernel and launch both call this so they always agree. + size_t SoftmaxBlockSize() const { + size_t axis = fAttrAxis < 0 ? fShape.size() + fAttrAxis : fAttrAxis; + std::string as = fShape[axis].GetVal(); + if (!IsInteger(as)) + return 256; + size_t n = std::stoul(as); + size_t p = 1; + while (p < n) p <<= 1; // next power of 2 >= n + if (p < 32) p = 32; // at least one warp + if (p > 1024) p = 1024; // block-size cap + return p; + } + std::string Generate_GPU_Kernel_ALPAKA(std::string opName) override { if (fShape.empty()) throw std::runtime_error("SOFIE Softmax called to Generate_GPU_Kernel_ALPAKA without being initialized first"); @@ -199,7 +215,12 @@ public: std::string axis_size = fShape[axis].GetVal();// per-row reduction length std::string inner_stride = UTILITY::ComputeStrideFromShape(fShape)[axis].GetVal();// stride along the axis - //one thread per row, serial 3-pass softmax (max, exp+sum, normalize). + const size_t kBlock = SoftmaxBlockSize(); // threads per row (block) + std::string bs = std::to_string(kBlock); + + // block-per-row: a block of kBlock threads reduces one row cooperatively. + // each thread strides over the row (coalesced for the last-axis case), then + // two shared-memory tree reductions compute the row max and the sum. std::string op; op = "\n//------ SOFTMAX_KERNEL_ALPAKA\n"; op += SP + "struct " + kname + " {\n"; @@ -210,42 +231,60 @@ public: op += SP + SP + SP + "T* __restrict__ Y,\n"; op += SP + SP + SP + "std::size_t const numRows) const {\n\n"; - op += SP + SP + SP + "auto const gid = alpaka::getIdx(acc)[0];\n"; - op += SP + SP + SP + "auto const grid_extent = alpaka::getWorkDiv(acc)[0];\n\n"; + op += SP + SP + SP + "auto& sdata = alpaka::declareSharedVar(acc);\n"; + op += SP + SP + SP + "auto const row = alpaka::getIdx(acc)[0];\n"; + op += SP + SP + SP + "auto const tid = alpaka::getIdx(acc)[0];\n"; + op += SP + SP + SP + "if (row >= numRows) return;\n\n"; op += SP + SP + SP + "std::size_t const axis_size = " + axis_size + ";\n"; op += SP + SP + SP + "std::size_t const inner_stride = " + inner_stride + ";\n"; - op += SP + SP + SP + "std::size_t const row_block = axis_size * inner_stride;\n\n"; - - op += SP + SP + SP + "for (std::size_t r = gid; r < numRows; r += grid_extent) {\n"; - op += SP + SP + SP + SP + "std::size_t const row_base = (r / inner_stride) * row_block + (r % inner_stride);\n\n"; - - op += SP + SP + SP + SP + "// pass 1: max\n"; - op += SP + SP + SP + SP + "T vmax = X[row_base];\n"; - op += SP + SP + SP + SP + "for (std::size_t l = 1; l < axis_size; ++l) {\n"; - op += SP + SP + SP + SP + SP + "T v = X[row_base + l * inner_stride];\n"; - op += SP + SP + SP + SP + SP + "if (v > vmax) vmax = v;\n"; - op += SP + SP + SP + SP + "}\n\n"; - - op += SP + SP + SP + SP + "// pass 2: exp(x - max), sum\n"; - op += SP + SP + SP + SP + "T sum = static_cast(0);\n"; - op += SP + SP + SP + SP + "for (std::size_t l = 0; l < axis_size; ++l) {\n"; - op += SP + SP + SP + SP + SP + "std::size_t const idx = row_base + l * inner_stride;\n"; - op += SP + SP + SP + SP + SP + "T e = alpaka::math::exp(acc, X[idx] - vmax);\n"; - op += SP + SP + SP + SP + SP + "Y[idx] = e;\n"; - op += SP + SP + SP + SP + SP + "sum += e;\n"; - op += SP + SP + SP + SP + "}\n\n"; - - op += SP + SP + SP + SP + "// pass 3: normalize\n"; - op += SP + SP + SP + SP + "T inv = static_cast(1) / sum;\n"; - op += SP + SP + SP + SP + "for (std::size_t l = 0; l < axis_size; ++l) {\n"; - op += SP + SP + SP + SP + SP + "std::size_t const idx = row_base + l * inner_stride;\n"; - op += SP + SP + SP + SP + SP + "Y[idx] *= inv;\n"; + op += SP + SP + SP + "std::size_t const row_block = axis_size * inner_stride;\n"; + op += SP + SP + SP + "std::size_t const row_base = (row / inner_stride) * row_block + (row % inner_stride);\n\n"; + + // pass 1: row max (partial per thread, then shared-memory tree reduce) + op += SP + SP + SP + "// pass 1: row max\n"; + op += SP + SP + SP + "T tmax = X[row_base];\n"; + op += SP + SP + SP + "for (std::size_t l = tid; l < axis_size; l += " + bs + "u) {\n"; + op += SP + SP + SP + SP + "T v = X[row_base + l * inner_stride];\n"; + op += SP + SP + SP + SP + "if (v > tmax) tmax = v;\n"; + op += SP + SP + SP + "}\n"; + op += SP + SP + SP + "sdata[tid] = tmax;\n"; + op += SP + SP + SP + "alpaka::syncBlockThreads(acc);\n"; + op += SP + SP + SP + "for (std::size_t s = " + bs + "u / 2u; s >= 1u; s /= 2u) {\n"; + op += SP + SP + SP + SP + "if (tid < s && sdata[tid + s] > sdata[tid]) sdata[tid] = sdata[tid + s];\n"; + op += SP + SP + SP + SP + "alpaka::syncBlockThreads(acc);\n"; + op += SP + SP + SP + "}\n"; + op += SP + SP + SP + "T const vmax = sdata[0];\n"; + op += SP + SP + SP + "alpaka::syncBlockThreads(acc);\n\n"; + + // pass 2: exp(x - max) -> Y, sum (tree reduce) + op += SP + SP + SP + "// pass 2: exp(x - max) and sum\n"; + op += SP + SP + SP + "T tsum = static_cast(0);\n"; + op += SP + SP + SP + "for (std::size_t l = tid; l < axis_size; l += " + bs + "u) {\n"; + op += SP + SP + SP + SP + "std::size_t const idx = row_base + l * inner_stride;\n"; + op += SP + SP + SP + SP + "T e = alpaka::math::exp(acc, X[idx] - vmax);\n"; + op += SP + SP + SP + SP + "Y[idx] = e;\n"; + op += SP + SP + SP + SP + "tsum += e;\n"; + op += SP + SP + SP + "}\n"; + op += SP + SP + SP + "sdata[tid] = tsum;\n"; + op += SP + SP + SP + "alpaka::syncBlockThreads(acc);\n"; + op += SP + SP + SP + "for (std::size_t s = " + bs + "u / 2u; s >= 1u; s /= 2u) {\n"; + op += SP + SP + SP + SP + "if (tid < s) sdata[tid] += sdata[tid + s];\n"; + op += SP + SP + SP + SP + "alpaka::syncBlockThreads(acc);\n"; + op += SP + SP + SP + "}\n"; + op += SP + SP + SP + "T const sum = sdata[0];\n"; + op += SP + SP + SP + "alpaka::syncBlockThreads(acc);\n\n"; + + // pass 3: normalize + op += SP + SP + SP + "// pass 3: normalize\n"; + op += SP + SP + SP + "T const inv = static_cast(1) / sum;\n"; + op += SP + SP + SP + "for (std::size_t l = tid; l < axis_size; l += " + bs + "u) {\n"; + op += SP + SP + SP + SP + "std::size_t const idx = row_base + l * inner_stride;\n"; + op += SP + SP + SP + SP + "Y[idx] *= inv;\n"; if (fLogSoftmax) - op += SP + SP + SP + SP + SP + "Y[idx] = alpaka::math::log(acc, Y[idx]);\n"; - op += SP + SP + SP + SP + "}\n"; + op += SP + SP + SP + SP + "Y[idx] = alpaka::math::log(acc, Y[idx]);\n"; + op += SP + SP + SP + "}\n"; - op += SP + SP + SP + "}\n";// row loop end op += SP + SP + "}\n";// operator() end op += SP + "};\n"; return op; @@ -275,11 +314,14 @@ public: else num_rows = "(" + length_str + ") / (" + axis_size + ")"; + const size_t kBlock = SoftmaxBlockSize(); // must match the kernel's block size + std::stringstream out; out << "\n//------ SOFTMAX_GPU_ALPAKA\n"; - out << SP << "auto const elementsPerThread_" << opName << " = Vec::all(static_cast(1));\n"; - out << SP << "auto const elementsPerGrid_" << opName << " = Vec::all(Idx{" << num_rows << "});\n"; - out << SP << "auto const workDiv_" << opName << " = sofie_workdiv(elementsPerGrid_" << opName << ");\n"; + out << SP << "alpaka::WorkDivMembers workDiv_" << opName << "(\n"; + out << SP << SP << "Vec::all(static_cast(" << num_rows << ")),\n";// numBlocks = one block per row + out << SP << SP << "Vec::all(Idx{" << kBlock << "u}),\n";// threads per block + out << SP << SP << "Vec::all(Idx{1u}));\n"; out << SP << "alpaka::exec(queue, workDiv_" << opName << ", " << kname << ", alpaka::getPtrNative(deviceBuf_" << fNX << ")" From 2e53990f3e9f41d143a7ca1f4482349f60aa4dea Mon Sep 17 00:00:00 2001 From: Harsh Chauhan Date: Tue, 23 Jun 2026 23:48:14 +0530 Subject: [PATCH 3/3] online softmax --- core/inc/SOFIE/ROperator_Softmax.hxx | 73 +++++++++++++++------------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/core/inc/SOFIE/ROperator_Softmax.hxx b/core/inc/SOFIE/ROperator_Softmax.hxx index be65c86..06acc33 100644 --- a/core/inc/SOFIE/ROperator_Softmax.hxx +++ b/core/inc/SOFIE/ROperator_Softmax.hxx @@ -218,9 +218,11 @@ public: const size_t kBlock = SoftmaxBlockSize(); // threads per row (block) std::string bs = std::to_string(kBlock); - // block-per-row: a block of kBlock threads reduces one row cooperatively. - // each thread strides over the row (coalesced for the last-axis case), then - // two shared-memory tree reductions compute the row max and the sum. + // block-per-row online softmax (Milakov & Gimelshein 2018, arXiv:1805.02867). + // a block of kBlock threads reduces one row cooperatively; each thread strides + // over the row (coalesced for the last-axis case) keeping a running (max, sum) + // pair, then a single shared-memory tree reduction merges the pairs with the online opeartor + std::string op; op = "\n//------ SOFTMAX_KERNEL_ALPAKA\n"; op += SP + "struct " + kname + " {\n"; @@ -231,7 +233,10 @@ public: op += SP + SP + SP + "T* __restrict__ Y,\n"; op += SP + SP + SP + "std::size_t const numRows) const {\n\n"; - op += SP + SP + SP + "auto& sdata = alpaka::declareSharedVar(acc);\n"; + // running max and sum live in shared memory; declared before the early return + // and every thread in the block reaches the collective declaration. + op += SP + SP + SP + "auto& smax = alpaka::declareSharedVar(acc);\n"; + op += SP + SP + SP + "auto& ssum = alpaka::declareSharedVar(acc);\n"; op += SP + SP + SP + "auto const row = alpaka::getIdx(acc)[0];\n"; op += SP + SP + SP + "auto const tid = alpaka::getIdx(acc)[0];\n"; op += SP + SP + SP + "if (row >= numRows) return;\n\n"; @@ -241,48 +246,47 @@ public: op += SP + SP + SP + "std::size_t const row_block = axis_size * inner_stride;\n"; op += SP + SP + SP + "std::size_t const row_base = (row / inner_stride) * row_block + (row % inner_stride);\n\n"; - // pass 1: row max (partial per thread, then shared-memory tree reduce) - op += SP + SP + SP + "// pass 1: row max\n"; - op += SP + SP + SP + "T tmax = X[row_base];\n"; + // fused pass: running max m and normalizer d over this thread's slice + // d is updated branchlessly; when the max does not move the correction will simply be exp(0)=1 + op += SP + SP + SP + "// fused pass: running (max, sum) per thread\n"; + op += SP + SP + SP + "T m = X[row_base];\n"; + op += SP + SP + SP + "T d = static_cast(0);\n"; op += SP + SP + SP + "for (std::size_t l = tid; l < axis_size; l += " + bs + "u) {\n"; - op += SP + SP + SP + SP + "T v = X[row_base + l * inner_stride];\n"; - op += SP + SP + SP + SP + "if (v > tmax) tmax = v;\n"; + op += SP + SP + SP + SP + "T x = X[row_base + l * inner_stride];\n"; + op += SP + SP + SP + SP + "T m_new = (x > m) ? x : m;\n"; + op += SP + SP + SP + SP + "d = d * alpaka::math::exp(acc, m - m_new) + alpaka::math::exp(acc, x - m_new);\n"; + op += SP + SP + SP + SP + "m = m_new;\n"; op += SP + SP + SP + "}\n"; - op += SP + SP + SP + "sdata[tid] = tmax;\n"; - op += SP + SP + SP + "alpaka::syncBlockThreads(acc);\n"; - op += SP + SP + SP + "for (std::size_t s = " + bs + "u / 2u; s >= 1u; s /= 2u) {\n"; - op += SP + SP + SP + SP + "if (tid < s && sdata[tid + s] > sdata[tid]) sdata[tid] = sdata[tid + s];\n"; - op += SP + SP + SP + SP + "alpaka::syncBlockThreads(acc);\n"; - op += SP + SP + SP + "}\n"; - op += SP + SP + SP + "T const vmax = sdata[0];\n"; + op += SP + SP + SP + "smax[tid] = m;\n"; + op += SP + SP + SP + "ssum[tid] = d;\n"; op += SP + SP + SP + "alpaka::syncBlockThreads(acc);\n\n"; - // pass 2: exp(x - max) -> Y, sum (tree reduce) - op += SP + SP + SP + "// pass 2: exp(x - max) and sum\n"; - op += SP + SP + SP + "T tsum = static_cast(0);\n"; - op += SP + SP + SP + "for (std::size_t l = tid; l < axis_size; l += " + bs + "u) {\n"; - op += SP + SP + SP + SP + "std::size_t const idx = row_base + l * inner_stride;\n"; - op += SP + SP + SP + SP + "T e = alpaka::math::exp(acc, X[idx] - vmax);\n"; - op += SP + SP + SP + SP + "Y[idx] = e;\n"; - op += SP + SP + SP + SP + "tsum += e;\n"; - op += SP + SP + SP + "}\n"; - op += SP + SP + SP + "sdata[tid] = tsum;\n"; - op += SP + SP + SP + "alpaka::syncBlockThreads(acc);\n"; + // single tree reduction merging (max, sum) pairs with the online operator + // (m_a, d_a) + (m_b, d_b) = (max, d_a*exp(m_a-max) + d_b*exp(m_b-max)) + op += SP + SP + SP + "// combined (max, sum) tree reduction\n"; op += SP + SP + SP + "for (std::size_t s = " + bs + "u / 2u; s >= 1u; s /= 2u) {\n"; - op += SP + SP + SP + SP + "if (tid < s) sdata[tid] += sdata[tid + s];\n"; + op += SP + SP + SP + SP + "if (tid < s) {\n"; + op += SP + SP + SP + SP + SP + "T m_a = smax[tid];\n"; + op += SP + SP + SP + SP + SP + "T m_b = smax[tid + s];\n"; + op += SP + SP + SP + SP + SP + "T m_r = (m_b > m_a) ? m_b : m_a;\n"; + op += SP + SP + SP + SP + SP + "ssum[tid] = ssum[tid] * alpaka::math::exp(acc, m_a - m_r) + ssum[tid + s] * alpaka::math::exp(acc, m_b - m_r);\n"; + op += SP + SP + SP + SP + SP + "smax[tid] = m_r;\n"; + op += SP + SP + SP + SP + "}\n"; op += SP + SP + SP + SP + "alpaka::syncBlockThreads(acc);\n"; op += SP + SP + SP + "}\n"; - op += SP + SP + SP + "T const sum = sdata[0];\n"; + op += SP + SP + SP + "T const vmax = smax[0];\n"; + op += SP + SP + SP + "T const sum = ssum[0];\n"; op += SP + SP + SP + "alpaka::syncBlockThreads(acc);\n\n"; - // pass 3: normalize - op += SP + SP + SP + "// pass 3: normalize\n"; + // normalize pass: recompute exp(x - max) and write Y once + op += SP + SP + SP + "// normalize pass\n"; op += SP + SP + SP + "T const inv = static_cast(1) / sum;\n"; op += SP + SP + SP + "for (std::size_t l = tid; l < axis_size; l += " + bs + "u) {\n"; op += SP + SP + SP + SP + "std::size_t const idx = row_base + l * inner_stride;\n"; - op += SP + SP + SP + SP + "Y[idx] *= inv;\n"; + op += SP + SP + SP + SP + "T e = alpaka::math::exp(acc, X[idx] - vmax) * inv;\n"; + op += SP + SP + SP + SP + "Y[idx] = e;\n"; if (fLogSoftmax) - op += SP + SP + SP + SP + "Y[idx] = alpaka::math::log(acc, Y[idx]);\n"; + op += SP + SP + SP + SP + "Y[idx] = alpaka::math::log(acc, e);\n"; op += SP + SP + SP + "}\n"; op += SP + SP + "}\n";// operator() end @@ -322,11 +326,12 @@ public: out << SP << SP << "Vec::all(static_cast(" << num_rows << ")),\n";// numBlocks = one block per row out << SP << SP << "Vec::all(Idx{" << kBlock << "u}),\n";// threads per block out << SP << SP << "Vec::all(Idx{1u}));\n"; - out << SP << "alpaka::exec(queue, workDiv_" << opName + out << SP << "auto task_" << opName << " = alpaka::createTaskKernel(workDiv_" << opName << ", " << kname << ", alpaka::getPtrNative(deviceBuf_" << fNX << ")" << ", alpaka::getPtrNative(deviceBuf_" << fNY << ")" << ", static_cast(" << num_rows << "));\n"; + out << SP << "alpaka::enqueue(queue, task_" << opName << ");\n"; return out.str(); }