Skip to content

Commit 96a122c

Browse files
committed
Tested with cuda 12.x cubin worked
1 parent dbf81a7 commit 96a122c

5 files changed

Lines changed: 285 additions & 27 deletions

File tree

mlir/cuda-tile/Toy/include/cuda_shim/CudaShimBuilder.hpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -174,19 +174,19 @@ class CudaShimRegistry {
174174
},
175175
{})};
176176

177-
// case CudaShimFn::LaunchBlockPacked:
178-
// return {"cuda_shim_launch_block_packed",
179-
// rewriter.getFunctionType(
180-
// {
181-
// i64, // module_handle
182-
// i64, // kernel_name_ptr
183-
// i32, i32, i32, // block
184-
// i64, // stream
185-
// i64, // arg_data_ptr
186-
// i64, // arg_sizes_ptr
187-
// i32 // num_args
188-
// },
189-
// {})};
177+
case CudaShimFn::LaunchBlockPacked:
178+
return {"cuda_shim_launch_block_packed",
179+
rewriter.getFunctionType(
180+
{
181+
i64, // module_handle
182+
i64, // kernel_name_ptr
183+
i32, i32, i32, // block
184+
i64, // stream
185+
i64, // arg_data_ptr
186+
i64, // arg_sizes_ptr
187+
i32 // num_args
188+
},
189+
{})};
190190

191191
// ===== Context =====
192192
case CudaShimFn::CtxSynchronize:

mlir/cuda-tile/Toy/mlir/LowerToAffineLoops.cpp

Lines changed: 108 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/DialectRegistry.h"
2424
#include "mlir/IR/Operation.h"
2525
#include "mlir/IR/PatternMatch.h"
26+
#include "mlir/IR/Types.h"
2627
#include "mlir/IR/Value.h"
2728
#include "mlir/IR/ValueRange.h"
2829
#include "mlir/Support/LLVM.h"
@@ -467,6 +468,13 @@ createCallToCudaShimMalloc(mlir::PatternRewriter &rewriter, Location loc,
467468
return callee;
468469
}
469470

