diff --git a/src/search/classic/search.cc b/src/search/classic/search.cc index 7a4d1c7deb..01acbda88d 100644 --- a/src/search/classic/search.cc +++ b/src/search/classic/search.cc @@ -196,6 +196,7 @@ Search::Search(const NodeTree& tree, Backend* backend, : ContemptMode::WHITE; } } + search_root_node_ = SearchNode{root_node_, nullptr, {}}; } namespace { @@ -1065,10 +1066,9 @@ void Search::Wait() { void Search::CancelSharedCollisions() REQUIRES(nodes_mutex_) { for (auto& entry : shared_collisions_) { - Node* node = entry.first; - for (node = node->GetParent(); node != root_node_->GetParent(); - node = node->GetParent()) { - node->CancelScoreUpdate(entry.second); + for (SearchNode* sn = entry.first->parent; sn != nullptr; + sn = sn->parent) { + sn->node->CancelScoreUpdate(entry.second); } } shared_collisions_.clear(); @@ -1144,7 +1144,8 @@ void SearchWorker::RunTasks(int tid) { if (task != nullptr) { switch (task->task_type) { case PickTask::kGathering: { - PickNodesToExtendTask(task->start, task->base_depth, + task_workspaces_[tid].current_search_node = task->start_search_node; + PickNodesToExtendTask(task->start_search_node, task->base_depth, task->collision_limit, task->moves_to_base, &(task->results), &(task_workspaces_[tid])); break; @@ -1399,11 +1400,9 @@ void SearchWorker::GatherMinibatch() { // This may remove too many items, but hopefully most of the time they // will just be added back in the same in the next gather. if (minibatch_[i].IsCollision()) { - Node* node = minibatch_[i].node; - for (node = node->GetParent(); - node != search_->root_node_->GetParent(); - node = node->GetParent()) { - node->CancelScoreUpdate(minibatch_[i].multivisit); + for (SearchNode* sn = minibatch_[i].search_node->parent; + sn != nullptr; sn = sn->parent) { + sn->node->CancelScoreUpdate(minibatch_[i].multivisit); } minibatch_.erase(minibatch_.begin() + i); } else if (minibatch_[i].ooo_completed) { @@ -1428,11 +1427,9 @@ void SearchWorker::GatherMinibatch() { int extra = std::min(picked_node.maxvisit, collisions_left) - picked_node.multivisit; picked_node.multivisit += extra; - Node* node = picked_node.node; - for (node = node->GetParent(); - node != search_->root_node_->GetParent(); - node = node->GetParent()) { - node->IncrementNInFlight(extra); + for (SearchNode* sn = picked_node.search_node->parent; sn != nullptr; + sn = sn->parent) { + sn->node->IncrementNInFlight(extra); } } if ((collisions_left -= picked_node.multivisit) <= 0) return; @@ -1517,8 +1514,9 @@ void SearchWorker::PickNodesToExtend(int collision_limit) { // Since the tasks perform work which assumes they have the lock, even though // actually this thread does. SharedMutex::Lock lock(search_->nodes_mutex_); - PickNodesToExtendTask(search_->root_node_, 0, collision_limit, empty_movelist, - &minibatch_, &main_workspace_); + main_workspace_.current_search_node = &search_->search_root_node_; + PickNodesToExtendTask(&search_->search_root_node_, 0, collision_limit, + empty_movelist, &minibatch_, &main_workspace_); WaitForTasks(); for (int i = 0; i < static_cast(picking_tasks_.size()); i++) { @@ -1529,13 +1527,14 @@ void SearchWorker::PickNodesToExtend(int collision_limit) { } } -void SearchWorker::EnsureNodeTwoFoldCorrectForDepth(Node* child_node, +void SearchWorker::EnsureNodeTwoFoldCorrectForDepth(SearchNode* sn, int depth) { // Check whether first repetition was before root. If yes, remove // terminal status of node and revert all visits in the tree. // Length of repetition was stored in m_. This code will only do // something when tree is reused and twofold visits need to be // reverted. + Node* child_node = sn->node; if (child_node->IsTwoFoldTerminal() && depth < child_node->GetM()) { // Take a mutex - any SearchWorker specific mutex... since this is // not safe to do concurrently between multiple tasks. @@ -1548,11 +1547,10 @@ void SearchWorker::EnsureNodeTwoFoldCorrectForDepth(Node* child_node, const auto d = child_node->GetD(); const auto m = child_node->GetM(); const auto terminal_visits = child_node->GetN(); - for (Node* node_to_revert = child_node; node_to_revert != nullptr; - node_to_revert = node_to_revert->GetParent()) { + for (SearchNode* s = sn; s != nullptr; s = s->parent) { // Revert all visits on twofold draw when making it non terminal. - node_to_revert->RevertTerminalVisits(wl, d, m + (float)depth_counter, - terminal_visits); + s->node->RevertTerminalVisits(wl, d, m + (float)depth_counter, + terminal_visits); depth_counter++; // Even if original tree still exists, we don't want to revert // more than until new root. @@ -1571,7 +1569,7 @@ void SearchWorker::EnsureNodeTwoFoldCorrectForDepth(Node* child_node, } void SearchWorker::PickNodesToExtendTask( - Node* node, int base_depth, int collision_limit, + SearchNode* start_search_node, int base_depth, int collision_limit, const std::vector& moves_to_base, std::vector* receiver, TaskWorkspace* workspace) NO_THREAD_SAFETY_ANALYSIS { @@ -1589,6 +1587,10 @@ void SearchWorker::PickNodesToExtendTask( current_path.clear(); auto& moves_to_path = workspace->moves_to_path; moves_to_path = moves_to_base; + // Current position in the shadow search tree. + auto& current_sn = workspace->current_search_node; + current_sn = start_search_node; + Node* node = current_sn->node; // Sometimes receiver is reused, othertimes not, so only jump start if small. if (receiver->capacity() < 30) { receiver->reserve(receiver->size() + 30); @@ -1603,8 +1605,6 @@ void SearchWorker::PickNodesToExtendTask( std::array current_nstarted; auto& cur_iters = workspace->cur_iters; - Node::Iterator best_edge; - Node::Iterator second_best_edge; // Fetch the current best root node visits for possible smart pruning. const int64_t best_node_n = search_->current_best_edge_.GetN(); @@ -1640,7 +1640,8 @@ void SearchWorker::PickNodesToExtendTask( if (node->TryStartScoreUpdate()) { cur_limit -= 1; minibatch_.push_back(NodeToProcess::Visit( - node, static_cast(current_path.size() + base_depth))); + current_sn, + static_cast(current_path.size() + base_depth))); completed_visits++; } } @@ -1652,11 +1653,13 @@ void SearchWorker::PickNodesToExtendTask( max_count = max_limit; } receiver->push_back(NodeToProcess::Collision( - node, static_cast(current_path.size() + base_depth), + current_sn, + static_cast(current_path.size() + base_depth), cur_limit, max_count)); completed_visits += cur_limit; } - node = node->GetParent(); + current_sn = current_sn->parent; + node = current_sn ? current_sn->node : nullptr; current_path.pop_back(); continue; } @@ -1718,6 +1721,16 @@ void SearchWorker::PickNodesToExtendTask( const float cpuct = ComputeCpuct(params_, node->GetN(), is_root_node); const float puct_mult = cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u)); + // Reserve children slots up-front so that iterators stored in cur_iters + // are not invalidated when GetOrSpawnAtIdx resizes children later. + current_sn->children.reserve(max_needed); + // edge_iter walks node->Edges() lazily in step with cache_filled_idx. + Node::Iterator edge_iter; + // Snapshot of edge_iter taken when best_idx was last updated inside the + // fill region (idx > cache_filled_idx). Valid only when the captured + // best_idx was filled in the current while-loop iteration; cleared + // otherwise so we know to reconstruct. + Node::Iterator best_spawn_iter; int cache_filled_idx = -1; while (cur_limit > 0) { // Perform UCT for current node. @@ -1725,21 +1738,24 @@ void SearchWorker::PickNodesToExtendTask( int best_idx = -1; float best_without_u = std::numeric_limits::lowest(); float second_best = std::numeric_limits::lowest(); + bool second_best_valid = false; bool can_exit = false; - best_edge.Reset(); for (int idx = 0; idx < max_needed; ++idx) { - if (idx > cache_filled_idx) { + const bool just_filled = (idx > cache_filled_idx); + if (just_filled) { if (idx == 0) { - cur_iters[idx] = node->Edges(); + edge_iter = node->Edges(); } else { - cur_iters[idx] = cur_iters[idx - 1]; - ++cur_iters[idx]; + ++edge_iter; } - current_nstarted[idx] = cur_iters[idx].GetNStarted(); + current_sn->GetOrSpawnAtIdx(idx, edge_iter.edge(), + edge_iter.node()); + cur_iters[idx] = current_sn->children.begin() + idx; + current_nstarted[idx] = edge_iter.GetNStarted(); } int nstarted = current_nstarted[idx]; const float util = current_util[idx]; - if (idx > cache_filled_idx) { + if (just_filled) { current_score[idx] = current_pol[idx] * puct_mult / (1 + nstarted) + util; cache_filled_idx++; @@ -1750,15 +1766,20 @@ void SearchWorker::PickNodesToExtendTask( // best_move_node_ could have changed since best_node_n was // retrieved. To ensure we have at least one node to expand, always // include current best node. - if (cur_iters[idx] != search_->current_best_edge_ && + if ((*cur_iters[idx])->edge != search_->current_best_edge_.edge() && latest_time_manager_hints_.GetEstimatedRemainingPlayouts() < - best_node_n - cur_iters[idx].GetN()) { + best_node_n - + static_cast( + (*cur_iters[idx])->node + ? (*cur_iters[idx])->node->GetN() + : 0)) { continue; } // If root move filter exists, make sure move is in the list. if (!root_move_filter.empty() && std::find(root_move_filter.begin(), root_move_filter.end(), - cur_iters[idx].GetMove()) == root_move_filter.end()) { + (*cur_iters[idx])->edge->GetMove()) == + root_move_filter.end()) { continue; } } @@ -1766,14 +1787,17 @@ void SearchWorker::PickNodesToExtendTask( float score = current_score[idx]; if (score > best) { second_best = best; - second_best_edge = best_edge; + second_best_valid = (best_idx != -1); best = score; best_idx = idx; best_without_u = util; - best_edge = cur_iters[idx]; + // If this idx was just filled, edge_iter is already pointing at it; + // save a copy for the spawning step below. Otherwise clear it so + // we know to reconstruct when the child hasn't been spawned yet. + best_spawn_iter = just_filled ? edge_iter : Node::Iterator{}; } else if (score > second_best) { second_best = score; - second_best_edge = cur_iters[idx]; + second_best_valid = true; } if (can_exit) break; if (nstarted == 0) { @@ -1784,7 +1808,7 @@ void SearchWorker::PickNodesToExtendTask( } } int new_visits = 0; - if (second_best_edge) { + if (second_best_valid) { int estimated_visits_to_change_best = std::numeric_limits::max(); if (best_without_u < second_best) { const auto n1 = current_nstarted[best_idx] + 1; @@ -1794,7 +1818,7 @@ void SearchWorker::PickNodesToExtendTask( n1 + 1, 1e9f))); } - second_best_edge.Reset(); + second_best_valid = false; max_limit = std::min(max_limit, estimated_visits_to_change_best); new_visits = std::min(cur_limit, estimated_visits_to_change_best); } else { @@ -1808,12 +1832,24 @@ void SearchWorker::PickNodesToExtendTask( } (*visits_to_perform.back())[best_idx] += new_visits; cur_limit -= new_visits; - Node* child_node = best_edge.GetOrSpawnNode(/* parent */ node); + Node* child_node = (*cur_iters[best_idx])->node; + if (child_node == nullptr) { + // Child not yet spawned in the NodeTree. Use the cached iterator + // snapshot if available; otherwise reconstruct from the edge list. + if (!best_spawn_iter) { + best_spawn_iter = node->Edges(); + for (int j = 0; j < best_idx; j++) ++best_spawn_iter; + } + child_node = best_spawn_iter.GetOrSpawnNode(/* parent */ node); + (*cur_iters[best_idx])->node = child_node; + } + // The shadow node for this child was created in GetOrSpawnAtIdx. + SearchNode* child_sn = (*cur_iters[best_idx]).get(); // Probably best place to check for two-fold draws consistently. // Depth starts with 1 at root, so real depth is depth - 1. EnsureNodeTwoFoldCorrectForDepth( - child_node, current_path.size() + base_depth + 1 - 1); + child_sn, current_path.size() + base_depth + 1 - 1); bool decremented = false; if (child_node->TryStartScoreUpdate()) { @@ -1834,12 +1870,13 @@ void SearchWorker::PickNodesToExtendTask( // doesn't include this visit. (*visits_to_perform.back())[best_idx] -= 1; receiver->push_back(NodeToProcess::Visit( - child_node, + child_sn, static_cast(current_path.size() + 1 + base_depth))); completed_visits++; receiver->back().moves_to_visit.reserve(moves_to_path.size() + 1); receiver->back().moves_to_visit = moves_to_path; - receiver->back().moves_to_visit.push_back(best_edge.GetMove()); + receiver->back().moves_to_visit.push_back( + (*cur_iters[best_idx])->edge->GetMove()); } if (best_idx > vtp_last_filled.back() && (*visits_to_perform.back())[best_idx] > 0) { @@ -1858,9 +1895,11 @@ void SearchWorker::PickNodesToExtendTask( child_limit + passed_off + completed_visits < collision_limit - params_.GetMinimumRemainingWorkSizeForPicking()) { - Node* child_node = cur_iters[i].GetOrSpawnNode(/* parent */ node); - // Don't split if not expanded or terminal. - if (child_node->GetN() == 0 || child_node->IsTerminal()) continue; + Node* child_node = (*cur_iters[i])->node; + // Don't split if not yet spawned, not expanded, or terminal. + if (!child_node || child_node->GetN() == 0 || + child_node->IsTerminal()) + continue; bool passed = false; { @@ -1868,9 +1907,10 @@ void SearchWorker::PickNodesToExtendTask( Mutex::Lock lock(picking_tasks_mutex_); // Ensure not to exceed size of reservation. if (picking_tasks_.size() < MAX_TASKS) { - moves_to_path.push_back(cur_iters[i].GetMove()); + moves_to_path.push_back((*cur_iters[i])->edge->GetMove()); picking_tasks_.emplace_back( - child_node, current_path.size() - 1 + base_depth + 1, + (*cur_iters[i]).get(), + current_path.size() - 1 + base_depth + 1, moves_to_path, child_limit); moves_to_path.pop_back(); task_count_.fetch_add(1, std::memory_order_acq_rel); @@ -1901,6 +1941,9 @@ void SearchWorker::PickNodesToExtendTask( current_path.back() = idx; current_path.push_back(-1); node = child.GetOrSpawnNode(/* parent */ node); + // current_sn->children was reserved to max_needed above, so + // resize inside GetOrSpawnAtIdx will not invalidate cur_iters. + current_sn = current_sn->GetOrSpawnAtIdx(idx, child.edge(), node); found_child = true; break; } @@ -1908,7 +1951,8 @@ void SearchWorker::PickNodesToExtendTask( } } if (!found_child) { - node = node->GetParent(); + current_sn = current_sn->parent; + node = current_sn ? current_sn->node : nullptr; if (!moves_to_path.empty()) moves_to_path.pop_back(); current_path.pop_back(); vtp_buffer.push_back(std::move(visits_to_perform.back())); @@ -2022,7 +2066,7 @@ void SearchWorker::CollectCollisions() { for (const NodeToProcess& node_to_process : minibatch_) { if (node_to_process.IsCollision()) { - search_->shared_collisions_.emplace_back(node_to_process.node, + search_->shared_collisions_.emplace_back(node_to_process.search_node, node_to_process.multivisit); } } @@ -2236,8 +2280,11 @@ void SearchWorker::DoBackupUpdateSingleNode( float m_delta = 0.0f; uint32_t solid_threshold = static_cast(params_.GetSolidTreeThreshold()); - for (Node *n = node, *p; n != search_->root_node_->GetParent(); n = p) { - p = n->GetParent(); + // Traverse from leaf to root via the shadow search tree. The root SearchNode + // has parent == nullptr, so the loop terminates naturally after root_node_. + for (SearchNode* cur_sn = node_to_process.search_node; cur_sn != nullptr; + cur_sn = cur_sn->parent) { + Node* n = cur_sn->node; // Current node might have become terminal from some other descendant, so // backup the rest of the way with more accurate values. @@ -2260,7 +2307,9 @@ void SearchWorker::DoBackupUpdateSingleNode( } // Nothing left to do without ancestors to update. - if (!p) break; + SearchNode* parent_sn = cur_sn->parent; + if (!parent_sn) break; + Node* p = parent_sn->node; bool old_update_parent_bounds = update_parent_bounds; // If parent already is terminal further adjustment is not required. diff --git a/src/search/classic/search.h b/src/search/classic/search.h index 34293f3173..823e4f8e95 100644 --- a/src/search/classic/search.h +++ b/src/search/classic/search.h @@ -47,6 +47,38 @@ namespace lczero { namespace classic { +// A shadow node that mirrors a Node in the NodeTree for the duration of a +// search. The tree of SearchNodes tracks all paths that have been explored +// during this search, providing parent/child links without requiring the +// NodeTree to expose them in a search-path-specific way. The tree is shared +// across all SearchWorkers and grows on demand; all accesses are serialized +// by Search::nodes_mutex_. +struct SearchNode { + Node* node = nullptr; + Edge* edge = nullptr; + SearchNode* parent = nullptr; + std::vector> children; + + // Returns the child SearchNode at @edge_idx, creating it (and any missing + // slots before it) if it does not yet exist. @edge is the corresponding Edge + // pointer; @child_node is the child Node pointer (may be nullptr if not yet + // spawned in the NodeTree). Callers must have reserved children to at least + // @edge_idx + 1 slots before the first call at a given depth level to avoid + // invalidating iterators. + SearchNode* GetOrSpawnAtIdx(int edge_idx, Edge* edge, Node* child_node) { + if (static_cast(children.size()) <= edge_idx) { + children.resize(edge_idx + 1); + } + if (children[edge_idx] == nullptr) { + children[edge_idx] = + std::make_unique(SearchNode{child_node, edge, this, {}}); + } else if (child_node != nullptr && children[edge_idx]->node == nullptr) { + children[edge_idx]->node = child_node; + } + return children[edge_idx].get(); + } +}; + class Search { public: Search(const NodeTree& tree, Backend* network, @@ -188,6 +220,13 @@ class Search { // Cumulative depth of all paths taken in PickNodetoExtend. uint64_t cum_depth_ GUARDED_BY(nodes_mutex_) = 0; + // Shadow search tree shared across all SearchWorkers. Mirrors the parts of + // the NodeTree that have been explored during this search. Children are added + // on demand (via SearchNode::GetOrSpawnAtIdx) as paths are explored; existing + // nodes are reused across iterations since parent/child relationships in the + // NodeTree are stable throughout a search. All accesses are under nodes_mutex_. + SearchNode search_root_node_ GUARDED_BY(nodes_mutex_); + std::optional nps_start_time_ GUARDED_BY(counters_mutex_); @@ -195,7 +234,7 @@ class Search { std::atomic backend_waiting_counter_{0}; std::atomic thread_count_{0}; - std::vector> shared_collisions_ + std::vector> shared_collisions_ GUARDED_BY(nodes_mutex_); std::unique_ptr uci_responder_; @@ -310,7 +349,9 @@ class SearchWorker { return is_cache_hit || node->IsTerminal(); } - // The node to extend. + // Shadow-tree node representing the path from the search root to this node. + SearchNode* search_node = nullptr; + // The node to extend (equal to search_node->node). Node* node; std::unique_ptr eval; int multivisit = 0; @@ -328,22 +369,23 @@ class SearchWorker { // Details that are filled in as we go. bool ooo_completed = false; - static NodeToProcess Collision(Node* node, uint16_t depth, + static NodeToProcess Collision(SearchNode* search_node, uint16_t depth, int collision_count) { - return NodeToProcess(node, depth, true, collision_count, 0); + return NodeToProcess(search_node, depth, true, collision_count, 0); } - static NodeToProcess Collision(Node* node, uint16_t depth, + static NodeToProcess Collision(SearchNode* search_node, uint16_t depth, int collision_count, int max_count) { - return NodeToProcess(node, depth, true, collision_count, max_count); + return NodeToProcess(search_node, depth, true, collision_count, max_count); } - static NodeToProcess Visit(Node* node, uint16_t depth) { - return NodeToProcess(node, depth, false, 1, 0); + static NodeToProcess Visit(SearchNode* search_node, uint16_t depth) { + return NodeToProcess(search_node, depth, false, 1, 0); } private: - NodeToProcess(Node* node, uint16_t depth, bool is_collision, int multivisit, - int max_count) - : node(node), + NodeToProcess(SearchNode* search_node, uint16_t depth, bool is_collision, + int multivisit, int max_count) + : search_node(search_node), + node(search_node->node), eval(std::make_unique()), multivisit(multivisit), maxvisit(max_count), @@ -353,12 +395,15 @@ class SearchWorker { // Holds per task worker scratch data struct TaskWorkspace { - std::array cur_iters; + std::array>::iterator, 256> + cur_iters; std::vector>> vtp_buffer; std::vector>> visits_to_perform; std::vector vtp_last_filled; std::vector current_path; std::vector moves_to_path; + // Current position in the shadow search tree during PickNodesToExtendTask. + SearchNode* current_search_node = nullptr; PositionHistory history; TaskWorkspace() { vtp_buffer.reserve(30); @@ -375,7 +420,7 @@ class SearchWorker { PickTaskType task_type; // For task type gathering. - Node* start; + SearchNode* start_search_node; int base_depth; int collision_limit; std::vector moves_to_base; @@ -387,10 +432,10 @@ class SearchWorker { bool complete = false; - PickTask(Node* node, uint16_t depth, const std::vector& base_moves, - int collision_limit) + PickTask(SearchNode* search_node, uint16_t depth, + const std::vector& base_moves, int collision_limit) : task_type(kGathering), - start(node), + start_search_node(search_node), base_depth(depth), collision_limit(collision_limit), moves_to_base(base_moves) {} @@ -405,12 +450,12 @@ class SearchWorker { bool MaybeSetBounds(Node* p, float m, int* n_to_fix, float* v_delta, float* d_delta, float* m_delta) const; void PickNodesToExtend(int collision_limit); - void PickNodesToExtendTask(Node* starting_point, int base_depth, + void PickNodesToExtendTask(SearchNode* start_search_node, int base_depth, int collision_limit, const std::vector& moves_to_base, std::vector* receiver, TaskWorkspace* workspace); - void EnsureNodeTwoFoldCorrectForDepth(Node* node, int depth); + void EnsureNodeTwoFoldCorrectForDepth(SearchNode* sn, int depth); void ProcessPickedTask(int batch_start, int batch_end, TaskWorkspace* workspace); void ExtendNode(Node* node, int depth, const std::vector& moves_to_add,