Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions badgyal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from badgyal.abstractnet import AbstractNet
from badgyal.bgnet import BGNet
from badgyal.mgnet import MGNet
from badgyal.lenet import LENet
from badgyal.abstractnet import AbstractNet, LoadedNet
from badgyal.named_nets import *
from badgyal.policy_index import policy_index
from badgyal.board2planes import board2planes, bulk_board2planes
44 changes: 43 additions & 1 deletion badgyal/abstractnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from badgyal.board2planes import board2planes, policy2moves, bulk_board2planes
import pylru
import sys

import os
from collections import defaultdict


CACHE=100000
Expand Down Expand Up @@ -133,3 +134,44 @@ def bulk_eval(self, boards, softmax_temp=1.61):
self.cache[b.epd()] = [policy, value]

return retval_p, retval_v

class LoadedNet(AbstractNet):
def __init__(self, path, channels=128, blocks=10, se=4, policy_channels=None, classical=True, cuda=True):
self.path = path
self.channels = channels
self.blocks = blocks
self.se = se
if policy_channels == None:
self.policy_channels = channels
else:
self.policy_channels = policy_channels
self.classical = classical
super().__init__(cuda=cuda)

def load_net(self):
cwd = os.path.abspath(os.path.dirname(__file__))
full_path = os.path.join(cwd, self.path)
net = model.Net(self.channels,
self.blocks,
self.policy_channels,
self.se,
classical=self.classical)
if self.classical:
net.import_proto_classical(full_path)
else:
net.import_proto(full_path)
return net


def value_to_scalar(self, value):
if not self.classical:
wdl0 = value[0].item()
wdl1 = value[1].item()
wdl2 = value[2].item()
min_val = min(wdl0, wdl1, wdl2)
w_val = math.exp(wdl0 - min_val)
d_val = math.exp(wdl1 - min_val)
l_val = math.exp(wdl2 - min_val)
p = (w_val * 1.0 + d_val * 0.5 ) / (w_val + d_val + l_val)
return 2.0*p-1.0;
return value.item()
26 changes: 0 additions & 26 deletions badgyal/bgnet.py

This file was deleted.

38 changes: 0 additions & 38 deletions badgyal/lenet.py

This file was deleted.

27 changes: 0 additions & 27 deletions badgyal/mgnet.py

This file was deleted.

8 changes: 8 additions & 0 deletions badgyal/named_nets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from badgyal.abstractnet import LoadedNet, MultiNet

BGNet = lambda cuda: LoadedNet("badgyal-8.pb.gz", 128, 10, 4, cuda=cuda)
MGNet = lambda cuda: LoadedNet("meangirl-8.pb.gz", 32, 4, 2, cuda=cuda)
LENet = lambda cuda: LoadedNet("LE.pb.gz", 128, 10, 4, classical=False, cuda=cuda)
T59 = lambda cuda: LoadedNet("../../nets/591226.pb.gz", 128, 10, 4, classical=False, cuda=cuda)
T70 = lambda cuda: LoadedNet("../../nets/701494.pb.gz", 128, 10, 4, classical=False, cuda=cuda)
M1 = lambda cuda: MultiNet([BGNet])
Loading