Skip to content

Commit dbf81a7

Browse files
committed
Added the input and allocation for the cuda shim
1 parent b9cfa7e commit dbf81a7

File tree

1 file changed

+89
-31
lines changed

1 file changed

+89
-31
lines changed

mlir/cuda-tile/Toy/mlir/LowerToAffineLoops.cpp

Lines changed: 89 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
#include "mlir/IR/DialectRegistry.h"
2424
#include "mlir/IR/Operation.h"
2525
#include "mlir/IR/PatternMatch.h"
26+
#include "mlir/IR/Value.h"
2627
#include "mlir/IR/ValueRange.h"
2728
#include "mlir/Support/LLVM.h"
2829
#include "mlir/Support/TypeID.h"
2930
#include "toy/Dialect.h"
3031
#include "toy/Passes.h"
32+
#include "llvm/ADT/SmallVector.h"
3133
#include "llvm/ADT/StringRef.h"
3234
#include "llvm/Support/DebugLog.h"
3335

@@ -430,36 +432,55 @@ memref::GlobalOp createGlobalForStringAttr(mlir::PatternRewriter &rewriter,
430432
return global;
431433
}
432434

435+
arith::IndexCastOp getIndexFromValue(mlir::PatternRewriter &rewriter,
436+
Location loc, Value value) {
437+
auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
438+
rewriter, loc, rewriter.getIndexType(), value);
439+
auto indexCastOp = arith::IndexCastOp::create(
440+
rewriter, loc, rewriter.getI64Type(), extractOp.getResult());
441+
return indexCastOp;
442+
}
443+
433444
arith::IndexCastOp getIndexFromGlobalMemref(mlir::PatternRewriter &rewriter,
434445
Location loc,
435446
memref::GlobalOp global) {
436447

437448
auto getGlobalOp = memref::GetGlobalOp::create(
438449
rewriter, loc, global.getType(), global.getName());
439-
auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
440-
rewriter, loc, rewriter.getIndexType(), getGlobalOp.getResult());
441450

442-
auto indexCastOp = arith::IndexCastOp::create(
443-
rewriter, loc, rewriter.getI64Type(), extractOp.getResult());
451+
return getIndexFromValue(rewriter, loc, getGlobalOp.getResult());
452+
}
444453

445-
return indexCastOp;
454+
func::CallOp
455+
createCallToCudaShimMalloc(mlir::PatternRewriter &rewriter, Location loc,
456+
CudaShimRegistry &registry, func::CallOp stream,
457+
arith::ConstantIntOp nbytesVal, bool isHostShared) {
458+
arith::ConstantIntOp isHostSharedVal;
459+
if (isHostShared) {
460+
isHostSharedVal = arith::ConstantIntOp::create(rewriter, loc, 1, 1);
461+
} else {
462+
isHostSharedVal = arith::ConstantIntOp::create(rewriter, loc, 0, 1);
463+
}
464+
auto sreamVal = stream.getResult(0);
465+
auto callee = registry.call(rewriter, stream, CudaShimFn::Malloc,
466+
ValueRange{nbytesVal, sreamVal, isHostSharedVal});
467+
return callee;
446468
}
447469

