-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathW2V_MultinomialNB.py
More file actions
74 lines (57 loc) · 2.12 KB
/
W2V_MultinomialNB.py
File metadata and controls
74 lines (57 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
from sklearn.naive_bayes import MultinomialNB
from sklearn import metrics
from sklearn.decomposition import NMF
import datetime
import matplotlib.pyplot as plt
if __name__ == "__main__":
startTime = datetime.datetime.now()
# Load training data
x = np.load('data/train_w2v_data_array.npy')
y = np.load('data/train_w2v_target_array.npy')
y = y.astype('int')
y = y.flatten()
# Load test data
z = np.load('data/test_w2v_data_array.npy')
t = np.load('data/test_w2v_target_array.npy')
t = t.astype('int')
t = t.flatten()
#Remove -ve values and scale all values by smallest -ve value in array
xmin = np.amin(x)
zmin = np.amin(z)
scale_min = min(xmin, zmin) * -1
x = np.add(x, scale_min)
z = np.add(z, scale_min)
# x = x + 11.573273289802543
# z = z + 16.698667840828804
# Predict using Naive Bayes Model
clf = MultinomialNB(alpha=1)
clf.fit(x, y)
p = clf.predict(z)
# Compute training time
endTime = datetime.datetime.now() - startTime
print("Total time taken to train: ", endTime)
print("\n")
print("W2V Multinomial Naive Bayes")
# Compute accuracy
accuracy = metrics.accuracy_score(t, p, normalize=False)
print("Accuracy: ", (accuracy / len(t)) * 100)
# Confusion matrix
confusion_matrix = metrics.confusion_matrix(t, p)
print("Confusion Matrix:\n", confusion_matrix)
# Replace 4s with 1s
t[np.where(t == 4)] = 1
p[np.where(p == 4)] = 1
y_scores = clf.predict_proba(z)
# Plot the Precision-Recall curve
precision, recall, _ = metrics.precision_recall_curve(t, y_scores[:, 1])
plt.step(recall, precision, color='b', alpha=0.2, where='post')
plt.fill_between(recall, precision, step='post', alpha=0.2, color='b')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
average_precision = metrics.average_precision_score(t, p)
plt.title('W2V Multinomial NB Precision-Recall curve: AP={0:0.2f}'.format(average_precision))
plt.savefig('data/w2v_MultinomialNB_precisionRecall.png')
plt.show()