Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions core/inc/SOFIE/ROperator_TopK.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,158 @@ public:
out << SP << "}\n"; // end operator scope
return out.str();
}

// next power of two >= the axis length (the bitonic network needs a power-of-2 size)
size_t TopKPaddedAxis() const {
size_t axis = fAttrAxis < 0 ? fShapeX.size() + fAttrAxis : fAttrAxis;
size_t n = fShapeX[axis];
size_t p = 1; while (p < n) p <<= 1;
return p;
}
// threads per block: cover the paddedN/2 comparator pairs, capped at 1024, one warp
// kernel and launch both call this
size_t TopKBlockThreads() const {
size_t pairs = TopKPaddedAxis() / 2;
size_t bt = (pairs < 1024) ? pairs : 1024;
if (bt < 32) bt = 32;
return bt;
}

// We have one block per slice. Cache the row in shared memory, bitonic-sort it
// best-first (indices ride along for the tie-break), then write the first K.
std::string Generate_GPU_Kernel_ALPAKA(std::string /*opName*/) override {
if (fShapeX.empty())
throw std::runtime_error("SOFIE Operator TopK called to Generate without being initialized first");

size_t axis = fAttrAxis < 0 ? fShapeX.size() + fAttrAxis : fAttrAxis;
std::string NE = std::to_string(fShapeX[axis]); // real axis length
std::string PAD = std::to_string(TopKPaddedAxis()); // next power of two >= NE
std::string PADH = std::to_string(TopKPaddedAxis() / 2); // number of comparator pairs
std::string BT = std::to_string(TopKBlockThreads()); // threads per block
std::string K = std::to_string(fK);
std::string OP = fAttrLargest ? ">" : "<"; // best-first value comparator
std::string SENT = fAttrLargest ? "std::numeric_limits<T>::lowest()"
: "std::numeric_limits<T>::max()"; // padded slots never win
std::string kname = "TopKKernel_" + fNVal;

// shared-memory budget guard (caches the whole padded row)
size_t valBytes = (fType == "double" || fType == "int64_t") ? 8 : 4;
if (TopKPaddedAxis() * (valBytes + 8) > 48u * 1024u)
throw std::runtime_error("SOFIE TopK GPU: axis length " + NE +
" too long for shared-memory bitonic top-K");

std::string op;
op = "\n//------ TopK_KERNEL_ALPAKA (block-per-row bitonic)\n";
op += SP + "struct " + kname + " {\n";
op += SP + SP + "template<typename TAcc, typename T>\n";
op += SP + SP + "ALPAKA_FN_ACC void operator()(\n";
op += SP + SP + SP + "TAcc const& acc,\n";
op += SP + SP + SP + "T const* __restrict__ x,\n";
op += SP + SP + SP + "T* __restrict__ vals,\n";
op += SP + SP + SP + "int64_t* __restrict__ inds,\n";
op += SP + SP + SP + "std::size_t const numSlices,\n";
op += SP + SP + SP + "std::size_t const nAfter,\n";
op += SP + SP + SP + "std::size_t const strideXAxis,\n";
op += SP + SP + SP + "std::size_t const strideXBefore,\n";
op += SP + SP + SP + "std::size_t const strideYAxis,\n";
op += SP + SP + SP + "std::size_t const strideYBefore) const {\n\n";

// shared row buffers (values + indices), declared before the early return
op += SP + SP + SP + "auto& sv = alpaka::declareSharedVar<T[" + PAD + "], __COUNTER__>(acc);\n";
op += SP + SP + SP + "auto& si = alpaka::declareSharedVar<int64_t[" + PAD + "], __COUNTER__>(acc);\n";
op += SP + SP + SP + "auto const slice = alpaka::getIdx<alpaka::Grid, alpaka::Blocks>(acc)[0];\n";
op += SP + SP + SP + "auto const tid = alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0];\n";
op += SP + SP + SP + "if (slice >= numSlices) return;\n\n";

op += SP + SP + SP + "std::size_t const ib = slice / nAfter;\n";
op += SP + SP + SP + "std::size_t const jb = slice % nAfter;\n";
op += SP + SP + SP + "std::size_t const xbase = ib * strideXBefore + jb;\n";
op += SP + SP + SP + "std::size_t const ybase = ib * strideYBefore + jb;\n\n";

// 1) load row into shared, pad the tail with a sentinel that never wins
op += SP + SP + SP + "for (std::size_t l = tid; l < " + PAD + "u; l += " + BT + "u) {\n";
op += SP + SP + SP + SP + "if (l < " + NE + "u) { sv[l] = x[xbase + strideXAxis * l]; si[l] = (int64_t)l; }\n";
op += SP + SP + SP + SP + "else { sv[l] = " + SENT + "; si[l] = (int64_t)" + NE + "; }\n";
op += SP + SP + SP + "}\n";
op += SP + SP + SP + "alpaka::syncBlockThreads(acc);\n\n";

// 2) bitonic sort: kk = bitonic seq size, jj = compare distance, best ends at index 0.
// each thread owns ONE comparator pair (i, i^jj)
op += SP + SP + SP + "for (std::size_t kk = 2u; kk <= " + PAD + "u; kk <<= 1) {\n";
op += SP + SP + SP + SP + "for (std::size_t jj = kk >> 1; jj > 0u; jj >>= 1) {\n";
op += SP + SP + SP + SP + SP + "for (std::size_t t = tid; t < " + PADH + "u; t += " + BT + "u) {\n";
op += SP + SP + SP + SP + SP + SP + "std::size_t const i = ((t & ~(jj - 1u)) << 1) | (t & (jj - 1u));\n";
op += SP + SP + SP + SP + SP + SP + "std::size_t const p = i | jj;\n";
op += SP + SP + SP + SP + SP + SP + "T av = sv[i]; T bv = sv[p];\n";
op += SP + SP + SP + SP + SP + SP + "int64_t ai = si[i]; int64_t bi = si[p];\n";
op += SP + SP + SP + SP + SP + SP + "bool const firstFirst = (av " + OP + " bv) || (av == bv && ai < bi);\n";
op += SP + SP + SP + SP + SP + SP + "bool const dir = ((i & kk) == 0u);\n";
op += SP + SP + SP + SP + SP + SP + "bool const sw = (firstFirst != dir);\n";
op += SP + SP + SP + SP + SP + SP + "sv[i] = sw ? bv : av; sv[p] = sw ? av : bv;\n";
op += SP + SP + SP + SP + SP + SP + "si[i] = sw ? bi : ai; si[p] = sw ? ai : bi;\n";
op += SP + SP + SP + SP + SP + "}\n";
op += SP + SP + SP + SP + SP + "alpaka::syncBlockThreads(acc);\n";
op += SP + SP + SP + SP + "}\n";
op += SP + SP + SP + "}\n\n";

// 3) write top-K (already best-first)
op += SP + SP + SP + "for (std::size_t s = tid; s < " + K + "u; s += " + BT + "u) {\n";
op += SP + SP + SP + SP + "vals[ybase + strideYAxis * s] = sv[s];\n";
op += SP + SP + SP + SP + "inds[ybase + strideYAxis * s] = si[s];\n";
op += SP + SP + SP + "}\n";

op += SP + SP + "}\n";// end operator()
op += SP + "};\n";// end struct
return op;
}

