Skip to content

Commit

Permalink
pass model output to loss functions
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Oct 30, 2024
1 parent 6eb3e81 commit 34b7be3
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 92 deletions.
35 changes: 18 additions & 17 deletions lightning_ir/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def validation_step(

dataset_id = self.get_dataset_id(dataloader_idx)
metrics = self.validate(
scores=output.scores,
output=output,
query_ids=batch.query_ids,
doc_ids=batch.doc_ids,
qrels=batch.qrels,
Expand Down Expand Up @@ -262,7 +262,7 @@ def get_dataset_id(self, dataloader_idx: int) -> str:

def validate(
self,
scores: torch.Tensor | None = None,
output: LightningIROutput,
query_ids: Sequence[str] | None = None,
doc_ids: Sequence[Sequence[str]] | None = None,
qrels: Sequence[Dict[str, int]] | None = None,
Expand All @@ -271,8 +271,8 @@ def validate(
) -> Dict[str, float]:
"""Validates the model output with the evaluation metrics and loss functions.
:param scores: Model output scores, defaults to None
:type scores: torch.Tensor | None, optional
:param output: Model output
:type output: LightningIROutput
:param query_ids: ids of the queries, defaults to None
:type query_ids: Sequence[str] | None, optional
:param doc_ids: ids of the documents, defaults to None
Expand All @@ -289,7 +289,7 @@ def validate(
:rtype: Dict[str, float]
"""
metrics: Dict[str, float] = {}
if self.evaluation_metrics is None or scores is None:
if self.evaluation_metrics is None or output.scores is None:
return metrics
if query_ids is None:
if num_docs is None:
Expand All @@ -299,21 +299,21 @@ def validate(
if num_docs is None:
raise ValueError("num_docs must be set if doc_ids is not set")
doc_ids = tuple(tuple(f"{i}-{j}" for j in range(docs)) for i, docs in enumerate(num_docs))
metrics.update(self.validate_metrics(scores, query_ids, doc_ids, qrels))
metrics.update(self.validate_loss(scores, query_ids, targets))
metrics.update(self.validate_metrics(output, query_ids, doc_ids, qrels))
metrics.update(self.validate_loss(output, query_ids, targets))
return metrics

def validate_metrics(
self,
scores: torch.Tensor,
output: LightningIROutput,
query_ids: Sequence[str],
doc_ids: Sequence[Sequence[str]],
qrels: Sequence[Dict[str, int]] | None,
) -> Dict[str, float]:
"""Validates the model output with the evaluation metrics.
:param scores: Model output scores
:type scores: torch.Tensor
:param output: Model output
:type output: LightningIROutput
:param query_ids: ids of the queries
:type query_ids: Sequence[str]
:param doc_ids: ids of the documents
Expand All @@ -328,18 +328,18 @@ def validate_metrics(
return metrics
evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"]
ir_measures_qrels = create_qrels_from_dicts(qrels)
if evaluation_metrics and qrels is not None:
run = create_run_from_scores(query_ids, doc_ids, scores)
if evaluation_metrics and qrels is not None and output.scores is not None:
run = create_run_from_scores(query_ids, doc_ids, output.scores)
metrics.update(evaluate_run(run, ir_measures_qrels, evaluation_metrics))
return metrics

def validate_loss(
self, scores: torch.Tensor, query_ids: Sequence[str], targets: torch.Tensor | None
self, output: LightningIROutput, query_ids: Sequence[str], targets: torch.Tensor | None
) -> Dict[str, float]:
"""Validates the model output with the loss functions.
:param scores: Model output scores
:type scores: torch.Tensor
:param output: Model output
:type output: LightningIROutput
:param query_ids: ids of the queries
:type query_ids: Sequence[str]
:param targets: Target tensor used during fine-tuning
Expand All @@ -353,15 +353,16 @@ def validate_loss(
or "loss" not in self.evaluation_metrics
or targets is None
or self.loss_functions is None
or output.scores is None
):
return metrics
scores = scores.view(len(query_ids), -1)
output.scores = output.scores.view(len(query_ids), -1)
for loss_function, _ in self.loss_functions:
# NOTE skip in-batch losses because they can use a lot of memory
if isinstance(loss_function, InBatchLossFunction):
continue
metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(
scores, targets
output, targets
).item()
return metrics

Expand Down
34 changes: 19 additions & 15 deletions lightning_ir/bi_encoder/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from ..base import LightningIRModule
from ..base import LightningIRModule, LightningIROutput
from ..data import IndexBatch, RankBatch, SearchBatch, TrainBatch
from ..loss.loss import EmbeddingLossFunction, InBatchLossFunction, LossFunction, ScoringLossFunction
from ..retrieve import SearchConfig, Searcher
Expand Down Expand Up @@ -79,36 +79,40 @@ def compute_losses(self, batch: TrainBatch, output: BiEncoderOutput) -> List[tor
if self.loss_functions is None:
raise ValueError("Loss function is not set")

scores = output.scores
query_embeddings = output.query_embeddings
doc_embeddings = output.doc_embeddings
if batch.targets is None or query_embeddings is None or doc_embeddings is None or scores is None:
if (
batch.targets is None
or output.query_embeddings is None
or output.doc_embeddings is None
or output.scores is None
):
raise ValueError(
"targets, scores, query_embeddings, and doc_embeddings must be set in " "the output and batch"
)

num_queries = len(batch.queries)
scores = scores.view(num_queries, -1)
targets = batch.targets.view(*scores.shape, -1)
output.scores = output.scores.view(num_queries, -1)
targets = batch.targets.view(*output.scores.shape, -1)
losses = []
for loss_function, _ in self.loss_functions:
if isinstance(loss_function, InBatchLossFunction):
pos_idcs, neg_idcs = loss_function.get_ib_idcs(*scores.shape)
ib_doc_embeddings = self.get_ib_doc_embeddings(doc_embeddings, pos_idcs, neg_idcs, num_queries)
ib_scores = self.model.score(query_embeddings, ib_doc_embeddings)
pos_idcs, neg_idcs = loss_function.get_ib_idcs(*output.scores.shape)
ib_doc_embeddings = self.get_ib_doc_embeddings(output.doc_embeddings, pos_idcs, neg_idcs, num_queries)
ib_scores = self.model.score(output.query_embeddings, ib_doc_embeddings)
ib_scores = ib_scores.view(num_queries, -1)
losses.append(loss_function.compute_loss(ib_scores))
losses.append(loss_function.compute_loss(LightningIROutput(ib_scores)))
elif isinstance(loss_function, EmbeddingLossFunction):
losses.append(loss_function.compute_loss(query_embeddings, doc_embeddings))
losses.append(loss_function.compute_loss(output))
elif isinstance(loss_function, ScoringLossFunction):
losses.append(loss_function.compute_loss(scores, targets))
losses.append(loss_function.compute_loss(output, targets))
else:
raise ValueError(f"Unknown loss function type {loss_function.__class__.__name__}")
if self.config.sparsification is not None:
query_num_nonzero = (
torch.nonzero(query_embeddings.embeddings).shape[0] / query_embeddings.embeddings.shape[0]
torch.nonzero(output.query_embeddings.embeddings).shape[0] / output.query_embeddings.embeddings.shape[0]
)
doc_num_nonzero = (
torch.nonzero(output.doc_embeddings.embeddings).shape[0] / output.doc_embeddings.embeddings.shape[0]
)
doc_num_nonzero = torch.nonzero(doc_embeddings.embeddings).shape[0] / doc_embeddings.embeddings.shape[0]
self.log("query_num_nonzero", query_num_nonzero)
self.log("doc_num_nonzero", doc_num_nonzero)
return losses
Expand Down
15 changes: 7 additions & 8 deletions lightning_ir/cross_encoder/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ..base.module import LightningIRModule
from ..data import RankBatch, SearchBatch, TrainBatch
from ..loss.loss import InBatchLossFunction, LossFunction
from ..loss.loss import LossFunction, ScoringLossFunction
from .config import CrossEncoderConfig
from .model import CrossEncoderModel, CrossEncoderOutput
from .tokenizer import CrossEncoderTokenizer
Expand Down Expand Up @@ -38,16 +38,15 @@ def compute_losses(self, batch: TrainBatch, output: CrossEncoderOutput) -> List[
if self.loss_functions is None:
raise ValueError("loss_functions must be set in the module")
output = self.forward(batch)
scores = output.scores
if scores is None or batch.targets is None:
if output.scores is None or batch.targets is None:
raise ValueError("scores and targets must be set in the output and batch")

scores = scores.view(len(batch.query_ids), -1)
targets = batch.targets.view(*scores.shape, -1)
output.scores = output.scores.view(len(batch.query_ids), -1)
targets = batch.targets.view(*output.scores.shape, -1)

losses = []
for loss_function, _ in self.loss_functions:
if isinstance(loss_function, InBatchLossFunction):
raise NotImplementedError("InBatchLossFunction not implemented for cross-encoders")
losses.append(loss_function.compute_loss(scores, targets))
if not isinstance(loss_function, ScoringLossFunction):
raise RuntimeError(f"Loss function {loss_function} is not a scoring loss function")
losses.append(loss_function.compute_loss(output, targets))
return losses
83 changes: 52 additions & 31 deletions lightning_ir/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,23 @@
import torch

if TYPE_CHECKING:
from ..bi_encoder import BiEncoderEmbedding
from ..base import LightningIROutput
from ..bi_encoder import BiEncoderOutput


class LossFunction(ABC):
@abstractmethod
def compute_loss(self, *args, **kwargs) -> torch.Tensor: ...
def compute_loss(self, output: LightningIROutput, *args, **kwargs) -> torch.Tensor: ...

def process_scores(self, output: LightningIROutput) -> torch.Tensor:
if output.scores is None:
raise ValueError("Expected scores in LightningIROutput")
return output.scores


class ScoringLossFunction(LossFunction):
@abstractmethod
def compute_loss(
self,
scores: torch.Tensor,
targets: torch.Tensor,
) -> torch.Tensor: ...
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor: ...

def process_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
if targets.ndim > scores.ndim:
Expand All @@ -30,9 +32,7 @@ def process_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.

class EmbeddingLossFunction(LossFunction):
@abstractmethod
def compute_loss(
self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding
) -> torch.Tensor: ...
def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: ...


class PairwiseLossFunction(ScoringLossFunction):
Expand All @@ -49,7 +49,8 @@ class MarginMSE(PairwiseLossFunction):
def __init__(self, margin: float | Literal["scores"] = 1.0):
self.margin = margin

def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
pos = scores[query_idcs, pos_idcs]
Expand All @@ -76,7 +77,8 @@ def __init__(self):


class RankNet(PairwiseLossFunction):
def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
pos = scores[query_idcs, pos_idcs]
Expand All @@ -87,7 +89,8 @@ def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Ten


class KLDivergence(ListwiseLossFunction):
def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
scores = torch.nn.functional.log_softmax(scores, dim=-1)
targets = torch.nn.functional.log_softmax(targets.to(scores), dim=-1)
Expand All @@ -96,7 +99,8 @@ def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Ten


class LocalizedContrastiveEstimation(ListwiseLossFunction):
def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
targets = targets.argmax(dim=1)
loss = torch.nn.functional.cross_entropy(scores, targets)
Expand Down Expand Up @@ -159,7 +163,9 @@ def get_ndcg(
ndcg = dcg / (idcg.clamp(min=1e-12))
return ndcg

def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
scores = self.process_scores(output)
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
approx_ranks = self.get_approx_ranks(scores, self.temperature)
ndcg = self.get_ndcg(approx_ranks, targets, k=None, scale_gains=self.scale_gains)
Expand All @@ -181,7 +187,8 @@ def get_mrr(ranks: torch.Tensor, targets: torch.Tensor, k: int | None = None) ->
mrr = mrr.max(dim=-1)[0]
return mrr

def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
approx_ranks = self.get_approx_ranks(scores, self.temperature)
mrr = self.get_mrr(approx_ranks, targets, k=None)
Expand All @@ -198,7 +205,8 @@ def __init__(
super().__init__(temperature)
self.discount = discount

def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput, targets: torch.Tensor) -> torch.Tensor:
scores = self.process_scores(output)
targets = self.process_targets(scores, targets)
approx_ranks = self.get_approx_ranks(scores, self.temperature)
ranks = torch.argsort(torch.argsort(targets, descending=True)) + 1
Expand Down Expand Up @@ -262,7 +270,7 @@ def get_sorted_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> tor
return pred_sorted_targets


class InBatchLossFunction(ScoringLossFunction):
class InBatchLossFunction(LossFunction):
def __init__(
self,
pos_sampling_technique: Literal["all", "first"] = "all",
Expand Down Expand Up @@ -305,12 +313,13 @@ def get_ib_idcs(self, num_queries: int, num_docs: int) -> Tuple[torch.Tensor, to
neg_idcs = neg_idcs.reshape(-1)
return pos_idcs, neg_idcs

def compute_loss(self, scores: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("InBatchLossFunction.compute_loss must be implemented by subclasses")
def compute_loss(self, output: LightningIROutput) -> torch.Tensor:
return super().compute_loss(output)


class InBatchCrossEntropy(InBatchLossFunction):
def compute_loss(self, scores: torch.Tensor) -> torch.Tensor:
def compute_loss(self, output: LightningIROutput) -> torch.Tensor:
scores = self.process_scores(output)
targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device)
loss = torch.nn.functional.cross_entropy(scores, targets)
return loss
Expand All @@ -321,27 +330,39 @@ def __init__(self, query_weight: float = 1e-4, doc_weight: float = 1e-4) -> None
self.query_weight = query_weight
self.doc_weight = doc_weight

def process_embeddings(self, output: BiEncoderOutput) -> Tuple[torch.Tensor, torch.Tensor]:
query_embeddings = output.query_embeddings
doc_embeddings = output.doc_embeddings
if query_embeddings is None:
raise ValueError("Expected query_embeddings in BiEncoderOutput")
if doc_embeddings is None:
raise ValueError("Expected doc_embeddings in BiEncoderOutput")
return query_embeddings.embeddings, doc_embeddings.embeddings


class L2Regularization(RegularizationLossFunction):
def compute_loss(self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding) -> torch.Tensor:
query_loss = self.query_weight * query_embeddings.embeddings.norm(dim=-1).mean()
doc_loss = self.doc_weight * doc_embeddings.embeddings.norm(dim=-1).mean()
def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
query_embeddings, doc_embeddings = self.process_embeddings(output)
query_loss = self.query_weight * query_embeddings.norm(dim=-1).mean()
doc_loss = self.doc_weight * doc_embeddings.norm(dim=-1).mean()
loss = query_loss + doc_loss
return loss


class L1Regularization(RegularizationLossFunction):
def compute_loss(self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding) -> torch.Tensor:
query_loss = self.query_weight * query_embeddings.embeddings.norm(p=1, dim=-1).mean()
doc_loss = self.doc_weight * doc_embeddings.embeddings.norm(p=1, dim=-1).mean()
def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
query_embeddings, doc_embeddings = self.process_embeddings(output)
query_loss = self.query_weight * query_embeddings.norm(p=1, dim=-1).mean()
doc_loss = self.doc_weight * doc_embeddings.norm(p=1, dim=-1).mean()
loss = query_loss + doc_loss
return loss


class FLOPSRegularization(RegularizationLossFunction):
def compute_loss(self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding) -> torch.Tensor:
query_loss = torch.sum(torch.mean(torch.abs(query_embeddings.embeddings), dim=0) ** 2)
doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings.embeddings), dim=0) ** 2)
anti_zero = 1 / (torch.sum(query_embeddings.embeddings) ** 2) + 1 / (torch.sum(doc_embeddings.embeddings) ** 2)
def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
query_embeddings, doc_embeddings = self.process_embeddings(output)
query_loss = torch.sum(torch.mean(torch.abs(query_embeddings), dim=0) ** 2)
doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings), dim=0) ** 2)
anti_zero = 1 / (torch.sum(query_embeddings) ** 2) + 1 / (torch.sum(doc_embeddings) ** 2)
loss = self.query_weight * query_loss + self.doc_weight * doc_loss + anti_zero
return loss
Loading

0 comments on commit 34b7be3

Please sign in to comment.