Skip to content

Commit 45a5e01

Browse files
committed
Tested on the GPU RTX4090 with cuda 12.x
1 parent b2e34d0 commit 45a5e01

File tree

7 files changed

+544
-497
lines changed

7 files changed

+544
-497
lines changed

mlir/cuda-tile/.gitignore

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
11
*.ptx
22
*.cubin
33
*.fatbin
4+
*.bc
5+
*.ll
6+
*.o
7+
*.s
8+
*.so
9+
*.dylib
10+
*.a
11+
*.dll
12+
*.obj
13+
*.exe
14+
*.log
15+
*.cache
16+
*.tmp
17+
*.bin
18+
*.out

mlir/cuda-tile/README.md

Lines changed: 270 additions & 464 deletions
Large diffs are not rendered by default.

mlir/cuda-tile/Toy/cuda_wrapper/cuda_shim.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ cuda_shim_load_module_from_file(uint64_t file_path_ptr,
330330
uint64_t /*file_path_nbytes*/) {
331331
auto file_path_cstr =
332332
reinterpret_cast<const char *>(asHostCPtr(file_path_ptr));
333-
// fprintf(stdout, "%s", file_path_cstr);
333+
debug_print("Loading CUDA module from file: %s\n", file_path_cstr);
334334
CUmodule module = nullptr;
335335
ScopedContext scopedContext;
336336
CUDA_REPORT_IF_ERROR(cuModuleLoad(&module, file_path_cstr));
@@ -519,7 +519,7 @@ extern "C" void cuda_shim_ctx_synchronize(void) { mgpuCtxSynchronize(); }
519519

520520
// only for debugging
521521
extern "C" void cuda_debug_dump_float(uint64_t dptr, int n) {
522-
auto *p = reinterpret_cast<const float*>(static_cast<uintptr_t>(dptr));
522+
auto *p = reinterpret_cast<const float *>(static_cast<uintptr_t>(dptr));
523523
for (uint32_t i = 0; i < n; ++i) {
524524
fprintf(stderr, "i=%u v=%f\n", i, p[i]);
525525
}

mlir/cuda-tile/Toy/include/cuda_shim/CudaShimBuilder.hpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Arith/IR/Arith.h"
1010
#include "mlir/Dialect/Func/IR/FuncOps.h"
1111
#include "mlir/Dialect/MemRef/IR/MemRef.h"
12+
#include "mlir/ExecutionEngine/ExecutionEngine.h"
1213
#include "mlir/IR/Builders.h"
1314
#include "mlir/IR/BuiltinAttributes.h"
1415
#include "mlir/IR/BuiltinOps.h"
@@ -296,3 +297,134 @@ inline unsigned long getNbytes(mlir::Type tensorType) {
296297
ranked_tensor_type.getElementTypeBitWidth(),
297298
8);
298299
}
300+
301+
extern "C" {
302+
// Load module from PTX or CUBIN image in memory.
303+
// Driver API supports cuModuleLoadDataEx for both PTX and cubin (it
304+
// auto-detects).
305+
uint64_t cuda_shim_load_module_from_image(uint64_t image_ptr,
306+
uint64_t image_nbytes);
307+
uint64_t cuda_shim_load_module_jit_from_image(uint64_t image_ptr,
308+
uint64_t image_nbytes,
309+
int opt_level);
310+
311+
uint64_t cuda_shim_load_module_from_file(uint64_t file_path_ptr,
312+
uint64_t /*file_path_nbytes*/);
313+
314+
void cuda_shim_unload_module(uint64_t module_handle);
315+
316+
uint64_t cuda_shim_malloc(uint64_t nbytes, uint64_t stream,
317+
bool is_host_shared);
318+
319+
void cuda_shim_free(uint64_t dptr, uint64_t stream);
320+
321+
void cuda_shim_memset32(uint64_t dptr, uint32_t value, uint64_t count_dwords,
322+
uint64_t stream);
323+
void cuda_shim_memset16(uint64_t dptr, uint32_t value, uint64_t count_dwords,
324+
uint64_t stream);
325+
326+
uint64_t cuda_shim_stream_create(void);
327+
328+
void cuda_shim_stream_destroy(uint64_t stream);
329+
330+
void cuda_shim_stream_synchronize(uint64_t stream);
331+
332+
uint64_t cuda_shim_event_create(void);
333+
334+
void cuda_shim_event_destroy(uint64_t ev);
335+
336+
void cuda_shim_event_record(uint64_t ev, uint64_t stream);
337+
338+
void cuda_shim_event_synchronize(uint64_t ev);
339+
340+
void cuda_shim_stream_wait_event(uint64_t stream, uint64_t ev);
341+
342+
// ----------------------------- Memcpy (raw ABI) --------------------------
343+
// Host pointers are passed as uint64_t. This is the key of 2A.
344+
345+
void cuda_shim_memcpy_h2d(uint64_t dst_dptr, uint64_t src_hptr,
346+
uint64_t nbytes);
347+
348+
void cuda_shim_memcpy_d2h(uint64_t dst_hptr, uint64_t src_dptr,
349+
uint64_t nbytes);
350+
351+
void cuda_shim_launch_packed(uint64_t module_handle, uint64_t kernel_name_ptr,
352+
uint32_t gridX, uint32_t gridY, uint32_t gridZ,
353+
uint32_t blockX, uint32_t blockY, uint32_t blockZ,
354+
uint32_t sharedMemBytes, uint64_t stream,
355+
uint64_t arg_data_ptr, uint64_t arg_sizes_ptr,
356+
uint32_t num_args);
357+
358+
// Convenience: 1D launch, shared=0, stream optional
359+
void cuda_shim_launch_block_packed(uint64_t module_handle,
360+
uint64_t kernel_name_ptr, uint32_t blockX,
361+
uint32_t blockY, uint32_t blockZ,
362+
uint64_t stream, uint64_t arg_data_ptr,
363+
uint64_t arg_sizes_ptr, uint32_t num_args);
364+
365+
// Optional: global sync (avoid in async pipeline; prefer event/stream sync)
366+
void cuda_shim_ctx_synchronize(void);
367+
368+
// only for debugging
369+
void cuda_debug_dump_float(uint64_t dptr, int n);
370+
}
371+
372+
static inline llvm::orc::SymbolMap
373+
buildCudaShimSymbolMap(llvm::orc::MangleAndInterner interner) {
374+
375+
using llvm::JITSymbolFlags;
376+
using llvm::orc::ExecutorAddr;
377+
using llvm::orc::ExecutorSymbolDef;
378+
using llvm::orc::SymbolMap;
379+
380+
SymbolMap syms;
381+
382+
auto add = [&](const char *name, void *addr) {
383+
syms[interner(name)] =
384+
ExecutorSymbolDef::fromPtr(addr, JITSymbolFlags::Exported);
385+
};
386+
387+
// ---- ctx ----
388+
add("cuda_shim_ctx_synchronize", (void *)&cuda_shim_ctx_synchronize);
389+
390+
// ---- module ----
391+
add("cuda_shim_load_module_from_image",
392+
(void *)&cuda_shim_load_module_from_image);
393+
add("cuda_shim_load_module_jit_from_image",
394+
(void *)&cuda_shim_load_module_jit_from_image);
395+
add("cuda_shim_load_module_from_file",
396+
(void *)&cuda_shim_load_module_from_file);
397+
add("cuda_shim_unload_module", (void *)&cuda_shim_unload_module);
398+
399+
// ---- memory ----
400+
add("cuda_shim_malloc", (void *)&cuda_shim_malloc);
401+
add("cuda_shim_free", (void *)&cuda_shim_free);
402+
403+
// ---- memcpy ----
404+
add("cuda_shim_memcpy_h2d", (void *)&cuda_shim_memcpy_h2d);
405+
add("cuda_shim_memcpy_d2h", (void *)&cuda_shim_memcpy_d2h);
406+
407+
// ---- stream ----
408+
add("cuda_shim_stream_create", (void *)&cuda_shim_stream_create);
409+
add("cuda_shim_stream_destroy", (void *)&cuda_shim_stream_destroy);
410+
add("cuda_shim_stream_synchronize", (void *)&cuda_shim_stream_synchronize);
411+
412+
// ---- event ----
413+
add("cuda_shim_event_create", (void *)&cuda_shim_event_create);
414+
add("cuda_shim_event_destroy", (void *)&cuda_shim_event_destroy);
415+
add("cuda_shim_event_record", (void *)&cuda_shim_event_record);
416+
add("cuda_shim_event_synchronize", (void *)&cuda_shim_event_synchronize);
417+
add("cuda_shim_stream_wait_event", (void *)&cuda_shim_stream_wait_event);
418+
419+
// ---- launch ----
420+
add("cuda_shim_launch_packed", (void *)&cuda_shim_launch_packed);
421+
add("cuda_shim_launch_block_packed", (void *)&cuda_shim_launch_block_packed);
422+
423+
return syms;
424+
}
425+
426+
static inline void registerCudaShimSymbols(mlir::ExecutionEngine &engine) {
427+
engine.registerSymbols([](llvm::orc::MangleAndInterner interner) {
428+
return buildCudaShimSymbolMap(interner);
429+
});
430+
}

mlir/cuda-tile/Toy/include/toy/Passes.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ std::unique_ptr<mlir::Pass> createGpuOutlinePass(std::string grid = "1,1,1");
3434

3535
std::unique_ptr<mlir::Pass> createCudaTileLoweringPass();
3636

37-
std::unique_ptr<mlir::Pass>
38-
createEmbedCudaTileBinaryPass(std::string tileirasExe = "tileiras",
39-
std::string gpuName = "sm_120");
37+
std::unique_ptr<mlir::Pass> createEmbedCudaTileBinaryPass(
38+
std::string tileirasExe = "tileiras", std::string gpuName = "sm_120",
39+
std::string cubinOrPtxPath = "", bool useCache = true);
4040

4141
} // namespace toy
4242
} // namespace mlir

