@@ -373,16 +373,20 @@ void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
373373// ===----------------------------------------------------------------------===//
374374
375375llvm::LogicalResult ReturnOp::verify () {
376- // We know that the parent operation is a function, because of the 'HasParent'
377- // trait attached to the operation definition.
378- auto function = cast<FuncOp>((*this )->getParentOp ());
376+ // Parent can be FuncOp or GPUFuncOp; both implement FunctionOpInterface.
377+ auto *parent = (*this )->getParentOp ();
378+ auto function = dyn_cast<FunctionOpInterface>(parent);
379+ if (!function)
380+ return emitOpError () << " must be enclosed in a function-like op" ;
381+
379382
380383 // / ReturnOps can only have a single optional operand.
381384 if (getNumOperands () > 1 )
382385 return emitOpError () << " expects at most 1 return operand" ;
383386
384387 // The operand number and types must match the function signature.
385- const auto &results = function.getFunctionType ().getResults ();
388+ auto funcType = llvm::cast<FunctionType>(function.getFunctionType ());
389+ const auto &results = funcType.getResults ();
386390 if (getNumOperands () != results.size ())
387391 return emitOpError () << " does not return the same number of values ("
388392 << getNumOperands () << " ) as the enclosing function ("
@@ -489,6 +493,79 @@ llvm::LogicalResult MatMulOp::verify() {
489493 return mlir::success ();
490494}
491495
496+ // ===----------------------------------------------------------------------===//
497+ // LaunchGpuOp
498+ // ===----------------------------------------------------------------------===//
499+
500+ void LaunchGpuOp::build (mlir::OpBuilder &builder, mlir::OperationState &state,
501+ StringRef callee, ArrayRef<mlir::Value> arguments) {
502+ // Generic call always returns an unranked Tensor initially.
503+ state.addTypes (UnrankedTensorType::get (builder.getF32Type ()));
504+ state.addOperands (arguments);
505+ state.addAttribute (" callee" ,
506+ mlir::SymbolRefAttr::get (builder.getContext (), callee));
507+ state.addAttribute (" grid" , builder.getI64ArrayAttr ({1 , 1 , 1 }));
508+ }
509+
510+ // / Return the callee of the generic call operation, this is required by the
511+ // / call interface.
512+ CallInterfaceCallable LaunchGpuOp::getCallableForCallee () {
513+ return (*this )->getAttrOfType <SymbolRefAttr>(" callee" );
514+ }
515+
516+ // / Set the callee for the generic call operation, this is required by the call
517+ // / interface.
518+ void LaunchGpuOp::setCalleeFromCallable (CallInterfaceCallable callee) {
519+ (*this )->setAttr (" callee" , cast<SymbolRefAttr>(callee));
520+ }
521+
522+ // / Get the argument operands to the called function, this is required by the
523+ // / call interface.
524+ Operation::operand_range LaunchGpuOp::getArgOperands () { return getInputs (); }
525+
526+ // / Get the argument operands to the called function as a mutable range, this is
527+ // / required by the call interface.
528+ MutableOperandRange LaunchGpuOp::getArgOperandsMutable () {
529+ return getInputsMutable ();
530+ }
531+
532+
533+ // ===----------------------------------------------------------------------===//
534+ // GPUFuncOp
535+ // ===----------------------------------------------------------------------===//
536+
537+ void GPUFuncOp::build (mlir::OpBuilder &builder, mlir::OperationState &state,
538+ llvm::StringRef name, mlir::FunctionType type,
539+ llvm::ArrayRef<mlir::NamedAttribute> attrs) {
540+ // FunctionOpInterface provides a convenient `build` method that will populate
541+ // the state of our GPUFuncOp, and create an entry block.
542+ buildWithEntryBlock (builder, state, name, type, attrs, type.getInputs ());
543+ }
544+
545+ mlir::ParseResult GPUFuncOp::parse (mlir::OpAsmParser &parser,
546+ mlir::OperationState &result) {
547+ // Dispatch to the FunctionOpInterface provided utility method that parses the
548+ // function operation.
549+ auto buildFuncType =
550+ [](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
551+ llvm::ArrayRef<mlir::Type> results,
552+ mlir::function_interface_impl::VariadicFlag,
553+ std::string &) { return builder.getFunctionType (argTypes, results); };
554+
555+ return mlir::function_interface_impl::parseFunctionOp (
556+ parser, result, /* allowVariadic=*/ false ,
557+ getFunctionTypeAttrName (result.name ), buildFuncType,
558+ getArgAttrsAttrName (result.name ), getResAttrsAttrName (result.name ));
559+ }
560+
561+ void GPUFuncOp::print (mlir::OpAsmPrinter &p) {
562+ // Dispatch to the FunctionOpInterface provided utility method that prints the
563+ // function operation.
564+ mlir::function_interface_impl::printFunctionOp (
565+ p, *this , /* isVariadic=*/ false , getFunctionTypeAttrName (),
566+ getArgAttrsAttrName (), getResAttrsAttrName ());
567+ }
568+
492569// ===----------------------------------------------------------------------===//
493570// TableGen'd op method definitions
494571// ===----------------------------------------------------------------------===//
0 commit comments