Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ common_files += [
'src/neural/shared_params.cc',
'src/neural/wrapper.cc',
'src/search/classic/node.cc',
'src/search/classic/search_tree.cc',
'src/syzygy/syzygy.cc',
'src/trainingdata/reader.cc',
'src/trainingdata/trainingdata.cc',
Expand Down
40 changes: 7 additions & 33 deletions src/search/classic/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ std::string Node::DebugString() const {
oss << " Term:" << static_cast<int>(terminal_type_) << " This:" << this
<< " Parent:" << parent_ << " Index:" << index_
<< " Child:" << child_.get() << " Sibling:" << sibling_.get()
<< " WL:" << wl_ << " N:" << n_ << " N_:" << n_in_flight_
<< " WL:" << wl_ << " N:" << n_
<< " Edges:" << static_cast<int>(num_edges_)
<< " Bounds:" << static_cast<int>(lower_bound_) - 2 << ","
<< static_cast<int>(upper_bound_) - 2 << " Solid:" << solid_children_;
Expand All @@ -244,28 +244,11 @@ std::string Node::DebugString() const {

bool Node::MakeSolid() {
if (solid_children_ || num_edges_ == 0 || IsTerminal()) return false;
// Can only make solid if no immediate leaf children are in flight since we
// allow the search code to hold references to leaf nodes across locks.
Node* old_child_to_check = child_.get();
uint32_t total_in_flight = 0;
while (old_child_to_check != nullptr) {
if (old_child_to_check->GetN() <= 1 &&
old_child_to_check->GetNInFlight() > 0) {
return false;
}
if (old_child_to_check->IsTerminal() &&
old_child_to_check->GetNInFlight() > 0) {
return false;
}
total_in_flight += old_child_to_check->GetNInFlight();
old_child_to_check = old_child_to_check->sibling_.get();
}
// If the total of children in flight is not the same as self, then there are
// collisions against immediate children (which don't update the GetNInFlight
// of the leaf) and its not safe.
if (total_in_flight != GetNInFlight()) {
return false;
}
// MakeSolid is currently disabled in the backup path (#if 0 block in
// DoBackupUpdateSingleNode). It should only be called when no active search
// is in progress (i.e., no thread holds nodes_mutex_ for selection/backup),
// ensuring no SearchTree n_in_flight counts are outstanding. When that
// invariant holds the previous n_in_flight safety checks are unnecessary.
std::allocator<Node> alloc;
auto* new_children = alloc.allocate(num_edges_);
for (int i = 0; i < num_edges_; i++) {
Expand Down Expand Up @@ -345,14 +328,6 @@ void Node::SetBounds(GameResult lower, GameResult upper) {
upper_bound_ = upper;
}

bool Node::TryStartScoreUpdate() {
if (n_ == 0 && n_in_flight_ > 0) return false;
++n_in_flight_;
return true;
}

void Node::CancelScoreUpdate(int multivisit) { n_in_flight_ -= multivisit; }

void Node::FinalizeScoreUpdate(float v, float d, float m, int multivisit) {
// Recompute Q.
wl_ += multivisit * (v - wl_) / (n_ + multivisit);
Expand All @@ -361,8 +336,7 @@ void Node::FinalizeScoreUpdate(float v, float d, float m, int multivisit) {

// Increment N.
n_ += multivisit;
// Decrement virtual loss.
n_in_flight_ -= multivisit;
// n_in_flight is now tracked by SearchTree; no decrement here.
}

void Node::AdjustForTerminal(float v, float d, float m, int multivisit) {
Expand Down
57 changes: 16 additions & 41 deletions src/search/classic/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,7 @@ class Node {
// Returns sum of policy priors which have had at least one playout.
float GetVisitedPolicy() const;
uint32_t GetN() const { return n_; }
uint32_t GetNInFlight() const { return n_in_flight_; }
uint32_t GetChildrenVisits() const { return n_ > 0 ? n_ - 1 : 0; }
// Returns n = n_if_flight.
int GetNStarted() const { return n_ + n_in_flight_; }
float GetQ(float draw_score) const { return wl_ + draw_score * d_; }
// Returns node eval, i.e. average subtree V for non-terminal node and -1/0/1
// for terminal nodes.
Expand Down Expand Up @@ -196,27 +193,16 @@ class Node {
void MakeNotTerminal();
void SetBounds(GameResult lower, GameResult upper);

// If this node is not in the process of being expanded by another thread
// (which can happen only if n==0 and n-in-flight==1), mark the node as
// "being updated" by incrementing n-in-flight, and return true.
// Otherwise return false.
bool TryStartScoreUpdate();
// Decrements n-in-flight back.
void CancelScoreUpdate(int multivisit);
// Updates the node with newly computed value v.
// Updates:
// * Q (weighted average of all V in a subtree)
// * N (+=1)
// * N-in-flight (-=1)
// * N (+=multivisit)
// Note: n_in_flight accounting is handled by SearchTree, not here.
void FinalizeScoreUpdate(float v, float d, float m, int multivisit);
// Like FinalizeScoreUpdate, but it updates n existing visits by delta amount.
void AdjustForTerminal(float v, float d, float m, int multivisit);
// Revert visits to a node which ended in a now reverted terminal.
void RevertTerminalVisits(float v, float d, float m, int multivisit);
// When search decides to treat one visit as several (in case of collisions
// or visiting terminal nodes several times), it amplifies the visit by
// incrementing n_in_flight.
void IncrementNInFlight(int multivisit) { n_in_flight_ += multivisit; }

// Updates max depth, if new depth is larger.
void UpdateMaxDepth(int depth);
Expand Down Expand Up @@ -307,10 +293,6 @@ class Node {
float m_ = 0.0f;
// How many completed visits this node had.
uint32_t n_ = 0;
// (AKA virtual loss.) How many threads currently process this node (started
// but not finished). This value is added to n during selection which node
// to pick in MCTS, and also when selecting the best move.
uint32_t n_in_flight_ = 0;

// 2 byte fields.
// Index of this node is parent's edge list.
Expand Down Expand Up @@ -348,9 +330,9 @@ class Node {

// A basic sanity check. This must be adjusted when Node members are adjusted.
#if defined(__i386__) || (defined(__arm__) && !defined(__aarch64__))
static_assert(sizeof(Node) == 48, "Unexpected size of Node for 32bit compile");
static_assert(sizeof(Node) == 44, "Unexpected size of Node for 32bit compile");
#else
static_assert(sizeof(Node) == 64, "Unexpected size of Node");
static_assert(sizeof(Node) == 56, "Unexpected size of Node");
#endif

// Contains Edge and Node pair and set of proxy functions to simplify access
Expand Down Expand Up @@ -386,8 +368,6 @@ class EdgeAndNode {
}
// N-related getters, from Node (if exists).
uint32_t GetN() const { return node_ ? node_->GetN() : 0; }
int GetNStarted() const { return node_ ? node_->GetNStarted() : 0; }
uint32_t GetNInFlight() const { return node_ ? node_->GetNInFlight() : 0; }

// Whether the node is known to be terminal.
bool IsTerminal() const { return node_ ? node_->IsTerminal() : false; }
Expand All @@ -403,10 +383,11 @@ class EdgeAndNode {
return edge_ ? edge_->GetMove(flip) : Move();
}

// Returns U = numerator * p / N.
// Returns U = numerator * p / (1 + nstarted).
// @nstarted must be pre-computed as node->GetN() + search_tree.GetNInFlight(node).
// Passed numerator is expected to be equal to (cpuct * sqrt(N[parent])).
float GetU(float numerator) const {
return numerator * GetP() / (1 + GetNStarted());
float GetU(float numerator, int nstarted) const {
return numerator * GetP() / (1 + nstarted);
}

std::string DebugString() const;
Expand Down Expand Up @@ -583,27 +564,21 @@ class VisitedNode_Iterator {
if (solid_) {
while (++current_idx_ != total_count_ &&
node_ptr_[current_idx_].GetN() == 0) {
if (node_ptr_[current_idx_].GetNInFlight() == 0) {
// Once there is not even n in flight, we can skip to the end. This is
// due to policy being in sorted order meaning that additional n in
// flight are always selected from the front of the section with no n
// in flight or visited.
current_idx_ = total_count_;
break;
}
// Once there is an N=0 entry, all remaining entries are also unvisited
// (due to sorted policy). Jump to the end.
current_idx_ = total_count_;
break;
}
if (current_idx_ == total_count_) {
node_ptr_ = nullptr;
}
} else {
do {
node_ptr_ = node_ptr_->sibling_.get();
// If n started is 0, can jump direct to end due to sorted policy
// ensuring that each time a new edge becomes best for the first time,
// it is always the first of the section at the end that has NStarted of
// 0.
if (node_ptr_ != nullptr && node_ptr_->GetN() == 0 &&
node_ptr_->GetNInFlight() == 0) {
// If N is 0, jump to end: due to sorted policy, once the first
// unvisited (N=0) node is encountered, all subsequent nodes are also
// unvisited.
if (node_ptr_ != nullptr && node_ptr_->GetN() == 0) {
node_ptr_ = nullptr;
break;
}
Expand Down
Loading