Skip to content

Commit

Permalink
add scoring based in batch loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Nov 21, 2024
1 parent 52cf726 commit 0eaa265
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 27 deletions.
4 changes: 4 additions & 0 deletions lightning_ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
L2Regularization,
LocalizedContrastiveEstimation,
RankNet,
ScoreBasedInBatchCrossEntropy,
ScoreBasedInBatchLossFunction,
SupervisedMarginMSE,
)
from .main import LightningIRTrainer, LightningIRWandbLogger
Expand Down Expand Up @@ -186,6 +188,8 @@
"RankSample",
"ReRankCallback",
"RunDataset",
"ScoreBasedInBatchCrossEntropy",
"ScoreBasedInBatchLossFunction",
"ScoringFunction",
"SearchBatch",
"SearchCallback",
Expand Down
2 changes: 1 addition & 1 deletion lightning_ir/bi_encoder/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _compute_losses(self, batch: TrainBatch, output: BiEncoderOutput) -> List[to
losses = []
for loss_function, _ in self.loss_functions:
if isinstance(loss_function, InBatchLossFunction):
pos_idcs, neg_idcs = loss_function.get_ib_idcs(*output.scores.shape)
pos_idcs, neg_idcs = loss_function.get_ib_idcs(output, batch)
ib_doc_embeddings = self._get_ib_doc_embeddings(output.doc_embeddings, pos_idcs, neg_idcs, num_queries)
ib_scores = self.model.score(output.query_embeddings, ib_doc_embeddings)
ib_scores = ib_scores.view(num_queries, -1)
Expand Down
4 changes: 4 additions & 0 deletions lightning_ir/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
L2Regularization,
LocalizedContrastiveEstimation,
RankNet,
ScoreBasedInBatchCrossEntropy,
ScoreBasedInBatchLossFunction,
SupervisedMarginMSE,
)

Expand All @@ -25,5 +27,7 @@
"L2Regularization",
"LocalizedContrastiveEstimation",
"RankNet",
"ScoreBasedInBatchCrossEntropy",
"ScoreBasedInBatchLossFunction",
"SupervisedMarginMSE",
]
93 changes: 85 additions & 8 deletions lightning_ir/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

import torch

from lightning_ir.data import TrainBatch

if TYPE_CHECKING:
from ..base import LightningIROutput
from ..bi_encoder import BiEncoderOutput
from ..data import TrainBatch


class LossFunction(ABC):
Expand Down Expand Up @@ -274,17 +277,19 @@ class InBatchLossFunction(LossFunction):
def __init__(
self,
pos_sampling_technique: Literal["all", "first"] = "all",
neg_sampling_technique: Literal["all", "first"] = "all",
neg_sampling_technique: Literal["all", "first", "all_and_non_first"] = "all",
max_num_neg_samples: int | None = None,
):
super().__init__()
self.pos_sampling_technique = pos_sampling_technique
self.neg_sampling_technique = neg_sampling_technique
self.max_num_neg_samples = max_num_neg_samples
if self.neg_sampling_technique == "all_and_non_first" and self.pos_sampling_technique != "first":
raise ValueError("all_and_non_first is only valid with pos_sampling_technique first")