mlir/cuda-tile/Toy/mlir/EmitCudaTile.cpp

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
#include "toy/Dialect.h"
88
#include "llvm/ADT/SmallVector.h"
99
#include "llvm/ADT/StringRef.h"
10+
#include "llvm/Support/DebugLog.h"
1011
#include "llvm/Support/FileSystem.h"
1112
#include "llvm/Support/MemoryBuffer.h"
1213
#include "llvm/Support/Program.h"
1314
#include "llvm/Support/raw_ostream.h"
15+
#include <string>
1416
#include <system_error>
1517

1618
using namespace llvm;
@@ -84,9 +86,13 @@ struct EmbedCudaTileBinaryPass
8486

8587
std::string tileirasExe;
8688
std::string gpuName;
89+
std::string cubinOrPtxPath;
90+
bool useCache;
8791

88-
EmbedCudaTileBinaryPass(std::string tileirasExe, std::string gpuName)
89-
: tileirasExe(std::move(tileirasExe)), gpuName(std::move(gpuName)) {}
92+
EmbedCudaTileBinaryPass(std::string tileirasExe, std::string gpuName,
93+
std::string cubinOrPtxPath, bool useCache)
94+
: tileirasExe(std::move(tileirasExe)), gpuName(std::move(gpuName)),
95+
cubinOrPtxPath(std::move(cubinOrPtxPath)), useCache(useCache) {}
9096