std::string Generate_GPU_Kernel_Definitions_ALPAKA(std::string /*opName*/) override {
return SP + "TopKKernel_" + fNVal + " topKernel_" + fNVal + ";\n";
}

std::vector<std::string> GetStdLibs() override {
return { std::string("limits") };
}

// the geometry is computed here at codegen and passed as args matching the kernel signature.
std::string Generate_GPU_ALPAKA(std::string opName) override {
opName = "op_" + opName;
if (fShapeX.empty())
throw std::runtime_error("SOFIE Operator TopK called to Generate without being initialized first");

size_t axis = fAttrAxis < 0 ? fShapeX.size() + fAttrAxis : fAttrAxis;
size_t length = ConvertShapeToLength(fShapeX);
auto strideX = UTILITY::ComputeStrideFromShape(fShapeX);
auto strideY = UTILITY::ComputeStrideFromShape(fShapeY); // output is shorter along axis (K, not N_EL)
size_t n_after = strideX[axis];
size_t n_before = (axis > 0) ? length / strideX[axis-1] : 1;
size_t numSlices = n_before * n_after;
size_t strideX_axis = strideX[axis];
size_t strideY_axis = strideY[axis];
size_t strideX_before = (axis > 0) ? strideX[axis-1] : 0; // 0 is safe: i==0 when axis==0
size_t strideY_before = (axis > 0) ? strideY[axis-1] : 0;

std::stringstream out;
out << "\n//-- TopK_GPU_ALPAKA\n";
const size_t blockThreads = TopKBlockThreads();
out << SP << "alpaka::WorkDivMembers<Dim, Idx> workDiv_" << fNVal << "(\n";
out << SP << SP << "Vec::all(static_cast<Idx>(" << numSlices << ")),\n";
out << SP << SP << "Vec::all(Idx{" << blockThreads << "u}),\n";
out << SP << SP << "Vec::all(Idx{1u}));\n";
out << SP << "auto task_" << fNVal << " = alpaka::createTaskKernel<Acc>(workDiv_" << fNVal << ", topKernel_" << fNVal
<< ", alpaka::getPtrNative(deviceBuf_" << fNX << ")"
<< ", alpaka::getPtrNative(deviceBuf_" << fNVal << ")"
<< ", alpaka::getPtrNative(deviceBuf_" << fNInd << ")"
<< ", static_cast<std::size_t>(" << numSlices << "u)"
<< ", static_cast<std::size_t>(" << n_after << "u)"
<< ", static_cast<std::size_t>(" << strideX_axis << "u)"
<< ", static_cast<std::size_t>(" << strideX_before<< "u)"
<< ", static_cast<std::size_t>(" << strideY_axis << "u)"
<< ", static_cast<std::size_t>(" << strideY_before<< "u));\n";
out << SP << "alpaka::enqueue(queue, task_" << fNVal << ");\n";
return out.str();
}

};

} // nameSPace SOFIE
Expand Down
42 changes: 42 additions & 0 deletions test/TestCustomModelsFromONNXForAlpakaCuda.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@
#include "Clip_FromONNX_GPU_ALPAKA.hxx"
#include "Not_FromONNX_GPU_ALPAKA.hxx"

