-
Notifications
You must be signed in to change notification settings - Fork 78
/
Copy pathtrain_emotion_classifier.py
65 lines (48 loc) · 1.95 KB
/
train_emotion_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import tensorflow as tf
import pandas as pd
import numpy as np
import subprocess
import os
import wandb
run = wandb.init()
config = run.config
config.batch_size = 32
config.num_epochs = 20
input_shape = (48, 48, 1)
def load_fer2013():
if not os.path.exists("fer2013"):
print("Downloading the face emotion dataset...")
subprocess.check_output(
"curl -SL https://www.dropbox.com/s/opuvvdv3uligypx/fer2013.tar | tar xz", shell=True)
data = pd.read_csv("fer2013/fer2013.csv")
pixels = data['pixels'].tolist()
width, height = 48, 48
faces = []
for pixel_sequence in pixels:
face = np.asarray(pixel_sequence.split(
' '), dtype=np.uint8).reshape(width, height)
faces.append(face.astype('float32'))
faces = np.asarray(faces)
faces = np.expand_dims(faces, -1)
emotions = pd.get_dummies(data['emotion']).as_matrix()
val_faces = faces[int(len(faces) * 0.8):]
val_emotions = emotions[int(len(faces) * 0.8):]
train_faces = faces[:int(len(faces) * 0.8)]
train_emotions = emotions[:int(len(faces) * 0.8)]
return train_faces, train_emotions, val_faces, val_emotions
# loading dataset
train_faces, train_emotions, val_faces, val_emotions = load_fer2013()
num_samples, num_classes = train_emotions.shape
train_faces /= 255.
val_faces /= 255.
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=input_shape))
model.add(tf.keras.layers.Dense(num_classes, activation="softmax"))
model.compile(optimizer='adam', loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_faces, train_emotions, batch_size=config.batch_size,
epochs=config.num_epochs, verbose=1, callbacks=[
wandb.keras.WandbCallback(data_type="image", labels=[
"Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"])
], validation_data=(val_faces, val_emotions))
model.save("emotion.h5")