From 817fb28249c043e715756adee3883d55a1b5f9d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niklas=20B=C3=B6hm?= Date: Fri, 13 Dec 2024 13:41:54 +0100 Subject: [PATCH] Fix loss logging --- tsimcne/losses/infonce.py | 12 ++++++++++-- tsimcne/tsimcne.py | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tsimcne/losses/infonce.py b/tsimcne/losses/infonce.py index 61428db..6df607c 100644 --- a/tsimcne/losses/infonce.py +++ b/tsimcne/losses/infonce.py @@ -83,7 +83,11 @@ def forward(self, features, backbone_features=None, labels=None): raw_uniformity = logsumexp_1 + logsumexp_2 loss = -(self.exaggeration * tempered_alignment - raw_uniformity / 2) - return loss + return dict( + loss=loss, + ta=-tempered_alignment, + ru=raw_uniformity / 2, + ) class InfoNCEGaussian(InfoNCECauchy): @@ -111,7 +115,11 @@ def forward(self, features, backbone_features=None, labels=None): raw_uniformity = logsumexp_1 + logsumexp_2 loss = -(tempered_alignment - raw_uniformity / 2) - return loss + return dict( + loss=loss, + ta=-tempered_alignment, + ru=raw_uniformity / 2, + ) class InfoNCELoss(LossBase): diff --git a/tsimcne/tsimcne.py b/tsimcne/tsimcne.py index 87dec06..95f5fcb 100644 --- a/tsimcne/tsimcne.py +++ b/tsimcne/tsimcne.py @@ -356,9 +356,9 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): features, backbone_features = self(batch) # backbone_features are unused in infonce loss - loss = self.loss(features) + lossd = self.loss(features) - return dict(loss=loss) + return lossd elif dataloader_idx == 1: features, backbone_features = self(batch)