diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td index d5a46b1b34312..c446dd67f37c5 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td @@ -67,6 +67,8 @@ def Tosa_AddShapeOp : Tosa_ElementwiseShapeOp<"add_shape", [Pure]> { ); let results = (outs Tosa_Shape:$output); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index c420a4c9596ff..2ea8481cb0a51 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -889,33 +889,157 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, // Operator Folders. //===----------------------------------------------------------------------===// -template -static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, - DenseElementsAttr rhs, - RankedTensorType returnTy) { - if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { - auto lETy = llvm::cast(lhs.getType()).getElementType(); - auto rETy = llvm::cast(rhs.getType()).getElementType(); - if (lETy != rETy) - return {}; +template +static DenseElementsAttr +binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, + bool foldDenseValues = false) { + if (!lhs || !rhs) + return {}; - if (llvm::isa(lETy)) { - APInt l = lhs.getSplatValue(); - APInt r = rhs.getSplatValue(); - auto result = IntFolder()(l, r); - return DenseElementsAttr::get(returnTy, result); + const auto lETy = llvm::cast(lhs.getType()).getElementType(); + const auto rETy = llvm::cast(rhs.getType()).getElementType(); + if (lETy != rETy) + return {}; + + if (lhs.isSplat() && rhs.isSplat()) { + if (isa(lETy)) { + const APFloat l = lhs.getSplatValue(); + const APFloat r = rhs.getSplatValue(); + const auto maybeResult = Folder::fold(l, r); + if (failed(maybeResult)) + return {}; + return DenseElementsAttr::get(returnTy, maybeResult.value()); } - if (llvm::isa(lETy)) { - APFloat l = lhs.getSplatValue(); - APFloat r = rhs.getSplatValue(); - auto result = FloatFolder()(l, r); - return DenseElementsAttr::get(returnTy, result); + if (const auto lIntTy = llvm::dyn_cast(lETy)) { + const APInt l = lhs.getSplatValue(); + const APInt r = rhs.getSplatValue(); + const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned()); + if (failed(maybeResult)) + return {}; + return DenseElementsAttr::get(returnTy, maybeResult.value()); } } + if (foldDenseValues) { + SmallVector resultValues; + for (auto [l, r] : + llvm::zip(lhs.getValues(), rhs.getValues())) { + const auto maybeResult = Folder::fold(l, r, false); + if (failed(maybeResult)) + return {}; + resultValues.push_back(maybeResult.value()); + } + return DenseElementsAttr::get(returnTy, resultValues); + } + + // Folding arbitrarily sized tensor operations is not supported return {}; } +struct AddFoldAdaptor { + static FailureOr fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + const unsigned originalWidth = lhs.getBitWidth(); + + APInt lhs64, rhs64; + if (isUnsigned) { + lhs64 = lhs.zext(64); + rhs64 = rhs.zext(64); + + // Check for overflow + const APInt max = APInt::getMaxValue(originalWidth).zext(64); + if (lhs64.ugt(max - rhs64)) + return failure(); + } else { + lhs64 = lhs.sext(64); + rhs64 = rhs.sext(64); + + // Check for overflow + const APInt zero = APInt::getZero(64); + const APInt max = APInt::getSignedMaxValue(originalWidth).sext(64); + const APInt min = APInt::getSignedMinValue(originalWidth).sext(64); + if ((rhs64.sgt(zero) && lhs64.sgt(max - rhs64)) || + (rhs64.slt(zero) && lhs64.slt(min - rhs64))) + return failure(); + } + + const APInt result64 = lhs64 + rhs64; + return result64.trunc(originalWidth); + } + + static FailureOr fold(const APFloat &lhs, const APFloat &rhs) { + return lhs + rhs; + } +}; + +struct SubFoldAdaptor { + static FailureOr fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + const unsigned originalWidth = lhs.getBitWidth(); + + APInt lhs64, rhs64; + if (isUnsigned) { + lhs64 = lhs.zext(64); + rhs64 = rhs.zext(64); + + // Check for overflow + const APInt max = APInt::getMaxValue(originalWidth).zext(64); + if (lhs64.ult(rhs64)) + return failure(); + } else { + lhs64 = lhs.sext(64); + rhs64 = rhs.sext(64); + + // Check for overflow + const APInt zero = APInt::getZero(64); + const APInt max = APInt::getSignedMaxValue(originalWidth).sext(64); + const APInt min = APInt::getSignedMinValue(originalWidth).sext(64); + if ((rhs64.sgt(zero) && lhs64.slt(min + rhs64)) || + (rhs64.slt(zero) && lhs64.sgt(max + rhs64))) + return failure(); + } + + const APInt result64 = lhs64 - rhs64; + return result64.trunc(originalWidth); + } + + static FailureOr fold(const APFloat &lhs, const APFloat &rhs) { + return lhs - rhs; + } +}; + +struct FoldGreaterAdaptor { + static FailureOr fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs)); + } + + static FailureOr fold(const APFloat &lhs, const APFloat &rhs) { + return APInt(1, lhs > rhs); + } +}; + +struct FoldGreaterEqualAdaptor { + static FailureOr fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs)); + } + + static FailureOr fold(const APFloat &lhs, const APFloat &rhs) { + return APInt(1, lhs >= rhs); + } +}; + +struct FoldEqualAdaptor { + static FailureOr fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + return APInt(1, lhs == rhs); + } + + static FailureOr fold(const APFloat &lhs, const APFloat &rhs) { + return APInt(1, lhs == rhs); + } +}; static bool isSplatZero(Type elemType, DenseElementsAttr val) { if (llvm::isa(elemType)) @@ -963,8 +1087,7 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder, std::plus>(lhsAttr, rhsAttr, - resultTy); + return binaryFolder(lhsAttr, rhsAttr, resultTy); } OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) { @@ -1145,38 +1268,9 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder, std::minus>(lhsAttr, rhsAttr, - resultTy); + return binaryFolder(lhsAttr, rhsAttr, resultTy); } -namespace { -template -struct ComparisonFold { - ComparisonFold() = default; - APInt operator()(const APInt &l, const APInt &r) { - return APInt(1, Cmp()(l, r)); - } - - APInt operator()(const APFloat &l, const APFloat &r) { - return APInt(1, Cmp()(l, r)); - } -}; - -struct APIntFoldGreater { - APIntFoldGreater() = default; - APInt operator()(const APInt &l, const APInt &r) { - return APInt(1, l.sgt(r)); - } -}; - -struct APIntFoldGreaterEqual { - APIntFoldGreaterEqual() = default; - APInt operator()(const APInt &l, const APInt &r) { - return APInt(1, l.sge(r)); - } -}; -} // namespace - OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::dyn_cast(getType()); auto lhsAttr = @@ -1187,8 +1281,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder>>( - lhsAttr, rhsAttr, resultTy); + return binaryFolder(lhsAttr, rhsAttr, resultTy); } OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { @@ -1201,9 +1294,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder>>( - lhsAttr, rhsAttr, resultTy); + return binaryFolder(lhsAttr, rhsAttr, resultTy); } OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { @@ -1226,9 +1317,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder>, - ComparisonFold>>(lhsAttr, rhsAttr, - resultTy); + return binaryFolder(lhsAttr, rhsAttr, resultTy); } OpFoldResult CastOp::fold(FoldAdaptor adaptor) { @@ -1650,3 +1739,18 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) { return {}; } + +OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) { + auto input1ConstShape = + dyn_cast(getInput1().getDefiningOp()); + auto input2ConstShape = + dyn_cast(getInput2().getDefiningOp()); + if (!input1ConstShape || !input2ConstShape) + return {}; + + const auto input1Attr = cast(input1ConstShape.getValues()); + const auto input2Attr = cast(input2ConstShape.getValues()); + + return binaryFolder( + input1Attr, input2Attr, input1Attr.getType(), /*foldDenseValues=*/true); +} diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir index d477a2479e913..4860c98d960e2 100644 --- a/mlir/test/Dialect/Tosa/constant_folding.mlir +++ b/mlir/test/Dialect/Tosa/constant_folding.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --test-single-fold %s | FileCheck %s +// RUN: mlir-opt --split-input-file --test-single-fold %s | FileCheck %s // CHECK-LABEL: func @test_const func.func @test_const(%arg0 : index) -> tensor<4xi32> { @@ -7,6 +7,8 @@ func.func @test_const(%arg0 : index) -> tensor<4xi32> { return %0 : tensor<4xi32> } +// ----- + // CHECK-LABEL: func @test_const_i64 func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> { // CHECK: tosa.const @@ -14,6 +16,8 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> { return %0 : tensor<4xi64> } +// ----- + // CHECK-LABEL: func @try_fold_equal_with_unranked_tensor func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) { // CHECK: tosa.equal @@ -21,3 +25,661 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> return } + +// ----- + +// CHECK-LABEL: @fold_add_zero_rhs_f32 +func.func @fold_add_zero_rhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor + %add = tosa.add %arg0, %zero : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %add : tensor +} + +// ----- + +// CHECK-LABEL: @fold_add_zero_lhs_f32 +func.func @fold_add_zero_lhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor + %add = tosa.add %zero, %arg0 : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %add : tensor +} + +// ----- + +// CHECK-LABEL: @fold_add_zero_rhs_i32 +func.func @fold_add_zero_rhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + %add = tosa.add %arg0, %zero : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %add : tensor +} + +// ----- + +// CHECK-LABEL: @fold_add_zero_lhs_i32 +func.func @fold_add_zero_lhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + %add = tosa.add %zero, %arg0 : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %add : tensor +} + +// ----- + +// CHECK-LABEL: @fold_add_splat_i32 +func.func @fold_add_splat_i32() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {values = dense<2> : tensor<10xi32>} : () -> tensor<10xi32> + %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<3> : tensor<10xi32>} + // CHECK: return %[[THREE]] + return %add : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_add_splat_f32 +func.func @fold_add_splat_f32() -> tensor<10xf32> { + %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %add = tosa.add %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<3.000000e+00> + // CHECK: return %[[THREE]] + return %add : tensor<10xf32> +} + +// ----- + +// CHECK-LABEL: @fold_add_splat_i32_positive_overflow +func.func @fold_add_splat_i32_positive_overflow() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<2147483647> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> + // CHECK: tosa.add + %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + return %add : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_add_splat_i32_negative_overflow +func.func @fold_add_splat_i32_negative_overflow() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<-1> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {values = dense<-2147483648> : tensor<10xi32>} : () -> tensor<10xi32> + // CHECK: tosa.add + %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + return %add : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_add_splat_ui8 +func.func @fold_add_splat_ui8() -> tensor<10xui8> { + %one = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8> + %two = "tosa.const"() {values = dense<254> : tensor<10xui8>} : () -> tensor<10xui8> + // CHECK: "tosa.const"() <{values = dense<255> : tensor<10xui8>}> : () -> tensor<10xui8> + %add = tosa.add %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8> + return %add : tensor<10xui8> +} + +// ----- + +// CHECK-LABEL: @fold_add_splat_ui8_overflow +func.func @fold_add_splat_ui8_overflow() -> tensor<10xui8> { + %one = "tosa.const"() {values = dense<2> : tensor<10xui8>} : () -> tensor<10xui8> + %two = "tosa.const"() {values = dense<254> : tensor<10xui8>} : () -> tensor<10xui8> + // CHECK: tosa.add + %add = tosa.add %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8> + return %add : tensor<10xui8> +} + +// ----- + +// CHECK-LABEL: @fold_div_zero_lhs_i32 +func.func @fold_div_zero_lhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> + %div = tosa.intdiv %zero, %arg0 : (tensor, tensor) -> tensor + // CHECK: return %[[ZERO]] + return %div : tensor +} + +// ----- + +// CHECK-LABEL: @fold_div_one_rhs_i32 +func.func @fold_div_one_rhs_i32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %div = tosa.intdiv %arg0, %one : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %div : tensor +} + +// ----- + +// CHECK-LABEL: @fold_div_splat_i32 +func.func @fold_div_splat_i32() -> tensor { + %lhs = "tosa.const"() {values = dense<10> : tensor} : () -> tensor + %rhs = "tosa.const"() {values = dense<-3> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-3> + %div = tosa.intdiv %lhs, %rhs : (tensor, tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %div : tensor +} + +// ----- + + +// CHECK-LABEL: @fold_mul_zero_rhs_f32 +func.func @fold_mul_zero_rhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor + // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0.000000e+00> + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %arg0, %zero, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %[[ZERO]] + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_zero_lhs_f32 +func.func @fold_mul_zero_lhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor + // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0.000000e+00> + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %zero, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %[[ZERO]] + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_zero_rhs_i32 +func.func @fold_mul_zero_rhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> + %mul = tosa.mul %arg0, %zero, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %[[ZERO]] + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_zero_lhs_i32 +func.func @fold_mul_zero_lhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> + %mul = tosa.mul %zero, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %[[ZERO]] + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_one_rhs_f32 +func.func @fold_mul_one_rhs_f32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {values = dense<1.0> : tensor} : () -> tensor + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %arg0, %one, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %arg0 + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_one_lhs_f32 +func.func @fold_mul_one_lhs_f32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {values = dense<1.0> : tensor} : () -> tensor + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %one, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %arg0 + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_one_rhs_i32 +func.func @fold_mul_one_rhs_i32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {values = dense<64> : tensor} : () -> tensor + %shift = "tosa.const"() {values = dense<6> : tensor<1xi8>} : () -> tensor<1xi8> + %mul = tosa.mul %arg0, %one, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %arg0 + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_one_lhs_i32 +func.func @fold_mul_one_lhs_i32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {values = dense<64> : tensor} : () -> tensor + %shift = "tosa.const"() {values = dense<6> : tensor<1xi8>} : () -> tensor<1xi8> + %mul = tosa.mul %one, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %arg0 + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_splat_i8 +func.func @fold_mul_splat_i8() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<17> : tensor<10xi8>} : () -> tensor<10xi8> + %two = "tosa.const"() {values = dense<32> : tensor<10xi8>} : () -> tensor<10xi8> + %shift = "tosa.const"() {values = dense<3> : tensor<1xi8>} : () -> tensor<1xi8> + %mul = tosa.mul %one, %two, %shift : (tensor<10xi8>, tensor<10xi8>, tensor<1xi8>) -> tensor<10xi32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<68> : tensor<10xi32>} + // CHECK: return %[[THREE]] + return %mul : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_mul_splat_f32 +func.func @fold_mul_splat_f32() -> tensor<10xf32> { + %one = "tosa.const"() {values = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32> + %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<6.000000e+00> : tensor<10xf32>} + // CHECK: return %[[THREE]] + return %mul : tensor<10xf32> +} + +// ----- + +// CHECK-LABEL: @fold_sub_zero_rhs_f32 +func.func @fold_sub_zero_rhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor + %sub = tosa.sub %arg0, %zero : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %sub : tensor +} + +// ----- + +// CHECK-LABEL: @fold_sub_zero_rhs_i32 +func.func @fold_sub_zero_rhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + %sub = tosa.sub %arg0, %zero : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %sub : tensor +} + +// ----- + +// CHECK-LABEL: @fold_sub_splat_i32 +func.func @fold_sub_splat_i32() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {values = dense<2> : tensor<10xi32>} : () -> tensor<10xi32> + %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<-1> : tensor<10xi32>} + // CHECK: return %[[THREE]] + return %sub : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_sub_splat_f32 +func.func @fold_sub_splat_f32() -> tensor<10xf32> { + %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %sub = tosa.sub %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<-1.000000e+00> : tensor<10xf32>} + // CHECK: return %[[THREE]] + return %sub : tensor<10xf32> +} + +// ----- + +// CHECK-LABEL: @fold_sub_splat_i32_positive_overflow +func.func @fold_sub_splat_i32_positive_overflow() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<2147483647> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {values = dense<-1> : tensor<10xi32>} : () -> tensor<10xi32> + // CHECK: tosa.sub + %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + return %sub : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_sub_splat_i32_negative_overflow +func.func @fold_sub_splat_i32_negative_overflow() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<-2147483648> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> + // CHECK: tosa.sub + %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + return %sub : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_sub_splat_ui8 +func.func @fold_sub_splat_ui8() -> tensor<10xui8> { + %one = "tosa.const"() {values = dense<255> : tensor<10xui8>} : () -> tensor<10xui8> + %two = "tosa.const"() {values = dense<253> : tensor<10xui8>} : () -> tensor<10xui8> + // CHECK: "tosa.const"() <{values = dense<2> : tensor<10xui8>}> : () -> tensor<10xui8> + %sub = tosa.sub %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8> + return %sub : tensor<10xui8> +} + +// ----- + +// CHECK-LABEL: @fold_sub_splat_ui8_overflow +func.func @fold_sub_splat_ui8_overflow() -> tensor<10xui8> { + %one = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8> + %two = "tosa.const"() {values = dense<253> : tensor<10xui8>} : () -> tensor<10xui8> + // CHECK: tosa.sub + %sub = tosa.sub %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8> + return %sub : tensor<10xui8> +} + +// ----- + +// CHECK-LABEL: @fold_greater_splat_f32 +func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %1 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %true = tosa.greater %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + %false = tosa.greater %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_splat_i32 +func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> + %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %3 = "tosa.const"() {values = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32> + %false = tosa.greater %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + %true = tosa.greater %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[FALSE]], %[[TRUE]] + return %false, %true : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_splat_ui8 +func.func @fold_greater_splat_ui8() -> (tensor<10xi1>, tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8> + %1 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8> + %2 = "tosa.const"() {values = dense<246> : tensor<10xui8>} : () -> tensor<10xui8> + %3 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8> + %true = tosa.greater %2, %3 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1> + %false = tosa.greater %0, %1 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1> + %false2 = tosa.greater %0, %2 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]], %[[FALSE]] + return %true, %false, %false2 : tensor<10xi1>, tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_eq_splat_f32 +func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %1 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %true = tosa.greater_equal %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + %false = tosa.greater_equal %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_eq_splat_i32 +func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> + %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %3 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %true = tosa.greater_equal %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + %false = tosa.greater_equal %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_eq_splat_ui8 +func.func @fold_greater_eq_splat_ui8() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8> + %1 = "tosa.const"() {values = dense<255> : tensor<10xui8>} : () -> tensor<10xui8> + %2 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8> + %3 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8> + %true = tosa.greater_equal %2, %3 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1> + %false = tosa.greater_equal %0, %1 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_eq_splat_f32 +func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %1 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %true = tosa.equal %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + %false = tosa.equal %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_eq_splat_i32 +func.func @fold_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> + %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %3 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %true = tosa.equal %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + %false = tosa.equal %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_eq_i32 +func.func @fold_eq_i32(%arg0 : tensor<10xi32>) -> (tensor<10xi1>) { + // CHECK: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + %0 = tosa.equal %arg0, %arg0 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK: return %[[TRUE]] + return %0 : tensor<10xi1> +} + +// ----- + +func.func @reshape_splat() -> tensor<6x5x4xi32> { + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<42> : tensor<6x5x4xi32>} + %splat = "tosa.const"() {values = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32> + %const = tosa.const_shape {values = dense<[6, 5, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> + %reshape = tosa.reshape %splat, %const : (tensor<4x5x6xi32>, !tosa.shape<3>) -> tensor<6x5x4xi32> + // CHECK: return %[[SPLAT]] + return %reshape : tensor<6x5x4xi32> +} + +// ----- + +// CHECK-LABEL: @slice_splat +func.func @slice_splat() -> tensor<1x1x1xi32> { + // CHECK: %[[SLICE:.+]] = "tosa.const"() <{values = dense<42> : tensor<1x1x1xi32>} + %splat = "tosa.const"() {values = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32> + %start = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + %size = tosa.const_shape {values = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %slice= tosa.slice %splat, %start, %size : (tensor<4x5x6xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x1x1xi32> + + // CHECK: return %[[SLICE]] + return %slice : tensor<1x1x1xi32> +} + +// ----- + +// CHECK-LABEL: @slice_singleton +func.func @slice_singleton() -> tensor<1x1xi32> { + %splat = "tosa.const"() {values = dense<[[0, 1, 2], [3, 4, 5], [6, 7 ,8]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32> + // CHECK: %[[SLICE:.+]] = "tosa.const"() <{values = dense<4> : tensor<1x1xi32>} + %start = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %size = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %slice= tosa.slice %splat, %start, %size : (tensor<3x3xi32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1xi32> + // CHECK: return %[[SLICE]] + return %slice : tensor<1x1xi32> +} + +// ----- + +// CHECK: func.func @cast_float_to_float +func.func @cast_float_to_float() -> tensor { + %splat = "tosa.const"() {values = dense<42.0> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<4.200000e+01> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_int_to_float +func.func @cast_int_to_float() -> tensor { + %splat = "tosa.const"() {values = dense<4> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<4.000000e+00> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_float_to_int +func.func @cast_float_to_int() -> tensor { + %splat = "tosa.const"() {values = dense<-4.0> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-4> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_float_to_int_round +func.func @cast_float_to_int_round() -> tensor { + %splat = "tosa.const"() {values = dense<-3.5> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-4> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_int_to_int_trunc +func.func @cast_int_to_int_trunc() -> tensor { + %splat = "tosa.const"() {values = dense<-1> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-1> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_int_to_int_sign +func.func @cast_int_to_int_sign() -> tensor { + %splat = "tosa.const"() {values = dense<-1> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-1> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK-LABEL: @reverse_splat +func.func @reverse_splat() -> tensor<10xi32> { + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<42> : tensor<10xi32>} + %splat = "tosa.const"() {values = dense<42> : tensor<10xi32>} : () -> tensor<10xi32> + %reverse = tosa.reverse %splat { axis = 0 : i32 } : (tensor<10xi32>) -> tensor<10xi32> + // CHECK: return %[[SPLAT]] + return %reverse : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @reverse_length_one +func.func @reverse_length_one(%arg0 : tensor<10x1xi32>) -> (tensor<10x1xi32>, tensor<10x1xi32>) { + %nofold = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32> + %fold = tosa.reverse %arg0 { axis = 1 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32> + // CHECK: %[[NOFOLD:.+]] = tosa.reverse %arg0 {axis = 0 : i32} + // CHECK: return %[[NOFOLD]], %arg0 + return %nofold, %fold : tensor<10x1xi32>, tensor<10x1xi32> +} + +// ----- + +// no_shift_op_reorder checks that %arg1 won't be reorder with %0 +// by the folder pass. +// CHECK-LABEL: @no_shift_op_reorder +func.func @no_shift_op_reorder (%arg0 : tensor<44x1xi16>, %arg1 : tensor<1xi8>) -> tensor<44x57xi32> { + %0 = "tosa.const"() {values = dense<1> : tensor<44x57xi16>} : () -> tensor<44x57xi16> + // CHECK: tosa.mul %arg0, %0, %arg1 + %1 = tosa.mul %arg0, %0, %arg1 : (tensor<44x1xi16>, tensor<44x57xi16>, tensor<1xi8>) -> tensor<44x57xi32> + return %1 : tensor<44x57xi32> +} + +// ----- + +// CHECK-LABEL: @test_fold_add_shape +// CHECK: tosa.const_shape {values = dense<[2, 4, 6, 8, 10, 12]> : tensor<6xindex>} : () -> !tosa.shape<6> +func.func @test_fold_add_shape() -> !tosa.shape<6> { + %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6> + %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6> + %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6> + return %c : !tosa.shape<6> +} + +// ----- + +// CHECK-LABEL: @test_no_fold_add_shape_positive_overflow +// CHECK: tosa.add_shape +func.func @test_no_fold_add_shape_positive_overflow() -> !tosa.shape<6> { + %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 9223372036854775807]> : tensor<6xindex>} : () -> !tosa.shape<6> + %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 1]> : tensor<6xindex>} : () -> !tosa.shape<6> + %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6> + return %c : !tosa.shape<6> +} + +// ----- + +// CHECK-LABEL: @test_no_fold_add_shape_negative_overflow +// CHECK: tosa.add_shape +func.func @test_no_fold_add_shape_negative_overflow() -> !tosa.shape<6> { + %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, -9223372036854775808]> : tensor<6xindex>} : () -> !tosa.shape<6> + %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, -1]> : tensor<6xindex>} : () -> !tosa.shape<6> + %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6> + return %c : !tosa.shape<6> +} diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir similarity index 63% rename from mlir/test/Dialect/Tosa/constant-op-fold.mlir rename to mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir index b1fbcdcc53e2f..d95d267e8c907 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir @@ -1,6 +1,4 @@ // RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s - - // RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="aggressive-reduce-constant=true" %s | FileCheck %s --check-prefix=AGGRESIVE // CHECK-LABEL: @armax_fold_dim_size_1 @@ -10,6 +8,8 @@ func.func @armax_fold_dim_size_1(%arg0: tensor<2x1x3xf32>) -> tensor<2x3xi32> { return %0 : tensor<2x3xi32> } +// ----- + // CHECK-LABEL: @argmax_dynamic_shape_no_fold_dim_size_1 func.func @argmax_dynamic_shape_no_fold_dim_size_1(%arg0: tensor) -> tensor { // CHECK: tosa.argmax @@ -17,6 +17,8 @@ func.func @argmax_dynamic_shape_no_fold_dim_size_1(%arg0: tensor) -> return %0 : tensor } +// ----- + // CHECK-LABEL: @transpose_fold func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: return %arg0 @@ -24,6 +26,8 @@ func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { return %1 : tensor<3x4xf32> } +// ----- + // CHECK-LABEL: @transpose_nofold func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { // CHECK: tosa.transpose @@ -31,6 +35,8 @@ func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { return %1 : tensor<3x3xf32> } +// ----- + // CHECK-LABEL: @transpose_nofold_shape func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor { // CHECK: tosa.transpose @@ -38,6 +44,8 @@ func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor { return %1 : tensor } +// ----- + // CHECK-LABEL: @transpose_fold_splat func.func @transpose_fold_splat() -> tensor<3x2xf32> { %input = "tosa.const"() {values = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32> @@ -48,6 +56,8 @@ func.func @transpose_fold_splat() -> tensor<3x2xf32> { return %1 : tensor<3x2xf32> } +// ----- + // CHECK-LABEL: @transpose_fold_2d_float func.func @transpose_fold_2d_float() -> tensor<3x2xf32> { %input = "tosa.const"() {values = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> @@ -58,6 +68,8 @@ func.func @transpose_fold_2d_float() -> tensor<3x2xf32> { return %1 : tensor<3x2xf32> } +// ----- + // CHECK-LABEL: @transpose_fold_2d_bool func.func @transpose_fold_2d_bool() -> tensor<3x2xi1> { %input = "tosa.const"() {values = dense<[[true, false, false], [false, false, true]]> : tensor<2x3xi1>} : () -> tensor<2x3xi1> @@ -68,6 +80,8 @@ func.func @transpose_fold_2d_bool() -> tensor<3x2xi1> { return %1 : tensor<3x2xi1> } +// ----- + // CHECK-LABEL: @transpose_fold_4d_int func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> { %input = "tosa.const"() {values = dense<[[ @@ -85,6 +99,8 @@ func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> { return %1 : tensor<3x1x4x2xi32> } +// ----- + // CHECK-LABEL: @transpose_nofold_non_cst_input func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> { // CHECK: tosa.transpose @@ -92,6 +108,8 @@ func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2 return %1 : tensor<3x2xf32> } +// ----- + // CHECK-LABEL: @transpose_nofold_multi_users func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) { %input = "tosa.const"() {values = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> @@ -100,6 +118,8 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) return %1, %input : tensor<3x2xf32>, tensor<2x3xf32> } +// ----- + // CHECK-LABEL: @transpose_nofold_quantized_types func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01}>> { %input = "tosa.const"() {values = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01}>> @@ -108,6 +128,8 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01}>> } +// ----- + // CHECK-LABEL: @transpose_fold_dense_resource func.func @transpose_fold_dense_resource() -> tensor<2x2xf32> { %0 = "tosa.const"() <{values = dense_resource : tensor<2x2xf32>}> : () -> tensor<2x2xf32> @@ -124,498 +146,6 @@ func.func @transpose_fold_dense_resource() -> tensor<2x2xf32> { } #-} -// ----- - -// CHECK-LABEL: @fold_add_zero_rhs_f32 -func.func @fold_add_zero_rhs_f32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor - %add = tosa.add %arg0, %zero : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %add : tensor -} - -// ----- - -// CHECK-LABEL: @fold_add_zero_lhs_f32 -func.func @fold_add_zero_lhs_f32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor - %add = tosa.add %zero, %arg0 : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %add : tensor -} - -// ----- - -// CHECK-LABEL: @fold_add_zero_rhs_i32 -func.func @fold_add_zero_rhs_i32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor - %add = tosa.add %arg0, %zero : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %add : tensor -} - -// ----- - -// CHECK-LABEL: @fold_add_zero_lhs_i32 -func.func @fold_add_zero_lhs_i32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor - %add = tosa.add %zero, %arg0 : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %add : tensor -} - -// ----- - -// CHECK-LABEL: @fold_add_splat_i32 -func.func @fold_add_splat_i32() -> tensor<10xi32> { - %one = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> - %two = "tosa.const"() {values = dense<2> : tensor<10xi32>} : () -> tensor<10xi32> - %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<3> : tensor<10xi32>} - // CHECK: return %[[THREE]] - return %add : tensor<10xi32> -} - -// ----- - -// CHECK-LABEL: @fold_add_splat_f32 -func.func @fold_add_splat_f32() -> tensor<10xf32> { - %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> - %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %add = tosa.add %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<3.000000e+00> - // CHECK: return %[[THREE]] - return %add : tensor<10xf32> -} - -// ----- - -// CHECK-LABEL: @fold_div_zero_lhs_i32 -func.func @fold_div_zero_lhs_i32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor - // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> - %div = tosa.intdiv %zero, %arg0 : (tensor, tensor) -> tensor - // CHECK: return %[[ZERO]] - return %div : tensor -} - -// ----- - -// CHECK-LABEL: @fold_div_one_rhs_i32 -func.func @fold_div_one_rhs_i32(%arg0: tensor) -> tensor { - %one = "tosa.const"() {values = dense<1> : tensor} : () -> tensor - %div = tosa.intdiv %arg0, %one : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %div : tensor -} - -// ----- - -// CHECK-LABEL: @fold_div_splat_i32 -func.func @fold_div_splat_i32() -> tensor { - %lhs = "tosa.const"() {values = dense<10> : tensor} : () -> tensor - %rhs = "tosa.const"() {values = dense<-3> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-3> - %div = tosa.intdiv %lhs, %rhs : (tensor, tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %div : tensor -} - -// ----- - - -// CHECK-LABEL: @fold_mul_zero_rhs_f32 -func.func @fold_mul_zero_rhs_f32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor - // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0.000000e+00> - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - %mul = tosa.mul %arg0, %zero, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %[[ZERO]] - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_zero_lhs_f32 -func.func @fold_mul_zero_lhs_f32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor - // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0.000000e+00> - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - %mul = tosa.mul %zero, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %[[ZERO]] - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_zero_rhs_i32 -func.func @fold_mul_zero_rhs_i32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> - %mul = tosa.mul %arg0, %zero, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %[[ZERO]] - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_zero_lhs_i32 -func.func @fold_mul_zero_lhs_i32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> - %mul = tosa.mul %zero, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %[[ZERO]] - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_one_rhs_f32 -func.func @fold_mul_one_rhs_f32(%arg0: tensor) -> tensor { - %one = "tosa.const"() {values = dense<1.0> : tensor} : () -> tensor - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - %mul = tosa.mul %arg0, %one, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %arg0 - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_one_lhs_f32 -func.func @fold_mul_one_lhs_f32(%arg0: tensor) -> tensor { - %one = "tosa.const"() {values = dense<1.0> : tensor} : () -> tensor - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - %mul = tosa.mul %one, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %arg0 - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_one_rhs_i32 -func.func @fold_mul_one_rhs_i32(%arg0: tensor) -> tensor { - %one = "tosa.const"() {values = dense<64> : tensor} : () -> tensor - %shift = "tosa.const"() {values = dense<6> : tensor<1xi8>} : () -> tensor<1xi8> - %mul = tosa.mul %arg0, %one, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %arg0 - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_one_lhs_i32 -func.func @fold_mul_one_lhs_i32(%arg0: tensor) -> tensor { - %one = "tosa.const"() {values = dense<64> : tensor} : () -> tensor - %shift = "tosa.const"() {values = dense<6> : tensor<1xi8>} : () -> tensor<1xi8> - %mul = tosa.mul %one, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %arg0 - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_splat_i8 -func.func @fold_mul_splat_i8() -> tensor<10xi32> { - %one = "tosa.const"() {values = dense<17> : tensor<10xi8>} : () -> tensor<10xi8> - %two = "tosa.const"() {values = dense<32> : tensor<10xi8>} : () -> tensor<10xi8> - %shift = "tosa.const"() {values = dense<3> : tensor<1xi8>} : () -> tensor<1xi8> - %mul = tosa.mul %one, %two, %shift : (tensor<10xi8>, tensor<10xi8>, tensor<1xi8>) -> tensor<10xi32> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<68> : tensor<10xi32>} - // CHECK: return %[[THREE]] - return %mul : tensor<10xi32> -} - -// ----- - -// CHECK-LABEL: @fold_mul_splat_f32 -func.func @fold_mul_splat_f32() -> tensor<10xf32> { - %one = "tosa.const"() {values = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32> - %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<6.000000e+00> : tensor<10xf32>} - // CHECK: return %[[THREE]] - return %mul : tensor<10xf32> -} - -// ----- - -// CHECK-LABEL: @fold_sub_zero_rhs_f32 -func.func @fold_sub_zero_rhs_f32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor - %sub = tosa.sub %arg0, %zero : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %sub : tensor -} - -// ----- - -// CHECK-LABEL: @fold_sub_zero_rhs_i32 -func.func @fold_sub_zero_rhs_i32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor - %sub = tosa.sub %arg0, %zero : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %sub : tensor -} - -// ----- - -// CHECK-LABEL: @fold_sub_splat_i32 -func.func @fold_sub_splat_i32() -> tensor<10xi32> { - %one = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> - %two = "tosa.const"() {values = dense<2> : tensor<10xi32>} : () -> tensor<10xi32> - %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<-1> : tensor<10xi32>} - // CHECK: return %[[THREE]] - return %sub : tensor<10xi32> -} - -// ----- - -// CHECK-LABEL: @fold_sub_splat_f32 -func.func @fold_sub_splat_f32() -> tensor<10xf32> { - %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> - %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %sub = tosa.sub %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<-1.000000e+00> : tensor<10xf32>} - // CHECK: return %[[THREE]] - return %sub : tensor<10xf32> -} - -// ----- - -// CHECK-LABEL: @fold_greater_splat_f32 -func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { - %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> - %1 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> - %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %true = tosa.greater %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - %false = tosa.greater %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK: return %[[TRUE]], %[[FALSE]] - return %true, %false : tensor<10xi1>, tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_greater_splat_i32 -func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { - %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> - %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %3 = "tosa.const"() {values = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32> - %false = tosa.greater %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - %true = tosa.greater %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK: return %[[FALSE]], %[[TRUE]] - return %false, %true : tensor<10xi1>, tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_greater_eq_splat_f32 -func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { - %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> - %1 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> - %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> - %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %true = tosa.greater_equal %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - %false = tosa.greater_equal %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK: return %[[TRUE]], %[[FALSE]] - return %true, %false : tensor<10xi1>, tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_greater_eq_splat_i32 -func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { - %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> - %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %3 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %true = tosa.greater_equal %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - %false = tosa.greater_equal %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK: return %[[TRUE]], %[[FALSE]] - return %true, %false : tensor<10xi1>, tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_eq_splat_f32 -func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { - %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> - %1 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> - %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> - %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %true = tosa.equal %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - %false = tosa.equal %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK: return %[[TRUE]], %[[FALSE]] - return %true, %false : tensor<10xi1>, tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_eq_splat_i32 -func.func @fold_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { - %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> - %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %3 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %true = tosa.equal %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - %false = tosa.equal %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK: return %[[TRUE]], %[[FALSE]] - return %true, %false : tensor<10xi1>, tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_eq_i32 -func.func @fold_eq_i32(%arg0 : tensor<10xi32>) -> (tensor<10xi1>) { - // CHECK: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - %0 = tosa.equal %arg0, %arg0 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - // CHECK: return %[[TRUE]] - return %0 : tensor<10xi1> -} - -// ----- - -func.func @reshape_splat() -> tensor<6x5x4xi32> { - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<42> : tensor<6x5x4xi32>} - %splat = "tosa.const"() {values = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32> - %const = tosa.const_shape {values = dense<[6, 5, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> - %reshape = tosa.reshape %splat, %const : (tensor<4x5x6xi32>, !tosa.shape<3>) -> tensor<6x5x4xi32> - // CHECK: return %[[SPLAT]] - return %reshape : tensor<6x5x4xi32> -} - -// ----- - -// CHECK-LABEL: @slice_splat -func.func @slice_splat() -> tensor<1x1x1xi32> { - // CHECK: %[[SLICE:.+]] = "tosa.const"() <{values = dense<42> : tensor<1x1x1xi32>} - %splat = "tosa.const"() {values = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32> - %start = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> - %size = tosa.const_shape {values = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> - %slice= tosa.slice %splat, %start, %size : (tensor<4x5x6xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x1x1xi32> - - // CHECK: return %[[SLICE]] - return %slice : tensor<1x1x1xi32> -} - -// ----- - -// CHECK-LABEL: @slice_singleton -func.func @slice_singleton() -> tensor<1x1xi32> { - %splat = "tosa.const"() {values = dense<[[0, 1, 2], [3, 4, 5], [6, 7 ,8]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32> - // CHECK: %[[SLICE:.+]] = "tosa.const"() <{values = dense<4> : tensor<1x1xi32>} - %start = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> - %size = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> - %slice= tosa.slice %splat, %start, %size : (tensor<3x3xi32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1xi32> - // CHECK: return %[[SLICE]] - return %slice : tensor<1x1xi32> -} - -// ----- - -// CHECK: func.func @cast_float_to_float -func.func @cast_float_to_float() -> tensor { - %splat = "tosa.const"() {values = dense<42.0> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<4.200000e+01> : tensor} - %cast = tosa.cast %splat : (tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %cast : tensor -} - -// ----- - -// CHECK: func.func @cast_int_to_float -func.func @cast_int_to_float() -> tensor { - %splat = "tosa.const"() {values = dense<4> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<4.000000e+00> : tensor} - %cast = tosa.cast %splat : (tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %cast : tensor -} - -// ----- - -// CHECK: func.func @cast_float_to_int -func.func @cast_float_to_int() -> tensor { - %splat = "tosa.const"() {values = dense<-4.0> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-4> : tensor} - %cast = tosa.cast %splat : (tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %cast : tensor -} - -// ----- - -// CHECK: func.func @cast_float_to_int_round -func.func @cast_float_to_int_round() -> tensor { - %splat = "tosa.const"() {values = dense<-3.5> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-4> : tensor} - %cast = tosa.cast %splat : (tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %cast : tensor -} - -// ----- - -// CHECK: func.func @cast_int_to_int_trunc -func.func @cast_int_to_int_trunc() -> tensor { - %splat = "tosa.const"() {values = dense<-1> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-1> : tensor} - %cast = tosa.cast %splat : (tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %cast : tensor -} - -// ----- - -// CHECK: func.func @cast_int_to_int_sign -func.func @cast_int_to_int_sign() -> tensor { - %splat = "tosa.const"() {values = dense<-1> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-1> : tensor} - %cast = tosa.cast %splat : (tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %cast : tensor -} - -// ----- - -// CHECK-LABEL: @reverse_splat -func.func @reverse_splat() -> tensor<10xi32> { - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<42> : tensor<10xi32>} - %splat = "tosa.const"() {values = dense<42> : tensor<10xi32>} : () -> tensor<10xi32> - %reverse = tosa.reverse %splat { axis = 0 : i32 } : (tensor<10xi32>) -> tensor<10xi32> - // CHECK: return %[[SPLAT]] - return %reverse : tensor<10xi32> -} - -// ----- - -// CHECK-LABEL: @reverse_length_one -func.func @reverse_length_one(%arg0 : tensor<10x1xi32>) -> (tensor<10x1xi32>, tensor<10x1xi32>) { - %nofold = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32> - %fold = tosa.reverse %arg0 { axis = 1 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32> - // CHECK: %[[NOFOLD:.+]] = tosa.reverse %arg0 {axis = 0 : i32} - // CHECK: return %[[NOFOLD]], %arg0 - return %nofold, %fold : tensor<10x1xi32>, tensor<10x1xi32> -} - // ----- func.func @reduce_sum_constant() -> tensor<1x3xi32> { @@ -1172,15 +702,3 @@ func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> { %res1 = tosa.add %res0, %argmax1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %res1 : tensor<2x3xi32> } - -// ----- - -// no_shift_op_reorder checks that %arg1 won't be reorder with %0 -// by the folder pass. -// CHECK-LABEL: @no_shift_op_reorder -func.func @no_shift_op_reorder (%arg0 : tensor<44x1xi16>, %arg1 : tensor<1xi8>) -> tensor<44x57xi32> { - %0 = "tosa.const"() {values = dense<1> : tensor<44x57xi16>} : () -> tensor<44x57xi16> - // CHECK: tosa.mul %arg0, %0, %arg1 - %1 = tosa.mul %arg0, %0, %arg1 : (tensor<44x1xi16>, tensor<44x57xi16>, tensor<1xi8>) -> tensor<44x57xi32> - return %1 : tensor<44x57xi32> -}