Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def Tosa_AddShapeOp : Tosa_ElementwiseShapeOp<"add_shape", [Pure]> {
);

let results = (outs Tosa_Shape:$output);

let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
222 changes: 163 additions & 59 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,33 +889,157 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
// Operator Folders.
//===----------------------------------------------------------------------===//

template <typename IntFolder, typename FloatFolder>
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
DenseElementsAttr rhs,
RankedTensorType returnTy) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
if (lETy != rETy)
return {};
template <typename Folder>
static DenseElementsAttr
binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy,
bool foldDenseValues = false) {
if (!lhs || !rhs)
return {};

if (llvm::isa<IntegerType>(lETy)) {
APInt l = lhs.getSplatValue<APInt>();
APInt r = rhs.getSplatValue<APInt>();
auto result = IntFolder()(l, r);
return DenseElementsAttr::get(returnTy, result);
const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
if (lETy != rETy)
return {};

if (lhs.isSplat() && rhs.isSplat()) {
if (isa<FloatType>(lETy)) {
const APFloat l = lhs.getSplatValue<APFloat>();
const APFloat r = rhs.getSplatValue<APFloat>();
const auto maybeResult = Folder::fold(l, r);
if (failed(maybeResult))
return {};
return DenseElementsAttr::get(returnTy, maybeResult.value());
}

if (llvm::isa<FloatType>(lETy)) {
APFloat l = lhs.getSplatValue<APFloat>();
APFloat r = rhs.getSplatValue<APFloat>();
auto result = FloatFolder()(l, r);
return DenseElementsAttr::get(returnTy, result);
if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
const APInt l = lhs.getSplatValue<APInt>();
const APInt r = rhs.getSplatValue<APInt>();
const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
if (failed(maybeResult))
return {};
return DenseElementsAttr::get(returnTy, maybeResult.value());
}
}

if (foldDenseValues) {
SmallVector<APInt> resultValues;
for (auto [l, r] :
llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) {
const auto maybeResult = Folder::fold(l, r, false);
if (failed(maybeResult))
return {};
resultValues.push_back(maybeResult.value());
}
return DenseElementsAttr::get(returnTy, resultValues);
}

// Folding arbitrarily sized tensor operations is not supported
return {};
}
struct AddFoldAdaptor {
static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
const bool isUnsigned) {
const unsigned originalWidth = lhs.getBitWidth();

APInt lhs64, rhs64;
if (isUnsigned) {
lhs64 = lhs.zext(64);
rhs64 = rhs.zext(64);

// Check for overflow
const APInt max = APInt::getMaxValue(originalWidth).zext(64);
if (lhs64.ugt(max - rhs64))
return failure();
} else {
lhs64 = lhs.sext(64);
rhs64 = rhs.sext(64);

// Check for overflow
const APInt zero = APInt::getZero(64);
const APInt max = APInt::getSignedMaxValue(originalWidth).sext(64);
const APInt min = APInt::getSignedMinValue(originalWidth).sext(64);
if ((rhs64.sgt(zero) && lhs64.sgt(max - rhs64)) ||
(rhs64.slt(zero) && lhs64.slt(min - rhs64)))
return failure();
}

const APInt result64 = lhs64 + rhs64;
return result64.trunc(originalWidth);
}

static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
return lhs + rhs;
}
};

struct SubFoldAdaptor {
static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
const bool isUnsigned) {
const unsigned originalWidth = lhs.getBitWidth();

APInt lhs64, rhs64;
if (isUnsigned) {
lhs64 = lhs.zext(64);
rhs64 = rhs.zext(64);

// Check for overflow
const APInt max = APInt::getMaxValue(originalWidth).zext(64);
if (lhs64.ult(rhs64))
return failure();
} else {
lhs64 = lhs.sext(64);
rhs64 = rhs.sext(64);

// Check for overflow
const APInt zero = APInt::getZero(64);
const APInt max = APInt::getSignedMaxValue(originalWidth).sext(64);
const APInt min = APInt::getSignedMinValue(originalWidth).sext(64);
if ((rhs64.sgt(zero) && lhs64.slt(min + rhs64)) ||
(rhs64.slt(zero) && lhs64.sgt(max + rhs64)))
return failure();
}

const APInt result64 = lhs64 - rhs64;
return result64.trunc(originalWidth);
}

static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
return lhs - rhs;
}
};

struct FoldGreaterAdaptor {
static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
const bool isUnsigned) {
return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
}

static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
return APInt(1, lhs > rhs);
}
};

struct FoldGreaterEqualAdaptor {
static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
const bool isUnsigned) {
return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
}

static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
return APInt(1, lhs >= rhs);
}
};

struct FoldEqualAdaptor {
static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
const bool isUnsigned) {
return APInt(1, lhs == rhs);
}

static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
return APInt(1, lhs == rhs);
}
};

static bool isSplatZero(Type elemType, DenseElementsAttr val) {
if (llvm::isa<FloatType>(elemType))
Expand Down Expand Up @@ -963,8 +1087,7 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};

return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
resultTy);
return binaryFolder<AddFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}

OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
Expand Down Expand Up @@ -1145,38 +1268,9 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};

return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
resultTy);
return binaryFolder<SubFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}

namespace {
template <typename Cmp>
struct ComparisonFold {
ComparisonFold() = default;
APInt operator()(const APInt &l, const APInt &r) {
return APInt(1, Cmp()(l, r));
}

APInt operator()(const APFloat &l, const APFloat &r) {
return APInt(1, Cmp()(l, r));
}
};

struct APIntFoldGreater {
APIntFoldGreater() = default;
APInt operator()(const APInt &l, const APInt &r) {
return APInt(1, l.sgt(r));
}
};

struct APIntFoldGreaterEqual {
APIntFoldGreaterEqual() = default;
APInt operator()(const APInt &l, const APInt &r) {
return APInt(1, l.sge(r));
}
};
} // namespace

OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
auto lhsAttr =
Expand All @@ -1187,8 +1281,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};

return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
return binaryFolder<FoldGreaterAdaptor>(lhsAttr, rhsAttr, resultTy);
}

OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
Expand All @@ -1201,9 +1294,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};

return binaryFolder<APIntFoldGreaterEqual,
ComparisonFold<std::greater_equal<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
return binaryFolder<FoldGreaterEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
}

OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
Expand All @@ -1226,9 +1317,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};

return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
resultTy);
return binaryFolder<FoldEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
}

OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
Expand Down Expand Up @@ -1650,3 +1739,18 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {

return {};
}

OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
auto input1ConstShape =
dyn_cast<tosa::ConstShapeOp>(getInput1().getDefiningOp());
auto input2ConstShape =
dyn_cast<tosa::ConstShapeOp>(getInput2().getDefiningOp());
if (!input1ConstShape || !input2ConstShape)
return {};

const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());

return binaryFolder<AddFoldAdaptor>(
input1Attr, input2Attr, input1Attr.getType(), /*foldDenseValues=*/true);
}
Loading