Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions xls/ir/node_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,25 @@ bool IsBinaryPrioritySelect(Node* node) {
return sel->cases().size() == 1;
}

absl::StatusOr<std::optional<BinarySelectArms>> MatchBinarySelectLike(
Node* node) {
if (IsBinarySelect(node)) {
Select* sel = node->As<Select>();
XLS_RET_CHECK_EQ(sel->selector()->BitCountOrDie(), 1);
return BinarySelectArms{.selector = sel->selector(),
.on_false = sel->get_case(0),
.on_true = sel->get_case(1)};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for the select to have a single case and the default value to cover the case for selector == 0b1?

}
if (IsBinaryPrioritySelect(node)) {
PrioritySelect* sel = node->As<PrioritySelect>();
XLS_RET_CHECK_EQ(sel->selector()->BitCountOrDie(), 1);
return BinarySelectArms{.selector = sel->selector(),
.on_false = sel->default_value(),
.on_true = sel->get_case(0)};
}
return std::nullopt;
}

absl::StatusOr<absl::flat_hash_map<Channel*, std::vector<Node*>>> ChannelUsers(
Package* package) {
absl::flat_hash_map<Channel*, std::vector<Node*>> channel_users;
Expand Down
17 changes: 17 additions & 0 deletions xls/ir/node_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,23 @@ bool IsBinarySelect(Node* node);
// default value)
bool IsBinaryPrioritySelect(Node* node);

// A uniform view of a "binary select-like" node.
//
// For `sel(p, cases=[on_false, on_true])`, the selector is `p`.
// For a binary `priority_sel(p, cases=[on_true], default=on_false)`, the
// selector is `p` (which is required to be a single bit).
struct BinarySelectArms {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename to BinarySelect since this includes both the arms and the selector?

Node* selector;
Node* on_false;
Node* on_true;
};

// Returns a uniform view of a binary `sel` or a binary `priority_sel`.
//
// Returns std::nullopt if `node` is not one of those binary forms.
absl::StatusOr<std::optional<BinarySelectArms>> MatchBinarySelectLike(
Node* node);

// Returns the op which is the inverse of the given comparison.
//
// That is (not (op L R)) == ((InvertComparisonOp op) L R).
Expand Down
48 changes: 48 additions & 0 deletions xls/ir/node_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,54 @@ TEST_F(NodeUtilTest, GatherAllTheBits) {
EXPECT_THAT(f->return_value(), m::Param("x"));
}

TEST_F(NodeUtilTest, MatchBinarySelectLikeSelect) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue selector = fb.Param("p", p->GetBitsType(1));
BValue a = fb.Param("a", p->GetBitsType(32));
BValue b = fb.Param("b", p->GetBitsType(32));
BValue sel = fb.Select(selector, {a, b});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused variable sel; same for the below 2 unit tests

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

XLS_ASSERT_OK_AND_ASSIGN(std::optional<BinarySelectArms> arms,
MatchBinarySelectLike(f->return_value()));
ASSERT_TRUE(arms.has_value());
EXPECT_EQ(arms->selector, f->param(0));
EXPECT_EQ(arms->on_false, f->param(1));
EXPECT_EQ(arms->on_true, f->param(2));
}

TEST_F(NodeUtilTest, MatchBinarySelectLikePrioritySelect) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue selector = fb.Param("p", p->GetBitsType(1));
BValue a = fb.Param("a", p->GetBitsType(32));
BValue b = fb.Param("b", p->GetBitsType(32));
BValue sel = fb.PrioritySelect(selector, {a}, b);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

XLS_ASSERT_OK_AND_ASSIGN(std::optional<BinarySelectArms> arms,
MatchBinarySelectLike(f->return_value()));
ASSERT_TRUE(arms.has_value());
EXPECT_EQ(arms->selector, f->param(0));
EXPECT_EQ(arms->on_false, f->param(2));
EXPECT_EQ(arms->on_true, f->param(1));
}

TEST_F(NodeUtilTest, MatchBinarySelectLikeNonMatch) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue selector = fb.Param("p", p->GetBitsType(2));
BValue a = fb.Param("a", p->GetBitsType(32));
BValue b = fb.Param("b", p->GetBitsType(32));
BValue sel = fb.Select(selector, {a, b}, /*default_value=*/a);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

