From dd13fcf5bea079fec9c6e5bdb52ebf0dee1e7921 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 8 Apr 2026 22:20:30 +0000 Subject: [PATCH 1/6] Add Elementwise trait to MIGraphX dialect elementwise ops Mark all MIGraphX elementwise operations (unary, binary, clip, where, convert) with the MLIR Elementwise trait. This enables programmatic querying of elementwise ops via op.hasTrait() instead of exhaustive isa<> lists. Also add AllShapesMatch<["inA", "output"]> to unary ops and convert to enforce that elementwise unary operations preserve shape. Made-with: Cursor --- .../mlir/Dialect/MIGraphX/IR/MIGraphX.td | 32 ++-- mlir/test/Dialect/MIGraphX/invalid.mlir | 32 ++++ mlir/test/Dialect/MIGraphX/ops.mlir | 137 ++++++++++++++++++ 3 files changed, 185 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td index 9e5006c2d6a6..274dcc33b31b 100644 --- a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td +++ b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td @@ -58,7 +58,8 @@ def MIGraphX_LiteralOp : MIGraphX_Op<"literal", class MIGraphX_ElementwiseBinaryOp traits = []> : MIGraphX_Op< name, !listconcat( - traits, [AllElementTypesMatch<["inA", "inB", "output"]>])>, + traits, [Elementwise, + AllElementTypesMatch<["inA", "inB", "output"]>])>, Arguments<(ins AnyMIXRShaped:$inA, AnyMIXRShaped:$inB)>, Results<(outs AnyMIXRShaped:$output)> { let summary = "Elementwise " # name # " of two shaped values with broadcast"; @@ -94,12 +95,10 @@ def MIGraphX_Equal : MIGraphX_ElementwiseBinaryOp<"equal"> { }]; } -def MIGraphX_ClipOp : - MIGraphX_Op<"clip">, - Arguments<(ins AnyMIXRShaped:$x, - AnyMIXRShaped:$minVals, - AnyMIXRShaped:$maxVals)>, - Results<(outs AnyMIXRShaped:$output)> { +def MIGraphX_ClipOp : MIGraphX_Op<"clip", [Elementwise]>, + Arguments<(ins AnyMIXRShaped:$x, AnyMIXRShaped:$minVals, + AnyMIXRShaped:$maxVals)>, + Results<(outs AnyMIXRShaped:$output)> { let summary = "Elementwise clip"; let description = [{ Elementwise clip: output = min(max(x, minVals), maxVals) @@ -113,7 +112,7 @@ def MIGraphX_ClipOp : // Note: when lowering to kernel calls, MIGraphX represents booleans as i8. // Keep that logic here. def MIGraphX_WhereOp - : MIGraphX_Op<"where", [AllElementTypesMatch<["inA", "inB", "output"]>, + : MIGraphX_Op<"where", [Elementwise, AllElementTypesMatch<["inA", "inB", "output"]>, AllShapesMatch<["inA", "inB", "output", "cond"]>]>, Arguments<(ins MIXRShapedOf<[I8, SI8, UI8]>:$cond, AnyMIXRShaped:$inA, AnyMIXRShaped:$inB)>, @@ -130,10 +129,10 @@ def MIGraphX_WhereOp // Elementwise unary operations -def MIGraphX_ConvertOp : - MIGraphX_Op<"convert">, - Arguments<(ins AnyMIXRShaped:$inA)>, - Results<(outs AnyMIXRShaped:$output)> { +def MIGraphX_ConvertOp : MIGraphX_Op<"convert", + [Elementwise, AllShapesMatch<["inA", "output"]>]>, + Arguments<(ins AnyMIXRShaped:$inA)>, + Results<(outs AnyMIXRShaped:$output)> { let summary = "Elementwise type conversion"; let description = [{ Type conversion. Due to impedance mismatches between MIGraphX and Tosa, @@ -142,10 +141,11 @@ def MIGraphX_ConvertOp : let assemblyFormat = "$inA attr-dict `:` type($inA) `to` type($output)"; } -class MIGraphX_ElementwiseUnaryOp traits=[]> : - MIGraphX_Op, - Arguments<(ins AnyMIXRShaped:$inA)>, - Results<(outs AnyMIXRShaped:$output)> { +class MIGraphX_ElementwiseUnaryOp traits = []> + : MIGraphX_Op])>, + Arguments<(ins AnyMIXRShaped:$inA)>, + Results<(outs AnyMIXRShaped:$output)> { let summary = "Elementwise " # name; let assemblyFormat = [{ $inA attr-dict `:` type($inA) `->` type($output) diff --git a/mlir/test/Dialect/MIGraphX/invalid.mlir b/mlir/test/Dialect/MIGraphX/invalid.mlir index e2c48ca3b5e5..49bf2db3edd3 100644 --- a/mlir/test/Dialect/MIGraphX/invalid.mlir +++ b/mlir/test/Dialect/MIGraphX/invalid.mlir @@ -316,3 +316,35 @@ func.func @where_mismatched_shapes(%arg0: !migraphx.shaped<4x4xi8, 4x1>, %arg1: %0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xi8, 4x1>, <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> return %0 : !migraphx.shaped<4x8xf32, 8x1> } + +// ----- + +func.func @abs_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x4xf32, 4x1> { + // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} + %0 = migraphx.abs %arg0 : <4x8xf32, 8x1> -> <4x4xf32, 4x1> + return %0 : !migraphx.shaped<4x4xf32, 4x1> +} + +// ----- + +func.func @relu_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<2x8xf32, 8x1> { + // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} + %0 = migraphx.relu %arg0 : <4x8xf32, 8x1> -> <2x8xf32, 8x1> + return %0 : !migraphx.shaped<2x8xf32, 8x1> +} + +// ----- + +func.func @convert_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x4xf16, 4x1> { + // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} + %0 = migraphx.convert %arg0 : <4x8xf32, 8x1> to <4x4xf16, 4x1> + return %0 : !migraphx.shaped<4x4xf16, 4x1> +} + +// ----- + +func.func @exp_rank_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<32xf32, 1> { + // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} + %0 = migraphx.exp %arg0 : <4x8xf32, 8x1> -> <32xf32, 1> + return %0 : !migraphx.shaped<32xf32, 1> +} diff --git a/mlir/test/Dialect/MIGraphX/ops.mlir b/mlir/test/Dialect/MIGraphX/ops.mlir index f22240c07f9d..adf50d8bf4ac 100644 --- a/mlir/test/Dialect/MIGraphX/ops.mlir +++ b/mlir/test/Dialect/MIGraphX/ops.mlir @@ -46,3 +46,140 @@ func.func @migraphx_dot_leading_ones_b_rank4(%arg0: !migraphx.shaped<3x2x2x2xf16 %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <1x1x2x2xf16, 4x2x1x1> -> <3x2x2x2xf16, 8x4x2x1> return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> } + +// ---- Elementwise binary ops ---- + +// CHECK-LABEL: func.func @migraphx_add +// CHECK-NEXT: migraphx.add +func.func @migraphx_add(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.add %arg0, %arg1 : <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_sub +// CHECK-NEXT: migraphx.sub +func.func @migraphx_sub(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.sub %arg0, %arg1 : <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_mul +// CHECK-NEXT: migraphx.mul +func.func @migraphx_mul(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.mul %arg0, %arg1 : <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_div +// CHECK-NEXT: migraphx.div +func.func @migraphx_div(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.div %arg0, %arg1 : <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_pow +// CHECK-NEXT: migraphx.pow +func.func @migraphx_pow(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.pow %arg0, %arg1 : <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_greater +// CHECK-NEXT: migraphx.greater +func.func @migraphx_greater(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.greater %arg0, %arg1 : <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_equal +// CHECK-NEXT: migraphx.equal +func.func @migraphx_equal(%arg0: !migraphx.shaped<4x8xi8, 8x1>, %arg1: !migraphx.shaped<4x8xi8, 8x1>) -> !migraphx.shaped<4x8xi8, 8x1> { + %0 = migraphx.equal %arg0, %arg1 : <4x8xi8, 8x1>, <4x8xi8, 8x1> -> <4x8xi8, 8x1> + return %0 : !migraphx.shaped<4x8xi8, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_clip +// CHECK-NEXT: migraphx.clip +func.func @migraphx_clip(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.clip %arg0, %arg1, %arg2 : <4x8xf32, 8x1>, <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_where +// CHECK-NEXT: migraphx.where +func.func @migraphx_where(%cond: !migraphx.shaped<4x8xi8, 8x1>, %arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.where %cond, %arg0, %arg1 : <4x8xi8, 8x1>, <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ---- Elementwise unary ops ---- + +// CHECK-LABEL: func.func @migraphx_abs +// CHECK-NEXT: migraphx.abs +func.func @migraphx_abs(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.abs %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_ceil +// CHECK-NEXT: migraphx.ceil +func.func @migraphx_ceil(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.ceil %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_exp +// CHECK-NEXT: migraphx.exp +func.func @migraphx_exp(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.exp %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_neg +// CHECK-NEXT: migraphx.neg +func.func @migraphx_neg(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.neg %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_recip +// CHECK-NEXT: migraphx.recip +func.func @migraphx_recip(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.recip %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_relu +// CHECK-NEXT: migraphx.relu +func.func @migraphx_relu(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.relu %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_sigmoid +// CHECK-NEXT: migraphx.sigmoid +func.func @migraphx_sigmoid(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.sigmoid %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_sqrt +// CHECK-NEXT: migraphx.sqrt +func.func @migraphx_sqrt(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.sqrt %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_tanh +// CHECK-NEXT: migraphx.tanh +func.func @migraphx_tanh(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + %0 = migraphx.tanh %arg0 : <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// CHECK-LABEL: func.func @migraphx_convert +// CHECK-NEXT: migraphx.convert +func.func @migraphx_convert(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf16, 8x1> { + %0 = migraphx.convert %arg0 : <4x8xf32, 8x1> to <4x8xf16, 8x1> + return %0 : !migraphx.shaped<4x8xf16, 8x1> +} From 2144972b880bf4bc29a3b73a66d7f9b97f692d2c Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 8 Apr 2026 22:27:51 +0000 Subject: [PATCH 2/6] [NFC] Group MIGraphX dialect tests by op Reorganize ops.mlir and invalid.mlir so that all tests for a given op are contiguous. Each op now has a section header comment and tests are ordered following the TD file definition order. Previously the first split of invalid.mlir mixed reshape, equal, and quant_dot tests together, and sigmoid was separated from other unary op tests. Made-with: Cursor --- mlir/test/Dialect/MIGraphX/invalid.mlir | 182 ++++++++++++++---------- mlir/test/Dialect/MIGraphX/ops.mlir | 139 +++++++++++------- 2 files changed, 191 insertions(+), 130 deletions(-) diff --git a/mlir/test/Dialect/MIGraphX/invalid.mlir b/mlir/test/Dialect/MIGraphX/invalid.mlir index 49bf2db3edd3..bbf8f91ca760 100644 --- a/mlir/test/Dialect/MIGraphX/invalid.mlir +++ b/mlir/test/Dialect/MIGraphX/invalid.mlir @@ -1,5 +1,23 @@ // RUN: rocmlir-opt %s -split-input-file -verify-diagnostics +// ---- migraphx.shaped type ---- + +// expected-error @+1 {{migraphx.shaped type has 1 elements in its shape but 2 strides defined}} +func.func @invalid_more_strides_than_shapes(%arg: !migraphx.shaped<1xf32, 1x1>) { + func.return +} + +// ----- + +// expected-error @+1 {{migraphx.shaped type has 2 elements in its shape but 1 strides defined}} +func.func @invalid_more_shapes_than_strides(%arg: !migraphx.shaped<1x1xf32, 1>) { + func.return +} + +// ----- + +// ---- migraphx.reshape ---- + func.func @mlir_reshape_inconsistent_dims(%arg0: !migraphx.shaped<4096x4096xf16, 0x1>) { // expected-error@+1 {{'migraphx.reshape' op dimValue: 64 inconsistent with result dimension 4096}} %0 = migraphx.reshape %arg0 {dims = [64, 128]} : <4096x4096xf16, 0x1> -> <4096x4096xf16, 16536x2> @@ -42,6 +60,10 @@ func.func @mlir_neg_one_with_zero(%arg0: !migraphx.shaped<2x4xf16, 0x1>) { return } +// ----- + +// ---- migraphx.equal ---- + func.func @func_equal(%arg0: !migraphx.shaped<1x36x384x64xi32, 884736x24576x64x1>) -> !migraphx.shaped<1x36x384x64xi32, 884736x24576x64x1> attributes{kernel, arch = ""} { %cst = migraphx.literal (dense<1> : tensor<1x36x384x64xi32>) : <1x36x384x64xi32, 884736x24576x64x1> %0 = migraphx.add %arg0, %cst : <1x36x384x64xi32, 884736x24576x64x1>, <1x36x384x64xi32, 884736x24576x64x1> -> <1x36x384x64xi32, 884736x24576x64x1> @@ -50,6 +72,86 @@ func.func @func_equal(%arg0: !migraphx.shaped<1x36x384x64xi32, 884736x24576x64x1 return %1 : !migraphx.shaped<1x36x384x64xi16, 884736x24576x64x1> } +// ----- + +// ---- migraphx.where ---- + +func.func @where_cond_not_bool(%arg0: !migraphx.shaped<4x4xf32, 4x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x4xf32, 4x1>) -> !migraphx.shaped<4x4xf32, 4x1> { + // expected-error @+1 {{'migraphx.where' op operand #0 must be !migraphx.shaped of 8-bit signless integer or 8-bit signed integer or 8-bit unsigned integer values, but got '!migraphx.shaped<4x4xf32, 4x1>'}} + %0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xf32, 4x1>, <4x4xf32, 4x1>, <4x4xf32, 4x1> -> <4x4xf32, 4x1> + return %0 : !migraphx.shaped<4x4xf32, 4x1> +} + +// ----- + +func.func @where_mismatched_types(%arg0: !migraphx.shaped<4x4xi8, 4x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x4xf16, 4x1>) -> !migraphx.shaped<4x4xf32, 4x1> { + // expected-error @+1 {{op failed to verify that all of {inA, inB, output} have same element type}} + %0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xi8, 4x1>, <4x4xf32, 4x1>, <4x4xf16, 4x1> -> <4x4xf32, 4x1> + return %0 : !migraphx.shaped<4x4xf32, 4x1> +} + +// ----- + +func.func @where_mismatched_shapes(%arg0: !migraphx.shaped<4x4xi8, 4x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + // expected-error @+1 {{op failed to verify that all of {inA, inB, output, cond} have same shape}} + %0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xi8, 4x1>, <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ----- + +// ---- migraphx.convert ---- + +func.func @convert_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x4xf16, 4x1> { + // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} + %0 = migraphx.convert %arg0 : <4x8xf32, 8x1> to <4x4xf16, 4x1> + return %0 : !migraphx.shaped<4x4xf16, 4x1> +} + +// ----- + +// ---- migraphx.abs ---- + +func.func @abs_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x4xf32, 4x1> { + // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} + %0 = migraphx.abs %arg0 : <4x8xf32, 8x1> -> <4x4xf32, 4x1> + return %0 : !migraphx.shaped<4x4xf32, 4x1> +} + +// ----- + +// ---- migraphx.exp ---- + +func.func @exp_rank_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<32xf32, 1> { + // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} + %0 = migraphx.exp %arg0 : <4x8xf32, 8x1> -> <32xf32, 1> + return %0 : !migraphx.shaped<32xf32, 1> +} + +// ----- + +// ---- migraphx.relu ---- + +func.func @relu_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<2x8xf32, 8x1> { + // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} + %0 = migraphx.relu %arg0 : <4x8xf32, 8x1> -> <2x8xf32, 8x1> + return %0 : !migraphx.shaped<2x8xf32, 8x1> +} + +// ----- + +// ---- migraphx.sigmoid ---- + +func.func @func_sigmoid_2d_i32(%arg0: !migraphx.shaped<4x8xi32, 8x1>) -> !migraphx.shaped<4x8xi32, 8x1> { + // expected-error @+1 {{only support floating point}} + %0 = migraphx.sigmoid %arg0 : <4x8xi32, 8x1> -> <4x8xi32, 8x1> + return %0 : !migraphx.shaped<4x8xi32, 8x1> +} + +// ----- + +// ---- migraphx.quant_dot ---- + // Test: Only scaleA provided (should fail - both scales required) func.func @quant_dot_only_scale_a( %arg0: !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1>, @@ -228,6 +330,8 @@ func.func @migraphx_quant_dot_f4_n_scales(%arg0: !migraphx.shaped<1x16x512xf4E2M // ----- +// ---- migraphx.dot ---- + // CHECK-LABEL: func.func @dot_rank_less_than_2 func.func @dot_rank_less_than_2(%arg0: !migraphx.shaped<320xf16, 1>, %arg1: !migraphx.shaped<320x64xf16, 64x1>) -> !migraphx.shaped<64xf16, 1> { // expected-error @+1 {{expect operand to have rank greater or equal to 2}} @@ -270,81 +374,3 @@ func.func @dot_result_shape_mismatch(%arg0: !migraphx.shaped<2x3x4xf16, 12x4x1>, %0 = migraphx.dot %arg0, %arg1 : <2x3x4xf16, 12x4x1>, <2x4x5xf16, 20x5x1> -> <2x3x4xf16, 12x4x1> return %0 : !migraphx.shaped<2x3x4xf16, 12x4x1> } - -// ----- - -// expected-error @+1 {{migraphx.shaped type has 1 elements in its shape but 2 strides defined}} -func.func @invalid_more_strides_than_shapes(%arg: !migraphx.shaped<1xf32, 1x1>) { - func.return -} - -// ----- - -// expected-error @+1 {{migraphx.shaped type has 2 elements in its shape but 1 strides defined}} -func.func @invalid_more_shapes_than_strides(%arg: !migraphx.shaped<1x1xf32, 1>) { - func.return -} - -// ----- - -func.func @func_sigmoid_2d_i32(%arg0: !migraphx.shaped<4x8xi32, 8x1>) -> !migraphx.shaped<4x8xi32, 8x1> { - // expected-error @+1 {{only support floating point}} - %0 = migraphx.sigmoid %arg0 : <4x8xi32, 8x1> -> <4x8xi32, 8x1> - return %0 : !migraphx.shaped<4x8xi32, 8x1> -} - -// ----- - -func.func @where_cond_not_bool(%arg0: !migraphx.shaped<4x4xf32, 4x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x4xf32, 4x1>) -> !migraphx.shaped<4x4xf32, 4x1> { - // expected-error @+1 {{'migraphx.where' op operand #0 must be !migraphx.shaped of 8-bit signless integer or 8-bit signed integer or 8-bit unsigned integer values, but got '!migraphx.shaped<4x4xf32, 4x1>'}} - %0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xf32, 4x1>, <4x4xf32, 4x1>, <4x4xf32, 4x1> -> <4x4xf32, 4x1> - return %0 : !migraphx.shaped<4x4xf32, 4x1> -} - -// ----- - -func.func @where_mismatched_types(%arg0: !migraphx.shaped<4x4xi8, 4x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x4xf16, 4x1>) -> !migraphx.shaped<4x4xf32, 4x1> { - // expected-error @+1 {{op failed to verify that all of {inA, inB, output} have same element type}} - %0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xi8, 4x1>, <4x4xf32, 4x1>, <4x4xf16, 4x1> -> <4x4xf32, 4x1> - return %0 : !migraphx.shaped<4x4xf32, 4x1> -} - -// ----- - -func.func @where_mismatched_shapes(%arg0: !migraphx.shaped<4x4xi8, 4x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { - // expected-error @+1 {{op failed to verify that all of {inA, inB, output, cond} have same shape}} - %0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xi8, 4x1>, <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> - return %0 : !migraphx.shaped<4x8xf32, 8x1> -} - -// ----- - -func.func @abs_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x4xf32, 4x1> { - // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} - %0 = migraphx.abs %arg0 : <4x8xf32, 8x1> -> <4x4xf32, 4x1> - return %0 : !migraphx.shaped<4x4xf32, 4x1> -} - -// ----- - -func.func @relu_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<2x8xf32, 8x1> { - // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} - %0 = migraphx.relu %arg0 : <4x8xf32, 8x1> -> <2x8xf32, 8x1> - return %0 : !migraphx.shaped<2x8xf32, 8x1> -} - -// ----- - -func.func @convert_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x4xf16, 4x1> { - // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} - %0 = migraphx.convert %arg0 : <4x8xf32, 8x1> to <4x4xf16, 4x1> - return %0 : !migraphx.shaped<4x4xf16, 4x1> -} - -// ----- - -func.func @exp_rank_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<32xf32, 1> { - // expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}} - %0 = migraphx.exp %arg0 : <4x8xf32, 8x1> -> <32xf32, 1> - return %0 : !migraphx.shaped<32xf32, 1> -} diff --git a/mlir/test/Dialect/MIGraphX/ops.mlir b/mlir/test/Dialect/MIGraphX/ops.mlir index adf50d8bf4ac..302a28b359f6 100644 --- a/mlir/test/Dialect/MIGraphX/ops.mlir +++ b/mlir/test/Dialect/MIGraphX/ops.mlir @@ -2,52 +2,7 @@ // RUN: rocmlir-opt %s | rocmlir-opt | FileCheck %s // RUN: rocmlir-opt -mlir-print-op-generic %s | rocmlir-opt | FileCheck %s -// CHECK-LABEL: func.func @migraphx_dot -// CHECK-NEXT: migraphx.dot -func.func @migraphx_dot(%arg0: !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1>, %arg1: !migraphx.shaped<1x512x16xf4E2M1FN, 8192x16x1>) -> !migraphx.shaped<1x16x16xf32, 256x16x1> { - %0 = migraphx.dot %arg0, %arg1 : <1x16x512xf4E2M1FN, 8192x512x1>, <1x512x16xf4E2M1FN, 8192x16x1> -> <1x16x16xf32, 256x16x1> - return %0 : !migraphx.shaped<1x16x16xf32, 256x16x1> -} - - - -// CHECK-LABEL: func.func @migraphx_quant_dot_scaled -// CHECK-NEXT: migraphx.quant_dot -func.func @migraphx_quant_dot_scaled(%arg0: !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1>, %arg1: !migraphx.shaped<1x512x16xf4E2M1FN, 8192x16x1>, %arg2: !migraphx.shaped<1x16x512xf8E8M0FNU, 8192x512x1>, %arg3: !migraphx.shaped<1x512x16xf8E8M0FNU, 8192x16x1>) -> !migraphx.shaped<1x16x16xf32, 256x16x1> { - %0 = migraphx.quant_dot - %arg0 scaled by %arg2, - %arg1 scaled by %arg3 - : !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1> scaled by - !migraphx.shaped<1x16x512xf8E8M0FNU, 8192x512x1>, - !migraphx.shaped<1x512x16xf4E2M1FN, 8192x16x1> scaled by - !migraphx.shaped<1x512x16xf8E8M0FNU, 8192x16x1> - -> !migraphx.shaped<1x16x16xf32, 256x16x1> - return %0 : !migraphx.shaped<1x16x16xf32, 256x16x1> -} - -// Checking to see if the verifier allows for broadcast -// CHECK-LABEL: func.func @migraphx_dot_no_batch_b -// CHECK-NEXT: migraphx.dot -func.func @migraphx_dot_no_batch_b(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<2x2xf16, 2x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> { - %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <2x2xf16, 2x1> -> <3x2x2x2xf16, 8x4x2x1> - return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> -} - -// CHECK-LABEL: func.func @migraphx_dot_leading_ones_b_rank3 -// CHECK-NEXT: migraphx.dot -func.func @migraphx_dot_leading_ones_b_rank3(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<1x2x2xf16, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> { - %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <1x2x2xf16, 4x2x1> -> <3x2x2x2xf16, 8x4x2x1> - return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> -} - -// CHECK-LABEL: func.func @migraphx_dot_leading_ones_b_rank4 -// CHECK-NEXT: migraphx.dot -func.func @migraphx_dot_leading_ones_b_rank4(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<1x1x2x2xf16, 4x2x1x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> { - %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <1x1x2x2xf16, 4x2x1x1> -> <3x2x2x2xf16, 8x4x2x1> - return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> -} - -// ---- Elementwise binary ops ---- +// ---- migraphx.add ---- // CHECK-LABEL: func.func @migraphx_add // CHECK-NEXT: migraphx.add @@ -56,6 +11,8 @@ func.func @migraphx_add(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx. return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.sub ---- + // CHECK-LABEL: func.func @migraphx_sub // CHECK-NEXT: migraphx.sub func.func @migraphx_sub(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -63,6 +20,8 @@ func.func @migraphx_sub(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx. return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.mul ---- + // CHECK-LABEL: func.func @migraphx_mul // CHECK-NEXT: migraphx.mul func.func @migraphx_mul(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -70,6 +29,8 @@ func.func @migraphx_mul(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx. return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.div ---- + // CHECK-LABEL: func.func @migraphx_div // CHECK-NEXT: migraphx.div func.func @migraphx_div(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -77,6 +38,8 @@ func.func @migraphx_div(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx. return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.pow ---- + // CHECK-LABEL: func.func @migraphx_pow // CHECK-NEXT: migraphx.pow func.func @migraphx_pow(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -84,6 +47,8 @@ func.func @migraphx_pow(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx. return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.greater ---- + // CHECK-LABEL: func.func @migraphx_greater // CHECK-NEXT: migraphx.greater func.func @migraphx_greater(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -91,6 +56,8 @@ func.func @migraphx_greater(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migra return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.equal ---- + // CHECK-LABEL: func.func @migraphx_equal // CHECK-NEXT: migraphx.equal func.func @migraphx_equal(%arg0: !migraphx.shaped<4x8xi8, 8x1>, %arg1: !migraphx.shaped<4x8xi8, 8x1>) -> !migraphx.shaped<4x8xi8, 8x1> { @@ -98,6 +65,8 @@ func.func @migraphx_equal(%arg0: !migraphx.shaped<4x8xi8, 8x1>, %arg1: !migraphx return %0 : !migraphx.shaped<4x8xi8, 8x1> } +// ---- migraphx.clip ---- + // CHECK-LABEL: func.func @migraphx_clip // CHECK-NEXT: migraphx.clip func.func @migraphx_clip(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -105,6 +74,8 @@ func.func @migraphx_clip(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.where ---- + // CHECK-LABEL: func.func @migraphx_where // CHECK-NEXT: migraphx.where func.func @migraphx_where(%cond: !migraphx.shaped<4x8xi8, 8x1>, %arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -112,7 +83,16 @@ func.func @migraphx_where(%cond: !migraphx.shaped<4x8xi8, 8x1>, %arg0: !migraphx return %0 : !migraphx.shaped<4x8xf32, 8x1> } -// ---- Elementwise unary ops ---- +// ---- migraphx.convert ---- + +// CHECK-LABEL: func.func @migraphx_convert +// CHECK-NEXT: migraphx.convert +func.func @migraphx_convert(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf16, 8x1> { + %0 = migraphx.convert %arg0 : <4x8xf32, 8x1> to <4x8xf16, 8x1> + return %0 : !migraphx.shaped<4x8xf16, 8x1> +} + +// ---- migraphx.abs ---- // CHECK-LABEL: func.func @migraphx_abs // CHECK-NEXT: migraphx.abs @@ -121,6 +101,8 @@ func.func @migraphx_abs(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shap return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.ceil ---- + // CHECK-LABEL: func.func @migraphx_ceil // CHECK-NEXT: migraphx.ceil func.func @migraphx_ceil(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -128,6 +110,8 @@ func.func @migraphx_ceil(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.sha return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.exp ---- + // CHECK-LABEL: func.func @migraphx_exp // CHECK-NEXT: migraphx.exp func.func @migraphx_exp(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -135,6 +119,8 @@ func.func @migraphx_exp(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shap return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.neg ---- + // CHECK-LABEL: func.func @migraphx_neg // CHECK-NEXT: migraphx.neg func.func @migraphx_neg(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -142,6 +128,8 @@ func.func @migraphx_neg(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shap return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.recip ---- + // CHECK-LABEL: func.func @migraphx_recip // CHECK-NEXT: migraphx.recip func.func @migraphx_recip(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -149,6 +137,8 @@ func.func @migraphx_recip(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.sh return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.relu ---- + // CHECK-LABEL: func.func @migraphx_relu // CHECK-NEXT: migraphx.relu func.func @migraphx_relu(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -156,6 +146,8 @@ func.func @migraphx_relu(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.sha return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.sigmoid ---- + // CHECK-LABEL: func.func @migraphx_sigmoid // CHECK-NEXT: migraphx.sigmoid func.func @migraphx_sigmoid(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -163,6 +155,8 @@ func.func @migraphx_sigmoid(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx. return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.sqrt ---- + // CHECK-LABEL: func.func @migraphx_sqrt // CHECK-NEXT: migraphx.sqrt func.func @migraphx_sqrt(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -170,6 +164,8 @@ func.func @migraphx_sqrt(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.sha return %0 : !migraphx.shaped<4x8xf32, 8x1> } +// ---- migraphx.tanh ---- + // CHECK-LABEL: func.func @migraphx_tanh // CHECK-NEXT: migraphx.tanh func.func @migraphx_tanh(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { @@ -177,9 +173,48 @@ func.func @migraphx_tanh(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.sha return %0 : !migraphx.shaped<4x8xf32, 8x1> } -// CHECK-LABEL: func.func @migraphx_convert -// CHECK-NEXT: migraphx.convert -func.func @migraphx_convert(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf16, 8x1> { - %0 = migraphx.convert %arg0 : <4x8xf32, 8x1> to <4x8xf16, 8x1> - return %0 : !migraphx.shaped<4x8xf16, 8x1> +// ---- migraphx.dot ---- + +// CHECK-LABEL: func.func @migraphx_dot +// CHECK-NEXT: migraphx.dot +func.func @migraphx_dot(%arg0: !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1>, %arg1: !migraphx.shaped<1x512x16xf4E2M1FN, 8192x16x1>) -> !migraphx.shaped<1x16x16xf32, 256x16x1> { + %0 = migraphx.dot %arg0, %arg1 : <1x16x512xf4E2M1FN, 8192x512x1>, <1x512x16xf4E2M1FN, 8192x16x1> -> <1x16x16xf32, 256x16x1> + return %0 : !migraphx.shaped<1x16x16xf32, 256x16x1> +} + +// CHECK-LABEL: func.func @migraphx_dot_no_batch_b +// CHECK-NEXT: migraphx.dot +func.func @migraphx_dot_no_batch_b(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<2x2xf16, 2x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> { + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <2x2xf16, 2x1> -> <3x2x2x2xf16, 8x4x2x1> + return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> +} + +// CHECK-LABEL: func.func @migraphx_dot_leading_ones_b_rank3 +// CHECK-NEXT: migraphx.dot +func.func @migraphx_dot_leading_ones_b_rank3(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<1x2x2xf16, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> { + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <1x2x2xf16, 4x2x1> -> <3x2x2x2xf16, 8x4x2x1> + return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> +} + +// CHECK-LABEL: func.func @migraphx_dot_leading_ones_b_rank4 +// CHECK-NEXT: migraphx.dot +func.func @migraphx_dot_leading_ones_b_rank4(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<1x1x2x2xf16, 4x2x1x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> { + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <1x1x2x2xf16, 4x2x1x1> -> <3x2x2x2xf16, 8x4x2x1> + return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> +} + +// ---- migraphx.quant_dot ---- + +// CHECK-LABEL: func.func @migraphx_quant_dot_scaled +// CHECK-NEXT: migraphx.quant_dot +func.func @migraphx_quant_dot_scaled(%arg0: !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1>, %arg1: !migraphx.shaped<1x512x16xf4E2M1FN, 8192x16x1>, %arg2: !migraphx.shaped<1x16x512xf8E8M0FNU, 8192x512x1>, %arg3: !migraphx.shaped<1x512x16xf8E8M0FNU, 8192x16x1>) -> !migraphx.shaped<1x16x16xf32, 256x16x1> { + %0 = migraphx.quant_dot + %arg0 scaled by %arg2, + %arg1 scaled by %arg3 + : !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1> scaled by + !migraphx.shaped<1x16x512xf8E8M0FNU, 8192x512x1>, + !migraphx.shaped<1x512x16xf4E2M1FN, 8192x16x1> scaled by + !migraphx.shaped<1x512x16xf8E8M0FNU, 8192x16x1> + -> !migraphx.shaped<1x16x16xf32, 256x16x1> + return %0 : !migraphx.shaped<1x16x16xf32, 256x16x1> } From db63eb4c2ec45f1f0496947777f07002dafb1bb0 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 8 Apr 2026 22:27:51 +0000 Subject: [PATCH 3/6] [NFC] Group MIGraphX dialect tests by op Reorganize ops.mlir and invalid.mlir so that all tests for a given op are contiguous. Each op now has a section header comment and tests are ordered following the TD file definition order. Previously the first split of invalid.mlir mixed reshape, equal, and quant_dot tests together, and sigmoid was separated from other unary op tests. Made-with: Cursor --- .../include/mlir/Dialect/MIGraphX/IR/MIGraphX.td | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td index 274dcc33b31b..e8c890f9ee23 100644 --- a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td +++ b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td @@ -112,7 +112,8 @@ def MIGraphX_ClipOp : MIGraphX_Op<"clip", [Elementwise]>, // Note: when lowering to kernel calls, MIGraphX represents booleans as i8. // Keep that logic here. def MIGraphX_WhereOp - : MIGraphX_Op<"where", [Elementwise, AllElementTypesMatch<["inA", "inB", "output"]>, + : MIGraphX_Op<"where", [Elementwise, + AllElementTypesMatch<["inA", "inB", "output"]>, AllShapesMatch<["inA", "inB", "output", "cond"]>]>, Arguments<(ins MIXRShapedOf<[I8, SI8, UI8]>:$cond, AnyMIXRShaped:$inA, AnyMIXRShaped:$inB)>, @@ -129,10 +130,10 @@ def MIGraphX_WhereOp // Elementwise unary operations -def MIGraphX_ConvertOp : MIGraphX_Op<"convert", - [Elementwise, AllShapesMatch<["inA", "output"]>]>, - Arguments<(ins AnyMIXRShaped:$inA)>, - Results<(outs AnyMIXRShaped:$output)> { +def MIGraphX_ConvertOp + : MIGraphX_Op<"convert", [Elementwise, AllShapesMatch<["inA", "output"]>]>, + Arguments<(ins AnyMIXRShaped:$inA)>, + Results<(outs AnyMIXRShaped:$output)> { let summary = "Elementwise type conversion"; let description = [{ Type conversion. Due to impedance mismatches between MIGraphX and Tosa, @@ -142,8 +143,9 @@ def MIGraphX_ConvertOp : MIGraphX_Op<"convert", } class MIGraphX_ElementwiseUnaryOp traits = []> - : MIGraphX_Op])>, + : MIGraphX_Op])>, Arguments<(ins AnyMIXRShaped:$inA)>, Results<(outs AnyMIXRShaped:$output)> { let summary = "Elementwise " # name; From 0e1d425c87f074d462cbd20e9ade8981e342a414 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 9 Apr 2026 12:39:17 +0000 Subject: [PATCH 4/6] address copilot comments --- .../mlir/Dialect/MIGraphX/IR/MIGraphX.td | 11 +++++++---- mlir/test/Dialect/MIGraphX/invalid.mlir | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td index e8c890f9ee23..a25792edd4d2 100644 --- a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td +++ b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td @@ -95,10 +95,13 @@ def MIGraphX_Equal : MIGraphX_ElementwiseBinaryOp<"equal"> { }]; } -def MIGraphX_ClipOp : MIGraphX_Op<"clip", [Elementwise]>, - Arguments<(ins AnyMIXRShaped:$x, AnyMIXRShaped:$minVals, - AnyMIXRShaped:$maxVals)>, - Results<(outs AnyMIXRShaped:$output)> { +def MIGraphX_ClipOp + : MIGraphX_Op<"clip", [Elementwise, + AllElementTypesMatch<["x", "minVals", "maxVals", "output"]>, + AllShapesMatch<["x", "minVals", "maxVals", "output"]>]>, + Arguments<(ins AnyMIXRShaped:$x, AnyMIXRShaped:$minVals, + AnyMIXRShaped:$maxVals)>, + Results<(outs AnyMIXRShaped:$output)> { let summary = "Elementwise clip"; let description = [{ Elementwise clip: output = min(max(x, minVals), maxVals) diff --git a/mlir/test/Dialect/MIGraphX/invalid.mlir b/mlir/test/Dialect/MIGraphX/invalid.mlir index bbf8f91ca760..ee00d94e2c43 100644 --- a/mlir/test/Dialect/MIGraphX/invalid.mlir +++ b/mlir/test/Dialect/MIGraphX/invalid.mlir @@ -74,6 +74,24 @@ func.func @func_equal(%arg0: !migraphx.shaped<1x36x384x64xi32, 884736x24576x64x1 // ----- +// ---- migraphx.clip ---- + +func.func @clip_mismatched_element_types(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf16, 8x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + // expected-error @+1 {{op failed to verify that all of {x, minVals, maxVals, output} have same element type}} + %0 = migraphx.clip %arg0, %arg1, %arg2 : <4x8xf32, 8x1>, <4x8xf16, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ----- + +func.func @clip_mismatched_shapes(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> { + // expected-error @+1 {{op failed to verify that all of {x, minVals, maxVals, output} have same shape}} + %0 = migraphx.clip %arg0, %arg1, %arg2 : <4x8xf32, 8x1>, <4x4xf32, 4x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1> + return %0 : !migraphx.shaped<4x8xf32, 8x1> +} + +// ----- + // ---- migraphx.where ---- func.func @where_cond_not_bool(%arg0: !migraphx.shaped<4x4xf32, 4x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x4xf32, 4x1>) -> !migraphx.shaped<4x4xf32, 4x1> { From a31028712dcf5a348dfca9c5e1b073a8d6cb8278 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 9 Apr 2026 13:25:00 +0000 Subject: [PATCH 5/6] Formatting --- mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td index a25792edd4d2..37abd0823f62 100644 --- a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td +++ b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td @@ -96,9 +96,10 @@ def MIGraphX_Equal : MIGraphX_ElementwiseBinaryOp<"equal"> { } def MIGraphX_ClipOp - : MIGraphX_Op<"clip", [Elementwise, - AllElementTypesMatch<["x", "minVals", "maxVals", "output"]>, - AllShapesMatch<["x", "minVals", "maxVals", "output"]>]>, + : MIGraphX_Op< + "clip", [Elementwise, + AllElementTypesMatch<["x", "minVals", "maxVals", "output"]>, + AllShapesMatch<["x", "minVals", "maxVals", "output"]>]>, Arguments<(ins AnyMIXRShaped:$x, AnyMIXRShaped:$minVals, AnyMIXRShaped:$maxVals)>, Results<(outs AnyMIXRShaped:$output)> { From 34bfdbb361666fb463fd5389e2e7c98c9884707c Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 13 Apr 2026 13:36:40 +0000 Subject: [PATCH 6/6] remove changes