#include "TopK_FromONNX_GPU_ALPAKA.hxx"
#include "input_models/references/TopK.ref.hxx"

#include "GNN_model_FromONNX_GPU_ALPAKA.hxx"

#include <alpaka/alpaka.hpp>
Expand Down Expand Up @@ -3161,3 +3164,42 @@ TEST_F(SofieAlpakaTest, Logic_BitwiseNot)
for (std::size_t i = 0; i < N; ++i)
EXPECT_EQ(res[i], ref[i]) << " index=" << i;
}

TEST_F(SofieAlpakaTest, TopK)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

// axis=-1, largest=1, sorted=1, k=5 (baked); input is a single 9-element row
std::vector<float> input {9.0, 8.0, 4.5, 1.7, 2.9, 3.2, 4.0, 2.6, 7.4};
constexpr std::size_t K = 5;

auto input_h = alpaka::allocBuf<float, Idx>(host, Ext1D::all(Idx{input.size()}));
float* input_ptr = reinterpret_cast<float*>(alpaka::getPtrNative(input_h));
for (Idx i = 0; i < input.size(); ++i) input_ptr[i] = input[i];

auto input_d = alpaka::allocBuf<float, Idx>(device, Ext1D::all(Idx{input.size()}));
alpaka::memcpy(queue, input_d, input_h);
alpaka::wait(queue);

auto values_h = alpaka::allocBuf<float, Idx>(host, Ext1D::all(Idx{K}));
auto indices_h = alpaka::allocBuf<int64_t, Idx>(host, Ext1D::all(Idx{K}));

{
SOFIE_TopK::Session<alpaka::TagGpuCudaRt> session;
auto [values, indices] = session.infer(input_d);
alpaka::wait(queue);
cudaDeviceSynchronize();

alpaka::memcpy(queue, values_h, values);
alpaka::memcpy(queue, indices_h, indices);
alpaka::wait(queue);
}

float* val = reinterpret_cast<float*>(alpaka::getPtrNative(values_h));
int64_t* idx = reinterpret_cast<int64_t*>(alpaka::getPtrNative(indices_h));

for (std::size_t i = 0; i < K; ++i) {
EXPECT_LE(std::abs(val[i] - TopK_ExpectedOutput::values[i]), TOLERANCE) << " value index=" << i;
EXPECT_EQ(idx[i], static_cast<int64_t>(TopK_ExpectedOutput::indexes[i])) << " index index=" << i;
}
}