XLS_ASSERT_OK_AND_ASSIGN(std::optional<BinarySelectArms> arms,
MatchBinarySelectLike(f->return_value()));
EXPECT_FALSE(arms.has_value());
}

TEST_F(NodeUtilTest, GatherTreeBits) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Expand Down
22 changes: 22 additions & 0 deletions xls/passes/basic_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,28 @@ absl::StatusOr<bool> MatchPatterns(Node* n) {
return true;
}

// X(sel(p, cases=[a, b]), sel(p, cases=[b, a])) => X(a, b)
//
// Because X is commutative (X(a, b) == X(b, a)), the select becomes
// redundant.
if (n->operand_count() == 2 && OpIsCommutative(n->op()) &&
!OpIsSideEffecting(n->op())) {
XLS_ASSIGN_OR_RETURN(std::optional<BinarySelectArms> sel0,
MatchBinarySelectLike(n->operand(0)));
XLS_ASSIGN_OR_RETURN(std::optional<BinarySelectArms> sel1,
MatchBinarySelectLike(n->operand(1)));
if (sel0.has_value() && sel1.has_value() &&
sel0->selector == sel1->selector && sel0->on_false == sel1->on_true &&
sel0->on_true == sel1->on_false) {
VLOG(2) << "FOUND: commutative op on swapped two-way selects: "
<< OpToString(n->op());
std::vector<Node*> new_operands = {sel0->on_false, sel0->on_true};
XLS_ASSIGN_OR_RETURN(Node * replacement, n->Clone(new_operands));
XLS_RETURN_IF_ERROR(n->ReplaceUsesWith(replacement));
return true;
}
}

// Remove duplicate operands of XORs.
// For any duplicate operand that appears an even number of times, remove all
// instances. For duplicates that appear an odd number of times, collapse to a
Expand Down
48 changes: 48 additions & 0 deletions xls/passes/basic_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,54 @@ TEST_F(BasicSimplificationPassTest, XorDuplicateOperandsCollapseToXor) {
EXPECT_THAT(f->return_value(), m::Xor(m::Param("x"), m::Param("y")));
}

TEST_F(BasicSimplificationPassTest, AndOfSwappedTwoWaySelects) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue p_sel = fb.Param("p", p->GetBitsType(1));
BValue a = fb.Param("a", p->GetBitsType(32));
BValue b = fb.Param("b", p->GetBitsType(32));
BValue sel_ab = fb.Select(p_sel, {a, b});
BValue sel_ba = fb.Select(p_sel, {b, a});
fb.And(sel_ab, sel_ba);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

ScopedVerifyEquivalence sve(f);
ASSERT_THAT(Run(p.get()), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(), m::And(m::Param("a"), m::Param("b")));
}

TEST_F(BasicSimplificationPassTest, XorOfSwappedTwoWaySelects) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue p_sel = fb.Param("p", p->GetBitsType(1));
BValue a = fb.Param("a", p->GetBitsType(32));
BValue b = fb.Param("b", p->GetBitsType(32));
BValue sel_ab = fb.Select(p_sel, {a, b});
BValue sel_ba = fb.Select(p_sel, {b, a});
fb.Xor({sel_ab, sel_ba});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

ScopedVerifyEquivalence sve(f);
ASSERT_THAT(Run(p.get()), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(), m::Xor(m::Param("a"), m::Param("b")));
}

TEST_F(BasicSimplificationPassTest, EqOfSwappedTwoWaySelects) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue p_sel = fb.Param("p", p->GetBitsType(1));
BValue a = fb.Param("a", p->GetBitsType(32));
BValue b = fb.Param("b", p->GetBitsType(32));
BValue sel_ab = fb.Select(p_sel, {a, b});
BValue sel_ba = fb.Select(p_sel, {b, a});
fb.Eq(sel_ab, sel_ba);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

ScopedVerifyEquivalence sve(f);
ASSERT_THAT(Run(p.get()), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(), m::Eq(m::Param("a"), m::Param("b")));
}

TEST_F(BasicSimplificationPassTest, AddWithZero) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a unit test on a function that is non-commutative?

auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(
Expand Down