From 48ad75c235303002d731058cb6dec349a11fba4e Mon Sep 17 00:00:00 2001 From: Ethan Blackwood Date: Thu, 11 May 2017 04:03:53 -0400 Subject: [PATCH 1/4] Change dice coefficient calculation --- tf_unet/unet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tf_unet/unet.py b/tf_unet/unet.py index 2b3b24f..39542c3 100644 --- a/tf_unet/unet.py +++ b/tf_unet/unet.py @@ -229,9 +229,9 @@ def _get_cost(self, logits, cost_name, cost_kwargs): elif cost_name == "dice_coefficient": eps = 1e-5 prediction = pixel_wise_softmax_2(logits) - intersection = tf.reduce_sum(prediction * self.y) - union = eps + tf.reduce_sum(prediction) + tf.reduce_sum(self.y) - loss = -(2 * intersection/ (union)) + intersection = tf.reduce_sum(prediction * self.y, axis=[0, 1, 2]) + union = eps + tf.reduce_sum(prediction, axis=[0, 1, 2]) + tf.reduce_sum(self.y, axis=[0, 1, 2]) + loss = tf.reduce_sum(-(2 * intersection/ (union))) else: raise ValueError("Unknown cost function: "%cost_name) From cb77c4820591795e4495f0633dda5025daf03013 Mon Sep 17 00:00:00 2001 From: Ethan Blackwood Date: Mon, 4 Jun 2018 21:38:55 -0400 Subject: [PATCH 2/4] Fix dice_coefficient denominator name and add iou --- tf_unet/unet.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tf_unet/unet.py b/tf_unet/unet.py index 4e5ceb0..ba85855 100644 --- a/tf_unet/unet.py +++ b/tf_unet/unet.py @@ -230,12 +230,17 @@ def _get_cost(self, logits, cost_name, cost_kwargs): else: loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits, labels=flat_labels)) - elif cost_name == "dice_coefficient": + elif cost_name == "dice_coefficient" or cost_name == "iou": eps = 1e-5 prediction = pixel_wise_softmax_2(logits) - intersection = tf.reduce_sum(prediction * self.y, axis=[0, 1, 2]) - union = eps + tf.reduce_sum(prediction, axis=[0, 1, 2]) + tf.reduce_sum(self.y, axis=[0, 1, 2]) - loss = tf.reduce_sum(-(2 * intersection / (union))) + A_intersect_B = tf.reduce_sum(prediction * self.y, axis=[0, 1, 2]) + A_plus_B = tf.reduce_sum(prediction, axis=[0, 1, 2]) + tf.reduce_sum(self.y, axis=[0, 1, 2]) + if cost_name == "dice_coefficient" + denominator = A_plus_B + else # intersection over union + A_union_B = A_plus_B - A_intersect_B + denominator = A_union_B + loss = tf.reduce_sum(-(2 * A_intersect_B / (eps + denominator))) else: raise ValueError("Unknown cost function: " % cost_name) From 828db7921b138a16d12784c1d9cc7db1bc23efbc Mon Sep 17 00:00:00 2001 From: Ethan Blackwood Date: Wed, 6 Jun 2018 21:59:48 -0400 Subject: [PATCH 3/4] Add documentation for iou --- tf_unet/unet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tf_unet/unet.py b/tf_unet/unet.py index ba85855..8489327 100644 --- a/tf_unet/unet.py +++ b/tf_unet/unet.py @@ -204,7 +204,8 @@ def __init__(self, channels=3, n_class=2, cost="cross_entropy", cost_kwargs={}, def _get_cost(self, logits, cost_name, cost_kwargs): """ - Constructs the cost function, either cross_entropy, weighted cross_entropy or dice_coefficient. + Constructs the cost function, either cross_entropy, weighted cross_entropy, + dice_coefficient, or iou (intersection over union). Optional arguments are: class_weights: weights for the different classes in case of multi-class imbalance regularizer: power of the L2 regularizers added to the loss function From adb031c17e51b3273834b5f10f3cbe8c05c45a14 Mon Sep 17 00:00:00 2001 From: Ethan Blackwood Date: Fri, 8 Jun 2018 13:14:02 -0400 Subject: [PATCH 4/4] Fix syntax errors --- tf_unet/unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tf_unet/unet.py b/tf_unet/unet.py index 8489327..f8ef61a 100644 --- a/tf_unet/unet.py +++ b/tf_unet/unet.py @@ -236,9 +236,9 @@ def _get_cost(self, logits, cost_name, cost_kwargs): prediction = pixel_wise_softmax_2(logits) A_intersect_B = tf.reduce_sum(prediction * self.y, axis=[0, 1, 2]) A_plus_B = tf.reduce_sum(prediction, axis=[0, 1, 2]) + tf.reduce_sum(self.y, axis=[0, 1, 2]) - if cost_name == "dice_coefficient" + if cost_name == "dice_coefficient": denominator = A_plus_B - else # intersection over union + else: # intersection over union A_union_B = A_plus_B - A_intersect_B denominator = A_union_B loss = tf.reduce_sum(-(2 * A_intersect_B / (eps + denominator)))