diff --git a/tsimcne/tsimcne.py b/tsimcne/tsimcne.py index ee76935..1047042 100644 --- a/tsimcne/tsimcne.py +++ b/tsimcne/tsimcne.py @@ -559,6 +559,8 @@ def transform( elif not return_labels and return_backbone_feat: return Y, backbone_features elif return_labels and not return_backbone_feat: + # XXX: this for some reason changes the labels; but I + # don't know what causes this! labels = torch.hstack([lbl for _, lbl in loader]) return Y, labels else: