-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdatabase.py
More file actions
385 lines (309 loc) · 13.6 KB
/
database.py
File metadata and controls
385 lines (309 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
#! /usr/bin/python
# Author: Jasjeet Dhaliwal
import sys,os, pickle, random, math, gc, codecs, time
from nltk import word_tokenize
from nltk.stem.porter import PorterStemmer
from stop_words import get_stop_words
from gensim import corpora
class Database(object):
"""Database object to read and store text corpus (in English only)"""
def __init__(self, db_name='', data_dir=None, test_pct=10):
"""Initialize database with all .txt files in the data directory
(Functionality is currently limited to .txt files)
Args:
data_dir (str): absolute or relative path to directory containing
text files. (i.e. each text file will become
a document in the database)
test_pct(int): Percent of total data to be used as the test set
Must be between 0 and 100
db_name(str): name of database
"""
if not db_name or not (isinstance(db_name,str)):
self.db_name = (time.strftime("%D:%H:%M:%S"))
else:
self.db_name = db_name
#Get stop words for the English language
self.stop_words = get_stop_words('en')
#Get stemmer
self.stemmer = PorterStemmer()
#Store file stats
self.files_read = 0
#List of pre-processed text files
self.tokenized_texts = list()
#Map of words and unique ids assigned to words (required by gensim)
self.word_to_id = None
#Training data
self.train_set = None
#Test data
self.test_set = None
#Utility list used to segment corpus into training and test set
self.train_epoch_idx = None
#Batch size of mini_batches to be used from the training set
self.batch_size = 0
if test_pct >= 0 and test_pct <= 100:
self.test_pct = test_pct*0.01
else:
self.test_pct = 0.1
#Mem usage profiling variables
start_idx = 0
size = 0
if data_dir is not None:
assert(os.path.isdir(data_dir)), "Invalid data directory path"
path, dirs, files = os.walk(data_dir).next()
print ('Files to add to database: {}'.format(len(files)))
for root, dirs, files in os.walk(data_dir):
#Iterate over files
for f in files:
#Read in data for all .txt files
if f.endswith('.txt'):
with codecs.open(data_dir + '/' + f, 'r', 'utf-8-sig') as data_f:
doc = data_f.read().replace('\n', ' ')
#Tokenize
tokens = word_tokenize(doc.lower())
#Remove stop words
stop_tokens = [token for token in tokens if token not in self.stop_words]
#Step text using Porter Stemming Algorithm
stem_tokens = [self.stemmer.stem(token) for token in stop_tokens]
self.tokenized_texts.append(stem_tokens)
self.files_read+=1
#Clear up unused variables for efficient mem usage
del doc
del tokens
del stop_tokens
del stem_tokens
gc.collect()
#Profile mem usage
if not (self.files_read % 1000):
for tt in self.tokenized_texts[start_idx:start_idx+1000]:
for tok in tt:
size+=sys.getsizeof(tok)
print ('Tokenized texts: {} ; Files read = {} '.format(size, self.files_read ))
start_idx+=1000
data_f.close()
print('Files successfully added to database: {}'.format(self.files_read))
if self.files_read > 0:
#Assign an integer to each unique word in the texts
self.word_to_id = corpora.Dictionary(self.tokenized_texts)
#Convert tokenized text into bow with id's used by gensim for LDA or LSI
corpus = [self.word_to_id.doc2bow(text) for text in self.tokenized_texts]
#Split into train and test corpus
random_sample = random.sample( range(len(corpus)) , int(math.floor( len(corpus) * self.test_pct)))
if random_sample:
self.test_set = [text for idx, text in enumerate(corpus) if idx in random_sample]
for idx,text in enumerate(corpus):
if idx in random_sample:
del corpus[idx]
self.train_set = corpus
print ('Training set size: {}, Test set size: {}'.format(len(self.train_set), len(self.test_set)))
else:
self.train_set = corpus
print ('Training set size: {}, Test set size: {}'.format(len(self.train_set), 0))
#Profile mem usage
size = 0
for text in self.train_set:
for tup in text:
size+=sys.getsizeof(tup)
print ('Train_set: {}'.format(size))
else:
print "Initialized empty database."
def add_text_file(self, data_file):
"""Adds another .txt file to the database
Note: This is HIGHLY inefficient in time and space.
Args:
data_file (str): absolute or relative path to the .txt file
"""
assert(os.path.isfile(data_file)), "Invalid file path"
assert(data_file.endswith('.txt')), "Invalid file type"
#Read in data
with codecs.open(data_file, 'r', 'utf-8-sig') as data_f:
doc = data_f.read().replace('\n', ' ')
#Tokenize
tokens = word_tokenize(doc.lower())
#Remove stop words
stop_tokens = [token for token in tokens if token not in self.stop_words]
#Step text using Porter Stemming Algorithm
stem_tokens = [self.stemmer.stem(token) for token in stop_tokens]
self.tokenized_texts.append(stem_tokens)
self.files_read+=1
#Clear up unused variables for efficient mem usage
del doc
del tokens
del stop_tokens
del stem_tokens
gc.collect()
data_f.close()
#Assign an integer to each unique word in the texts
self.word_to_id = corpora.Dictionary(self.tokenized_texts)
#Convert tokenized text into bow with id's used by LDA (or LSA)
corpus = [self.word_to_id.doc2bow(text) for text in self.tokenized_texts]
#Split into train and test corpus
random_sample = random.sample( range(len(corpus)) , int(math.floor( len(corpus) * self.test_pct)))
if random_sample:
self.test_set = [text for idx, text in enumerate(corpus) if idx in random_sample]
for idx,text in enumerate(corpus):
if idx in random_sample:
del corpus[idx]
self.train_set = corpus
print ('Training set size: {}, Test set size: {}'.format(len(self.train_set), len(self.test_set)))
else:
self.train_set = corpus
print ('Training set size: {}, Test set size: {}'.format(len(self.train_set), 0))
def add_data_dir(self, data_dir):
"""Adds .txt files from data_dir to the database
Note: This is inefficient in time and space.
Args:
data_dir (str): absolute or relative path to the dir
containing .txt files
"""
assert(os.path.isdir(data_dir)), "Invalid data directory path"
for root, dirs, files in os.walk(data_dir):
#Iterate over files
for f in files:
#Read in data for all .txt files
if f.endswith('.txt'):
with codecs.open(data_dir + '/' + f, 'r', 'utf-8-sig') as data_f:
doc = data_f.read().replace('\n', ' ')
#Tokenize
tokens = word_tokenize(doc.lower())
#Remove stop words
stop_tokens = [token for token in tokens if token not in self.stop_words]
#Step text using Porter Stemming Algorithm
stem_tokens = [self.stemmer.stem(token) for token in stop_tokens]
self.tokenized_texts.append(stem_tokens)
self.files_read+=1
#Clear up unused variables for efficient mem usage
del doc
del tokens
del stop_tokens
del stem_tokens
gc.collect()
data_f.close()
#Assign an integer to each unique word in the texts
self.word_to_id = corpora.Dictionary(self.tokenized_texts)
#Convert tokenized text into bow with id's used by LDA (or LSA)
corpus = [self.word_to_id.doc2bow(text) for text in self.tokenized_texts]
#Split into train and test corpus
random_sample = random.sample( range(len(corpus)) , int(math.floor( len(corpus) * self.test_pct)))
if random_sample:
self.test_set = [text for idx, text in enumerate(corpus) if idx in random_sample]
for idx,text in enumerate(corpus):
if idx in random_sample:
del corpus[idx]
self.train_set = corpus
print ('Training set size: {}, Test set size: {}'.format(len(self.train_set), len(self.test_set)))
else:
self.train_set = corpus
print ('Training set size: {}, Test set size: {}'.format(len(self.train_set), 0))
def store_to_disk(self, file_path):
"""Store the database object to disk for future use
Args:
file_path(str): absolute or relative path of file to store the db in
"""
assert(os.path.dirname(file_path)), 'Invalid directory provided to save file'
assert(os.access(os.path.dirname(file_path), os.W_OK)), 'Need write permissions to parent dir'
with open(file_path, 'w') as f:
pickle.dump([self.train_set,
self.test_set,
self.stop_words,
self.stemmer,
self.files_read,
self.tokenized_texts,
self.word_to_id,
self.train_epoch_idx,
self.batch_size,
self.db_name],
f)
def load_from_disk(self, file_path):
"""Load the corpus from disk
Args:
file_path(str): absolute or relative path of file to store the db in
"""
assert(os.path.isfile(file_path)), 'Invalid file path to load db'
print ('Loading database from file {}'.format(file_path))
with open(file_path) as f:
attr = pickle.load(f)
self.train_set = attr[0]
self.test_set = attr[1]
self.stop_words = attr[2]
self.stemmer = attr[3]
self.files_read = attr[4]
self.tokenized_texts = attr[5]
self.word_to_id = attr[6]
self.train_epoch_idx = attr[7]
self.batch_size = attr[8]
self.db_name = attr[9]
print 'Successfully loaded database.'
def prep_train_epoch(self, batch_size=1, num_epochs=1):
""" Prepare the training corpus for one epoch by splitting the training
corpus into mini batches of batch_size. This function needs to be
called before every epoch of training as it does the book keeping
required to send minibatches of data
Args: batch_size(int): Size of every mini batch that will be returned
by get_train_batch
num_epochs(int): Number of epochs to train for
"""
assert(self.train_set is not None),'There is no training data in the database'
assert(batch_size > 0),'Batch size must be a positive int less than size of training set'
assert(batch_size<=len(self.train_set)),'Batch size must be a positive int less than size of training set'
assert(num_epochs > 0),'Num epochs must be a positive int'
self.batch_size = batch_size
self.train_epoch_idx = []
for i in range(num_epochs):
self.train_epoch_idx.extend(random.sample(range(len(self.train_set)), len(self.train_set)))
if not self.train_epoch_idx:
print "Warning: There is no training data in the database."
def get_mini_batch(self):
"""Get a mini batch of data
Note that if less than batch_size of training samples are remaining, then we return the remaining samples.
"""
assert(self.train_epoch_idx is not None),'Need to call prep_train_epoch(batch_size) before calling get_mini_batch()'
if self.train_epoch_idx:
if len(self.train_epoch_idx) >= self.batch_size:
mb_idx = self.train_epoch_idx[0:self.batch_size]
self.train_epoch_idx = self.train_epoch_idx[self.batch_size:len(self.train_epoch_idx)]
mini_batch = [text for idx,text in enumerate(self.train_set) if idx in mb_idx]
return mini_batch
elif len(self.train_epoch_idx) < self.batch_size and len(self.train_epoch_idx) > 0:
mb_idx = self.train_epoch_idx[:]
self.train_epoch_idx = []
mini_batch = [text for idx,text in enumerate(self.train_set) if idx in mb_idx]
return mini_batch
else:
print "Training data exhausted."
return []
def get_train_set(self):
"""Get the test set. In case the test set is very large, get a fraction of the test set
Args:
test_pct(int): percet of test set to return
"""
if (self.train_set is None) or (not self.train_set):
print "There is no train data in the database."
return
else:
return self.train_set
def get_test_set(self, test_pct=100):
"""Get the test set. In case the test set is very large, get a fraction of the test set
Args:
test_pct(int): percet of test set to return
"""
assert(test_pct >= 0 and test_pct <= 100),'test_pct must be a positive int <= 100'
if (self.test_set is None) or (not self.test_set):
print "There is no test data in the database."
return
if test_pct != 100:
mb_idx = random.sample(range(len(self.test_set)), int(math.floor(len(self.train_set)*test_pct*0.01)) )
mini_batch = [text for idx,text in enumerate(self.test_set) if idx in mb_idx]
return mini_batch
else:
return self.test_set
def get_word2id(self):
assert(self.word_to_id is not None),'Database is empty. Please initialize database with data.'
return self.word_to_id
def corpus_size(self):
return self.files_read
def train_set_size(self):
return len(self.train_set)
def test_set_size(self):
return len(self.test_set)
def get_name(self):
return self.db_name