From e4396313b534b4670925e64982b4292241f1c4fb Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 16 Dec 2025 21:38:13 +0000 Subject: [PATCH 1/3] [mlir][tosa] Separate layerwise folding and simple folder tests (NFC) This commit moves the 'simple' folder tests (invoked via `--canonicalize`) away from other layerwise constant folding tests (invoked via `--tosa-layerwise-constant-fold`) into a separate test file to help reduce confusion. Also rename the layerwise folding test file to reflect the the pass name that they are invoked by. Change-Id: I22bfa76480eddd8f850702986d79608d956e766a --- mlir/test/Dialect/Tosa/constant_folding.mlir | 510 ++++++++++++++++- ...mlir => tosa-layerwise-constant-fold.mlir} | 530 +----------------- 2 files changed, 533 insertions(+), 507 deletions(-) rename mlir/test/Dialect/Tosa/{constant-op-fold.mlir => tosa-layerwise-constant-fold.mlir} (63%) diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir index d477a2479e913..bf6e1ad23bcb9 100644 --- a/mlir/test/Dialect/Tosa/constant_folding.mlir +++ b/mlir/test/Dialect/Tosa/constant_folding.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --test-single-fold %s | FileCheck %s +// RUN: mlir-opt --split-input-file --test-single-fold %s | FileCheck %s // CHECK-LABEL: func @test_const func.func @test_const(%arg0 : index) -> tensor<4xi32> { @@ -7,6 +7,8 @@ func.func @test_const(%arg0 : index) -> tensor<4xi32> { return %0 : tensor<4xi32> } +// ----- + // CHECK-LABEL: func @test_const_i64 func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> { // CHECK: tosa.const @@ -14,6 +16,8 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> { return %0 : tensor<4xi64> } +// ----- + // CHECK-LABEL: func @try_fold_equal_with_unranked_tensor func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) { // CHECK: tosa.equal @@ -21,3 +25,507 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> return } + +// ----- + +// CHECK-LABEL: @fold_add_zero_rhs_f32 +func.func @fold_add_zero_rhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor + %add = tosa.add %arg0, %zero : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %add : tensor +} + +// ----- + +// CHECK-LABEL: @fold_add_zero_lhs_f32 +func.func @fold_add_zero_lhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor + %add = tosa.add %zero, %arg0 : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %add : tensor +} + +// ----- + +// CHECK-LABEL: @fold_add_zero_rhs_i32 +func.func @fold_add_zero_rhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + %add = tosa.add %arg0, %zero : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %add : tensor +} + +// ----- + +// CHECK-LABEL: @fold_add_zero_lhs_i32 +func.func @fold_add_zero_lhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + %add = tosa.add %zero, %arg0 : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %add : tensor +} + +// ----- + +// CHECK-LABEL: @fold_add_splat_i32 +func.func @fold_add_splat_i32() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {values = dense<2> : tensor<10xi32>} : () -> tensor<10xi32> + %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<3> : tensor<10xi32>} + // CHECK: return %[[THREE]] + return %add : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_add_splat_f32 +func.func @fold_add_splat_f32() -> tensor<10xf32> { + %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %add = tosa.add %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<3.000000e+00> + // CHECK: return %[[THREE]] + return %add : tensor<10xf32> +} + +// ----- + +// CHECK-LABEL: @fold_div_zero_lhs_i32 +func.func @fold_div_zero_lhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> + %div = tosa.intdiv %zero, %arg0 : (tensor, tensor) -> tensor + // CHECK: return %[[ZERO]] + return %div : tensor +} + +// ----- + +// CHECK-LABEL: @fold_div_one_rhs_i32 +func.func @fold_div_one_rhs_i32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %div = tosa.intdiv %arg0, %one : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %div : tensor +} + +// ----- + +// CHECK-LABEL: @fold_div_splat_i32 +func.func @fold_div_splat_i32() -> tensor { + %lhs = "tosa.const"() {values = dense<10> : tensor} : () -> tensor + %rhs = "tosa.const"() {values = dense<-3> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-3> + %div = tosa.intdiv %lhs, %rhs : (tensor, tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %div : tensor +} + +// ----- + + +// CHECK-LABEL: @fold_mul_zero_rhs_f32 +func.func @fold_mul_zero_rhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor + // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0.000000e+00> + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %arg0, %zero, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %[[ZERO]] + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_zero_lhs_f32 +func.func @fold_mul_zero_lhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor + // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0.000000e+00> + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %zero, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %[[ZERO]] + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_zero_rhs_i32 +func.func @fold_mul_zero_rhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> + %mul = tosa.mul %arg0, %zero, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %[[ZERO]] + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_zero_lhs_i32 +func.func @fold_mul_zero_lhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> + %mul = tosa.mul %zero, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %[[ZERO]] + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_one_rhs_f32 +func.func @fold_mul_one_rhs_f32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {values = dense<1.0> : tensor} : () -> tensor + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %arg0, %one, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %arg0 + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_one_lhs_f32 +func.func @fold_mul_one_lhs_f32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {values = dense<1.0> : tensor} : () -> tensor + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %one, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %arg0 + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_one_rhs_i32 +func.func @fold_mul_one_rhs_i32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {values = dense<64> : tensor} : () -> tensor + %shift = "tosa.const"() {values = dense<6> : tensor<1xi8>} : () -> tensor<1xi8> + %mul = tosa.mul %arg0, %one, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %arg0 + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_one_lhs_i32 +func.func @fold_mul_one_lhs_i32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {values = dense<64> : tensor} : () -> tensor + %shift = "tosa.const"() {values = dense<6> : tensor<1xi8>} : () -> tensor<1xi8> + %mul = tosa.mul %one, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + // CHECK: return %arg0 + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_splat_i8 +func.func @fold_mul_splat_i8() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<17> : tensor<10xi8>} : () -> tensor<10xi8> + %two = "tosa.const"() {values = dense<32> : tensor<10xi8>} : () -> tensor<10xi8> + %shift = "tosa.const"() {values = dense<3> : tensor<1xi8>} : () -> tensor<1xi8> + %mul = tosa.mul %one, %two, %shift : (tensor<10xi8>, tensor<10xi8>, tensor<1xi8>) -> tensor<10xi32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<68> : tensor<10xi32>} + // CHECK: return %[[THREE]] + return %mul : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_mul_splat_f32 +func.func @fold_mul_splat_f32() -> tensor<10xf32> { + %one = "tosa.const"() {values = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32> + %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<6.000000e+00> : tensor<10xf32>} + // CHECK: return %[[THREE]] + return %mul : tensor<10xf32> +} + +// ----- + +// CHECK-LABEL: @fold_sub_zero_rhs_f32 +func.func @fold_sub_zero_rhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor + %sub = tosa.sub %arg0, %zero : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %sub : tensor +} + +// ----- + +// CHECK-LABEL: @fold_sub_zero_rhs_i32 +func.func @fold_sub_zero_rhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + %sub = tosa.sub %arg0, %zero : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %sub : tensor +} + +// ----- + +// CHECK-LABEL: @fold_sub_splat_i32 +func.func @fold_sub_splat_i32() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {values = dense<2> : tensor<10xi32>} : () -> tensor<10xi32> + %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<-1> : tensor<10xi32>} + // CHECK: return %[[THREE]] + return %sub : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_sub_splat_f32 +func.func @fold_sub_splat_f32() -> tensor<10xf32> { + %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %sub = tosa.sub %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<-1.000000e+00> : tensor<10xf32>} + // CHECK: return %[[THREE]] + return %sub : tensor<10xf32> +} + +// ----- + +// CHECK-LABEL: @fold_greater_splat_f32 +func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %1 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %true = tosa.greater %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + %false = tosa.greater %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_splat_i32 +func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> + %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %3 = "tosa.const"() {values = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32> + %false = tosa.greater %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + %true = tosa.greater %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[FALSE]], %[[TRUE]] + return %false, %true : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_eq_splat_f32 +func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %1 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %true = tosa.greater_equal %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + %false = tosa.greater_equal %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_eq_splat_i32 +func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> + %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %3 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %true = tosa.greater_equal %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + %false = tosa.greater_equal %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_eq_splat_f32 +func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %1 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %true = tosa.equal %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + %false = tosa.equal %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_eq_splat_i32 +func.func @fold_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> + %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %3 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %true = tosa.equal %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + %false = tosa.equal %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_eq_i32 +func.func @fold_eq_i32(%arg0 : tensor<10xi32>) -> (tensor<10xi1>) { + // CHECK: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + %0 = tosa.equal %arg0, %arg0 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK: return %[[TRUE]] + return %0 : tensor<10xi1> +} + +// ----- + +func.func @reshape_splat() -> tensor<6x5x4xi32> { + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<42> : tensor<6x5x4xi32>} + %splat = "tosa.const"() {values = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32> + %const = tosa.const_shape {values = dense<[6, 5, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> + %reshape = tosa.reshape %splat, %const : (tensor<4x5x6xi32>, !tosa.shape<3>) -> tensor<6x5x4xi32> + // CHECK: return %[[SPLAT]] + return %reshape : tensor<6x5x4xi32> +} + +// ----- + +// CHECK-LABEL: @slice_splat +func.func @slice_splat() -> tensor<1x1x1xi32> { + // CHECK: %[[SLICE:.+]] = "tosa.const"() <{values = dense<42> : tensor<1x1x1xi32>} + %splat = "tosa.const"() {values = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32> + %start = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + %size = tosa.const_shape {values = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %slice= tosa.slice %splat, %start, %size : (tensor<4x5x6xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x1x1xi32> + + // CHECK: return %[[SLICE]] + return %slice : tensor<1x1x1xi32> +} + +// ----- + +// CHECK-LABEL: @slice_singleton +func.func @slice_singleton() -> tensor<1x1xi32> { + %splat = "tosa.const"() {values = dense<[[0, 1, 2], [3, 4, 5], [6, 7 ,8]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32> + // CHECK: %[[SLICE:.+]] = "tosa.const"() <{values = dense<4> : tensor<1x1xi32>} + %start = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %size = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %slice= tosa.slice %splat, %start, %size : (tensor<3x3xi32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1xi32> + // CHECK: return %[[SLICE]] + return %slice : tensor<1x1xi32> +} + +// ----- + +// CHECK: func.func @cast_float_to_float +func.func @cast_float_to_float() -> tensor { + %splat = "tosa.const"() {values = dense<42.0> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<4.200000e+01> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_int_to_float +func.func @cast_int_to_float() -> tensor { + %splat = "tosa.const"() {values = dense<4> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<4.000000e+00> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_float_to_int +func.func @cast_float_to_int() -> tensor { + %splat = "tosa.const"() {values = dense<-4.0> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-4> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_float_to_int_round +func.func @cast_float_to_int_round() -> tensor { + %splat = "tosa.const"() {values = dense<-3.5> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-4> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_int_to_int_trunc +func.func @cast_int_to_int_trunc() -> tensor { + %splat = "tosa.const"() {values = dense<-1> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-1> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK: func.func @cast_int_to_int_sign +func.func @cast_int_to_int_sign() -> tensor { + %splat = "tosa.const"() {values = dense<-1> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-1> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + +// CHECK-LABEL: @reverse_splat +func.func @reverse_splat() -> tensor<10xi32> { + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<42> : tensor<10xi32>} + %splat = "tosa.const"() {values = dense<42> : tensor<10xi32>} : () -> tensor<10xi32> + %reverse = tosa.reverse %splat { axis = 0 : i32 } : (tensor<10xi32>) -> tensor<10xi32> + // CHECK: return %[[SPLAT]] + return %reverse : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @reverse_length_one +func.func @reverse_length_one(%arg0 : tensor<10x1xi32>) -> (tensor<10x1xi32>, tensor<10x1xi32>) { + %nofold = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32> + %fold = tosa.reverse %arg0 { axis = 1 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32> + // CHECK: %[[NOFOLD:.+]] = tosa.reverse %arg0 {axis = 0 : i32} + // CHECK: return %[[NOFOLD]], %arg0 + return %nofold, %fold : tensor<10x1xi32>, tensor<10x1xi32> +} + +// ----- + +// no_shift_op_reorder checks that %arg1 won't be reorder with %0 +// by the folder pass. +// CHECK-LABEL: @no_shift_op_reorder +func.func @no_shift_op_reorder (%arg0 : tensor<44x1xi16>, %arg1 : tensor<1xi8>) -> tensor<44x57xi32> { + %0 = "tosa.const"() {values = dense<1> : tensor<44x57xi16>} : () -> tensor<44x57xi16> + // CHECK: tosa.mul %arg0, %0, %arg1 + %1 = tosa.mul %arg0, %0, %arg1 : (tensor<44x1xi16>, tensor<44x57xi16>, tensor<1xi8>) -> tensor<44x57xi32> + return %1 : tensor<44x57xi32> +} diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir similarity index 63% rename from mlir/test/Dialect/Tosa/constant-op-fold.mlir rename to mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir index b1fbcdcc53e2f..d95d267e8c907 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir @@ -1,6 +1,4 @@ // RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s - - // RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="aggressive-reduce-constant=true" %s | FileCheck %s --check-prefix=AGGRESIVE // CHECK-LABEL: @armax_fold_dim_size_1 @@ -10,6 +8,8 @@ func.func @armax_fold_dim_size_1(%arg0: tensor<2x1x3xf32>) -> tensor<2x3xi32> { return %0 : tensor<2x3xi32> } +// ----- + // CHECK-LABEL: @argmax_dynamic_shape_no_fold_dim_size_1 func.func @argmax_dynamic_shape_no_fold_dim_size_1(%arg0: tensor) -> tensor { // CHECK: tosa.argmax @@ -17,6 +17,8 @@ func.func @argmax_dynamic_shape_no_fold_dim_size_1(%arg0: tensor) -> return %0 : tensor } +// ----- + // CHECK-LABEL: @transpose_fold func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: return %arg0 @@ -24,6 +26,8 @@ func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { return %1 : tensor<3x4xf32> } +// ----- + // CHECK-LABEL: @transpose_nofold func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { // CHECK: tosa.transpose @@ -31,6 +35,8 @@ func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { return %1 : tensor<3x3xf32> } +// ----- + // CHECK-LABEL: @transpose_nofold_shape func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor { // CHECK: tosa.transpose @@ -38,6 +44,8 @@ func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor { return %1 : tensor } +// ----- + // CHECK-LABEL: @transpose_fold_splat func.func @transpose_fold_splat() -> tensor<3x2xf32> { %input = "tosa.const"() {values = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32> @@ -48,6 +56,8 @@ func.func @transpose_fold_splat() -> tensor<3x2xf32> { return %1 : tensor<3x2xf32> } +// ----- + // CHECK-LABEL: @transpose_fold_2d_float func.func @transpose_fold_2d_float() -> tensor<3x2xf32> { %input = "tosa.const"() {values = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> @@ -58,6 +68,8 @@ func.func @transpose_fold_2d_float() -> tensor<3x2xf32> { return %1 : tensor<3x2xf32> } +// ----- + // CHECK-LABEL: @transpose_fold_2d_bool func.func @transpose_fold_2d_bool() -> tensor<3x2xi1> { %input = "tosa.const"() {values = dense<[[true, false, false], [false, false, true]]> : tensor<2x3xi1>} : () -> tensor<2x3xi1> @@ -68,6 +80,8 @@ func.func @transpose_fold_2d_bool() -> tensor<3x2xi1> { return %1 : tensor<3x2xi1> } +// ----- + // CHECK-LABEL: @transpose_fold_4d_int func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> { %input = "tosa.const"() {values = dense<[[ @@ -85,6 +99,8 @@ func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> { return %1 : tensor<3x1x4x2xi32> } +// ----- + // CHECK-LABEL: @transpose_nofold_non_cst_input func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> { // CHECK: tosa.transpose @@ -92,6 +108,8 @@ func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2 return %1 : tensor<3x2xf32> } +// ----- + // CHECK-LABEL: @transpose_nofold_multi_users func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) { %input = "tosa.const"() {values = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> @@ -100,6 +118,8 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) return %1, %input : tensor<3x2xf32>, tensor<2x3xf32> } +// ----- + // CHECK-LABEL: @transpose_nofold_quantized_types func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01}>> { %input = "tosa.const"() {values = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01}>> @@ -108,6 +128,8 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01}>> } +// ----- + // CHECK-LABEL: @transpose_fold_dense_resource func.func @transpose_fold_dense_resource() -> tensor<2x2xf32> { %0 = "tosa.const"() <{values = dense_resource : tensor<2x2xf32>}> : () -> tensor<2x2xf32> @@ -124,498 +146,6 @@ func.func @transpose_fold_dense_resource() -> tensor<2x2xf32> { } #-} -// ----- - -// CHECK-LABEL: @fold_add_zero_rhs_f32 -func.func @fold_add_zero_rhs_f32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor - %add = tosa.add %arg0, %zero : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %add : tensor -} - -// ----- - -// CHECK-LABEL: @fold_add_zero_lhs_f32 -func.func @fold_add_zero_lhs_f32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor - %add = tosa.add %zero, %arg0 : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %add : tensor -} - -// ----- - -// CHECK-LABEL: @fold_add_zero_rhs_i32 -func.func @fold_add_zero_rhs_i32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor - %add = tosa.add %arg0, %zero : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %add : tensor -} - -// ----- - -// CHECK-LABEL: @fold_add_zero_lhs_i32 -func.func @fold_add_zero_lhs_i32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor - %add = tosa.add %zero, %arg0 : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %add : tensor -} - -// ----- - -// CHECK-LABEL: @fold_add_splat_i32 -func.func @fold_add_splat_i32() -> tensor<10xi32> { - %one = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> - %two = "tosa.const"() {values = dense<2> : tensor<10xi32>} : () -> tensor<10xi32> - %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<3> : tensor<10xi32>} - // CHECK: return %[[THREE]] - return %add : tensor<10xi32> -} - -// ----- - -// CHECK-LABEL: @fold_add_splat_f32 -func.func @fold_add_splat_f32() -> tensor<10xf32> { - %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> - %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %add = tosa.add %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<3.000000e+00> - // CHECK: return %[[THREE]] - return %add : tensor<10xf32> -} - -// ----- - -// CHECK-LABEL: @fold_div_zero_lhs_i32 -func.func @fold_div_zero_lhs_i32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor - // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> - %div = tosa.intdiv %zero, %arg0 : (tensor, tensor) -> tensor - // CHECK: return %[[ZERO]] - return %div : tensor -} - -// ----- - -// CHECK-LABEL: @fold_div_one_rhs_i32 -func.func @fold_div_one_rhs_i32(%arg0: tensor) -> tensor { - %one = "tosa.const"() {values = dense<1> : tensor} : () -> tensor - %div = tosa.intdiv %arg0, %one : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %div : tensor -} - -// ----- - -// CHECK-LABEL: @fold_div_splat_i32 -func.func @fold_div_splat_i32() -> tensor { - %lhs = "tosa.const"() {values = dense<10> : tensor} : () -> tensor - %rhs = "tosa.const"() {values = dense<-3> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-3> - %div = tosa.intdiv %lhs, %rhs : (tensor, tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %div : tensor -} - -// ----- - - -// CHECK-LABEL: @fold_mul_zero_rhs_f32 -func.func @fold_mul_zero_rhs_f32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor - // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0.000000e+00> - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - %mul = tosa.mul %arg0, %zero, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %[[ZERO]] - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_zero_lhs_f32 -func.func @fold_mul_zero_lhs_f32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor - // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0.000000e+00> - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - %mul = tosa.mul %zero, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %[[ZERO]] - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_zero_rhs_i32 -func.func @fold_mul_zero_rhs_i32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> - %mul = tosa.mul %arg0, %zero, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %[[ZERO]] - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_zero_lhs_i32 -func.func @fold_mul_zero_lhs_i32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> - %mul = tosa.mul %zero, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %[[ZERO]] - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_one_rhs_f32 -func.func @fold_mul_one_rhs_f32(%arg0: tensor) -> tensor { - %one = "tosa.const"() {values = dense<1.0> : tensor} : () -> tensor - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - %mul = tosa.mul %arg0, %one, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %arg0 - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_one_lhs_f32 -func.func @fold_mul_one_lhs_f32(%arg0: tensor) -> tensor { - %one = "tosa.const"() {values = dense<1.0> : tensor} : () -> tensor - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - %mul = tosa.mul %one, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %arg0 - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_one_rhs_i32 -func.func @fold_mul_one_rhs_i32(%arg0: tensor) -> tensor { - %one = "tosa.const"() {values = dense<64> : tensor} : () -> tensor - %shift = "tosa.const"() {values = dense<6> : tensor<1xi8>} : () -> tensor<1xi8> - %mul = tosa.mul %arg0, %one, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %arg0 - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_one_lhs_i32 -func.func @fold_mul_one_lhs_i32(%arg0: tensor) -> tensor { - %one = "tosa.const"() {values = dense<64> : tensor} : () -> tensor - %shift = "tosa.const"() {values = dense<6> : tensor<1xi8>} : () -> tensor<1xi8> - %mul = tosa.mul %one, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor - // CHECK: return %arg0 - return %mul : tensor -} - -// ----- - -// CHECK-LABEL: @fold_mul_splat_i8 -func.func @fold_mul_splat_i8() -> tensor<10xi32> { - %one = "tosa.const"() {values = dense<17> : tensor<10xi8>} : () -> tensor<10xi8> - %two = "tosa.const"() {values = dense<32> : tensor<10xi8>} : () -> tensor<10xi8> - %shift = "tosa.const"() {values = dense<3> : tensor<1xi8>} : () -> tensor<1xi8> - %mul = tosa.mul %one, %two, %shift : (tensor<10xi8>, tensor<10xi8>, tensor<1xi8>) -> tensor<10xi32> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<68> : tensor<10xi32>} - // CHECK: return %[[THREE]] - return %mul : tensor<10xi32> -} - -// ----- - -// CHECK-LABEL: @fold_mul_splat_f32 -func.func @fold_mul_splat_f32() -> tensor<10xf32> { - %one = "tosa.const"() {values = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32> - %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<6.000000e+00> : tensor<10xf32>} - // CHECK: return %[[THREE]] - return %mul : tensor<10xf32> -} - -// ----- - -// CHECK-LABEL: @fold_sub_zero_rhs_f32 -func.func @fold_sub_zero_rhs_f32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0.0> : tensor} : () -> tensor - %sub = tosa.sub %arg0, %zero : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %sub : tensor -} - -// ----- - -// CHECK-LABEL: @fold_sub_zero_rhs_i32 -func.func @fold_sub_zero_rhs_i32(%arg0: tensor) -> tensor { - %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor - %sub = tosa.sub %arg0, %zero : (tensor, tensor) -> tensor - // CHECK: return %arg0 - return %sub : tensor -} - -// ----- - -// CHECK-LABEL: @fold_sub_splat_i32 -func.func @fold_sub_splat_i32() -> tensor<10xi32> { - %one = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> - %two = "tosa.const"() {values = dense<2> : tensor<10xi32>} : () -> tensor<10xi32> - %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<-1> : tensor<10xi32>} - // CHECK: return %[[THREE]] - return %sub : tensor<10xi32> -} - -// ----- - -// CHECK-LABEL: @fold_sub_splat_f32 -func.func @fold_sub_splat_f32() -> tensor<10xf32> { - %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> - %two = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %sub = tosa.sub %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{values = dense<-1.000000e+00> : tensor<10xf32>} - // CHECK: return %[[THREE]] - return %sub : tensor<10xf32> -} - -// ----- - -// CHECK-LABEL: @fold_greater_splat_f32 -func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { - %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> - %1 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> - %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %true = tosa.greater %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - %false = tosa.greater %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK: return %[[TRUE]], %[[FALSE]] - return %true, %false : tensor<10xi1>, tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_greater_splat_i32 -func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { - %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> - %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %3 = "tosa.const"() {values = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32> - %false = tosa.greater %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - %true = tosa.greater %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK: return %[[FALSE]], %[[TRUE]] - return %false, %true : tensor<10xi1>, tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_greater_eq_splat_f32 -func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { - %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> - %1 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> - %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> - %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %true = tosa.greater_equal %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - %false = tosa.greater_equal %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK: return %[[TRUE]], %[[FALSE]] - return %true, %false : tensor<10xi1>, tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_greater_eq_splat_i32 -func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { - %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> - %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %3 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %true = tosa.greater_equal %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - %false = tosa.greater_equal %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK: return %[[TRUE]], %[[FALSE]] - return %true, %false : tensor<10xi1>, tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_eq_splat_f32 -func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { - %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> - %1 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> - %2 = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> - %3 = "tosa.const"() {values = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %true = tosa.equal %0, %1 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - %false = tosa.equal %2, %3 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK: return %[[TRUE]], %[[FALSE]] - return %true, %false : tensor<10xi1>, tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_eq_splat_i32 -func.func @fold_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { - %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %1 = "tosa.const"() {values = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> - %2 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %3 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %true = tosa.equal %2, %3 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - %false = tosa.equal %0, %1 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - // CHECK: return %[[TRUE]], %[[FALSE]] - return %true, %false : tensor<10xi1>, tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_eq_i32 -func.func @fold_eq_i32(%arg0 : tensor<10xi32>) -> (tensor<10xi1>) { - // CHECK: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} - %0 = tosa.equal %arg0, %arg0 : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - // CHECK: return %[[TRUE]] - return %0 : tensor<10xi1> -} - -// ----- - -func.func @reshape_splat() -> tensor<6x5x4xi32> { - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<42> : tensor<6x5x4xi32>} - %splat = "tosa.const"() {values = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32> - %const = tosa.const_shape {values = dense<[6, 5, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> - %reshape = tosa.reshape %splat, %const : (tensor<4x5x6xi32>, !tosa.shape<3>) -> tensor<6x5x4xi32> - // CHECK: return %[[SPLAT]] - return %reshape : tensor<6x5x4xi32> -} - -// ----- - -// CHECK-LABEL: @slice_splat -func.func @slice_splat() -> tensor<1x1x1xi32> { - // CHECK: %[[SLICE:.+]] = "tosa.const"() <{values = dense<42> : tensor<1x1x1xi32>} - %splat = "tosa.const"() {values = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32> - %start = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> - %size = tosa.const_shape {values = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> - %slice= tosa.slice %splat, %start, %size : (tensor<4x5x6xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x1x1xi32> - - // CHECK: return %[[SLICE]] - return %slice : tensor<1x1x1xi32> -} - -// ----- - -// CHECK-LABEL: @slice_singleton -func.func @slice_singleton() -> tensor<1x1xi32> { - %splat = "tosa.const"() {values = dense<[[0, 1, 2], [3, 4, 5], [6, 7 ,8]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32> - // CHECK: %[[SLICE:.+]] = "tosa.const"() <{values = dense<4> : tensor<1x1xi32>} - %start = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> - %size = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> - %slice= tosa.slice %splat, %start, %size : (tensor<3x3xi32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1xi32> - // CHECK: return %[[SLICE]] - return %slice : tensor<1x1xi32> -} - -// ----- - -// CHECK: func.func @cast_float_to_float -func.func @cast_float_to_float() -> tensor { - %splat = "tosa.const"() {values = dense<42.0> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<4.200000e+01> : tensor} - %cast = tosa.cast %splat : (tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %cast : tensor -} - -// ----- - -// CHECK: func.func @cast_int_to_float -func.func @cast_int_to_float() -> tensor { - %splat = "tosa.const"() {values = dense<4> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<4.000000e+00> : tensor} - %cast = tosa.cast %splat : (tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %cast : tensor -} - -// ----- - -// CHECK: func.func @cast_float_to_int -func.func @cast_float_to_int() -> tensor { - %splat = "tosa.const"() {values = dense<-4.0> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-4> : tensor} - %cast = tosa.cast %splat : (tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %cast : tensor -} - -// ----- - -// CHECK: func.func @cast_float_to_int_round -func.func @cast_float_to_int_round() -> tensor { - %splat = "tosa.const"() {values = dense<-3.5> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-4> : tensor} - %cast = tosa.cast %splat : (tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %cast : tensor -} - -// ----- - -// CHECK: func.func @cast_int_to_int_trunc -func.func @cast_int_to_int_trunc() -> tensor { - %splat = "tosa.const"() {values = dense<-1> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-1> : tensor} - %cast = tosa.cast %splat : (tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %cast : tensor -} - -// ----- - -// CHECK: func.func @cast_int_to_int_sign -func.func @cast_int_to_int_sign() -> tensor { - %splat = "tosa.const"() {values = dense<-1> : tensor} : () -> tensor - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-1> : tensor} - %cast = tosa.cast %splat : (tensor) -> tensor - // CHECK: return %[[SPLAT]] - return %cast : tensor -} - -// ----- - -// CHECK-LABEL: @reverse_splat -func.func @reverse_splat() -> tensor<10xi32> { - // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<42> : tensor<10xi32>} - %splat = "tosa.const"() {values = dense<42> : tensor<10xi32>} : () -> tensor<10xi32> - %reverse = tosa.reverse %splat { axis = 0 : i32 } : (tensor<10xi32>) -> tensor<10xi32> - // CHECK: return %[[SPLAT]] - return %reverse : tensor<10xi32> -} - -// ----- - -// CHECK-LABEL: @reverse_length_one -func.func @reverse_length_one(%arg0 : tensor<10x1xi32>) -> (tensor<10x1xi32>, tensor<10x1xi32>) { - %nofold = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32> - %fold = tosa.reverse %arg0 { axis = 1 : i32 } : (tensor<10x1xi32>) -> tensor<10x1xi32> - // CHECK: %[[NOFOLD:.+]] = tosa.reverse %arg0 {axis = 0 : i32} - // CHECK: return %[[NOFOLD]], %arg0 - return %nofold, %fold : tensor<10x1xi32>, tensor<10x1xi32> -} - // ----- func.func @reduce_sum_constant() -> tensor<1x3xi32> { @@ -1172,15 +702,3 @@ func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> { %res1 = tosa.add %res0, %argmax1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %res1 : tensor<2x3xi32> } - -// ----- - -// no_shift_op_reorder checks that %arg1 won't be reorder with %0 -// by the folder pass. -// CHECK-LABEL: @no_shift_op_reorder -func.func @no_shift_op_reorder (%arg0 : tensor<44x1xi16>, %arg1 : tensor<1xi8>) -> tensor<44x57xi32> { - %0 = "tosa.const"() {values = dense<1> : tensor<44x57xi16>} : () -> tensor<44x57xi16> - // CHECK: tosa.mul %arg0, %0, %arg1 - %1 = tosa.mul %arg0, %0, %arg1 : (tensor<44x1xi16>, tensor<44x57xi16>, tensor<1xi8>) -> tensor<44x57xi32> - return %1 : tensor<44x57xi32> -} From 168b359f5d96ae3eff6f7c1510b3cf6f09e1cd9d Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 17 Dec 2025 13:28:15 +0000 Subject: [PATCH 2/3] [mlir][tosa] Check for overflow in integer folders For these folders to be TOSA compliant, they need to check for overflow. This commit adds those checks, subsequently preventing folding if an overflow is detected. This commit also fixes the greater/greater_equal folders to account for unsigned types. Change-Id: I2b5a5b92fb840d6c34a1f2faa18ae68a20d0ecdf --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 177 +++++++++++++----- mlir/test/Dialect/Tosa/constant_folding.mlir | 121 ++++++++++++ 2 files changed, 246 insertions(+), 52 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index c420a4c9596ff..3e9d803a916a9 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -889,33 +889,141 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, // Operator Folders. //===----------------------------------------------------------------------===// -template +template static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy) { if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { - auto lETy = llvm::cast(lhs.getType()).getElementType(); - auto rETy = llvm::cast(rhs.getType()).getElementType(); + const auto lETy = llvm::cast(lhs.getType()).getElementType(); + const auto rETy = llvm::cast(rhs.getType()).getElementType(); if (lETy != rETy) return {}; - if (llvm::isa(lETy)) { - APInt l = lhs.getSplatValue(); - APInt r = rhs.getSplatValue(); - auto result = IntFolder()(l, r); - return DenseElementsAttr::get(returnTy, result); + if (const auto lIntTy = dyn_cast(lETy)) { + const APInt l = lhs.getSplatValue(); + const APInt r = rhs.getSplatValue(); + const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned()); + if (failed(maybeResult)) + return {}; + return DenseElementsAttr::get(returnTy, maybeResult.value()); } if (llvm::isa(lETy)) { - APFloat l = lhs.getSplatValue(); - APFloat r = rhs.getSplatValue(); - auto result = FloatFolder()(l, r); - return DenseElementsAttr::get(returnTy, result); + const APFloat l = lhs.getSplatValue(); + const APFloat r = rhs.getSplatValue(); + const auto maybeResult = Folder::fold(l, r); + if (failed(maybeResult)) + return {}; + return DenseElementsAttr::get(returnTy, maybeResult.value()); } } return {}; } +struct AddFoldAdaptor { + static FailureOr fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + const unsigned originalWidth = lhs.getBitWidth(); + + APInt lhs64, rhs64; + if (isUnsigned) { + lhs64 = lhs.zext(64); + rhs64 = rhs.zext(64); + + // Check for overflow + const APInt max = APInt::getMaxValue(originalWidth).zext(64); + if (lhs64.ugt(max - rhs64)) + return failure(); + } else { + lhs64 = lhs.sext(64); + rhs64 = rhs.sext(64); + + // Check for overflow + const APInt zero = APInt::getZero(64); + const APInt max = APInt::getSignedMaxValue(originalWidth).sext(64); + const APInt min = APInt::getSignedMinValue(originalWidth).sext(64); + if ((rhs64.sgt(zero) && lhs64.sgt(max - rhs64)) || + (rhs64.slt(zero) && lhs64.slt(min - rhs64))) + return failure(); + } + + const APInt result64 = lhs64 + rhs64; + return result64.trunc(originalWidth); + } + + static FailureOr fold(const APFloat &lhs, const APFloat &rhs) { + return lhs + rhs; + } +}; + +struct SubFoldAdaptor { + static FailureOr fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + const unsigned originalWidth = lhs.getBitWidth(); + + APInt lhs64, rhs64; + if (isUnsigned) { + lhs64 = lhs.zext(64); + rhs64 = rhs.zext(64); + + // Check for overflow + const APInt max = APInt::getMaxValue(originalWidth).zext(64); + if (lhs64.ult(rhs64)) + return failure(); + } else { + lhs64 = lhs.sext(64); + rhs64 = rhs.sext(64); + + // Check for overflow + const APInt zero = APInt::getZero(64); + const APInt max = APInt::getSignedMaxValue(originalWidth).sext(64); + const APInt min = APInt::getSignedMinValue(originalWidth).sext(64); + if ((rhs64.sgt(zero) && lhs64.slt(min + rhs64)) || + (rhs64.slt(zero) && lhs64.sgt(max + rhs64))) + return failure(); + } + + const APInt result64 = lhs64 - rhs64; + return result64.trunc(originalWidth); + } + + static FailureOr fold(const APFloat &lhs, const APFloat &rhs) { + return lhs - rhs; + } +}; + +struct FoldGreaterAdaptor { + static FailureOr fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs)); + } + + static FailureOr fold(const APFloat &lhs, const APFloat &rhs) { + return APInt(1, lhs > rhs); + } +}; + +struct FoldGreaterEqualAdaptor { + static FailureOr fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs)); + } + + static FailureOr fold(const APFloat &lhs, const APFloat &rhs) { + return APInt(1, lhs >= rhs); + } +}; + +struct FoldEqualAdaptor { + static FailureOr fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + return APInt(1, lhs == rhs); + } + + static FailureOr fold(const APFloat &lhs, const APFloat &rhs) { + return APInt(1, lhs == rhs); + } +}; static bool isSplatZero(Type elemType, DenseElementsAttr val) { if (llvm::isa(elemType)) @@ -963,8 +1071,7 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder, std::plus>(lhsAttr, rhsAttr, - resultTy); + return binaryFolder(lhsAttr, rhsAttr, resultTy); } OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) { @@ -1145,38 +1252,9 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder, std::minus>(lhsAttr, rhsAttr, - resultTy); + return binaryFolder(lhsAttr, rhsAttr, resultTy); } -namespace { -template -struct ComparisonFold { - ComparisonFold() = default; - APInt operator()(const APInt &l, const APInt &r) { - return APInt(1, Cmp()(l, r)); - } - - APInt operator()(const APFloat &l, const APFloat &r) { - return APInt(1, Cmp()(l, r)); - } -}; - -struct APIntFoldGreater { - APIntFoldGreater() = default; - APInt operator()(const APInt &l, const APInt &r) { - return APInt(1, l.sgt(r)); - } -}; - -struct APIntFoldGreaterEqual { - APIntFoldGreaterEqual() = default; - APInt operator()(const APInt &l, const APInt &r) { - return APInt(1, l.sge(r)); - } -}; -} // namespace - OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::dyn_cast(getType()); auto lhsAttr = @@ -1187,8 +1265,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder>>( - lhsAttr, rhsAttr, resultTy); + return binaryFolder(lhsAttr, rhsAttr, resultTy); } OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { @@ -1201,9 +1278,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder>>( - lhsAttr, rhsAttr, resultTy); + return binaryFolder(lhsAttr, rhsAttr, resultTy); } OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { @@ -1226,9 +1301,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder>, - ComparisonFold>>(lhsAttr, rhsAttr, - resultTy); + return binaryFolder(lhsAttr, rhsAttr, resultTy); } OpFoldResult CastOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir index bf6e1ad23bcb9..0922d6d2ee6fb 100644 --- a/mlir/test/Dialect/Tosa/constant_folding.mlir +++ b/mlir/test/Dialect/Tosa/constant_folding.mlir @@ -92,6 +92,50 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> { // ----- +// CHECK-LABEL: @fold_add_splat_i32_positive_overflow +func.func @fold_add_splat_i32_positive_overflow() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<2147483647> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> + // CHECK: tosa.add + %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + return %add : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_add_splat_i32_negative_overflow +func.func @fold_add_splat_i32_negative_overflow() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<-1> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {values = dense<-2147483648> : tensor<10xi32>} : () -> tensor<10xi32> + // CHECK: tosa.add + %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + return %add : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_add_splat_ui8 +func.func @fold_add_splat_ui8() -> tensor<10xui8> { + %one = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8> + %two = "tosa.const"() {values = dense<254> : tensor<10xui8>} : () -> tensor<10xui8> + // CHECK: "tosa.const"() <{values = dense<255> : tensor<10xui8>}> : () -> tensor<10xui8> + %add = tosa.add %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8> + return %add : tensor<10xui8> +} + +// ----- + +// CHECK-LABEL: @fold_add_splat_ui8_overflow +func.func @fold_add_splat_ui8_overflow() -> tensor<10xui8> { + %one = "tosa.const"() {values = dense<2> : tensor<10xui8>} : () -> tensor<10xui8> + %two = "tosa.const"() {values = dense<254> : tensor<10xui8>} : () -> tensor<10xui8> + // CHECK: tosa.add + %add = tosa.add %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8> + return %add : tensor<10xui8> +} + +// ----- + // CHECK-LABEL: @fold_div_zero_lhs_i32 func.func @fold_div_zero_lhs_i32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {values = dense<0> : tensor} : () -> tensor @@ -288,6 +332,50 @@ func.func @fold_sub_splat_f32() -> tensor<10xf32> { // ----- +// CHECK-LABEL: @fold_sub_splat_i32_positive_overflow +func.func @fold_sub_splat_i32_positive_overflow() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<2147483647> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {values = dense<-1> : tensor<10xi32>} : () -> tensor<10xi32> + // CHECK: tosa.sub + %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + return %sub : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_sub_splat_i32_negative_overflow +func.func @fold_sub_splat_i32_negative_overflow() -> tensor<10xi32> { + %one = "tosa.const"() {values = dense<-2147483648> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> + // CHECK: tosa.sub + %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + return %sub : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_sub_splat_ui8 +func.func @fold_sub_splat_ui8() -> tensor<10xui8> { + %one = "tosa.const"() {values = dense<255> : tensor<10xui8>} : () -> tensor<10xui8> + %two = "tosa.const"() {values = dense<253> : tensor<10xui8>} : () -> tensor<10xui8> + // CHECK: "tosa.const"() <{values = dense<2> : tensor<10xui8>}> : () -> tensor<10xui8> + %sub = tosa.sub %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8> + return %sub : tensor<10xui8> +} + +// ----- + +// CHECK-LABEL: @fold_sub_splat_ui8_overflow +func.func @fold_sub_splat_ui8_overflow() -> tensor<10xui8> { + %one = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8> + %two = "tosa.const"() {values = dense<253> : tensor<10xui8>} : () -> tensor<10xui8> + // CHECK: tosa.sub + %sub = tosa.sub %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8> + return %sub : tensor<10xui8> +} + +// ----- + // CHECK-LABEL: @fold_greater_splat_f32 func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> @@ -320,6 +408,23 @@ func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { // ----- +// CHECK-LABEL: @fold_greater_splat_ui8 +func.func @fold_greater_splat_ui8() -> (tensor<10xi1>, tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8> + %1 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8> + %2 = "tosa.const"() {values = dense<246> : tensor<10xui8>} : () -> tensor<10xui8> + %3 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8> + %true = tosa.greater %2, %3 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1> + %false = tosa.greater %0, %1 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1> + %false2 = tosa.greater %0, %2 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]], %[[FALSE]] + return %true, %false, %false2 : tensor<10xi1>, tensor<10xi1>, tensor<10xi1> +} + +// ----- + // CHECK-LABEL: @fold_greater_eq_splat_f32 func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> @@ -352,6 +457,22 @@ func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { // ----- +// CHECK-LABEL: @fold_greater_eq_splat_ui8 +func.func @fold_greater_eq_splat_ui8() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8> + %1 = "tosa.const"() {values = dense<255> : tensor<10xui8>} : () -> tensor<10xui8> + %2 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8> + %3 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8> + %true = tosa.greater_equal %2, %3 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1> + %false = tosa.greater_equal %0, %1 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + // CHECK-LABEL: @fold_eq_splat_f32 func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> From ca346bf411c4e022a3160587775b46eef57103a3 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 17 Dec 2025 16:22:28 +0000 Subject: [PATCH 3/3] [mlir][tosa] Add constant folding for tosa.add_shape operation This commit introduces constant folding for the tosa.add_shape operation. When both operands of the add_shape operation are constant shapes, the operation is evaluated at compile-time. Change-Id: I5567fae8290bf238f809088573d40666fe3bdf51 --- .../mlir/Dialect/Tosa/IR/TosaShapeOps.td | 2 + .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 59 ++++++++++++++----- mlir/test/Dialect/Tosa/constant_folding.mlir | 33 +++++++++++ 3 files changed, 80 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td index d5a46b1b34312..c446dd67f37c5 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td @@ -67,6 +67,8 @@ def Tosa_AddShapeOp : Tosa_ElementwiseShapeOp<"add_shape", [Pure]> { ); let results = (outs Tosa_Shape:$output); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 3e9d803a916a9..2ea8481cb0a51 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -890,16 +890,28 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, //===----------------------------------------------------------------------===// template -static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, - DenseElementsAttr rhs, - RankedTensorType returnTy) { - if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { - const auto lETy = llvm::cast(lhs.getType()).getElementType(); - const auto rETy = llvm::cast(rhs.getType()).getElementType(); - if (lETy != rETy) - return {}; +static DenseElementsAttr +binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, + bool foldDenseValues = false) { + if (!lhs || !rhs) + return {}; - if (const auto lIntTy = dyn_cast(lETy)) { + const auto lETy = llvm::cast(lhs.getType()).getElementType(); + const auto rETy = llvm::cast(rhs.getType()).getElementType(); + if (lETy != rETy) + return {}; + + if (lhs.isSplat() && rhs.isSplat()) { + if (isa(lETy)) { + const APFloat l = lhs.getSplatValue(); + const APFloat r = rhs.getSplatValue(); + const auto maybeResult = Folder::fold(l, r); + if (failed(maybeResult)) + return {}; + return DenseElementsAttr::get(returnTy, maybeResult.value()); + } + + if (const auto lIntTy = llvm::dyn_cast(lETy)) { const APInt l = lhs.getSplatValue(); const APInt r = rhs.getSplatValue(); const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned()); @@ -907,17 +919,21 @@ static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, return {}; return DenseElementsAttr::get(returnTy, maybeResult.value()); } + } - if (llvm::isa(lETy)) { - const APFloat l = lhs.getSplatValue(); - const APFloat r = rhs.getSplatValue(); - const auto maybeResult = Folder::fold(l, r); + if (foldDenseValues) { + SmallVector resultValues; + for (auto [l, r] : + llvm::zip(lhs.getValues(), rhs.getValues())) { + const auto maybeResult = Folder::fold(l, r, false); if (failed(maybeResult)) return {}; - return DenseElementsAttr::get(returnTy, maybeResult.value()); + resultValues.push_back(maybeResult.value()); } + return DenseElementsAttr::get(returnTy, resultValues); } + // Folding arbitrarily sized tensor operations is not supported return {}; } struct AddFoldAdaptor { @@ -1723,3 +1739,18 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) { return {}; } + +OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) { + auto input1ConstShape = + dyn_cast(getInput1().getDefiningOp()); + auto input2ConstShape = + dyn_cast(getInput2().getDefiningOp()); + if (!input1ConstShape || !input2ConstShape) + return {}; + + const auto input1Attr = cast(input1ConstShape.getValues()); + const auto input2Attr = cast(input2ConstShape.getValues()); + + return binaryFolder( + input1Attr, input2Attr, input1Attr.getType(), /*foldDenseValues=*/true); +} diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir index 0922d6d2ee6fb..4860c98d960e2 100644 --- a/mlir/test/Dialect/Tosa/constant_folding.mlir +++ b/mlir/test/Dialect/Tosa/constant_folding.mlir @@ -650,3 +650,36 @@ func.func @no_shift_op_reorder (%arg0 : tensor<44x1xi16>, %arg1 : tensor<1xi8>) %1 = tosa.mul %arg0, %0, %arg1 : (tensor<44x1xi16>, tensor<44x57xi16>, tensor<1xi8>) -> tensor<44x57xi32> return %1 : tensor<44x57xi32> } + +// ----- + +// CHECK-LABEL: @test_fold_add_shape +// CHECK: tosa.const_shape {values = dense<[2, 4, 6, 8, 10, 12]> : tensor<6xindex>} : () -> !tosa.shape<6> +func.func @test_fold_add_shape() -> !tosa.shape<6> { + %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6> + %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6> + %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6> + return %c : !tosa.shape<6> +} + +// ----- + +// CHECK-LABEL: @test_no_fold_add_shape_positive_overflow +// CHECK: tosa.add_shape +func.func @test_no_fold_add_shape_positive_overflow() -> !tosa.shape<6> { + %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 9223372036854775807]> : tensor<6xindex>} : () -> !tosa.shape<6> + %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 1]> : tensor<6xindex>} : () -> !tosa.shape<6> + %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6> + return %c : !tosa.shape<6> +} + +// ----- + +// CHECK-LABEL: @test_no_fold_add_shape_negative_overflow +// CHECK: tosa.add_shape +func.func @test_no_fold_add_shape_negative_overflow() -> !tosa.shape<6> { + %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, -9223372036854775808]> : tensor<6xindex>} : () -> !tosa.shape<6> + %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, -1]> : tensor<6xindex>} : () -> !tosa.shape<6> + %c = tosa.add_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6> + return %c : !tosa.shape<6> +}