9197
void runOnOperation() override {
9298
ModuleOp top = getOperation();
@@ -126,13 +132,38 @@ struct EmbedCudaTileBinaryPass
126132
return;
127133
}
128134

129-
if (std::error_code ec =
130-
createTemporaryFile(cudaBinPath, "cuda_tile", "bin")) {
131-
op->emitError() << "failed to create temp out bin: " << ec.message();
132-
signalPassFailure();
135+
if (cubinOrPtxPath.empty()) {
136+
if (std::error_code ec =
137+
createTemporaryFile(cudaBinPath, "cuda_tile", "bin")) {
138+
op->emitError() << "failed to create temp out bin: " << ec.message();
139+
signalPassFailure();
140+
return;
141+
}
142+
} else {
143+
if (!useCache) {
144+
if (llvm::sys::fs::exists(cubinOrPtxPath)) {
145+
op->emitWarning() << "cuda binary file exist " << cubinOrPtxPath
146+
<< ", tileiras will overwrite it.";
147+
std::error_code ec = llvm::sys::fs::remove(cubinOrPtxPath);
148+
if (ec) {
149+
op->emitError() << "failed to remove existing cuda binary file: "
150+
<< ec.message();
151+
signalPassFailure();
152+
return;
153+
}
154+
}
155+
}
156+
cudaBinPath = cubinOrPtxPath;
157+
}
158+
159+
if (useCache && llvm::sys::fs::exists(cudaBinPath)) {
160+
LDBG() << "cuda binary file exist and will be reused: " << cudaBinPath
161+
<< "\n";
133162
return;
134163
}
135164

165+
// ! [FIXME]: please comment out this following code since this is only
166+
// for testing.
136167
if (failed(writeFileBytes(inPath, tilebcBytes))) {
137168
op->emitError() << "failed to write temp tilebc";
138169
signalPassFailure();
@@ -145,6 +176,8 @@ struct EmbedCudaTileBinaryPass
145176
}
146177
});
147178

179+
LDBG() << "cuda binary path: " << cudaBinPath << "\n";
180+
148181
top->walk([&](toy::LaunchGpuOp launchOp) {
149182
// ---- Step D: read cuda binary bytes ----
150183
auto binBytesOrErr = readFileBytes(cudaBinPath);
@@ -189,8 +222,10 @@ struct EmbedCudaTileBinaryPass
189222
namespace mlir::toy {
190223

191224
std::unique_ptr<mlir::Pass>
192-
createEmbedCudaTileBinaryPass(std::string tileirasExe, std::string gpuName) {
193-
return std::make_unique<EmbedCudaTileBinaryPass>(tileirasExe, gpuName);
225+
createEmbedCudaTileBinaryPass(std::string tileirasExe, std::string gpuName,
226+
std::string cubinOrPtxPath, bool useCache) {
227+
return std::make_unique<EmbedCudaTileBinaryPass>(tileirasExe, gpuName,
228+
cubinOrPtxPath, useCache);
194229
};
195230

196231
}; // namespace mlir::toy

0 commit comments

Comments
 (0)