diff --git a/mlir/include/mlir/Conversion/LinalgToRock/LinalgToRock.h b/mlir/include/mlir/Conversion/LinalgToRock/LinalgToRock.h index 199fca16da61..2eeaea4ce619 100644 --- a/mlir/include/mlir/Conversion/LinalgToRock/LinalgToRock.h +++ b/mlir/include/mlir/Conversion/LinalgToRock/LinalgToRock.h @@ -28,7 +28,7 @@ void populateLinalgToRockConversionPattern(RewritePatternSet &pattern, /// A tensor.insert_slice is said to be a rock.expand_strides bool isRockExpandStride(tensor::InsertSliceOp op); -} +} // namespace rock } // namespace mlir #endif diff --git a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td index 4dcdcf598180..b96f7ca9c9bd 100644 --- a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td +++ b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td @@ -58,7 +58,8 @@ def MIGraphX_LiteralOp : MIGraphX_Op<"literal", class MIGraphX_ElementwiseBinaryOp traits = []> : MIGraphX_Op< name, !listconcat( - traits, [AllElementTypesMatch<["inA", "inB", "output"]>])>, + traits, [Elementwise, + AllElementTypesMatch<["inA", "inB", "output"]>])>, Arguments<(ins AnyMIXRShaped:$inA, AnyMIXRShaped:$inB)>, Results<(outs AnyMIXRShaped:$output)> { let summary = "Elementwise " # name # " of two shaped values with broadcast"; @@ -94,12 +95,14 @@ def MIGraphX_Equal : MIGraphX_ElementwiseBinaryOp<"equal"> { }]; } -def MIGraphX_ClipOp : - MIGraphX_Op<"clip">, - Arguments<(ins AnyMIXRShaped:$x, - AnyMIXRShaped:$minVals, - AnyMIXRShaped:$maxVals)>, - Results<(outs AnyMIXRShaped:$output)> { +def MIGraphX_ClipOp + : MIGraphX_Op< + "clip", [Elementwise, + AllElementTypesMatch<["x", "minVals", "maxVals", "output"]>, + AllShapesMatch<["x", "minVals", "maxVals", "output"]>]>, + Arguments<(ins AnyMIXRShaped:$x, AnyMIXRShaped:$minVals, + AnyMIXRShaped:$maxVals)>, + Results<(outs AnyMIXRShaped:$output)> { let summary = "Elementwise clip"; let description = [{ Elementwise clip: output = min(max(x, minVals), maxVals) @@ -113,7 +116,8 @@ def MIGraphX_ClipOp : // Note: when lowering to kernel calls, MIGraphX represents booleans as i8. // Keep that logic here. def MIGraphX_WhereOp - : MIGraphX_Op<"where", [AllElementTypesMatch<["inA", "inB", "output"]>, + : MIGraphX_Op<"where", [Elementwise, + AllElementTypesMatch<["inA", "inB", "output"]>, AllShapesMatch<["inA", "inB", "output", "cond"]>]>, Arguments<(ins MIXRShapedOf<[I8, SI8, UI8]>:$cond, AnyMIXRShaped:$inA, AnyMIXRShaped:$inB)>, @@ -130,10 +134,10 @@ def MIGraphX_WhereOp // Elementwise unary operations -def MIGraphX_ConvertOp : - MIGraphX_Op<"convert">, - Arguments<(ins AnyMIXRShaped:$inA)>, - Results<(outs AnyMIXRShaped:$output)> { +def MIGraphX_ConvertOp + : MIGraphX_Op<"convert", [Elementwise, AllShapesMatch<["inA", "output"]>]>, + Arguments<(ins AnyMIXRShaped:$inA)>, + Results<(outs AnyMIXRShaped:$output)> { let summary = "Elementwise type conversion"; let description = [{ Type conversion. Due to impedance mismatches between MIGraphX and Tosa, @@ -142,10 +146,12 @@ def MIGraphX_ConvertOp : let assemblyFormat = "$inA attr-dict `:` type($inA) `to` type($output)"; } -class MIGraphX_ElementwiseUnaryOp traits=[]> : - MIGraphX_Op, - Arguments<(ins AnyMIXRShaped:$inA)>, - Results<(outs AnyMIXRShaped:$output)> { +class MIGraphX_ElementwiseUnaryOp traits = []> + : MIGraphX_Op])>, + Arguments<(ins AnyMIXRShaped:$inA)>, + Results<(outs AnyMIXRShaped:$output)> { let summary = "Elementwise " # name; let assemblyFormat = [{ $inA attr-dict `:` type($inA) `->` type($output) diff --git a/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp b/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp index 44636b79d15c..148a5e829d9f 100644 --- a/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp +++ b/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp @@ -334,16 +334,15 @@ LogicalResult ReshapeOp::verify() { << outType.getRank() << ")"; // Check that there is only a single -1 value - int missingDims = llvm::count_if( - dimsAttr.getAsRange(), - [](IntegerAttr a) { return a.getInt() == -1; }); + int missingDims = + llvm::count_if(dimsAttr.getAsRange(), + [](IntegerAttr a) { return a.getInt() == -1; }); if (missingDims > 1) return emitOpError("expected at most one target dimension to be -1"); // Check how many zero dimensions there are - int numZeros = llvm::count_if( - dimsAttr.getAsRange(), - [](IntegerAttr a) { return a.getInt() == 0; }); + int numZeros = llvm::count_if(dimsAttr.getAsRange(), + [](IntegerAttr a) { return a.getInt() == 0; }); if (missingDims > 0 && numZeros > 0) return emitOpError("Cannot mix missing dimensions with zero dimension"); @@ -351,8 +350,8 @@ LogicalResult ReshapeOp::verify() { // Compare dimension values to output shape for (auto [dimVal, outDim] : llvm::zip(dimsAttr, outType.getShape())) { int64_t dimValue = cast(dimVal).getInt(); - // We cannot handle negative dims values that aren't -1 - if (dimValue < -1 ) { + // We cannot handle negative dims values that aren't -1 + if (dimValue < -1) { return emitOpError("Non -1 negative values are not supported"); } diff --git a/mlir/test/Dialect/MIGraphX/invalid.mlir b/mlir/test/Dialect/MIGraphX/invalid.mlir index 898bd003d47d..b7b73f5128c8 100644 --- a/mlir/test/Dialect/MIGraphX/invalid.mlir +++ b/mlir/test/Dialect/MIGraphX/invalid.mlir @@ -1,5 +1,23 @@ // RUN: rocmlir-opt %s -split-input-file -verify-diagnostics +// ---- migraphx.shaped type ---- + +// expected-error @+1 {{migraphx.shaped type has 1 elements in its shape but 2 strides defined}} +func.func @invalid_more_strides_than_shapes(%arg: !migraphx.shaped<1xf32, 1x1>) { + func.return +} + +// ----- + +// expected-error @+1 {{migraphx.shaped type has 2 elements in its shape but 1 strides defined}} +func.func @invalid_more_shapes_than_strides(%arg: !migraphx.shaped<1x1xf32, 1>) { + func.return +} + +// ----- + +// ---- migraphx.reshape ---- + func.func @mlir_reshape_inconsistent_dims(%arg0: !migraphx.shaped<4096x4096xf16, 0x1>) { // expected-error@+1 {{'migraphx.reshape' op dimValue: 64 inconsistent with result dimension 4096}} %0 = migraphx.reshape %arg0 {dims = [64, 128]} : <4096x4096xf16, 0x1> -> <4096x4096xf16, 16536x2> @@ -42,6 +60,10 @@ func.func @mlir_neg_one_with_zero(%arg0: !migraphx.shaped<2x4xf16, 0x1>) { return } +// ----- + +// ---- migraphx.equal ---- + func.func @func_equal(%arg0: !migraphx.shaped<1x36x384x64xi32, 884736x24576x64x1>) -> !migraphx.shaped<1x36x384x64xi32, 884736x24576x64x1> attributes{rock.kernel, rock.arch = ""} { %cst = migraphx.literal (dense<1> : tensor<1x36x384x64xi32>) : <1x36x384x64xi32, 884736x24576x64x1> %0 = migraphx.add %arg0, %cst : <1x36x384x64xi32, 884736x24576x64x1>, <1x36x384x64xi32, 884736x24576x64x1> -> <1x36x384x64xi32, 884736x24576x64x1> @@ -50,6 +72,104 @@ func.func @func_equal(%arg0: !migraphx.shaped<1x36x384x64xi32, 884736x24576x64x1 return %1 : !migraphx.shaped<1x36x384x64xi16, 884736x24576x64x1> } +// ----- + +// ---- migraphx.clip ---- + +func.func @clip_mismatched_element_types(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf16, 8x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + // expected-error @+1 {{op failed to verify that all of {x, minVals, maxVals, output} have same element type}} + %0 = migraphx.clip %arg0, %arg1, %arg2 : <4x8xf32, 8x1>, <4x8xf16, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ----- + +func.func @clip_mismatched_shapes(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + // expected-error @+1 {{op failed to verify that all of {x, minVals, maxVals, output} have same shape}} + %0 = migraphx.clip %arg0, %arg1, %arg2 : <4x8xf32, 8x1>, <4x4xf32, 4x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ----- + +// ---- migraphx.where ---- + +func.func @where_cond_not_bool(%arg0: !migraphx.shaped<4x4xf32, 4x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x4xf32, 4x1>) -> !migraphx.shaped<4x4xf32, 4x1> { + // expected-error @+1 {{'migraphx.where' op operand #0 must be !migraphx.shaped of 8-bit signless integer or 8-bit signed integer or 8-bit unsigned integer values, but got '!migraphx.shaped<4x4xf32, 4x1>'}} + %0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xf32, 4x1>, <4x4xf32, 4x1>, <4x4xf32, 4x1> -> <4x4xf32, 4x1> + return %0 : !migraphx.shaped<4x4xf32, 4x1> +} + +// ----- + +func.func @where_mismatched_types(%arg0: !migraphx.shaped<4x4xi8, 4x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x4xf16, 4x1>) -> !migraphx.shaped<4x4xf32, 4x1> { + // expected-error @+1 {{op failed to verify that all of {inA, inB, output} have same element type}} + %0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xi8, 4x1>, <4x4xf32, 4x1>, <4x4xf16, 4x1> -> <4x4xf32, 4x1> + return %0 : !migraphx.shaped<4x4xf32, 4x1> +} + +// ----- + +func.func @where_mismatched_shapes(%arg0: !migraphx.shaped<4x4xi8, 4x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + // expected-error @+1 {{op failed to verify that all of {inA, inB, output, cond} have same shape}} + %0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xi8, 4x1>, <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ----- + +// ---- migraphx.convert ---- + +func.func @convert_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x4xf16, 4x1> { + // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} + %0 = migraphx.convert %arg0 : <4x8xf32, 8x1> to <4x4xf16, 4x1> + return %0 : !migraphx.shaped<4x4xf16, 4x1> +} + +// ----- + +// ---- migraphx.abs ---- + +func.func @abs_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x4xf32, 4x1> { + // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} + %0 = migraphx.abs %arg0 : <4x8xf32, 8x1> -> <4x4xf32, 4x1> + return %0 : !migraphx.shaped<4x4xf32, 4x1> +} + +// ----- + +// ---- migraphx.exp ---- + +func.func @exp_rank_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<32xf32, 1> { + // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} + %0 = migraphx.exp %arg0 : <4x8xf32, 8x1> -> <32xf32, 1> + return %0 : !migraphx.shaped<32xf32, 1> +} + +// ----- + +// ---- migraphx.relu ---- + +func.func @relu_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<2x8xf32, 8x1> { + // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} + %0 = migraphx.relu %arg0 : <4x8xf32, 8x1> -> <2x8xf32, 8x1> + return %0 : !migraphx.shaped<2x8xf32, 8x1> +} + +// ----- + +// ---- migraphx.sigmoid ---- + +func.func @func_sigmoid_2d_i32(%arg0: !migraphx.shaped<4x8xi32, 8x1>) -> !migraphx.shaped<4x8xi32, 8x1> { + // expected-error @+1 {{only support floating point}} + %0 = migraphx.sigmoid %arg0 : <4x8xi32, 8x1> -> <4x8xi32, 8x1> + return %0 : !migraphx.shaped<4x8xi32, 8x1> +} + +// ----- + +// ---- migraphx.quant_dot ---- + // Test: Only scaleA provided (should fail - both scales required) func.func @quant_dot_only_scale_a( %arg0: !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1>, @@ -228,6 +348,8 @@ func.func @migraphx_quant_dot_f4_n_scales(%arg0: !migraphx.shaped<1x16x512xf4E2M // ----- +// ---- migraphx.dot ---- + // CHECK-LABEL: func.func @dot_rank_less_than_2 func.func @dot_rank_less_than_2(%arg0: !migraphx.shaped<320xf16, 1>, %arg1: !migraphx.shaped<320x64xf16, 64x1>) -> !migraphx.shaped<64xf16, 1> { // expected-error @+1 {{expect operand to have rank greater or equal to 2}} @@ -273,51 +395,7 @@ func.func @dot_result_shape_mismatch(%arg0: !migraphx.shaped<2x3x4xf16, 12x4x1>, // ----- -// expected-error @+1 {{migraphx.shaped type has 1 elements in its shape but 2 strides defined}} -func.func @invalid_more_strides_than_shapes(%arg: !migraphx.shaped<1xf32, 1x1>) { - func.return -} - -// ----- - -// expected-error @+1 {{migraphx.shaped type has 2 elements in its shape but 1 strides defined}} -func.func @invalid_more_shapes_than_strides(%arg: !migraphx.shaped<1x1xf32, 1>) { - func.return -} - -// ----- - -func.func @func_sigmoid_2d_i32(%arg0: !migraphx.shaped<4x8xi32, 8x1>) -> !migraphx.shaped<4x8xi32, 8x1> { - // expected-error @+1 {{only support floating point}} - %0 = migraphx.sigmoid %arg0 : <4x8xi32, 8x1> -> <4x8xi32, 8x1> - return %0 : !migraphx.shaped<4x8xi32, 8x1> -} - -// ----- - -func.func @where_cond_not_bool(%arg0: !migraphx.shaped<4x4xf32, 4x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x4xf32, 4x1>) -> !migraphx.shaped<4x4xf32, 4x1> { - // expected-error @+1 {{'migraphx.where' op operand #0 must be !migraphx.shaped of 8-bit signless integer or 8-bit signed integer or 8-bit unsigned integer values, but got '!migraphx.shaped<4x4xf32, 4x1>'}} - %0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xf32, 4x1>, <4x4xf32, 4x1>, <4x4xf32, 4x1> -> <4x4xf32, 4x1> - return %0 : !migraphx.shaped<4x4xf32, 4x1> -} - -// ----- - -func.func @where_mismatched_types(%arg0: !migraphx.shaped<4x4xi8, 4x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x4xf16, 4x1>) -> !migraphx.shaped<4x4xf32, 4x1> { - // expected-error @+1 {{op failed to verify that all of {inA, inB, output} have same element type}} - %0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xi8, 4x1>, <4x4xf32, 4x1>, <4x4xf16, 4x1> -> <4x4xf32, 4x1> - return %0 : !migraphx.shaped<4x4xf32, 4x1> -} - -// ----- - -func.func @where_mismatched_shapes(%arg0: !migraphx.shaped<4x4xi8, 4x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { - // expected-error @+1 {{op failed to verify that all of {inA, inB, output, cond} have same shape}} - %0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xi8, 4x1>, <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> - return %0 : !migraphx.shaped<4x8xf32, 8x1> -} - -// ----- +// ---- migraphx.slice ---- func.func @invalid_attr_size_mismatch(%input: !migraphx.shaped<10x10xf32, 10x1>) { // expected-error @+1 {{op axes, starts, and ends must have the same size}} diff --git a/mlir/test/Dialect/MIGraphX/ops.mlir b/mlir/test/Dialect/MIGraphX/ops.mlir index f22240c07f9d..302a28b359f6 100644 --- a/mlir/test/Dialect/MIGraphX/ops.mlir +++ b/mlir/test/Dialect/MIGraphX/ops.mlir @@ -2,30 +2,186 @@ // RUN: rocmlir-opt %s | rocmlir-opt | FileCheck %s // RUN: rocmlir-opt -mlir-print-op-generic %s | rocmlir-opt | FileCheck %s -// CHECK-LABEL: func.func @migraphx_dot -// CHECK-NEXT: migraphx.dot -func.func @migraphx_dot(%arg0: !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1>, %arg1: !migraphx.shaped<1x512x16xf4E2M1FN, 8192x16x1>) -> !migraphx.shaped<1x16x16xf32, 256x16x1> { - %0 = migraphx.dot %arg0, %arg1 : <1x16x512xf4E2M1FN, 8192x512x1>, <1x512x16xf4E2M1FN, 8192x16x1> -> <1x16x16xf32, 256x16x1> - return %0 : !migraphx.shaped<1x16x16xf32, 256x16x1> +// ---- migraphx.add ---- + +// CHECK-LABEL: func.func @migraphx_add +// CHECK-NEXT: migraphx.add +func.func @migraphx_add(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.add %arg0, %arg1 : <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.sub ---- +// CHECK-LABEL: func.func @migraphx_sub +// CHECK-NEXT: migraphx.sub +func.func @migraphx_sub(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.sub %arg0, %arg1 : <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} -// CHECK-LABEL: func.func @migraphx_quant_dot_scaled -// CHECK-NEXT: migraphx.quant_dot -func.func @migraphx_quant_dot_scaled(%arg0: !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1>, %arg1: !migraphx.shaped<1x512x16xf4E2M1FN, 8192x16x1>, %arg2: !migraphx.shaped<1x16x512xf8E8M0FNU, 8192x512x1>, %arg3: !migraphx.shaped<1x512x16xf8E8M0FNU, 8192x16x1>) -> !migraphx.shaped<1x16x16xf32, 256x16x1> { - %0 = migraphx.quant_dot - %arg0 scaled by %arg2, - %arg1 scaled by %arg3 - : !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1> scaled by - !migraphx.shaped<1x16x512xf8E8M0FNU, 8192x512x1>, - !migraphx.shaped<1x512x16xf4E2M1FN, 8192x16x1> scaled by - !migraphx.shaped<1x512x16xf8E8M0FNU, 8192x16x1> - -> !migraphx.shaped<1x16x16xf32, 256x16x1> +// ---- migraphx.mul ---- + +// CHECK-LABEL: func.func @migraphx_mul +// CHECK-NEXT: migraphx.mul +func.func @migraphx_mul(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.mul %arg0, %arg1 : <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.div ---- + +// CHECK-LABEL: func.func @migraphx_div +// CHECK-NEXT: migraphx.div +func.func @migraphx_div(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.div %arg0, %arg1 : <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.pow ---- + +// CHECK-LABEL: func.func @migraphx_pow +// CHECK-NEXT: migraphx.pow +func.func @migraphx_pow(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.pow %arg0, %arg1 : <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.greater ---- + +// CHECK-LABEL: func.func @migraphx_greater +// CHECK-NEXT: migraphx.greater +func.func @migraphx_greater(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.greater %arg0, %arg1 : <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.equal ---- + +// CHECK-LABEL: func.func @migraphx_equal +// CHECK-NEXT: migraphx.equal +func.func @migraphx_equal(%arg0: !migraphx.shaped<4x8xi8, 8x1>, %arg1: !migraphx.shaped<4x8xi8, 8x1>) -> !migraphx.shaped<4x8xi8, 8x1> { + %0 = migraphx.equal %arg0, %arg1 : <4x8xi8, 8x1>, <4x8xi8, 8x1> -> <4x8xi8, 8x1> + return %0 : !migraphx.shaped<4x8xi8, 8x1> +} + +// ---- migraphx.clip ---- + +// CHECK-LABEL: func.func @migraphx_clip +// CHECK-NEXT: migraphx.clip +func.func @migraphx_clip(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.clip %arg0, %arg1, %arg2 : <4x8xf32, 8x1>, <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.where ---- + +// CHECK-LABEL: func.func @migraphx_where +// CHECK-NEXT: migraphx.where +func.func @migraphx_where(%cond: !migraphx.shaped<4x8xi8, 8x1>, %arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.where %cond, %arg0, %arg1 : <4x8xi8, 8x1>, <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.convert ---- + +// CHECK-LABEL: func.func @migraphx_convert +// CHECK-NEXT: migraphx.convert +func.func @migraphx_convert(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf16, 8x1> { + %0 = migraphx.convert %arg0 : <4x8xf32, 8x1> to <4x8xf16, 8x1> + return %0 : !migraphx.shaped<4x8xf16, 8x1> +} + +// ---- migraphx.abs ---- + +// CHECK-LABEL: func.func @migraphx_abs +// CHECK-NEXT: migraphx.abs +func.func @migraphx_abs(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.abs %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.ceil ---- + +// CHECK-LABEL: func.func @migraphx_ceil +// CHECK-NEXT: migraphx.ceil +func.func @migraphx_ceil(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.ceil %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.exp ---- + +// CHECK-LABEL: func.func @migraphx_exp +// CHECK-NEXT: migraphx.exp +func.func @migraphx_exp(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.exp %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.neg ---- + +// CHECK-LABEL: func.func @migraphx_neg +// CHECK-NEXT: migraphx.neg +func.func @migraphx_neg(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.neg %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.recip ---- + +// CHECK-LABEL: func.func @migraphx_recip +// CHECK-NEXT: migraphx.recip +func.func @migraphx_recip(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.recip %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.relu ---- + +// CHECK-LABEL: func.func @migraphx_relu +// CHECK-NEXT: migraphx.relu +func.func @migraphx_relu(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.relu %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.sigmoid ---- + +// CHECK-LABEL: func.func @migraphx_sigmoid +// CHECK-NEXT: migraphx.sigmoid +func.func @migraphx_sigmoid(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.sigmoid %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.sqrt ---- + +// CHECK-LABEL: func.func @migraphx_sqrt +// CHECK-NEXT: migraphx.sqrt +func.func @migraphx_sqrt(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.sqrt %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.tanh ---- + +// CHECK-LABEL: func.func @migraphx_tanh +// CHECK-NEXT: migraphx.tanh +func.func @migraphx_tanh(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.tanh %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- migraphx.dot ---- + +// CHECK-LABEL: func.func @migraphx_dot +// CHECK-NEXT: migraphx.dot +func.func @migraphx_dot(%arg0: !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1>, %arg1: !migraphx.shaped<1x512x16xf4E2M1FN, 8192x16x1>) -> !migraphx.shaped<1x16x16xf32, 256x16x1> { + %0 = migraphx.dot %arg0, %arg1 : <1x16x512xf4E2M1FN, 8192x512x1>, <1x512x16xf4E2M1FN, 8192x16x1> -> <1x16x16xf32, 256x16x1> return %0 : !migraphx.shaped<1x16x16xf32, 256x16x1> } -// Checking to see if the verifier allows for broadcast // CHECK-LABEL: func.func @migraphx_dot_no_batch_b // CHECK-NEXT: migraphx.dot func.func @migraphx_dot_no_batch_b(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<2x2xf16, 2x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> { @@ -46,3 +202,19 @@ func.func @migraphx_dot_leading_ones_b_rank4(%arg0: !migraphx.shaped<3x2x2x2xf16 %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <1x1x2x2xf16, 4x2x1x1> -> <3x2x2x2xf16, 8x4x2x1> return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> } + +// ---- migraphx.quant_dot ---- + +// CHECK-LABEL: func.func @migraphx_quant_dot_scaled +// CHECK-NEXT: migraphx.quant_dot +func.func @migraphx_quant_dot_scaled(%arg0: !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1>, %arg1: !migraphx.shaped<1x512x16xf4E2M1FN, 8192x16x1>, %arg2: !migraphx.shaped<1x16x512xf8E8M0FNU, 8192x512x1>, %arg3: !migraphx.shaped<1x512x16xf8E8M0FNU, 8192x16x1>) -> !migraphx.shaped<1x16x16xf32, 256x16x1> { + %0 = migraphx.quant_dot + %arg0 scaled by %arg2, + %arg1 scaled by %arg3 + : !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1> scaled by + !migraphx.shaped<1x16x512xf8E8M0FNU, 8192x512x1>, + !migraphx.shaped<1x512x16xf4E2M1FN, 8192x16x1> scaled by + !migraphx.shaped<1x512x16xf8E8M0FNU, 8192x16x1> + -> !migraphx.shaped<1x16x16xf32, 256x16x1> + return %0 : !migraphx.shaped<1x16x16xf32, 256x16x1> +}