forked from brilee/MuGo
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathload_data_sets.py
More file actions
127 lines (107 loc) · 5.26 KB
/
load_data_sets.py
File metadata and controls
127 lines (107 loc) · 5.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import itertools
import gzip
import numpy as np
import os
import struct
import sys
from features import bulk_extract_features
import go
from sgf_wrapper import replay_sgf
import utils
# Number of data points to store in a chunk on disk
CHUNK_SIZE = 4096
CHUNK_HEADER_FORMAT = "iii?"
CHUNK_HEADER_SIZE = struct.calcsize(CHUNK_HEADER_FORMAT)
def make_onehot(coords):
num_positions = len(coords)
output = np.zeros([num_positions, go.N ** 2], dtype=np.uint8)
for i, coord in enumerate(coords):
output[i, utils.flatten_coords(coord)] = 1
return output
def find_sgf_files(*dataset_dirs):
for dataset_dir in dataset_dirs:
full_dir = os.path.join(os.getcwd(), dataset_dir)
dataset_files = [os.path.join(full_dir, name) for name in os.listdir(full_dir)]
for f in dataset_files:
if os.path.isfile(f) and f.endswith(".sgf"):
yield f
def get_positions_from_sgf(file):
with open(file) as f:
for position_w_context in replay_sgf(f.read()):
if position_w_context.is_usable():
yield position_w_context
def split_test_training(positions_w_context, est_num_positions):
print("Estimated number of chunks: %s" % (est_num_positions // CHUNK_SIZE), file=sys.stderr)
desired_test_size = 10**5
if est_num_positions < 2 * desired_test_size:
positions_w_context = list(positions_w_context)
test_size = len(positions_w_context) // 3
return positions_w_context[:test_size], [positions_w_context[test_size:]]
else:
shuffled_positions = utils.shuffler(positions_w_context)
test_chunk = utils.take_n(desired_test_size, shuffled_positions)
training_chunks = utils.iter_chunks(CHUNK_SIZE, shuffled_positions)
return test_chunk, training_chunks
class DataSet(object):
def __init__(self, pos_features, next_moves, results, is_test=False):
self.pos_features = pos_features
self.next_moves = next_moves
self.results = results
self.is_test = is_test
assert pos_features.shape[0] == next_moves.shape[0], "Didn't pass in same number of pos_features and next_moves."
self.data_size = pos_features.shape[0]
self.board_size = pos_features.shape[1]
self.input_planes = pos_features.shape[-1]
self._index_within_epoch = 0
def shuffle(self):
perm = np.arange(self.data_size)
np.random.shuffle(perm)
self.pos_features = self.pos_features[perm]
self.next_moves = self.next_moves[perm]
self._index_within_epoch = 0
def get_batch(self, batch_size):
assert batch_size < self.data_size
if self._index_within_epoch + batch_size > self.data_size:
self.shuffle()
start = self._index_within_epoch
end = start + batch_size
self._index_within_epoch += batch_size
return self.pos_features[start:end], self.next_moves[start:end]
@staticmethod
def from_positions_w_context(positions_w_context, is_test=False):
positions, next_moves, results = zip(*positions_w_context)
extracted_features = bulk_extract_features(positions)
encoded_moves = make_onehot(next_moves)
return DataSet(extracted_features, encoded_moves, results, is_test=is_test)
def write(self, filename):
header_bytes = struct.pack(CHUNK_HEADER_FORMAT, self.data_size, self.board_size, self.input_planes, self.is_test)
position_bytes = np.packbits(self.pos_features).tostring()
next_move_bytes = np.packbits(self.next_moves).tostring()
with gzip.open(filename, "wb", compresslevel=6) as f:
f.write(header_bytes)
f.write(position_bytes)
f.write(next_move_bytes)
@staticmethod
def read(filename):
with gzip.open(filename, "rb") as f:
header_bytes = f.read(CHUNK_HEADER_SIZE)
data_size, board_size, input_planes, is_test = struct.unpack(CHUNK_HEADER_FORMAT, header_bytes)
position_dims = data_size * board_size * board_size * input_planes
next_move_dims = data_size * board_size * board_size
# the +7 // 8 compensates for numpy's bitpacking padding
packed_position_bytes = f.read((position_dims + 7) // 8)
packed_next_move_bytes = f.read((next_move_dims + 7) // 8)
# should have cleanly finished reading all bytes from file!
assert len(f.read()) == 0
flat_position = np.unpackbits(np.fromstring(packed_position_bytes, dtype=np.uint8))[:position_dims]
flat_nextmoves = np.unpackbits(np.fromstring(packed_next_move_bytes, dtype=np.uint8))[:next_move_dims]
pos_features = flat_position.reshape(data_size, board_size, board_size, input_planes)
next_moves = flat_nextmoves.reshape(data_size, board_size * board_size)
return DataSet(pos_features, next_moves, [], is_test=is_test)
def parse_data_sets(*data_sets):
sgf_files = list(find_sgf_files(*data_sets))
print("%s sgfs found." % len(sgf_files), file=sys.stderr)
est_num_positions = len(sgf_files) * 200 # about 200 moves per game
positions_w_context = itertools.chain(*map(get_positions_from_sgf, sgf_files))
test_chunk, training_chunks = split_test_training(positions_w_context, est_num_positions)
return test_chunk, training_chunks