Skip to content

Commit

Permalink
fix scored in batch loss for multi-dim targets
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Nov 21, 2024
1 parent 0164c53 commit f52fb47
Showing 1 changed file with 45 additions and 18 deletions.
63 changes: 45 additions & 18 deletions lightning_ir/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ def process_scores(self, output: LightningIROutput) -> torch.Tensor:
raise ValueError("Expected scores in LightningIROutput")
return output.scores


class ScoringLossFunction(LossFunction):
@abstractmethod
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: ...

def process_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
if targets.ndim > scores.ndim:
return targets.max(-1).values
return targets


class ScoringLossFunction(LossFunction):
@abstractmethod
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: ...


class EmbeddingLossFunction(LossFunction):
@abstractmethod
def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: ...
Expand Down Expand Up @@ -288,7 +288,13 @@ def __init__(
raise ValueError("all_and_non_first is only valid with pos_sampling_technique first")

def _get_pos_mask(
self, num_queries: int, num_docs: int, max_idx: torch.Tensor, min_idx: torch.Tensor, batch: TrainBatch
self,
num_queries: int,
num_docs: int,
max_idx: torch.Tensor,
min_idx: torch.Tensor,
output: LightningIROutput,
batch: TrainBatch,
) -> torch.Tensor:
if self.pos_sampling_technique == "all":
pos_mask = torch.arange(num_queries * num_docs)[None].greater_equal(min_idx) & torch.arange(
Expand All @@ -301,7 +307,13 @@ def _get_pos_mask(
return pos_mask

def _get_neg_mask(
self, num_queries: int, num_docs: int, max_idx: torch.Tensor, min_idx: torch.Tensor, batch: TrainBatch
self,
num_queries: int,
num_docs: int,
max_idx: torch.Tensor,
min_idx: torch.Tensor,
output: LightningIROutput,
batch: TrainBatch,
) -> torch.Tensor:
if self.neg_sampling_technique == "all_and_non_first":
neg_mask = torch.arange(num_queries * num_docs)[None].not_equal(min_idx)
Expand All @@ -323,8 +335,8 @@ def get_ib_idcs(self, output: LightningIROutput, batch: TrainBatch) -> Tuple[tor
num_queries, num_docs = output.scores.shape
min_idx = torch.arange(num_queries)[:, None] * num_docs
max_idx = min_idx + num_docs
pos_mask = self._get_pos_mask(num_queries, num_docs, max_idx, min_idx, batch)
neg_mask = self._get_neg_mask(num_queries, num_docs, max_idx, min_idx, batch)
pos_mask = self._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
neg_mask = self._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
pos_idcs = pos_mask.nonzero(as_tuple=True)[1]
neg_idcs = neg_mask.nonzero(as_tuple=True)[1]
if self.max_num_neg_samples is not None:
Expand All @@ -346,25 +358,40 @@ def __init__(self, min_target_diff: float, max_num_neg_samples: int | None = Non
)
self.min_target_diff = min_target_diff

def _sort_mask(self, mask: torch.Tensor, num_queries: int, num_docs: int, batch: TrainBatch) -> torch.Tensor:
assert batch.targets is not None
idcs = batch.targets.argsort(descending=True).argsort()
def _sort_mask(
self, mask: torch.Tensor, num_queries: int, num_docs: int, output: LightningIROutput, batch: TrainBatch
) -> torch.Tensor:
assert output.scores is not None and batch.targets is not None
targets = self.process_targets(output.scores, batch.targets)
idcs = targets.argsort(descending=True).argsort()
idcs = idcs + torch.arange(num_queries)[:, None] * num_docs
block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs
return mask.scatter(1, block_idcs, mask.gather(1, idcs))

def _get_pos_mask(
self, num_queries: int, num_docs: int, max_idx: torch.Tensor, min_idx: torch.Tensor, batch: TrainBatch
self,
num_queries: int,
num_docs: int,
max_idx: torch.Tensor,
min_idx: torch.Tensor,
output: LightningIROutput,
batch: TrainBatch,
) -> torch.Tensor:
pos_mask = super()._get_pos_mask(num_queries, num_docs, max_idx, min_idx, batch)
pos_mask = self._sort_mask(pos_mask, num_queries, num_docs, batch)
pos_mask = super()._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
pos_mask = self._sort_mask(pos_mask, num_queries, num_docs, output, batch)
return pos_mask

def _get_neg_mask(
self, num_queries: int, num_docs: int, max_idx: torch.Tensor, min_idx: torch.Tensor, batch: TrainBatch
self,
num_queries: int,
num_docs: int,
max_idx: torch.Tensor,
min_idx: torch.Tensor,
output: LightningIROutput,
batch: TrainBatch,
) -> torch.Tensor:
neg_mask = super()._get_neg_mask(num_queries, num_docs, max_idx, min_idx, batch)
neg_mask = self._sort_mask(neg_mask, num_queries, num_docs, batch)
neg_mask = super()._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
neg_mask = self._sort_mask(neg_mask, num_queries, num_docs, output, batch)
assert batch.targets is not None
max_score, _ = batch.targets.max(dim=-1, keepdim=True)
score_diff = max_score - batch.targets
Expand Down

0 comments on commit f52fb47

Please sign in to comment.