Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 106 additions & 57 deletions src/search/classic/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ Search::Search(const NodeTree& tree, Backend* backend,
: ContemptMode::WHITE;
}
}
search_root_node_ = SearchNode{root_node_, nullptr, {}};
}

namespace {
Expand Down Expand Up @@ -1065,10 +1066,9 @@ void Search::Wait() {

void Search::CancelSharedCollisions() REQUIRES(nodes_mutex_) {
for (auto& entry : shared_collisions_) {
Node* node = entry.first;
for (node = node->GetParent(); node != root_node_->GetParent();
node = node->GetParent()) {
node->CancelScoreUpdate(entry.second);
for (SearchNode* sn = entry.first->parent; sn != nullptr;
sn = sn->parent) {
sn->node->CancelScoreUpdate(entry.second);
}
}
shared_collisions_.clear();
Expand Down Expand Up @@ -1144,7 +1144,8 @@ void SearchWorker::RunTasks(int tid) {
if (task != nullptr) {
switch (task->task_type) {
case PickTask::kGathering: {
PickNodesToExtendTask(task->start, task->base_depth,
task_workspaces_[tid].current_search_node = task->start_search_node;
PickNodesToExtendTask(task->start_search_node, task->base_depth,
task->collision_limit, task->moves_to_base,
&(task->results), &(task_workspaces_[tid]));
break;
Expand Down Expand Up @@ -1399,11 +1400,9 @@ void SearchWorker::GatherMinibatch() {
// This may remove too many items, but hopefully most of the time they
// will just be added back in the same in the next gather.
if (minibatch_[i].IsCollision()) {
Node* node = minibatch_[i].node;
for (node = node->GetParent();
node != search_->root_node_->GetParent();
node = node->GetParent()) {
node->CancelScoreUpdate(minibatch_[i].multivisit);
for (SearchNode* sn = minibatch_[i].search_node->parent;
sn != nullptr; sn = sn->parent) {
sn->node->CancelScoreUpdate(minibatch_[i].multivisit);
}
minibatch_.erase(minibatch_.begin() + i);
} else if (minibatch_[i].ooo_completed) {
Expand All @@ -1428,11 +1427,9 @@ void SearchWorker::GatherMinibatch() {
int extra = std::min(picked_node.maxvisit, collisions_left) -
picked_node.multivisit;
picked_node.multivisit += extra;
Node* node = picked_node.node;
for (node = node->GetParent();
node != search_->root_node_->GetParent();
node = node->GetParent()) {
node->IncrementNInFlight(extra);
for (SearchNode* sn = picked_node.search_node->parent; sn != nullptr;
sn = sn->parent) {
sn->node->IncrementNInFlight(extra);
}
}
if ((collisions_left -= picked_node.multivisit) <= 0) return;
Expand Down Expand Up @@ -1517,8 +1514,9 @@ void SearchWorker::PickNodesToExtend(int collision_limit) {
// Since the tasks perform work which assumes they have the lock, even though
// actually this thread does.
SharedMutex::Lock lock(search_->nodes_mutex_);
PickNodesToExtendTask(search_->root_node_, 0, collision_limit, empty_movelist,
&minibatch_, &main_workspace_);
main_workspace_.current_search_node = &search_->search_root_node_;
PickNodesToExtendTask(&search_->search_root_node_, 0, collision_limit,
empty_movelist, &minibatch_, &main_workspace_);

WaitForTasks();
for (int i = 0; i < static_cast<int>(picking_tasks_.size()); i++) {
Expand All @@ -1529,13 +1527,14 @@ void SearchWorker::PickNodesToExtend(int collision_limit) {
}
}

void SearchWorker::EnsureNodeTwoFoldCorrectForDepth(Node* child_node,
void SearchWorker::EnsureNodeTwoFoldCorrectForDepth(SearchNode* sn,
int depth) {
// Check whether first repetition was before root. If yes, remove
// terminal status of node and revert all visits in the tree.
// Length of repetition was stored in m_. This code will only do
// something when tree is reused and twofold visits need to be
// reverted.
Node* child_node = sn->node;
if (child_node->IsTwoFoldTerminal() && depth < child_node->GetM()) {
// Take a mutex - any SearchWorker specific mutex... since this is
// not safe to do concurrently between multiple tasks.
Expand All @@ -1548,11 +1547,10 @@ void SearchWorker::EnsureNodeTwoFoldCorrectForDepth(Node* child_node,
const auto d = child_node->GetD();
const auto m = child_node->GetM();
const auto terminal_visits = child_node->GetN();
for (Node* node_to_revert = child_node; node_to_revert != nullptr;
node_to_revert = node_to_revert->GetParent()) {
for (SearchNode* s = sn; s != nullptr; s = s->parent) {
// Revert all visits on twofold draw when making it non terminal.
node_to_revert->RevertTerminalVisits(wl, d, m + (float)depth_counter,
terminal_visits);
s->node->RevertTerminalVisits(wl, d, m + (float)depth_counter,
terminal_visits);
depth_counter++;
// Even if original tree still exists, we don't want to revert
// more than until new root.
Expand All @@ -1571,7 +1569,7 @@ void SearchWorker::EnsureNodeTwoFoldCorrectForDepth(Node* child_node,
}

void SearchWorker::PickNodesToExtendTask(
Node* node, int base_depth, int collision_limit,
SearchNode* start_search_node, int base_depth, int collision_limit,
const std::vector<Move>& moves_to_base,
std::vector<NodeToProcess>* receiver,
TaskWorkspace* workspace) NO_THREAD_SAFETY_ANALYSIS {
Expand All @@ -1589,6 +1587,10 @@ void SearchWorker::PickNodesToExtendTask(
current_path.clear();
auto& moves_to_path = workspace->moves_to_path;
moves_to_path = moves_to_base;
// Current position in the shadow search tree.
auto& current_sn = workspace->current_search_node;
current_sn = start_search_node;
Node* node = current_sn->node;
// Sometimes receiver is reused, othertimes not, so only jump start if small.
if (receiver->capacity() < 30) {
receiver->reserve(receiver->size() + 30);
Expand All @@ -1603,8 +1605,6 @@ void SearchWorker::PickNodesToExtendTask(
std::array<int, 256> current_nstarted;
auto& cur_iters = workspace->cur_iters;

Node::Iterator best_edge;
Node::Iterator second_best_edge;
// Fetch the current best root node visits for possible smart pruning.
const int64_t best_node_n = search_->current_best_edge_.GetN();

Expand Down Expand Up @@ -1640,7 +1640,8 @@ void SearchWorker::PickNodesToExtendTask(
if (node->TryStartScoreUpdate()) {
cur_limit -= 1;
minibatch_.push_back(NodeToProcess::Visit(
node, static_cast<uint16_t>(current_path.size() + base_depth)));
current_sn,
static_cast<uint16_t>(current_path.size() + base_depth)));
completed_visits++;
}
}
Expand All @@ -1652,11 +1653,13 @@ void SearchWorker::PickNodesToExtendTask(
max_count = max_limit;
}
receiver->push_back(NodeToProcess::Collision(
node, static_cast<uint16_t>(current_path.size() + base_depth),
current_sn,
static_cast<uint16_t>(current_path.size() + base_depth),
cur_limit, max_count));
completed_visits += cur_limit;
}
node = node->GetParent();
current_sn = current_sn->parent;
node = current_sn ? current_sn->node : nullptr;
current_path.pop_back();
continue;
}
Expand Down Expand Up @@ -1718,28 +1721,41 @@ void SearchWorker::PickNodesToExtendTask(
const float cpuct = ComputeCpuct(params_, node->GetN(), is_root_node);
const float puct_mult =
cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u));
// Reserve children slots up-front so that iterators stored in cur_iters
// are not invalidated when GetOrSpawnAtIdx resizes children later.
current_sn->children.reserve(max_needed);
// edge_iter walks node->Edges() lazily in step with cache_filled_idx.
Node::Iterator edge_iter;
// Snapshot of edge_iter taken when best_idx was last updated inside the
// fill region (idx > cache_filled_idx). Valid only when the captured
// best_idx was filled in the current while-loop iteration; cleared
// otherwise so we know to reconstruct.
Node::Iterator best_spawn_iter;
int cache_filled_idx = -1;
while (cur_limit > 0) {
// Perform UCT for current node.
float best = std::numeric_limits<float>::lowest();
int best_idx = -1;
float best_without_u = std::numeric_limits<float>::lowest();
float second_best = std::numeric_limits<float>::lowest();
bool second_best_valid = false;
bool can_exit = false;
best_edge.Reset();
for (int idx = 0; idx < max_needed; ++idx) {
if (idx > cache_filled_idx) {
const bool just_filled = (idx > cache_filled_idx);
if (just_filled) {
if (idx == 0) {
cur_iters[idx] = node->Edges();
edge_iter = node->Edges();
} else {
cur_iters[idx] = cur_iters[idx - 1];
++cur_iters[idx];
++edge_iter;
}
current_nstarted[idx] = cur_iters[idx].GetNStarted();
current_sn->GetOrSpawnAtIdx(idx, edge_iter.edge(),
edge_iter.node());
cur_iters[idx] = current_sn->children.begin() + idx;
current_nstarted[idx] = edge_iter.GetNStarted();
}
int nstarted = current_nstarted[idx];
const float util = current_util[idx];
if (idx > cache_filled_idx) {
if (just_filled) {
current_score[idx] =
current_pol[idx] * puct_mult / (1 + nstarted) + util;
cache_filled_idx++;
Expand All @@ -1750,30 +1766,38 @@ void SearchWorker::PickNodesToExtendTask(
// best_move_node_ could have changed since best_node_n was
// retrieved. To ensure we have at least one node to expand, always
// include current best node.
if (cur_iters[idx] != search_->current_best_edge_ &&
if ((*cur_iters[idx])->edge != search_->current_best_edge_.edge() &&
latest_time_manager_hints_.GetEstimatedRemainingPlayouts() <
best_node_n - cur_iters[idx].GetN()) {
best_node_n -
static_cast<int64_t>(
(*cur_iters[idx])->node
? (*cur_iters[idx])->node->GetN()
: 0)) {
continue;
}
// If root move filter exists, make sure move is in the list.
if (!root_move_filter.empty() &&
std::find(root_move_filter.begin(), root_move_filter.end(),
cur_iters[idx].GetMove()) == root_move_filter.end()) {
(*cur_iters[idx])->edge->GetMove()) ==
root_move_filter.end()) {
continue;
}
}

float score = current_score[idx];
if (score > best) {
second_best = best;
second_best_edge = best_edge;
second_best_valid = (best_idx != -1);
best = score;
best_idx = idx;
best_without_u = util;
best_edge = cur_iters[idx];
// If this idx was just filled, edge_iter is already pointing at it;
// save a copy for the spawning step below. Otherwise clear it so
// we know to reconstruct when the child hasn't been spawned yet.
best_spawn_iter = just_filled ? edge_iter : Node::Iterator{};
} else if (score > second_best) {
second_best = score;
second_best_edge = cur_iters[idx];
second_best_valid = true;
}
if (can_exit) break;
if (nstarted == 0) {
Expand All @@ -1784,7 +1808,7 @@ void SearchWorker::PickNodesToExtendTask(
}
}
int new_visits = 0;
if (second_best_edge) {
if (second_best_valid) {
int estimated_visits_to_change_best = std::numeric_limits<int>::max();
if (best_without_u < second_best) {
const auto n1 = current_nstarted[best_idx] + 1;
Expand All @@ -1794,7 +1818,7 @@ void SearchWorker::PickNodesToExtendTask(
n1 + 1,
1e9f)));
}
second_best_edge.Reset();
second_best_valid = false;
max_limit = std::min(max_limit, estimated_visits_to_change_best);
new_visits = std::min(cur_limit, estimated_visits_to_change_best);
} else {
Expand All @@ -1808,12 +1832,24 @@ void SearchWorker::PickNodesToExtendTask(
}
(*visits_to_perform.back())[best_idx] += new_visits;
cur_limit -= new_visits;
Node* child_node = best_edge.GetOrSpawnNode(/* parent */ node);
Node* child_node = (*cur_iters[best_idx])->node;
if (child_node == nullptr) {
// Child not yet spawned in the NodeTree. Use the cached iterator
// snapshot if available; otherwise reconstruct from the edge list.
if (!best_spawn_iter) {
best_spawn_iter = node->Edges();
for (int j = 0; j < best_idx; j++) ++best_spawn_iter;
}
child_node = best_spawn_iter.GetOrSpawnNode(/* parent */ node);
(*cur_iters[best_idx])->node = child_node;
}
// The shadow node for this child was created in GetOrSpawnAtIdx.
SearchNode* child_sn = (*cur_iters[best_idx]).get();

// Probably best place to check for two-fold draws consistently.
// Depth starts with 1 at root, so real depth is depth - 1.
EnsureNodeTwoFoldCorrectForDepth(
child_node, current_path.size() + base_depth + 1 - 1);
child_sn, current_path.size() + base_depth + 1 - 1);

bool decremented = false;
if (child_node->TryStartScoreUpdate()) {
Expand All @@ -1834,12 +1870,13 @@ void SearchWorker::PickNodesToExtendTask(
// doesn't include this visit.
(*visits_to_perform.back())[best_idx] -= 1;
receiver->push_back(NodeToProcess::Visit(
child_node,
child_sn,
static_cast<uint16_t>(current_path.size() + 1 + base_depth)));
completed_visits++;
receiver->back().moves_to_visit.reserve(moves_to_path.size() + 1);
receiver->back().moves_to_visit = moves_to_path;
receiver->back().moves_to_visit.push_back(best_edge.GetMove());
receiver->back().moves_to_visit.push_back(
(*cur_iters[best_idx])->edge->GetMove());
}
if (best_idx > vtp_last_filled.back() &&
(*visits_to_perform.back())[best_idx] > 0) {
Expand All @@ -1858,19 +1895,22 @@ void SearchWorker::PickNodesToExtendTask(
child_limit + passed_off + completed_visits <
collision_limit -
params_.GetMinimumRemainingWorkSizeForPicking()) {
Node* child_node = cur_iters[i].GetOrSpawnNode(/* parent */ node);
// Don't split if not expanded or terminal.
if (child_node->GetN() == 0 || child_node->IsTerminal()) continue;
Node* child_node = (*cur_iters[i])->node;
// Don't split if not yet spawned, not expanded, or terminal.
if (!child_node || child_node->GetN() == 0 ||
child_node->IsTerminal())
continue;

bool passed = false;
{
// Multiple writers, so need mutex here.
Mutex::Lock lock(picking_tasks_mutex_);
// Ensure not to exceed size of reservation.
if (picking_tasks_.size() < MAX_TASKS) {
moves_to_path.push_back(cur_iters[i].GetMove());
moves_to_path.push_back((*cur_iters[i])->edge->GetMove());
picking_tasks_.emplace_back(
child_node, current_path.size() - 1 + base_depth + 1,
(*cur_iters[i]).get(),
current_path.size() - 1 + base_depth + 1,
moves_to_path, child_limit);
moves_to_path.pop_back();
task_count_.fetch_add(1, std::memory_order_acq_rel);
Expand Down Expand Up @@ -1901,14 +1941,18 @@ void SearchWorker::PickNodesToExtendTask(
current_path.back() = idx;
current_path.push_back(-1);
node = child.GetOrSpawnNode(/* parent */ node);
// current_sn->children was reserved to max_needed above, so
// resize inside GetOrSpawnAtIdx will not invalidate cur_iters.
current_sn = current_sn->GetOrSpawnAtIdx(idx, child.edge(), node);
found_child = true;
break;
}
if (idx >= vtp_last_filled.back()) break;
}
}
if (!found_child) {
node = node->GetParent();
current_sn = current_sn->parent;
node = current_sn ? current_sn->node : nullptr;
if (!moves_to_path.empty()) moves_to_path.pop_back();
current_path.pop_back();
vtp_buffer.push_back(std::move(visits_to_perform.back()));
Expand Down Expand Up @@ -2022,7 +2066,7 @@ void SearchWorker::CollectCollisions() {

for (const NodeToProcess& node_to_process : minibatch_) {
if (node_to_process.IsCollision()) {
search_->shared_collisions_.emplace_back(node_to_process.node,
search_->shared_collisions_.emplace_back(node_to_process.search_node,
node_to_process.multivisit);
}
}
Expand Down Expand Up @@ -2236,8 +2280,11 @@ void SearchWorker::DoBackupUpdateSingleNode(
float m_delta = 0.0f;
uint32_t solid_threshold =
static_cast<uint32_t>(params_.GetSolidTreeThreshold());
for (Node *n = node, *p; n != search_->root_node_->GetParent(); n = p) {
p = n->GetParent();
// Traverse from leaf to root via the shadow search tree. The root SearchNode
// has parent == nullptr, so the loop terminates naturally after root_node_.
for (SearchNode* cur_sn = node_to_process.search_node; cur_sn != nullptr;
cur_sn = cur_sn->parent) {
Node* n = cur_sn->node;

// Current node might have become terminal from some other descendant, so
// backup the rest of the way with more accurate values.
Expand All @@ -2260,7 +2307,9 @@ void SearchWorker::DoBackupUpdateSingleNode(
}

// Nothing left to do without ancestors to update.
if (!p) break;
SearchNode* parent_sn = cur_sn->parent;
if (!parent_sn) break;
Node* p = parent_sn->node;

bool old_update_parent_bounds = update_parent_bounds;
// If parent already is terminal further adjustment is not required.
Expand Down
Loading