Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Conversion/LinalgToRock/LinalgToRock.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void populateLinalgToRockConversionPattern(RewritePatternSet &pattern,

/// A tensor.insert_slice is said to be a rock.expand_strides
bool isRockExpandStride(tensor::InsertSliceOp op);
}
} // namespace rock
} // namespace mlir

#endif
38 changes: 22 additions & 16 deletions mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def MIGraphX_LiteralOp : MIGraphX_Op<"literal",
class MIGraphX_ElementwiseBinaryOp<string name, list<Trait> traits = []>
: MIGraphX_Op<
name, !listconcat(
traits, [AllElementTypesMatch<["inA", "inB", "output"]>])>,
traits, [Elementwise,
AllElementTypesMatch<["inA", "inB", "output"]>])>,
Arguments<(ins AnyMIXRShaped:$inA, AnyMIXRShaped:$inB)>,
Results<(outs AnyMIXRShaped:$output)> {
let summary = "Elementwise " # name # " of two shaped values with broadcast";
Expand Down Expand Up @@ -94,12 +95,14 @@ def MIGraphX_Equal : MIGraphX_ElementwiseBinaryOp<"equal"> {
}];
}

def MIGraphX_ClipOp :
MIGraphX_Op<"clip">,
Arguments<(ins AnyMIXRShaped:$x,
AnyMIXRShaped:$minVals,
AnyMIXRShaped:$maxVals)>,
Results<(outs AnyMIXRShaped:$output)> {
def MIGraphX_ClipOp
: MIGraphX_Op<
"clip", [Elementwise,
AllElementTypesMatch<["x", "minVals", "maxVals", "output"]>,
AllShapesMatch<["x", "minVals", "maxVals", "output"]>]>,
Arguments<(ins AnyMIXRShaped:$x, AnyMIXRShaped:$minVals,
AnyMIXRShaped:$maxVals)>,
Results<(outs AnyMIXRShaped:$output)> {
let summary = "Elementwise clip";
let description = [{
Elementwise clip: output = min(max(x, minVals), maxVals)
Expand All @@ -113,7 +116,8 @@ def MIGraphX_ClipOp :
// Note: when lowering to kernel calls, MIGraphX represents booleans as i8.
// Keep that logic here.
def MIGraphX_WhereOp
: MIGraphX_Op<"where", [AllElementTypesMatch<["inA", "inB", "output"]>,
: MIGraphX_Op<"where", [Elementwise,
AllElementTypesMatch<["inA", "inB", "output"]>,
AllShapesMatch<["inA", "inB", "output", "cond"]>]>,
Arguments<(ins MIXRShapedOf<[I8, SI8, UI8]>:$cond, AnyMIXRShaped:$inA,
AnyMIXRShaped:$inB)>,
Expand All @@ -130,10 +134,10 @@ def MIGraphX_WhereOp

// Elementwise unary operations

def MIGraphX_ConvertOp :
MIGraphX_Op<"convert">,
Arguments<(ins AnyMIXRShaped:$inA)>,
Results<(outs AnyMIXRShaped:$output)> {
def MIGraphX_ConvertOp
: MIGraphX_Op<"convert", [Elementwise, AllShapesMatch<["inA", "output"]>]>,
Arguments<(ins AnyMIXRShaped:$inA)>,
Results<(outs AnyMIXRShaped:$output)> {
let summary = "Elementwise type conversion";
let description = [{
Type conversion. Due to impedance mismatches between MIGraphX and Tosa,
Expand All @@ -142,10 +146,12 @@ def MIGraphX_ConvertOp :
let assemblyFormat = "$inA attr-dict `:` type($inA) `to` type($output)";
}

class MIGraphX_ElementwiseUnaryOp<string name, list<Trait> traits=[]> :
MIGraphX_Op<name, traits>,
Arguments<(ins AnyMIXRShaped:$inA)>,
Results<(outs AnyMIXRShaped:$output)> {
class MIGraphX_ElementwiseUnaryOp<string name, list<Trait> traits = []>
: MIGraphX_Op<name,
!listconcat(traits, [Elementwise,
AllShapesMatch<["inA", "output"]>])>,
Arguments<(ins AnyMIXRShaped:$inA)>,
Results<(outs AnyMIXRShaped:$output)> {
let summary = "Elementwise " # name;
let assemblyFormat = [{
$inA attr-dict `:` type($inA) `->` type($output)
Expand Down
15 changes: 7 additions & 8 deletions mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,25 +334,24 @@ LogicalResult ReshapeOp::verify() {
<< outType.getRank() << ")";

// Check that there is only a single -1 value
int missingDims = llvm::count_if(
dimsAttr.getAsRange<IntegerAttr>(),
[](IntegerAttr a) { return a.getInt() == -1; });
int missingDims =
llvm::count_if(dimsAttr.getAsRange<IntegerAttr>(),
[](IntegerAttr a) { return a.getInt() == -1; });
if (missingDims > 1)
return emitOpError("expected at most one target dimension to be -1");

// Check how many zero dimensions there are
int numZeros = llvm::count_if(
dimsAttr.getAsRange<IntegerAttr>(),
[](IntegerAttr a) { return a.getInt() == 0; });
int numZeros = llvm::count_if(dimsAttr.getAsRange<IntegerAttr>(),
[](IntegerAttr a) { return a.getInt() == 0; });

if (missingDims > 0 && numZeros > 0)
return emitOpError("Cannot mix missing dimensions with zero dimension");

// Compare dimension values to output shape
for (auto [dimVal, outDim] : llvm::zip(dimsAttr, outType.getShape())) {
int64_t dimValue = cast<IntegerAttr>(dimVal).getInt();
// We cannot handle negative dims values that aren't -1
if (dimValue < -1 ) {
// We cannot handle negative dims values that aren't -1
if (dimValue < -1) {
return emitOpError("Non -1 negative values are not supported");
}

Expand Down
168 changes: 123 additions & 45 deletions mlir/test/Dialect/MIGraphX/invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
// RUN: rocmlir-opt %s -split-input-file -verify-diagnostics

// ---- migraphx.shaped type ----

// expected-error @+1 {{migraphx.shaped type has 1 elements in its shape but 2 strides defined}}
func.func @invalid_more_strides_than_shapes(%arg: !migraphx.shaped<1xf32, 1x1>) {
func.return
}

// -----

// expected-error @+1 {{migraphx.shaped type has 2 elements in its shape but 1 strides defined}}
func.func @invalid_more_shapes_than_strides(%arg: !migraphx.shaped<1x1xf32, 1>) {
func.return
}

// -----

// ---- migraphx.reshape ----

func.func @mlir_reshape_inconsistent_dims(%arg0: !migraphx.shaped<4096x4096xf16, 0x1>) {
// expected-error@+1 {{'migraphx.reshape' op dimValue: 64 inconsistent with result dimension 4096}}
%0 = migraphx.reshape %arg0 {dims = [64, 128]} : <4096x4096xf16, 0x1> -> <4096x4096xf16, 16536x2>
Expand Down Expand Up @@ -42,6 +60,10 @@ func.func @mlir_neg_one_with_zero(%arg0: !migraphx.shaped<2x4xf16, 0x1>) {
return
}

// -----

// ---- migraphx.equal ----

func.func @func_equal(%arg0: !migraphx.shaped<1x36x384x64xi32, 884736x24576x64x1>) -> !migraphx.shaped<1x36x384x64xi32, 884736x24576x64x1> attributes{rock.kernel, rock.arch = ""} {
%cst = migraphx.literal (dense<1> : tensor<1x36x384x64xi32>) : <1x36x384x64xi32, 884736x24576x64x1>
%0 = migraphx.add %arg0, %cst : <1x36x384x64xi32, 884736x24576x64x1>, <1x36x384x64xi32, 884736x24576x64x1> -> <1x36x384x64xi32, 884736x24576x64x1>
Expand All @@ -50,6 +72,104 @@ func.func @func_equal(%arg0: !migraphx.shaped<1x36x384x64xi32, 884736x24576x64x1
return %1 : !migraphx.shaped<1x36x384x64xi16, 884736x24576x64x1>
}

// -----

// ---- migraphx.clip ----

func.func @clip_mismatched_element_types(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x8xf16, 8x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> {
// expected-error @+1 {{op failed to verify that all of {x, minVals, maxVals, output} have same element type}}
%0 = migraphx.clip %arg0, %arg1, %arg2 : <4x8xf32, 8x1>, <4x8xf16, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1>
return %0 : !migraphx.shaped<4x8xf32, 8x1>
}

// -----

func.func @clip_mismatched_shapes(%arg0: !migraphx.shaped<4x8xf32, 8x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> {
// expected-error @+1 {{op failed to verify that all of {x, minVals, maxVals, output} have same shape}}
%0 = migraphx.clip %arg0, %arg1, %arg2 : <4x8xf32, 8x1>, <4x4xf32, 4x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1>
return %0 : !migraphx.shaped<4x8xf32, 8x1>
}

// -----

// ---- migraphx.where ----

func.func @where_cond_not_bool(%arg0: !migraphx.shaped<4x4xf32, 4x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x4xf32, 4x1>) -> !migraphx.shaped<4x4xf32, 4x1> {
// expected-error @+1 {{'migraphx.where' op operand #0 must be !migraphx.shaped of 8-bit signless integer or 8-bit signed integer or 8-bit unsigned integer values, but got '!migraphx.shaped<4x4xf32, 4x1>'}}
%0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xf32, 4x1>, <4x4xf32, 4x1>, <4x4xf32, 4x1> -> <4x4xf32, 4x1>
return %0 : !migraphx.shaped<4x4xf32, 4x1>
}

// -----

func.func @where_mismatched_types(%arg0: !migraphx.shaped<4x4xi8, 4x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x4xf16, 4x1>) -> !migraphx.shaped<4x4xf32, 4x1> {
// expected-error @+1 {{op failed to verify that all of {inA, inB, output} have same element type}}
%0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xi8, 4x1>, <4x4xf32, 4x1>, <4x4xf16, 4x1> -> <4x4xf32, 4x1>
return %0 : !migraphx.shaped<4x4xf32, 4x1>
}

// -----

func.func @where_mismatched_shapes(%arg0: !migraphx.shaped<4x4xi8, 4x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> {
// expected-error @+1 {{op failed to verify that all of {inA, inB, output, cond} have same shape}}
%0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xi8, 4x1>, <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1>
return %0 : !migraphx.shaped<4x8xf32, 8x1>
}

// -----

// ---- migraphx.convert ----

func.func @convert_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x4xf16, 4x1> {
// expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}}
%0 = migraphx.convert %arg0 : <4x8xf32, 8x1> to <4x4xf16, 4x1>
return %0 : !migraphx.shaped<4x4xf16, 4x1>
}

// -----

// ---- migraphx.abs ----

func.func @abs_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x4xf32, 4x1> {
// expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}}
%0 = migraphx.abs %arg0 : <4x8xf32, 8x1> -> <4x4xf32, 4x1>
return %0 : !migraphx.shaped<4x4xf32, 4x1>
}

// -----

// ---- migraphx.exp ----

func.func @exp_rank_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<32xf32, 1> {
// expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}}
%0 = migraphx.exp %arg0 : <4x8xf32, 8x1> -> <32xf32, 1>
return %0 : !migraphx.shaped<32xf32, 1>
}

// -----

// ---- migraphx.relu ----

func.func @relu_shape_mismatch(%arg0: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<2x8xf32, 8x1> {
// expected-error @+1 {{op failed to verify that all of {inA, output} have same shape}}
%0 = migraphx.relu %arg0 : <4x8xf32, 8x1> -> <2x8xf32, 8x1>
return %0 : !migraphx.shaped<2x8xf32, 8x1>
}

// -----

// ---- migraphx.sigmoid ----

func.func @func_sigmoid_2d_i32(%arg0: !migraphx.shaped<4x8xi32, 8x1>) -> !migraphx.shaped<4x8xi32, 8x1> {
// expected-error @+1 {{only support floating point}}
%0 = migraphx.sigmoid %arg0 : <4x8xi32, 8x1> -> <4x8xi32, 8x1>
return %0 : !migraphx.shaped<4x8xi32, 8x1>
}

// -----

// ---- migraphx.quant_dot ----

// Test: Only scaleA provided (should fail - both scales required)
func.func @quant_dot_only_scale_a(
%arg0: !migraphx.shaped<1x16x512xf4E2M1FN, 8192x512x1>,
Expand Down Expand Up @@ -228,6 +348,8 @@ func.func @migraphx_quant_dot_f4_n_scales(%arg0: !migraphx.shaped<1x16x512xf4E2M

// -----

// ---- migraphx.dot ----

// CHECK-LABEL: func.func @dot_rank_less_than_2
func.func @dot_rank_less_than_2(%arg0: !migraphx.shaped<320xf16, 1>, %arg1: !migraphx.shaped<320x64xf16, 64x1>) -> !migraphx.shaped<64xf16, 1> {
// expected-error @+1 {{expect operand to have rank greater or equal to 2}}
Expand Down Expand Up @@ -273,51 +395,7 @@ func.func @dot_result_shape_mismatch(%arg0: !migraphx.shaped<2x3x4xf16, 12x4x1>,

// -----

// expected-error @+1 {{migraphx.shaped type has 1 elements in its shape but 2 strides defined}}
func.func @invalid_more_strides_than_shapes(%arg: !migraphx.shaped<1xf32, 1x1>) {
func.return
}

// -----

// expected-error @+1 {{migraphx.shaped type has 2 elements in its shape but 1 strides defined}}
func.func @invalid_more_shapes_than_strides(%arg: !migraphx.shaped<1x1xf32, 1>) {
func.return
}

// -----

func.func @func_sigmoid_2d_i32(%arg0: !migraphx.shaped<4x8xi32, 8x1>) -> !migraphx.shaped<4x8xi32, 8x1> {
// expected-error @+1 {{only support floating point}}
%0 = migraphx.sigmoid %arg0 : <4x8xi32, 8x1> -> <4x8xi32, 8x1>
return %0 : !migraphx.shaped<4x8xi32, 8x1>
}

// -----

func.func @where_cond_not_bool(%arg0: !migraphx.shaped<4x4xf32, 4x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x4xf32, 4x1>) -> !migraphx.shaped<4x4xf32, 4x1> {
// expected-error @+1 {{'migraphx.where' op operand #0 must be !migraphx.shaped of 8-bit signless integer or 8-bit signed integer or 8-bit unsigned integer values, but got '!migraphx.shaped<4x4xf32, 4x1>'}}
%0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xf32, 4x1>, <4x4xf32, 4x1>, <4x4xf32, 4x1> -> <4x4xf32, 4x1>
return %0 : !migraphx.shaped<4x4xf32, 4x1>
}

// -----

func.func @where_mismatched_types(%arg0: !migraphx.shaped<4x4xi8, 4x1>, %arg1: !migraphx.shaped<4x4xf32, 4x1>, %arg2: !migraphx.shaped<4x4xf16, 4x1>) -> !migraphx.shaped<4x4xf32, 4x1> {
// expected-error @+1 {{op failed to verify that all of {inA, inB, output} have same element type}}
%0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xi8, 4x1>, <4x4xf32, 4x1>, <4x4xf16, 4x1> -> <4x4xf32, 4x1>
return %0 : !migraphx.shaped<4x4xf32, 4x1>
}

// -----

func.func @where_mismatched_shapes(%arg0: !migraphx.shaped<4x4xi8, 4x1>, %arg1: !migraphx.shaped<4x8xf32, 8x1>, %arg2: !migraphx.shaped<4x8xf32, 8x1>) -> !migraphx.shaped<4x8xf32, 8x1> {
// expected-error @+1 {{op failed to verify that all of {inA, inB, output, cond} have same shape}}
%0 = migraphx.where %arg0, %arg1, %arg2 : <4x4xi8, 4x1>, <4x8xf32, 8x1>, <4x8xf32, 8x1> -> <4x8xf32, 8x1>
return %0 : !migraphx.shaped<4x8xf32, 8x1>
}

// -----
// ---- migraphx.slice ----

func.func @invalid_attr_size_mismatch(%input: !migraphx.shaped<10x10xf32, 10x1>) {
// expected-error @+1 {{op axes, starts, and ends must have the same size}}
Expand Down
Loading
Loading