-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_ferplus.py
More file actions
executable file
·103 lines (80 loc) · 3.05 KB
/
train_ferplus.py
File metadata and controls
executable file
·103 lines (80 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/env python
import numpy as np
import tensorflow as tf
import pandas as pd
import model
# Training Parameters.
max_steps = 5000
batch_size = 128
eval_interval = 240
# Read CSV with pandas.
ferplus = pd.read_csv('data/ferplus.csv')
# Split and cast pixel intensities to float32.
ferplus.pixels = ferplus.pixels.str.split()
ferplus.pixels = ferplus.pixels.map(lambda p: pd.to_numeric(p, downcast='float'))
# Filter out noface class.
ferplus = ferplus.query('NF == 0')
# Filter out faces where four or more workers were not sure about the emotion.
ferplus = ferplus.query('unknown < 4')
# Class labels in the usual order.
classes = ['anger', 'disgust', 'fear', 'happiness', 'sadness', 'surprise', 'neutral']
# Get argmax of class distribution to use as label.
labels = ferplus[classes]
maxlabels = labels.idxmax(axis=1).map(labels.columns.get_loc)
ferplus.insert(loc=12, column='maxlabel', value=maxlabels)
# Split train, test and validation set.
train = ferplus.loc[ferplus['Usage'] == 'Training']
valid = ferplus.loc[ferplus['Usage'] == 'PublicTest']
test = ferplus.loc[ferplus['Usage'] == 'PrivateTest']
# Prepare input images
x_test = np.array(test['pixels'].values.tolist())
x_train = np.array(train['pixels'].values.tolist())
x_valid = np.array(valid['pixels'].values.tolist())
# Prepate input label distributions
y_test = test[classes].values
y_train = train[classes].values
y_valid = valid[classes].values
# Build the Estimator
estimator = tf.estimator.Estimator(
model_fn=model.model_fn,
model_dir='./model/ferplus',
params={
'learning_rate': 0.001,
'num_classes': 7,
'img_size': 48,
'dropout_rate': 0.3
})
# Define the input function for training.
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'images': x_train}, y=y_train,
batch_size=batch_size, num_epochs=None, shuffle=True)
# Define the input function for validation during training.
valid_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'images': x_valid}, y=y_valid,
batch_size=batch_size, num_epochs=None, shuffle=False)
# Specify training operations.
train_spec = tf.estimator.TrainSpec(
input_fn = train_input_fn,
max_steps = max_steps
)
# Specify evaluation operations.
eval_spec = tf.estimator.EvalSpec(
input_fn = valid_input_fn,
throttle_secs=eval_interval,
start_delay_secs=eval_interval,
)
# Train the Model and evaluate periodically
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
# Define the input function for evaluating with the test set.
input_fn = tf.estimator.inputs.numpy_input_fn(
x={'images': x_test}, y=y_test,
batch_size=batch_size, shuffle=False)
# Evaluate the model accuracy with the test set
estimator.evaluate(input_fn)
# Export the model as a SavedModel for production use
feature_spec = {'images': tf.placeholder(dtype=tf.float32, shape=[None, 48 * 48])}
serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_spec)
estimator.export_savedmodel(
export_dir_base='saved_models/ferplus',
serving_input_receiver_fn=serving_input_fn,
as_text=True)