From 2bc8f800dbe77a0c9b5388c37552b4b29c260080 Mon Sep 17 00:00:00 2001 From: Valentin Trifonov Date: Sun, 26 Feb 2017 11:42:27 +0100 Subject: [PATCH] compatibility with tensorflow 1.0+. Fixes #19. --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 162a8bd..cd3f2a6 100755 --- a/train.py +++ b/train.py @@ -126,9 +126,9 @@ def get_loss(y, y_): # Calculate the loss from digits being incorrect. Don't count loss from # digits that are in non-present plates. digits_loss = tf.nn.softmax_cross_entropy_with_logits( - tf.reshape(y[:, 1:], + logits=tf.reshape(y[:, 1:], [-1, len(common.CHARS)]), - tf.reshape(y_[:, 1:], + labels=tf.reshape(y_[:, 1:], [-1, len(common.CHARS)])) digits_loss = tf.reshape(digits_loss, [-1, 7]) digits_loss = tf.reduce_sum(digits_loss, 1) @@ -137,7 +137,7 @@ def get_loss(y, y_): # Calculate the loss from presence indicator being wrong. presence_loss = tf.nn.sigmoid_cross_entropy_with_logits( - y[:, :1], y_[:, :1]) + logits=y[:, :1], labels=y_[:, :1]) presence_loss = 7 * tf.reduce_sum(presence_loss) return digits_loss, presence_loss, digits_loss + presence_loss