diff --git a/ptn/model_rotator.py b/ptn/model_rotator.py index 32210c81dd4..6c0c97c2ef5 100644 --- a/ptn/model_rotator.py +++ b/ptn/model_rotator.py @@ -190,8 +190,34 @@ def get_train_op_for_scope(loss, optimizer, scopes, params): def get_metrics(inputs, outputs, params): - names_to_values, names_to_updates = metrics.rotator_metrics( - inputs, outputs, params) + """Aggregate the metrics for rotator model. + + Args: + inputs: Input dictionary of the rotator model. + outputs: Output dictionary returned by the rotator model. + params: Hyperparameters of the rotator model. + + Returns: + names_to_values: metrics->values (dict). + names_to_updates: metrics->ops (dict). + """ + names_to_values = dict() + names_to_updates = dict() + + tmp_values, tmp_updates = metrics.add_image_pred_metrics( + inputs, outputs, params.num_views, 3*params.image_size**2) + names_to_values.update(tmp_values) + names_to_updates.update(tmp_updates) + + tmp_values, tmp_updates = metrics.add_mask_pred_metrics( + inputs, outputs, params.num_views, params.image_size**2) + names_to_values.update(tmp_values) + names_to_updates.update(tmp_updates) + + for name, value in names_to_values.iteritems(): + slim.summaries.add_scalar_summary( + value, name, prefix='eval', print_summary=True) + return names_to_values, names_to_updates