From 164bacf5ca76cd4176a8c6b4052b1c2613fee6c7 Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Tue, 30 Dec 2025 23:42:05 -0800 Subject: [PATCH 1/2] [opt] Optimize one-bit logical shifts. --- xls/passes/arith_simplification_pass.cc | 16 ++++++++++++ xls/passes/arith_simplification_pass_test.cc | 26 ++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/xls/passes/arith_simplification_pass.cc b/xls/passes/arith_simplification_pass.cc index 6d06e247db..6a0917ca76 100644 --- a/xls/passes/arith_simplification_pass.cc +++ b/xls/passes/arith_simplification_pass.cc @@ -1187,6 +1187,22 @@ absl::StatusOr MatchArithPatterns(int64_t opt_level, Node* n, return true; } + // Logical shift of a 1-bit value: + // shll(y:bits[1], amt) == y & (amt == 0) + // shrl(y:bits[1], amt) == y & (amt == 0) + // + // A 1-bit logical shift by a nonzero amount produces zero, and a shift by + // zero is a no-op. + if ((n->op() == Op::kShll || n->op() == Op::kShrl) && + n->operand(0)->BitCountOrDie() == 1) { + VLOG(2) << "FOUND: logical shift of 1-bit value"; + XLS_ASSIGN_OR_RETURN(Node * amt_is_zero, + CompareLiteral(n->operand(1), 0, Op::kEq)); + std::vector args = {n->operand(0), amt_is_zero}; + XLS_RETURN_IF_ERROR(n->ReplaceUsesWithNew(args, Op::kAnd).status()); + return true; + } + // Ext(Ext(x, w_0), w_1) => Ext(x, w_1) if (n->Is() && n->op() == n->operand(0)->op()) { VLOG(2) << "FOUND: replace extend(extend(x)) with extend(x)"; diff --git a/xls/passes/arith_simplification_pass_test.cc b/xls/passes/arith_simplification_pass_test.cc index e0e6adddfc..6f38a57e1a 100644 --- a/xls/passes/arith_simplification_pass_test.cc +++ b/xls/passes/arith_simplification_pass_test.cc @@ -2170,6 +2170,32 @@ TEST_F(ArithSimplificationPassTest, GuardedShiftOperationLowLimit) { EXPECT_THAT(f->return_value(), m::Shll(m::Param("x"), m::Select())); } +TEST_F(ArithSimplificationPassTest, LogicalShiftLeftOfOneBitUnknownAmount) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue amt = fb.Param("amt", p->GetBitsType(10)); + BValue y = fb.Param("y", p->GetBitsType(1)); + fb.Shll(y, amt); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + ScopedVerifyEquivalence sve(f, kProverTimeout); + ASSERT_THAT(Run(p.get()), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::And(m::Param("y"), m::Eq(m::Param("amt"), m::Literal(0)))); +} + +TEST_F(ArithSimplificationPassTest, LogicalShiftRightOfOneBitUnknownAmount) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue amt = fb.Param("amt", p->GetBitsType(10)); + BValue y = fb.Param("y", p->GetBitsType(1)); + fb.Shrl(y, amt); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + ScopedVerifyEquivalence sve(f, kProverTimeout); + ASSERT_THAT(Run(p.get()), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::And(m::Param("y"), m::Eq(m::Param("amt"), m::Literal(0)))); +} + TEST_F(ArithSimplificationPassTest, UMulCompare) { auto p = CreatePackage(); FunctionBuilder fb(TestName(), p.get()); From 3417825d530f1fbcc7953034c596a7dd21b33c10 Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Fri, 2 Jan 2026 09:53:08 -0800 Subject: [PATCH 2/2] Move the logical shift opt down towards related opt. --- xls/passes/arith_simplification_pass.cc | 32 ++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/xls/passes/arith_simplification_pass.cc b/xls/passes/arith_simplification_pass.cc index 6a0917ca76..f3532d51f0 100644 --- a/xls/passes/arith_simplification_pass.cc +++ b/xls/passes/arith_simplification_pass.cc @@ -1187,22 +1187,6 @@ absl::StatusOr MatchArithPatterns(int64_t opt_level, Node* n, return true; } - // Logical shift of a 1-bit value: - // shll(y:bits[1], amt) == y & (amt == 0) - // shrl(y:bits[1], amt) == y & (amt == 0) - // - // A 1-bit logical shift by a nonzero amount produces zero, and a shift by - // zero is a no-op. - if ((n->op() == Op::kShll || n->op() == Op::kShrl) && - n->operand(0)->BitCountOrDie() == 1) { - VLOG(2) << "FOUND: logical shift of 1-bit value"; - XLS_ASSIGN_OR_RETURN(Node * amt_is_zero, - CompareLiteral(n->operand(1), 0, Op::kEq)); - std::vector args = {n->operand(0), amt_is_zero}; - XLS_RETURN_IF_ERROR(n->ReplaceUsesWithNew(args, Op::kAnd).status()); - return true; - } - // Ext(Ext(x, w_0), w_1) => Ext(x, w_1) if (n->Is() && n->op() == n->operand(0)->op()) { VLOG(2) << "FOUND: replace extend(extend(x)) with extend(x)"; @@ -1247,6 +1231,22 @@ absl::StatusOr MatchArithPatterns(int64_t opt_level, Node* n, return true; } + // Logical shift of a 1-bit value: + // shll(y:bits[1], amt) == y & (amt == 0) + // shrl(y:bits[1], amt) == y & (amt == 0) + // + // A 1-bit logical shift by a nonzero amount produces zero, and a shift by + // zero is a no-op. + if ((n->op() == Op::kShll || n->op() == Op::kShrl) && + n->operand(0)->BitCountOrDie() == 1) { + VLOG(2) << "FOUND: logical shift of 1-bit value"; + XLS_ASSIGN_OR_RETURN(Node * amt_is_zero, + CompareLiteral(n->operand(1), 0, Op::kEq)); + std::vector args = {n->operand(0), amt_is_zero}; + XLS_RETURN_IF_ERROR(n->ReplaceUsesWithNew(args, Op::kAnd).status()); + return true; + } + // An arithmetic shift-right of a 1-bit value is a no-op. if (n->op() == Op::kShra && n->operand(0)->BitCountOrDie() == 1) { VLOG(2) << "FOUND: arithmetic shift-right of 1-bit value";