471+
unsigned long getNbytes(Type tensorType) {
472+
auto ranked_tensor_type = llvm::cast<MemRefType>(tensorType);
473+
return llvm::divideCeil(ranked_tensor_type.getNumElements() *
474+
ranked_tensor_type.getElementTypeBitWidth(),
475+
8);
476+
}
477+
470478
struct LanchGpuLowering : public OpConversionPattern<toy::LaunchGpuOp> {
471479
using OpConversionPattern<toy::LaunchGpuOp>::OpConversionPattern;
472480

@@ -565,37 +573,126 @@ struct LanchGpuLowering : public OpConversionPattern<toy::LaunchGpuOp> {
565573
}
566574
cudaAllInputs.push_back(outputTensorAlloc);
567575
mlir::func::CallOp memcpyH2DCall;
576+
577+
// ---------- Build argSlots / argSizes from host side ----------
578+
auto argSlots =
579+
memref::AllocOp::create(rewriter, loc,
580+
MemRefType::get({(int64_t)cudaAllInputs.size()},
581+
rewriter.getI64Type()));
582+
583+
auto argSizes =
584+
memref::AllocOp::create(rewriter, loc,
585+
MemRefType::get({(int64_t)cudaAllInputs.size()},
586+
rewriter.getI64Type()));
587+
568588
for (auto [i, opr] : llvm::enumerate(cudaAllInputs)) {
569-
auto ranked_tensor_type = llvm::cast<MemRefType>(opr.getType());
570-
auto shape = ranked_tensor_type.getShape();
571-
auto elem_type = ranked_tensor_type.getElementType();
572-
auto nbytes = llvm::divideCeil(
573-
shape[0] * shape[1] * elem_type.getIntOrFloatBitWidth(), 8);
589+
auto nbytes = getNbytes(opr.getType());
574590
auto nbytesVal = arith::ConstantIntOp::create(rewriter, loc, nbytes, 64);
575591
auto device_ptr_callOp = createCallToCudaShimMalloc(
576592
rewriter, loc, registry, stream, nbytesVal, false);
577593

578594
devicePtrs.push_back(device_ptr_callOp.getResult(0));
579595

580596
auto host_ptr = getIndexFromValue(rewriter, loc, opr);
581-
registry.call(
582-
rewriter, launchGpuOp, CudaShimFn::MemcpyH2D,
583-
ValueRange{device_ptr_callOp.getResult(0), host_ptr, nbytesVal});
584-
if (i >= adaptor.getOperands().size()) {
597+
598+
if (i < adaptor.getOperands().size()) {
599+
registry.call(
600+
rewriter, launchGpuOp, CudaShimFn::MemcpyH2D,
601+
ValueRange{device_ptr_callOp.getResult(0), host_ptr, nbytesVal});
602+
} else {
585603
// this is the output tensor, we will add memcpy from device to host for
586-
// it after the kernel launch.
604+
// it after the kernel launch. and we will move this to the end of lanch
605+
// kernel later.
587606
memcpyH2DCall = registry.call(
588607
rewriter, launchGpuOp, CudaShimFn::MemcpyD2H,
589-
ValueRange{device_ptr_callOp.getResult(0), host_ptr, nbytesVal});
608+
ValueRange{host_ptr, device_ptr_callOp.getResult(0), nbytesVal});
590609
}
610+
611+
// constuct the argSlots and argSizes on host side for the kernel launch.
612+
arith::ConstantIndexOp indexVal =
613+
arith::ConstantIndexOp::create(rewriter, loc, i);
614+
615+
memref::StoreOp::create(rewriter, loc, devicePtrs[i], argSlots,
616+
ValueRange{indexVal});
617+
618+
auto nElements = arith::ConstantIntOp::create(
619+
rewriter, loc, llvm::cast<MemRefType>(opr.getType()).getNumElements(),
620+
64);
621+
622+
// store the size of the argument to argSizes.
623+
memref::StoreOp::create(rewriter, loc, nElements, argSizes,
624+
ValueRange{indexVal});
625+
}
626+
627+
// create the block size for the kernel lauch.
628+
auto gridAttr = launchGpuOp->getDiscardableAttr("grid");
629+
if (!gridAttr) {
630+
return rewriter.notifyMatchFailure(
631+
launchGpuOp, "expected 'grid' attribute to be present");
591632
}
633+
auto gridArrayAttr = llvm::dyn_cast<DenseI64ArrayAttr>(gridAttr);
634+
635+
if (!gridArrayAttr || gridArrayAttr.size() != 3) {
636+
return rewriter.notifyMatchFailure(
637+
launchGpuOp,
638+
"expected 'grid' attribute to be an array of 3 integers");
639+
}
640+
641+
// because of the limitation of the unsupported grid size in the cuda tile,
642+
// we will just use 1 for all dimensions of the grid.
643+
auto blockX = gridArrayAttr[0];
644+
auto blockY = gridArrayAttr[1];
645+
auto blockZ = gridArrayAttr[2];
646+
647+
if (!blockX || !blockY || !blockZ) {
648+
return rewriter.notifyMatchFailure(
649+
launchGpuOp,
650+
"expected 'grid' attribute to be an array of 3 integers");
651+
}
652+
653+
arith::ConstantIntOp blockXVal =
654+
arith::ConstantIntOp::create(rewriter, loc, blockX, 32);
655+
arith::ConstantIntOp blockYVal =
656+
arith::ConstantIntOp::create(rewriter, loc, blockY, 32);
657+
arith::ConstantIntOp blockZVal =
658+
arith::ConstantIntOp::create(rewriter, loc, blockZ, 32);
659+
660+
// create the number of arguments for the kernel launch, which is the number
661+
// of input tensors + 1 (for the output tensor).
662+
auto numArgsVal =
663+
arith::ConstantIntOp::create(rewriter, loc, cudaAllInputs.size(), 32);
664+
665+
auto argSlotPtr = getIndexFromValue(rewriter, loc, argSlots);
666+
auto argSizePtr = getIndexFromValue(rewriter, loc, argSizes);
667+
668+
// create a call to the cuda shim function to launch the kernel.
669+
registry.call(rewriter, launchGpuOp, CudaShimFn::LaunchBlockPacked,
670+
ValueRange{load_cubin_callee.getResult(0), kname_loaded_index,
671+
blockXVal, blockYVal, blockZVal,
672+
stream.getResult(0), argSlotPtr, argSizePtr,
673+
numArgsVal});
674+
675+
auto sync =
676+
registry.call(rewriter, launchGpuOp, CudaShimFn::StreamSynchronize,
677+
ValueRange{stream.getResult(0)});
678+
679+
memcpyH2DCall->moveAfter(sync);
592680

593681
// add free after the kernel launch.
682+
memref::DeallocOp::create(rewriter, loc, argSlots);
683+
memref::DeallocOp::create(rewriter, loc, argSizes);
684+
594685
for (auto operand : llvm::reverse(devicePtrs)) {
595686
registry.call(rewriter, launchGpuOp, CudaShimFn::Free,
596687
ValueRange{operand, stream.getResult(0)});
597688
}
598689

690+
// clean up
691+
registry.call(rewriter, launchGpuOp, CudaShimFn::StreamDestroy,
692+
ValueRange{stream.getResult(0)});
693+
registry.call(rewriter, launchGpuOp, CudaShimFn::UnloadModule,
694+
ValueRange{load_cubin_callee.getResult(0)});
695+
599696
rewriter.replaceOp(launchGpuOp, outputTensorAlloc);
600697
return success();
601698
}

