diff --git a/nemo/models/nemo.py b/nemo/models/nemo.py index 135b2e1..fabfac0 100644 --- a/nemo/models/nemo.py +++ b/nemo/models/nemo.py @@ -524,8 +524,7 @@ def evaluate(self, sample, debug=False): for i, pred in enumerate(preds): if "azimuth" in sample and "elevation" in sample and "theta" in sample: pose_error_ = pose_error({k: sample[k][i] for k in ["azimuth", "elevation", "theta"]}, pred["final"][0]) - if CATEGORIES.index(self.cate) == sample["label"][i]: - pred["pose_error"] = pose_error_ + pred["pose_error"] = pose_error_ if self.training_params.classification: # print(pred['final'][0]['score']) classification_result[sample['this_name'][i]] = (pred['final'][0]['score'], pose_error_)