From f52fb47710c68c403fc2520016565c4f51c3f461 Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Thu, 21 Nov 2024 17:19:48 +0100 Subject: [PATCH] fix scored in batch loss for multi-dim targets --- lightning_ir/loss/loss.py | 63 ++++++++++++++++++++++++++++----------- 1 file changed, 45 insertions(+), 18 deletions(-) diff --git a/lightning_ir/loss/loss.py b/lightning_ir/loss/loss.py index ff0f6a7..17a577f 100644 --- a/lightning_ir/loss/loss.py +++ b/lightning_ir/loss/loss.py @@ -22,17 +22,17 @@ def process_scores(self, output: LightningIROutput) -> torch.Tensor: raise ValueError("Expected scores in LightningIROutput") return output.scores - -class ScoringLossFunction(LossFunction): - @abstractmethod - def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: ... - def process_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: if targets.ndim > scores.ndim: return targets.max(-1).values return targets +class ScoringLossFunction(LossFunction): + @abstractmethod + def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: ... + + class EmbeddingLossFunction(LossFunction): @abstractmethod def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: ... @@ -288,7 +288,13 @@ def __init__( raise ValueError("all_and_non_first is only valid with pos_sampling_technique first") def _get_pos_mask( - self, num_queries: int, num_docs: int, max_idx: torch.Tensor, min_idx: torch.Tensor, batch: TrainBatch + self, + num_queries: int, + num_docs: int, + max_idx: torch.Tensor, + min_idx: torch.Tensor, + output: LightningIROutput, + batch: TrainBatch, ) -> torch.Tensor: if self.pos_sampling_technique == "all": pos_mask = torch.arange(num_queries * num_docs)[None].greater_equal(min_idx) & torch.arange( @@ -301,7 +307,13 @@ def _get_pos_mask( return pos_mask def _get_neg_mask( - self, num_queries: int, num_docs: int, max_idx: torch.Tensor, min_idx: torch.Tensor, batch: TrainBatch + self, + num_queries: int, + num_docs: int, + max_idx: torch.Tensor, + min_idx: torch.Tensor, + output: LightningIROutput, + batch: TrainBatch, ) -> torch.Tensor: if self.neg_sampling_technique == "all_and_non_first": neg_mask = torch.arange(num_queries * num_docs)[None].not_equal(min_idx) @@ -323,8 +335,8 @@ def get_ib_idcs(self, output: LightningIROutput, batch: TrainBatch) -> Tuple[tor num_queries, num_docs = output.scores.shape min_idx = torch.arange(num_queries)[:, None] * num_docs max_idx = min_idx + num_docs - pos_mask = self._get_pos_mask(num_queries, num_docs, max_idx, min_idx, batch) - neg_mask = self._get_neg_mask(num_queries, num_docs, max_idx, min_idx, batch) + pos_mask = self._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch) + neg_mask = self._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch) pos_idcs = pos_mask.nonzero(as_tuple=True)[1] neg_idcs = neg_mask.nonzero(as_tuple=True)[1] if self.max_num_neg_samples is not None: @@ -346,25 +358,40 @@ def __init__(self, min_target_diff: float, max_num_neg_samples: int | None = Non ) self.min_target_diff = min_target_diff - def _sort_mask(self, mask: torch.Tensor, num_queries: int, num_docs: int, batch: TrainBatch) -> torch.Tensor: - assert batch.targets is not None - idcs = batch.targets.argsort(descending=True).argsort() + def _sort_mask( + self, mask: torch.Tensor, num_queries: int, num_docs: int, output: LightningIROutput, batch: TrainBatch + ) -> torch.Tensor: + assert output.scores is not None and batch.targets is not None + targets = self.process_targets(output.scores, batch.targets) + idcs = targets.argsort(descending=True).argsort() idcs = idcs + torch.arange(num_queries)[:, None] * num_docs block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs return mask.scatter(1, block_idcs, mask.gather(1, idcs)) def _get_pos_mask( - self, num_queries: int, num_docs: int, max_idx: torch.Tensor, min_idx: torch.Tensor, batch: TrainBatch + self, + num_queries: int, + num_docs: int, + max_idx: torch.Tensor, + min_idx: torch.Tensor, + output: LightningIROutput, + batch: TrainBatch, ) -> torch.Tensor: - pos_mask = super()._get_pos_mask(num_queries, num_docs, max_idx, min_idx, batch) - pos_mask = self._sort_mask(pos_mask, num_queries, num_docs, batch) + pos_mask = super()._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch) + pos_mask = self._sort_mask(pos_mask, num_queries, num_docs, output, batch) return pos_mask def _get_neg_mask( - self, num_queries: int, num_docs: int, max_idx: torch.Tensor, min_idx: torch.Tensor, batch: TrainBatch + self, + num_queries: int, + num_docs: int, + max_idx: torch.Tensor, + min_idx: torch.Tensor, + output: LightningIROutput, + batch: TrainBatch, ) -> torch.Tensor: - neg_mask = super()._get_neg_mask(num_queries, num_docs, max_idx, min_idx, batch) - neg_mask = self._sort_mask(neg_mask, num_queries, num_docs, batch) + neg_mask = super()._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch) + neg_mask = self._sort_mask(neg_mask, num_queries, num_docs, output, batch) assert batch.targets is not None max_score, _ = batch.targets.max(dim=-1, keepdim=True) score_diff = max_score - batch.targets