mlir/cuda-tile/cuda_shim/cuda_shim.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ cuda_shim_load_module_from_file(uint64_t file_path_ptr,
330330
uint64_t /*file_path_nbytes*/) {
331331
auto file_path_cstr =
332332
reinterpret_cast<const char *>(asHostCPtr(file_path_ptr));
333-
// fprintf(stdout, "%s", file_path_cstr);
333+
fprintf(stdout, "%s", file_path_cstr);
334334
CUmodule module = nullptr;
335335
ScopedContext scopedContext;
336336
CUDA_REPORT_IF_ERROR(cuModuleLoad(&module, file_path_cstr));

mlir/cuda-tile/sample/lowering-llvm.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#!/bin/bash
22

3-
./third_party/llvm/bin/mlir-opt sample/gpu-func.mlir \
3+
./third_party/llvm/bin/mlir-opt sample/test.mlir \
44
-canonicalize -cse \
5+
-lower-affine \
56
-convert-scf-to-cf \
67
-convert-arith-to-llvm \
78
-convert-math-to-llvm \
@@ -12,7 +13,7 @@
1213

1314
./third_party/llvm/bin/mlir-translate lowered-llvm-dialect.mlir --mlir-to-llvmir -o lowered.ll
1415

15-
clang++ -O2 lowered.ll cuda_shim/cuda_shim.cc \
16+
clang++ -g -O0 lowered.ll cuda_shim/cuda_shim.cc \
1617
-I/usr/local/cuda/include \
1718
-L/usr/lib/x86_64-linux-gnu \
1819
-lcuda -ldl -lpthread -o cuda_shim/a.out

mlir/cuda-tile/sample/test.mlir

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
module {
2+
func.func private @cuda_shim_unload_module(i64)
3+
func.func private @cuda_shim_stream_destroy(i64)
4+
func.func private @cuda_shim_free(i64, i64)
5+
func.func private @cuda_shim_stream_synchronize(i64)
6+
func.func private @cuda_shim_launch_block_packed(i64, i64, i32, i32, i32, i64, i64, i64, i32)
7+
func.func private @cuda_shim_memcpy_d2h(i64, i64, i64)
8+
func.func private @cuda_shim_memcpy_h2d(i64, i64, i64)
9+
func.func private @cuda_shim_malloc(i64, i64, i1) -> i64
10+
func.func private @cuda_shim_stream_create() -> i64
11+
func.func private @cuda_shim_load_module_from_file(i64, i64) -> i64
12+
func.func private @cuda_debug_dump_float(i64, i32)
13+
memref.global "private" constant @kname : memref<22xi8> = dense<[111, 117, 116, 108, 105, 110, 101, 100, 95, 103, 112, 117, 95, 107, 101, 114, 110, 101, 108, 95, 48, 0]>
14+
memref.global "private" constant @cuda_blob : memref<26xi8> = dense<[47, 116, 109, 112, 47, 99, 117, 100, 97, 95, 116, 105, 108, 101, 45, 57, 52, 100, 50, 56, 48, 46, 98, 105, 110, 0]>
15+
func.func @main() {
16+
%alloc = memref.alloc() : memref<2x4xf32>
17+
%alloc_0 = memref.alloc() : memref<2x4xf32>
18+
%alloc_1 = memref.alloc() : memref<2x4xf32>
19+
%alloc_2 = memref.alloc() : memref<2x4xf32>
20+
%c0 = arith.constant 0 : index
21+
%c1 = arith.constant 1 : index
22+
%c2 = arith.constant 2 : index
23+
%c3 = arith.constant 3 : index
24+
%cst = arith.constant 1.000000e+00 : f32
25+
affine.store %cst, %alloc_2[%c0, %c0] : memref<2x4xf32>
26+
%cst_3 = arith.constant 2.000000e+00 : f32
27+
affine.store %cst_3, %alloc_2[%c0, %c1] : memref<2x4xf32>
28+
%cst_4 = arith.constant 3.000000e+00 : f32
29+
affine.store %cst_4, %alloc_2[%c0, %c2] : memref<2x4xf32>
30+
%cst_5 = arith.constant 9.000000e+00 : f32
31+
affine.store %cst_5, %alloc_2[%c0, %c3] : memref<2x4xf32>
32+
%cst_6 = arith.constant 4.000000e+00 : f32
33+
affine.store %cst_6, %alloc_2[%c1, %c0] : memref<2x4xf32>
34+
%cst_7 = arith.constant 5.000000e+00 : f32
35+
affine.store %cst_7, %alloc_2[%c1, %c1] : memref<2x4xf32>
36+
%cst_8 = arith.constant 6.000000e+00 : f32
37+
affine.store %cst_8, %alloc_2[%c1, %c2] : memref<2x4xf32>
38+
%cst_9 = arith.constant 1.000000e+01 : f32
39+
affine.store %cst_9, %alloc_2[%c1, %c3] : memref<2x4xf32>
40+
%c0_10 = arith.constant 0 : index
41+
%c1_11 = arith.constant 1 : index
42+
%c2_12 = arith.constant 2 : index
43+
%c3_13 = arith.constant 3 : index
44+
%cst_14 = arith.constant 1.100000e+01 : f32
45+
affine.store %cst_14, %alloc_1[%c0_10, %c0_10] : memref<2x4xf32>
46+
%cst_15 = arith.constant 1.200000e+01 : f32
47+
affine.store %cst_15, %alloc_1[%c0_10, %c1_11] : memref<2x4xf32>
48+
%cst_16 = arith.constant 1.300000e+01 : f32
49+
affine.store %cst_16, %alloc_1[%c0_10, %c2_12] : memref<2x4xf32>
50+
%cst_17 = arith.constant 1.400000e+01 : f32
51+
affine.store %cst_17, %alloc_1[%c0_10, %c3_13] : memref<2x4xf32>
52+
%cst_18 = arith.constant 1.500000e+01 : f32
53+
affine.store %cst_18, %alloc_1[%c1_11, %c0_10] : memref<2x4xf32>
54+
%cst_19 = arith.constant 1.600000e+01 : f32
55+
affine.store %cst_19, %alloc_1[%c1_11, %c1_11] : memref<2x4xf32>
56+
%cst_20 = arith.constant 1.700000e+01 : f32
57+
affine.store %cst_20, %alloc_1[%c1_11, %c2_12] : memref<2x4xf32>
58+
%cst_21 = arith.constant 1.800000e+01 : f32
59+
affine.store %cst_21, %alloc_1[%c1_11, %c3_13] : memref<2x4xf32>
60+
%c0_22 = arith.constant 0 : index
61+
%c1_23 = arith.constant 1 : index
62+
%c2_24 = arith.constant 2 : index
63+
%c3_25 = arith.constant 3 : index
64+
%cst_26 = arith.constant 7.000000e+00 : f32
65+
affine.store %cst_26, %alloc_0[%c0_22, %c0_22] : memref<2x4xf32>
66+
%cst_27 = arith.constant 8.000000e+00 : f32
67+
affine.store %cst_27, %alloc_0[%c0_22, %c1_23] : memref<2x4xf32>
68+
%cst_28 = arith.constant 9.000000e+00 : f32
69+
affine.store %cst_28, %alloc_0[%c0_22, %c2_24] : memref<2x4xf32>
70+
%cst_29 = arith.constant 1.300000e+01 : f32
71+
affine.store %cst_29, %alloc_0[%c0_22, %c3_25] : memref<2x4xf32>
72+
%cst_30 = arith.constant 1.000000e+01 : f32
73+
affine.store %cst_30, %alloc_0[%c1_23, %c0_22] : memref<2x4xf32>
74+
%cst_31 = arith.constant 1.100000e+01 : f32
75+
affine.store %cst_31, %alloc_0[%c1_23, %c1_23] : memref<2x4xf32>
76+
%cst_32 = arith.constant 1.200000e+01 : f32
77+
affine.store %cst_32, %alloc_0[%c1_23, %c2_24] : memref<2x4xf32>
78+
%cst_33 = arith.constant 1.400000e+01 : f32
79+
affine.store %cst_33, %alloc_0[%c1_23, %c3_25] : memref<2x4xf32>
80+
%0 = memref.get_global @cuda_blob : memref<26xi8>
81+
%intptr = memref.extract_aligned_pointer_as_index %0 : memref<26xi8> -> index
82+
%1 = arith.index_cast %intptr : index to i64
83+
%2 = memref.get_global @kname : memref<22xi8>
84+
%intptr_34 = memref.extract_aligned_pointer_as_index %2 : memref<22xi8> -> index
85+
%3 = arith.index_cast %intptr_34 : index to i64
86+
%c26_i64 = arith.constant 26 : i64
87+
%4 = call @cuda_shim_load_module_from_file(%1, %c26_i64) : (i64, i64) -> i64
88+
%5 = call @cuda_shim_stream_create() : () -> i64
89+
%alloc_35 = memref.alloc() : memref<4xi64>
90+
%alloc_36 = memref.alloc() : memref<4xi64>
91+
%c32_i64 = arith.constant 32 : i64
92+
%false = arith.constant false
93+
%6 = call @cuda_shim_malloc(%c32_i64, %5, %false) : (i64, i64, i1) -> i64
94+
%intptr_37 = memref.extract_aligned_pointer_as_index %alloc_2 : memref<2x4xf32> -> index
95+
%7 = arith.index_cast %intptr_37 : index to i64
96+
call @cuda_shim_memcpy_h2d(%6, %7, %c32_i64) : (i64, i64, i64) -> ()
97+
%c0_38 = arith.constant 0 : index
98+
memref.store %6, %alloc_35[%c0_38] : memref<4xi64>
99+
%c8_i64 = arith.constant 8 : i64
100+
memref.store %c8_i64, %alloc_36[%c0_38] : memref<4xi64>
101+
%c32_i64_39 = arith.constant 32 : i64
102+
%false_40 = arith.constant false
103+
%8 = call @cuda_shim_malloc(%c32_i64_39, %5, %false_40) : (i64, i64, i1) -> i64
104+
%intptr_41 = memref.extract_aligned_pointer_as_index %alloc_0 : memref<2x4xf32> -> index
105+
%9 = arith.index_cast %intptr_41 : index to i64
106+
call @cuda_shim_memcpy_h2d(%8, %9, %c32_i64_39) : (i64, i64, i64) -> ()
107+
%c1_42 = arith.constant 1 : index
108+
memref.store %8, %alloc_35[%c1_42] : memref<4xi64>
109+
%c8_i64_43 = arith.constant 8 : i64
110+
memref.store %c8_i64_43, %alloc_36[%c1_42] : memref<4xi64>
111+
%c32_i64_44 = arith.constant 32 : i64
112+
%false_45 = arith.constant false
113+
%10 = call @cuda_shim_malloc(%c32_i64_44, %5, %false_45) : (i64, i64, i1) -> i64
114+
%intptr_46 = memref.extract_aligned_pointer_as_index %alloc_1 : memref<2x4xf32> -> index
115+
%11 = arith.index_cast %intptr_46 : index to i64
116+
call @cuda_shim_memcpy_h2d(%10, %11, %c32_i64_44) : (i64, i64, i64) -> ()
117+
%c2_47 = arith.constant 2 : index
118+
memref.store %10, %alloc_35[%c2_47] : memref<4xi64>
119+
%c8_i64_48 = arith.constant 8 : i64
120+
memref.store %c8_i64_48, %alloc_36[%c2_47] : memref<4xi64>
121+
%c32_i64_49 = arith.constant 32 : i64
122+
%false_50 = arith.constant false
123+
%12 = call @cuda_shim_malloc(%c32_i64_49, %5, %false_50) : (i64, i64, i1) -> i64
124+
%intptr_51 = memref.extract_aligned_pointer_as_index %alloc : memref<2x4xf32> -> index
125+
%13 = arith.index_cast %intptr_51 : index to i64
126+
%c3_52 = arith.constant 3 : index
127+
memref.store %12, %alloc_35[%c3_52] : memref<4xi64>
128+
%c8_i64_53 = arith.constant 8 : i64
129+
memref.store %c8_i64_53, %alloc_36[%c3_52] : memref<4xi64>
130+
%c1_i32 = arith.constant 1 : i32
131+
%c1_i32_54 = arith.constant 1 : i32
132+
%c1_i32_55 = arith.constant 1 : i32
133+
%c4_i32 = arith.constant 4 : i32
134+
%intptr_56 = memref.extract_aligned_pointer_as_index %alloc_35 : memref<4xi64> -> index
135+
%14 = arith.index_cast %intptr_56 : index to i64
136+
%intptr_57 = memref.extract_aligned_pointer_as_index %alloc_36 : memref<4xi64> -> index
137+
%15 = arith.index_cast %intptr_57 : index to i64
138+
call @cuda_shim_launch_block_packed(%4, %3, %c1_i32, %c1_i32_54, %c1_i32_55, %5, %14, %15, %c4_i32) : (i64, i64, i32, i32, i32, i64, i64, i64, i32) -> ()
139+
call @cuda_shim_stream_synchronize(%5) : (i64) -> ()
140+
call @cuda_shim_memcpy_d2h(%13, %12, %c32_i64_49) : (i64, i64, i64) -> ()
141+
memref.dealloc %alloc_35 : memref<4xi64>
142+
memref.dealloc %alloc_36 : memref<4xi64>
143+
call @cuda_shim_free(%12, %5) : (i64, i64) -> ()
144+
call @cuda_shim_free(%10, %5) : (i64, i64) -> ()
145+
call @cuda_shim_free(%8, %5) : (i64, i64) -> ()
146+
call @cuda_shim_free(%6, %5) : (i64, i64) -> ()
147+
call @cuda_shim_stream_destroy(%5) : (i64) -> ()
148+
call @cuda_shim_unload_module(%4) : (i64) -> ()
149+
150+
// toy.print %alloc : memref<2x4xf32>
151+
%ci8 = arith.constant 8 : i32
152+
func.call @cuda_debug_dump_float(%13, %ci8) : (i64, i32) -> ()
153+
154+
memref.dealloc %alloc_2 : memref<2x4xf32>
155+
memref.dealloc %alloc_1 : memref<2x4xf32>
156+
memref.dealloc %alloc_0 : memref<2x4xf32>
157+
memref.dealloc %alloc : memref<2x4xf32>
158+
return
159+
}
160+
}

0 commit comments

Comments
 (0)