-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathpredict.py
More file actions
executable file
·90 lines (69 loc) · 2.1 KB
/
Copy pathpredict.py
File metadata and controls
executable file
·90 lines (69 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/bin/python
import sys;
import os;
from common import *
from latent_factor import *;
import logging, Logger;
import arffio;
import pickle;
import numpy as np;
import copy;
import time;
def printUsages():
print "Usage: predict.py test_file result_file model_file";
def parseParameter(argv):
if len(argv) < 4: #at least 3 paramters: predict.py test_file result_file model_file
printUsages();
exit(1);
parameters = dict();
parameters["test_file"] = argv[len(argv) - 3];
parameters["result_file"] = argv[len(argv) - 2];
parameters["model_file"] = argv[len(argv) - 1];
return parameters;
def predict(model, x):
p = model.ff(x);
m,n = p.shape;
'''
for i in xrange(m):
for j in xrange(n):
if p[i,j] > model.thrsel.threshold:
p[i,j] = 1;
else:
p[i,j] = 0;
'''
p[p > model.thrsel.threshold] = 1
p[p != 1 ] = 0
return p;
if __name__ == "__main__":
logger = logging.getLogger(Logger.project_name)
parameters = parseParameter(sys.argv);
test_file = parameters["test_file"];
model_file = parameters["model_file"];
result_file = parameters["result_file"];
reader = arffio.SvmReader(test_file, batch = 1000000000000);
x, _ = reader.full_read();
model = Model(dict())
model.load(model_file)
#import cProfile, pstats, StringIO
#pr = cProfile.Profile()
#pr.enable()
start = time.time()
p = predict(model, x);
end = time.time()
logger.info("predict time is %f seconds"%((end-start)))
#pr.disable()
#s = StringIO.StringIO()
#sortby = 'cumulative'
#ps = pstats.Stats(pr, stream = s).sort_stats(sortby)
#ps.print_stats()
#print "update",s.getvalue()
##predctions to sparse data
y = sp.csr_matrix(p)
x = np.zeros((p.shape[0],1))
for i in xrange(p.shape[0]):
x[i][0] = 1
x = sp.csr_matrix(x)
#write
writer = arffio.SvmWriter(result_file, model.num_feature, model.num_label);
writer.write(x,y);
writer.close();