2323#include " mlir/IR/DialectRegistry.h"
2424#include " mlir/IR/Operation.h"
2525#include " mlir/IR/PatternMatch.h"
26+ #include " mlir/IR/Value.h"
2627#include " mlir/IR/ValueRange.h"
2728#include " mlir/Support/LLVM.h"
2829#include " mlir/Support/TypeID.h"
2930#include " toy/Dialect.h"
3031#include " toy/Passes.h"
32+ #include " llvm/ADT/SmallVector.h"
3133#include " llvm/ADT/StringRef.h"
3234#include " llvm/Support/DebugLog.h"
3335
@@ -430,36 +432,55 @@ memref::GlobalOp createGlobalForStringAttr(mlir::PatternRewriter &rewriter,
430432 return global;
431433}
432434
435+ arith::IndexCastOp getIndexFromValue (mlir::PatternRewriter &rewriter,
436+ Location loc, Value value) {
437+ auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create (
438+ rewriter, loc, rewriter.getIndexType (), value);
439+ auto indexCastOp = arith::IndexCastOp::create (
440+ rewriter, loc, rewriter.getI64Type (), extractOp.getResult ());
441+ return indexCastOp;
442+ }
443+
433444arith::IndexCastOp getIndexFromGlobalMemref (mlir::PatternRewriter &rewriter,
434445 Location loc,
435446 memref::GlobalOp global) {
436447
437448 auto getGlobalOp = memref::GetGlobalOp::create (
438449 rewriter, loc, global.getType (), global.getName ());
439- auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create (
440- rewriter, loc, rewriter.getIndexType (), getGlobalOp.getResult ());
441450
442- auto indexCastOp = arith::IndexCastOp::create (
443- rewriter, loc, rewriter. getI64Type (), extractOp. getResult ());
451+ return getIndexFromValue (rewriter, loc, getGlobalOp. getResult ());
452+ }
444453
445- return indexCastOp;
454+ func::CallOp
455+ createCallToCudaShimMalloc (mlir::PatternRewriter &rewriter, Location loc,
456+ CudaShimRegistry ®istry, func::CallOp stream,
457+ arith::ConstantIntOp nbytesVal, bool isHostShared) {
458+ arith::ConstantIntOp isHostSharedVal;
459+ if (isHostShared) {
460+ isHostSharedVal = arith::ConstantIntOp::create (rewriter, loc, 1 , 1 );
461+ } else {
462+ isHostSharedVal = arith::ConstantIntOp::create (rewriter, loc, 0 , 1 );
463+ }
464+ auto sreamVal = stream.getResult (0 );
465+ auto callee = registry.call (rewriter, stream, CudaShimFn::Malloc,
466+ ValueRange{nbytesVal, sreamVal, isHostSharedVal});
467+ return callee;
446468}
447469
448- struct LanchGpuLowering : public ConversionPattern {
449- LanchGpuLowering (MLIRContext *ctx)
450- : ConversionPattern(toy::LaunchGpuOp::getOperationName(), 1 , ctx) {}
470+ struct LanchGpuLowering : public OpConversionPattern <toy::LaunchGpuOp> {
471+ using OpConversionPattern<toy::LaunchGpuOp>::OpConversionPattern;
451472
452473 LogicalResult
453- matchAndRewrite (Operation *op, ArrayRef<Value> operands ,
474+ matchAndRewrite (toy::LaunchGpuOp launchGpuOp, OpAdaptor adaptor ,
454475 ConversionPatternRewriter &rewriter) const final {
455- auto loc = op ->getLoc ();
456- CudaShimRegistry registry (op ->getParentOfType <ModuleOp>());
476+ auto loc = launchGpuOp ->getLoc ();
477+ CudaShimRegistry registry (launchGpuOp ->getParentOfType <ModuleOp>());
457478
458- toy::LaunchGpuOp launchGpuOp = llvm::cast<toy::LaunchGpuOp>(op);
459479 for (auto ranked_tensor_type : launchGpuOp->getOperands ()) {
460480 if (!llvm::isa<RankedTensorType>(ranked_tensor_type.getType ())) {
461- return rewriter.notifyMatchFailure (op, " expected operand to be a "
462- " ranked tensor type" );
481+ return rewriter.notifyMatchFailure (launchGpuOp,
482+ " expected operand to be a "
483+ " ranked tensor type" );
463484 }
464485 }
465486
@@ -494,15 +515,16 @@ struct LanchGpuLowering : public ConversionPattern {
494515 launchGpuOp->getDiscardableAttr (" cuda_binary_path" );
495516 if (!cudaBinaryPathAttr) {
496517 return rewriter.notifyMatchFailure (
497- op , " expected 'cuda_binary_path' attribute to be present" );
518+ launchGpuOp , " expected 'cuda_binary_path' attribute to be present" );
498519 }
499520
500521 auto cudaBinaryPathStr = llvm::dyn_cast<StringAttr>(cudaBinaryPathAttr);
501522 if (!cudaBinaryPathStr) {
502523 return rewriter.notifyMatchFailure (
503- op , " expected 'cuda_binary_path' attribute to be a string" );
524+ launchGpuOp , " expected 'cuda_binary_path' attribute to be a string" );
504525 }
505526
527+ // add the global memref for the cuda binary path and the kernel name.
506528 auto cuda_blob_memref = createGlobalForStringAttr (
507529 rewriter, launchGpuOp, " cuda_blob" , cudaBinaryPathStr);
508530
@@ -520,25 +542,61 @@ struct LanchGpuLowering : public ConversionPattern {
520542 // Added blob size.
521543 auto blob_size =
522544 llvm::cast<MemRefType>(cuda_blob_memref.getType ()).getShape ()[0 ];
523- arith::ConstantIndexOp blob_size_index =
524- arith::ConstantIndexOp::create (rewriter, loc, blob_size);
525-
526- // handle the input of the launch op, we will create a cuda allocation for
527- // each input tensor.
528- for (auto operand : launchGpuOp->getOperands ()) {
529- auto ranked_tensor_type = llvm::cast<RankedTensorType>(operand.getType ());
545+ auto blob_size_index =
546+ arith::ConstantIntOp::create (rewriter, loc, blob_size, 64 );
547+
548+ // create a call to the cuda shim function to load the cuda binary
549+ auto load_cubin_callee =
550+ registry.call (rewriter, launchGpuOp, CudaShimFn::LoadModuleFromFile,
551+ ValueRange{cuda_blob_index, blob_size_index});
552+
553+ // create a stream for the kernel launch, for simplicity we use the default
554+ // stream (0).
555+ auto stream =
556+ registry.call (rewriter, launchGpuOp, CudaShimFn::StreamCreate);
557+
558+ // we assume the number of output tensors is only 1, and it's the last
559+ // operand of the launch op.
560+ llvm::SmallVector<Value, 8 > devicePtrs;
561+ llvm::SmallVector<Value, 8 > cudaAllInputs;
562+
563+ for (auto operand : adaptor.getOperands ()) {
564+ cudaAllInputs.push_back (operand);
565+ }
566+ cudaAllInputs.push_back (outputTensorAlloc);
567+ mlir::func::CallOp memcpyH2DCall;
568+ for (auto [i, opr] : llvm::enumerate (cudaAllInputs)) {
569+ auto ranked_tensor_type = llvm::cast<MemRefType>(opr.getType ());
530570 auto shape = ranked_tensor_type.getShape ();
571+ auto elem_type = ranked_tensor_type.getElementType ();
572+ auto nbytes = llvm::divideCeil (
573+ shape[0 ] * shape[1 ] * elem_type.getIntOrFloatBitWidth (), 8 );
574+ auto nbytesVal = arith::ConstantIntOp::create (rewriter, loc, nbytes, 64 );
575+ auto device_ptr_callOp = createCallToCudaShimMalloc (
576+ rewriter, loc, registry, stream, nbytesVal, false );
577+
578+ devicePtrs.push_back (device_ptr_callOp.getResult (0 ));
579+
580+ auto host_ptr = getIndexFromValue (rewriter, loc, opr);
581+ registry.call (
582+ rewriter, launchGpuOp, CudaShimFn::MemcpyH2D,
583+ ValueRange{device_ptr_callOp.getResult (0 ), host_ptr, nbytesVal});
584+ if (i >= adaptor.getOperands ().size ()) {
585+ // this is the output tensor, we will add memcpy from device to host for
586+ // it after the kernel launch.
587+ memcpyH2DCall = registry.call (
588+ rewriter, launchGpuOp, CudaShimFn::MemcpyD2H,
589+ ValueRange{device_ptr_callOp.getResult (0 ), host_ptr, nbytesVal});
590+ }
531591 }
532592
533- auto nbytesVal = arith::ConstantIntOp::create (rewriter, loc, 1 , 64 );
534- auto streamVal = arith::ConstantIntOp::create (rewriter, loc, 0 , 64 );
535- auto isHostSharedVal = arith::ConstantIntOp::create (rewriter, loc, 0 , 1 );
536-
537- auto callee =
538- registry.call (rewriter, launchGpuOp, CudaShimFn::Malloc,
539- ValueRange{nbytesVal, streamVal, isHostSharedVal});
593+ // add free after the kernel launch.
594+ for (auto operand : llvm::reverse (devicePtrs)) {
595+ registry.call (rewriter, launchGpuOp, CudaShimFn::Free,
596+ ValueRange{operand, stream.getResult (0 )});
597+ }
540598
541- rewriter.replaceOp (op , outputTensorAlloc);
599+ rewriter.replaceOp (launchGpuOp , outputTensorAlloc);
542600 return success ();
543601 }
544602};
0 commit comments