diff --git a/xls/passes/arith_simplification_pass.cc b/xls/passes/arith_simplification_pass.cc index 6d06e247db..f3532d51f0 100644 --- a/xls/passes/arith_simplification_pass.cc +++ b/xls/passes/arith_simplification_pass.cc @@ -1231,6 +1231,22 @@ absl::StatusOr MatchArithPatterns(int64_t opt_level, Node* n, return true; } + // Logical shift of a 1-bit value: + // shll(y:bits[1], amt) == y & (amt == 0) + // shrl(y:bits[1], amt) == y & (amt == 0) + // + // A 1-bit logical shift by a nonzero amount produces zero, and a shift by + // zero is a no-op. + if ((n->op() == Op::kShll || n->op() == Op::kShrl) && + n->operand(0)->BitCountOrDie() == 1) { + VLOG(2) << "FOUND: logical shift of 1-bit value"; + XLS_ASSIGN_OR_RETURN(Node * amt_is_zero, + CompareLiteral(n->operand(1), 0, Op::kEq)); + std::vector args = {n->operand(0), amt_is_zero}; + XLS_RETURN_IF_ERROR(n->ReplaceUsesWithNew(args, Op::kAnd).status()); + return true; + } + // An arithmetic shift-right of a 1-bit value is a no-op. if (n->op() == Op::kShra && n->operand(0)->BitCountOrDie() == 1) { VLOG(2) << "FOUND: arithmetic shift-right of 1-bit value"; diff --git a/xls/passes/arith_simplification_pass_test.cc b/xls/passes/arith_simplification_pass_test.cc index e0e6adddfc..6f38a57e1a 100644 --- a/xls/passes/arith_simplification_pass_test.cc +++ b/xls/passes/arith_simplification_pass_test.cc @@ -2170,6 +2170,32 @@ TEST_F(ArithSimplificationPassTest, GuardedShiftOperationLowLimit) { EXPECT_THAT(f->return_value(), m::Shll(m::Param("x"), m::Select())); } +TEST_F(ArithSimplificationPassTest, LogicalShiftLeftOfOneBitUnknownAmount) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue amt = fb.Param("amt", p->GetBitsType(10)); + BValue y = fb.Param("y", p->GetBitsType(1)); + fb.Shll(y, amt); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + ScopedVerifyEquivalence sve(f, kProverTimeout); + ASSERT_THAT(Run(p.get()), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::And(m::Param("y"), m::Eq(m::Param("amt"), m::Literal(0)))); +} + +TEST_F(ArithSimplificationPassTest, LogicalShiftRightOfOneBitUnknownAmount) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue amt = fb.Param("amt", p->GetBitsType(10)); + BValue y = fb.Param("y", p->GetBitsType(1)); + fb.Shrl(y, amt); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + ScopedVerifyEquivalence sve(f, kProverTimeout); + ASSERT_THAT(Run(p.get()), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::And(m::Param("y"), m::Eq(m::Param("amt"), m::Literal(0)))); +} + TEST_F(ArithSimplificationPassTest, UMulCompare) { auto p = CreatePackage(); FunctionBuilder fb(TestName(), p.get());