448-
struct LanchGpuLowering : public ConversionPattern {
449-
LanchGpuLowering(MLIRContext *ctx)
450-
: ConversionPattern(toy::LaunchGpuOp::getOperationName(), 1, ctx) {}
470+
struct LanchGpuLowering : public OpConversionPattern<toy::LaunchGpuOp> {
471+
using OpConversionPattern<toy::LaunchGpuOp>::OpConversionPattern;
451472

452473
LogicalResult
453-
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
474+
matchAndRewrite(toy::LaunchGpuOp launchGpuOp, OpAdaptor adaptor,
454475
ConversionPatternRewriter &rewriter) const final {
455-
auto loc = op->getLoc();
456-
CudaShimRegistry registry(op->getParentOfType<ModuleOp>());
476+
auto loc = launchGpuOp->getLoc();
477+
CudaShimRegistry registry(launchGpuOp->getParentOfType<ModuleOp>());
457478

458-
toy::LaunchGpuOp launchGpuOp = llvm::cast<toy::LaunchGpuOp>(op);
459479
for (auto ranked_tensor_type : launchGpuOp->getOperands()) {
460480
if (!llvm::isa<RankedTensorType>(ranked_tensor_type.getType())) {
461-
return rewriter.notifyMatchFailure(op, "expected operand to be a "
462-
"ranked tensor type");
481+
return rewriter.notifyMatchFailure(launchGpuOp,
482+
"expected operand to be a "
483+
"ranked tensor type");
463484
}
464485
}
465486

@@ -494,15 +515,16 @@ struct LanchGpuLowering : public ConversionPattern {
494515
launchGpuOp->getDiscardableAttr("cuda_binary_path");
495516
if (!cudaBinaryPathAttr) {
496517
return rewriter.notifyMatchFailure(
497-
op, "expected 'cuda_binary_path' attribute to be present");
518+
launchGpuOp, "expected 'cuda_binary_path' attribute to be present");
498519
}
499520

500521
auto cudaBinaryPathStr = llvm::dyn_cast<StringAttr>(cudaBinaryPathAttr);
501522
if (!cudaBinaryPathStr) {
502523
return rewriter.notifyMatchFailure(
503-
op, "expected 'cuda_binary_path' attribute to be a string");
524+
launchGpuOp, "expected 'cuda_binary_path' attribute to be a string");
504525
}
505526

527+
// add the global memref for the cuda binary path and the kernel name.
506528
auto cuda_blob_memref = createGlobalForStringAttr(
507529
rewriter, launchGpuOp, "cuda_blob", cudaBinaryPathStr);
508530

@@ -520,25 +542,61 @@ struct LanchGpuLowering : public ConversionPattern {
520542
// Added blob size.
521543
auto blob_size =
522544
llvm::cast<MemRefType>(cuda_blob_memref.getType()).getShape()[0];
523-
arith::ConstantIndexOp blob_size_index =
524-
arith::ConstantIndexOp::create(rewriter, loc, blob_size);
525-
526-
// handle the input of the launch op, we will create a cuda allocation for
527-
// each input tensor.
528-
for (auto operand : launchGpuOp->getOperands()) {
529-
auto ranked_tensor_type = llvm::cast<RankedTensorType>(operand.getType());
545+
auto blob_size_index =
546+
arith::ConstantIntOp::create(rewriter, loc, blob_size, 64);
547+
548+
// create a call to the cuda shim function to load the cuda binary
549+
auto load_cubin_callee =
550+
registry.call(rewriter, launchGpuOp, CudaShimFn::LoadModuleFromFile,
551+
ValueRange{cuda_blob_index, blob_size_index});
552+
553+
// create a stream for the kernel launch, for simplicity we use the default
554+
// stream (0).
555+
auto stream =
556+
registry.call(rewriter, launchGpuOp, CudaShimFn::StreamCreate);
557+
558+
// we assume the number of output tensors is only 1, and it's the last
559+
// operand of the launch op.
560+
llvm::SmallVector<Value, 8> devicePtrs;
561+
llvm::SmallVector<Value, 8> cudaAllInputs;
562+
563+
for (auto operand : adaptor.getOperands()) {
564+
cudaAllInputs.push_back(operand);
565+
}
566+
cudaAllInputs.push_back(outputTensorAlloc);
567+
mlir::func::CallOp memcpyH2DCall;
568+
for (auto [i, opr] : llvm::enumerate(cudaAllInputs)) {
569+
auto ranked_tensor_type = llvm::cast<MemRefType>(opr.getType());
530570
auto shape = ranked_tensor_type.getShape();
571+
auto elem_type = ranked_tensor_type.getElementType();
572+
auto nbytes = llvm::divideCeil(
573+
shape[0] * shape[1] * elem_type.getIntOrFloatBitWidth(), 8);
574+
auto nbytesVal = arith::ConstantIntOp::create(rewriter, loc, nbytes, 64);
575+
auto device_ptr_callOp = createCallToCudaShimMalloc(
576+
rewriter, loc, registry, stream, nbytesVal, false);
577+
578+
devicePtrs.push_back(device_ptr_callOp.getResult(0));
579+
580+
auto host_ptr = getIndexFromValue(rewriter, loc, opr);
581+
registry.call(
582+
rewriter, launchGpuOp, CudaShimFn::MemcpyH2D,
583+
ValueRange{device_ptr_callOp.getResult(0), host_ptr, nbytesVal});
584+
if (i >= adaptor.getOperands().size()) {
585+
// this is the output tensor, we will add memcpy from device to host for
586+
// it after the kernel launch.
587+
memcpyH2DCall = registry.call(
588+
rewriter, launchGpuOp, CudaShimFn::MemcpyD2H,
589+
ValueRange{device_ptr_callOp.getResult(0), host_ptr, nbytesVal});
590+
}
531591
}
532592

533-
auto nbytesVal = arith::ConstantIntOp::create(rewriter, loc, 1, 64);
534-
auto streamVal = arith::ConstantIntOp::create(rewriter, loc, 0, 64);
535-
auto isHostSharedVal = arith::ConstantIntOp::create(rewriter, loc, 0, 1);
536-
537-
auto callee =
538-
registry.call(rewriter, launchGpuOp, CudaShimFn::Malloc,
539-
ValueRange{nbytesVal, streamVal, isHostSharedVal});
593+
// add free after the kernel launch.
594+
for (auto operand : llvm::reverse(devicePtrs)) {
595+
registry.call(rewriter, launchGpuOp, CudaShimFn::Free,
596+
ValueRange{operand, stream.getResult(0)});
597+
}
540598

541-
rewriter.replaceOp(op, outputTensorAlloc);
599+
rewriter.replaceOp(launchGpuOp, outputTensorAlloc);
542600
return success();
543601
}
544602
};

0 commit comments

Comments
 (0)