def get_ib_idcs(self, num_queries: int, num_docs: int) -> Tuple[torch.Tensor, torch.Tensor]:
min_idx = torch.arange(num_queries)[:, None] * num_docs
max_idx = min_idx + num_docs
def _get_pos_mask(
self, num_queries: int, num_docs: int, max_idx: torch.Tensor, min_idx: torch.Tensor, batch: TrainBatch
) -> torch.Tensor:
if self.pos_sampling_technique == "all":
pos_mask = torch.arange(num_queries * num_docs)[None].greater_equal(min_idx) & torch.arange(
num_queries * num_docs
Expand All @@ -293,8 +298,14 @@ def get_ib_idcs(self, num_queries: int, num_docs: int) -> Tuple[torch.Tensor, to
pos_mask = torch.arange(num_queries * num_docs)[None].eq(min_idx)
else:
raise ValueError("invalid pos sampling technique")
pos_idcs = pos_mask.nonzero(as_tuple=True)[1]
if self.neg_sampling_technique == "all":
return pos_mask

def _get_neg_mask(
self, num_queries: int, num_docs: int, max_idx: torch.Tensor, min_idx: torch.Tensor, batch: TrainBatch
) -> torch.Tensor:
if self.neg_sampling_technique == "all_and_non_first":
neg_mask = torch.arange(num_queries * num_docs)[None].not_equal(min_idx)
elif self.neg_sampling_technique == "all":
neg_mask = torch.arange(num_queries * num_docs)[None].less(min_idx) | torch.arange(num_queries * num_docs)[
None
].greater_equal(max_idx)
Expand All @@ -304,6 +315,17 @@ def get_ib_idcs(self, num_queries: int, num_docs: int) -> Tuple[torch.Tensor, to
)[None].ne(min_idx)
else:
raise ValueError("invalid neg sampling technique")
return neg_mask

def get_ib_idcs(self, output: LightningIROutput, batch: TrainBatch) -> Tuple[torch.Tensor, torch.Tensor]:
if output.scores is None:
raise ValueError("Expected scores in LightningIROutput")
num_queries, num_docs = output.scores.shape
min_idx = torch.arange(num_queries)[:, None] * num_docs
max_idx = min_idx + num_docs
pos_mask = self._get_pos_mask(num_queries, num_docs, max_idx, min_idx, batch)
neg_mask = self._get_neg_mask(num_queries, num_docs, max_idx, min_idx, batch)
pos_idcs = pos_mask.nonzero(as_tuple=True)[1]
neg_idcs = neg_mask.nonzero(as_tuple=True)[1]
if self.max_num_neg_samples is not None:
neg_idcs = neg_idcs.view(num_queries, -1)
Expand All @@ -313,8 +335,54 @@ def get_ib_idcs(self, num_queries: int, num_docs: int) -> Tuple[torch.Tensor, to
neg_idcs = neg_idcs.reshape(-1)
return pos_idcs, neg_idcs

def compute_loss(self, output: LightningIROutput) -> torch.Tensor:
return super().compute_loss(output)

class ScoreBasedInBatchLossFunction(InBatchLossFunction):

def __init__(
self,
min_target_diff: float,
pos_sampling_technique: Literal["first"] = "first",
neg_sampling_technique: Literal["all_and_non_first"] = "all_and_non_first",
max_num_neg_samples: int | None = None,
):
super().__init__(pos_sampling_technique, neg_sampling_technique, max_num_neg_samples)
self.min_target_diff = min_target_diff

def _sort_mask(self, mask: torch.Tensor, num_queries: int, num_docs: int, batch: TrainBatch) -> torch.Tensor:
assert batch.targets is not None
idcs = batch.targets.argsort(descending=True).argsort()
idcs = idcs + torch.arange(num_queries)[:, None] * num_docs
block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs
return mask.scatter(1, block_idcs, mask.gather(1, idcs))

def _get_pos_mask(
self, num_queries: int, num_docs: int, max_idx: torch.Tensor, min_idx: torch.Tensor, batch: TrainBatch
) -> torch.Tensor:
pos_mask = super()._get_pos_mask(num_queries, num_docs, max_idx, min_idx, batch)
pos_mask = self._sort_mask(pos_mask, num_queries, num_docs, batch)
return pos_mask

def _get_neg_mask(
self, num_queries: int, num_docs: int, max_idx: torch.Tensor, min_idx: torch.Tensor, batch: TrainBatch
) -> torch.Tensor:
neg_mask = super()._get_neg_mask(num_queries, num_docs, max_idx, min_idx, batch)
neg_mask = self._sort_mask(neg_mask, num_queries, num_docs, batch)
assert batch.targets is not None
max_score, _ = batch.targets.max(dim=-1, keepdim=True)
score_diff = max_score - batch.targets
score_mask = score_diff.ge(self.min_target_diff)
block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs
neg_mask = neg_mask.scatter(1, block_idcs, score_mask)
# num_neg_samples might be different between queries
num_neg_samples = neg_mask.sum(dim=1)
min_num_neg_samples = num_neg_samples.min()
additional_neg_samples = num_neg_samples - min_num_neg_samples
for query_idx, neg_samples in enumerate(additional_neg_samples):
neg_idcs = neg_mask[query_idx].nonzero().squeeze(1)
additional_neg_idcs = torch.randperm(neg_idcs.shape[0])[:neg_samples]
neg_mask[query_idx, additional_neg_idcs] = False
assert neg_mask.sum(dim=1).eq(min_num_neg_samples).all()
return neg_mask


class InBatchCrossEntropy(InBatchLossFunction):
Expand All @@ -325,6 +393,15 @@ def compute_loss(self, output: LightningIROutput) -> torch.Tensor:
return loss


class ScoreBasedInBatchCrossEntropy(ScoreBasedInBatchLossFunction):

def compute_loss(self, output: LightningIROutput) -> torch.Tensor:
scores = self.process_scores(output)
targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device)
loss = torch.nn.functional.cross_entropy(scores, targets)
return loss


class RegularizationLossFunction(EmbeddingLossFunction):
def __init__(self, query_weight: float = 1e-4, doc_weight: float = 1e-4) -> None:
self.query_weight = query_weight
Expand Down
2 changes: 1 addition & 1 deletion tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_register_local_dataset_callback(module: LightningIRModule):
qrels=str(CORPUS_DIR / "qrels.tsv"),
docpairs=str(CORPUS_DIR / "docpairs.tsv"),
)
datamodule = LightningIRDataModule(train_dataset=RunDataset("test"))
datamodule = LightningIRDataModule(train_dataset=RunDataset("test"), train_batch_size=2)

trainer = LightningIRTrainer(logger=False, enable_checkpointing=False, callbacks=[callback])

Expand Down
65 changes: 48 additions & 17 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from lightning_ir.base.model import LightningIROutput
from lightning_ir.bi_encoder.model import BiEncoderEmbedding, BiEncoderOutput
from lightning_ir.data.data import TrainBatch
from lightning_ir.loss.loss import (
ApproxMRR,
ApproxNDCG,
Expand All @@ -19,6 +20,7 @@
LocalizedContrastiveEstimation,
RankNet,
RegularizationLossFunction,
ScoreBasedInBatchCrossEntropy,
ScoringLossFunction,
SupervisedMarginMSE,
)
Expand Down Expand Up @@ -63,37 +65,66 @@ def embeddings(batch_size: int, sequence_length: int, embedding_dim: int) -> tor
return tensor


@pytest.fixture(scope="module")
def batch(batch_size: int, depth: int) -> TrainBatch:
return TrainBatch(
queries=["query"] * batch_size,
docs=[[f"doc{i}" for i in range(depth)]] * batch_size,
targets=torch.randint(0, 5, (batch_size, depth)),
)


@pytest.mark.parametrize(
"LossFunc",
"loss_func",
[
ApproxMRR,
ApproxNDCG,
ApproxRankMSE,
ConstantMarginMSE,
KLDivergence,
LocalizedContrastiveEstimation,
RankNet,
SupervisedMarginMSE,
ApproxMRR(),
ApproxNDCG(),
ApproxRankMSE(),
ConstantMarginMSE(),
KLDivergence(),
LocalizedContrastiveEstimation(),
RankNet(),
SupervisedMarginMSE(),
],
ids=[
"ApproxMRR",
"ApproxNDCG",
"ApproxRankMSE",
"ConstantMarginMSE",
"KLDivergence",
"LocalizedContrastiveEstimation",
"RankNet",
"SupervisedMarginMSE",
],
)
def test_loss_func(output: LightningIROutput, labels: torch.Tensor, LossFunc: Type[ScoringLossFunction]):
loss_func = LossFunc()
def test_loss_func(output: LightningIROutput, labels: torch.Tensor, loss_func: ScoringLossFunction):
loss = loss_func.compute_loss(output, labels)
assert loss >= 0
assert loss.requires_grad


@pytest.mark.parametrize("InBatchLossFunc", [InBatchCrossEntropy])
def test_in_batch_loss_func(InBatchLossFunc: Type[InBatchLossFunction], output: LightningIROutput):
loss_func = InBatchLossFunc()
@pytest.mark.parametrize(
"loss_func",
[
InBatchCrossEntropy(),
InBatchCrossEntropy("first", "all_and_non_first"),
ScoreBasedInBatchCrossEntropy(2, "first", "all_and_non_first"),
],
ids=["InBatchCrossEntropy(default)", "InBatchCrossEntropy(all_and_non_first)", "ScoreBasedInBatchCrossEntropy"],
)
def test_in_batch_loss_func(loss_func: InBatchLossFunction, output: LightningIROutput, batch: TrainBatch):
pos_idcs, neg_idcs = loss_func.get_ib_idcs(output, batch)
assert pos_idcs is not None
assert neg_idcs is not None
loss = loss_func.compute_loss(output)
assert loss >= 0
assert loss.requires_grad


@pytest.mark.parametrize("RegularizationLossFunc", [L1Regularization, L2Regularization, FLOPSRegularization])
def test_regularization_loss_func(RegularizationLossFunc: Type[RegularizationLossFunction], embeddings: torch.Tensor):
loss_func = RegularizationLossFunc()
@pytest.mark.parametrize(
"loss_func", [L1Regularization(), L2Regularization(), FLOPSRegularization()], ids=["L1", "L2", "FLOPS"]
)
def test_regularization_loss_func(loss_func: RegularizationLossFunction, embeddings: torch.Tensor):
loss = loss_func.compute_loss(
BiEncoderOutput(
None, BiEncoderEmbedding(embeddings, torch.empty(0)), BiEncoderEmbedding(embeddings, torch.empty(0))
Expand Down

0 comments on commit 0eaa265

Please sign in to comment.