-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtf_losses.py
30 lines (25 loc) · 1.19 KB
/
tf_losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import tensorflow as tf
def policy_loss(target, output):
# Illegal moves are marked by a value of -1 in the labels - we mask these with large negative values
output = tf.where(target < 0, -1e5, output)
# The large negative values will still break the loss, so we replace them with 0 once we finish masking
target = tf.nn.relu(target)
# The stop gradient is maybe paranoia, but it can't hurt
policy_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
labels=tf.stop_gradient(target), logits=output
)
return tf.reduce_mean(input_tensor=policy_cross_entropy)
def value_loss(target, output):
# Value loss head is WDL, so this is a cross-entropy loss too
output = tf.cast(output, tf.float32)
value_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
labels=tf.stop_gradient(target), logits=output
)
return tf.reduce_mean(input_tensor=value_cross_entropy)
def moves_left_loss(target, output):
# Scale the loss to similar range as other losses.
scale = 20.0
target = target / scale
output = tf.cast(output, tf.float32) / scale
huber = tf.keras.losses.Huber(10.0 / scale)
return tf.reduce_mean(huber(target, output))