From de4396869a7e198bbb4ec14d96bb3b64709b262d Mon Sep 17 00:00:00 2001 From: Marco Date: Tue, 7 Oct 2025 06:22:14 +0900 Subject: [PATCH 1/2] fix masked softmax --- maia2/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/maia2/inference.py b/maia2/inference.py index f4f2cfa..09e9710 100644 --- a/maia2/inference.py +++ b/maia2/inference.py @@ -60,7 +60,7 @@ def get_preds(model, dataloader, all_moves_dict_reversed): legal_moves = legal_moves.to(device) logits_maia, _, logits_value = model(boards, elos_self, elos_oppo) - logits_maia_legal = logits_maia * legal_moves + logits_maia_legal = logits_maia + legal_moves.log() probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() logits_value = (logits_value / 2 + 0.5).clamp(0, 1).cpu().tolist() From c3b54d9e84248757534e8ddc2fd19b9dedbdd3b4 Mon Sep 17 00:00:00 2001 From: Marco Date: Sun, 16 Nov 2025 02:06:25 -0800 Subject: [PATCH 2/2] fix in inference_each also --- maia2/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/maia2/inference.py b/maia2/inference.py index 09e9710..8da5e91 100644 --- a/maia2/inference.py +++ b/maia2/inference.py @@ -154,7 +154,7 @@ def inference_each(model, prepared, fen, elo_self, elo_oppo): legal_moves = legal_moves.unsqueeze(dim=0).to(device) logits_maia, _, logits_value = model(board_input, elo_self, elo_oppo) - logits_maia_legal = logits_maia * legal_moves + logits_maia_legal = logits_maia + legal_moves.log() probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() logits_value = (logits_value / 2 + 0.5).clamp(0, 1).item()