Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Nov 4, 2024
1 parent 2fa6f68 commit 2875f3a
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 27 deletions.
9 changes: 6 additions & 3 deletions lightning_ir/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def validate(
doc_ids: Sequence[Sequence[str]] | None = None,
qrels: Sequence[Dict[str, int]] | None = None,
targets: torch.Tensor | None = None,
num_docs: Sequence[int] | None = None,
num_docs: Sequence[int] | int | None = None,
) -> Dict[str, float]:
"""Validates the model output with the evaluation metrics and loss functions.
Expand All @@ -283,8 +283,11 @@ def validate(
:type qrels: Sequence[Dict[str, int]] | None, optional
:param targets: Target tensor used during fine-tuning, defaults to None
:type targets: torch.Tensor | None, optional
:param num_docs: Number of documents per query, defaults to None
:type num_docs: Sequence[int] | None, optional
:param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
sequence contains one value per query specifying the number of documents for that query. If an integer,
assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
the number of documents by the number of queries, defaults to None
:raises ValueError: If num_docs can not be parsed and query_ids are not set
:raises ValueError: If num_docs can not be parsed and doc_ids are not set
:return: _description_
Expand Down
17 changes: 12 additions & 5 deletions lightning_ir/bi_encoder/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@


class BiEncoderConfig(LightningIRConfig):
"""The configuration class to instantiate a Bi-Encoder model."""

model_type = "bi-encoder"
model_type: str = "bi-encoder"
"""Model type for bi-encoder models."""

TOKENIZER_ARGS = LightningIRConfig.TOKENIZER_ARGS.union(
{
Expand All @@ -26,6 +25,7 @@ class BiEncoderConfig(LightningIRConfig):
"add_marker_tokens",
}
)
"""Arguments for the tokenizer."""

ADDED_ARGS = LightningIRConfig.ADDED_ARGS.union(
{
Expand All @@ -41,9 +41,12 @@ class BiEncoderConfig(LightningIRConfig):
"projection",
}
).union(TOKENIZER_ARGS)
"""Arguments added to the configuration."""

def __init__(
self,
query_length: int = 32,
doc_length: int = 512,
similarity_function: Literal["cosine", "dot"] = "dot",
query_expansion: bool = False,
attend_to_query_expanded_tokens: bool = False,
Expand All @@ -61,8 +64,12 @@ def __init__(
projection: Literal["linear", "linear_no_bias", "mlm"] | None = "linear",
**kwargs,
):
"""Initializes a bi-encoder configuration.
"""Configuration class for a bi-encoder model.
:param query_length: Maximum query length, defaults to 32
:type query_length: int, optional
:param doc_length: Maximum document length, defaults to 512
:type doc_length: int, optional
:param similarity_function: Similarity function to compute scores between query and document embeddings,
defaults to "dot"
:type similarity_function: Literal['cosine', 'dot'], optional
Expand Down Expand Up @@ -97,7 +104,7 @@ def __init__(
:param projection: Whether and how to project the output emeddings, defaults to "linear"
:type projection: Literal['linear', 'linear_no_bias', 'mlm'] | None, optional
"""
super().__init__(**kwargs)
super().__init__(query_length=query_length, doc_length=doc_length, **kwargs)
self.similarity_function = similarity_function
self.query_expansion = query_expansion
self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
Expand Down
4 changes: 2 additions & 2 deletions lightning_ir/bi_encoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import dataclass
from functools import wraps
from string import punctuation
from typing import Callable, Iterable, Literal, Sequence, Tuple, overload
from typing import Callable, Iterable, Literal, Sequence, Tuple, Type, overload

import torch
from transformers import BatchEncoding
Expand Down Expand Up @@ -124,7 +124,7 @@ class BiEncoderModel(LightningIRModel):
_tied_weights_keys = ["projection.decoder.bias", "projection.decoder.weight", "encoder.embed_tokens.weight"]
_keys_to_ignore_on_load_unexpected = [r"decoder"]

config_class = BiEncoderConfig
config_class: Type[BiEncoderConfig] = BiEncoderConfig
"""Configuration class for the bi-encoder model."""

def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None:
Expand Down
4 changes: 2 additions & 2 deletions lightning_ir/bi_encoder/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import warnings
from typing import Dict, Sequence
from typing import Dict, Sequence, Type

from tokenizers.processors import TemplateProcessing
from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast
Expand All @@ -16,7 +16,7 @@

class BiEncoderTokenizer(LightningIRTokenizer):

config_class = BiEncoderConfig
config_class: Type[BiEncoderConfig] = BiEncoderConfig
"""Configuration class for the tokenizer."""

QUERY_TOKEN: str = "[QUE]"
Expand Down
5 changes: 5 additions & 0 deletions lightning_ir/cross_encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Base module for cross-encoder models.
This module provides the main classes and functions for cross-encoder models, including configurations, models,
modules, and tokenizers."""

from .config import CrossEncoderConfig
from .model import CrossEncoderModel, CrossEncoderOutput
from .module import CrossEncoderModule
Expand Down
22 changes: 21 additions & 1 deletion lightning_ir/cross_encoder/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
"""
Configuration module for cross-encoder models.
This module defines the configuration class used to instantiate cross-encoder models.
"""

from typing import Literal

from ..base import LightningIRConfig


class CrossEncoderConfig(LightningIRConfig):
model_type = "cross-encoder"
model_type: str = "cross-encoder"
"""Model type for cross-encoder models."""

ADDED_ARGS = LightningIRConfig.ADDED_ARGS.union({"pooling_strategy", "linear_bias"})
"""Arguments added to the configuration."""

def __init__(
self,
Expand All @@ -16,6 +24,18 @@ def __init__(
linear_bias: bool = False,
**kwargs
):
"""Configuration class for a cross-encoder model
:param query_length: Maximum query length, defaults to 32
:type query_length: int, optional
:param doc_length: Maximum document length, defaults to 512
:type doc_length: int, optional
:param pooling_strategy: Pooling strategy to aggregate the contextualized embeddings into a single vector for
computing a relevance score, defaults to "first"
:type pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional
:param linear_bias: Whether to use a bias in the prediction linear layer, defaults to False
:type linear_bias: bool, optional
"""
super().__init__(query_length=query_length, doc_length=doc_length, **kwargs)
self.pooling_strategy = pooling_strategy
self.linear_bias = linear_bias
24 changes: 24 additions & 0 deletions lightning_ir/cross_encoder/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
Model module for cross-encoder models.
This module defines the model class used to implement cross-encoder models.
"""

from dataclasses import dataclass
from typing import Type

Expand All @@ -11,19 +17,37 @@

@dataclass
class CrossEncoderOutput(LightningIROutput):
"""Dataclass containing the output of a cross-encoder model"""

embeddings: torch.Tensor | None = None
"""Joint query-document embeddings"""


class CrossEncoderModel(LightningIRModel):
config_class: Type[CrossEncoderConfig] = CrossEncoderConfig
"""Configuration class for cross-encoder models."""

def __init__(self, config: CrossEncoderConfig, *args, **kwargs):
"""A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are
aggragated into a single vector and fed to a linear layer which computes a final relevance score.
:param config: Configuration for the cross-encoder model
:type config: CrossEncoderConfig
"""
super().__init__(config, *args, **kwargs)
self.config: CrossEncoderConfig
self.linear = torch.nn.Linear(config.hidden_size, 1, bias=config.linear_bias)

@batch_encoding_wrapper
def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput:
"""Computes contextualized embeddings for the joint query-document input sequence and computes a relevance
score.
:param encoding: Tokenizer encoding for the joint query-document input sequence
:type encoding: BatchEncoding
:return: Output of the model
:rtype: CrossEncoderOutput
"""
embeddings = self._backbone_forward(**encoding).last_hidden_state
embeddings = self._pooling(
embeddings, encoding.get("attention_mask", None), pooling_strategy=self.config.pooling_strategy
Expand Down
33 changes: 32 additions & 1 deletion lightning_ir/cross_encoder/module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
Module module for cross-encoder models.
This module defines the Lightning IR module class used to implement cross-encoder models.
"""

from typing import List, Sequence, Tuple

import torch
Expand All @@ -19,14 +25,38 @@ def __init__(
loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None,
evaluation_metrics: Sequence[str] | None = None,
):
""":class:`.LightningIRModule` for cross-encoder models. It contains a :class:`.CrossEncoderModel` and a
:class:`.CrossEncoderTokenizer` and implements the training, validation, and testing steps for the model.
:param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None
:type model_name_or_path: str | None, optional
:param config: CrossEncoderConfig to apply when loading from backbone model, defaults to None
:type config: CrossEncoderConfig | None, optional
:param model: Already instantiated CrossEncoderModel, defaults to None
:type model: CrossEncoderModel | None, optional
:param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per
loss function, defaults to None
:type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional
:param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or
testing, defaults to None
"""
super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics)
self.model: CrossEncoderModel
self.config: CrossEncoderConfig
self.tokenizer: CrossEncoderTokenizer

def forward(self, batch: RankBatch | TrainBatch | SearchBatch) -> CrossEncoderOutput:
"""Runs a forward pass of the model on a batch of data and returns the contextualized embeddings from the
backbone model as well as the relevance scores.
:param batch: Batch of data to run the forward pass on
:type batch: RankBatch | TrainBatch | SearchBatch
:raises ValueError: If the batch is a SearchBatch
:return: Output of the model
:rtype: CrossEncoderOutput
"""
if isinstance(batch, SearchBatch):
raise NotImplementedError("Searching is not available for cross-encoders")
raise ValueError("Searching is not available for cross-encoders")
queries = batch.queries
docs = [d for docs in batch.docs for d in docs]
num_docs = [len(docs) for docs in batch.docs]
Expand All @@ -35,6 +65,7 @@ def forward(self, batch: RankBatch | TrainBatch | SearchBatch) -> CrossEncoderOu
return output

def _compute_losses(self, batch: TrainBatch, output: CrossEncoderOutput) -> List[torch.Tensor]:
"""Computes the losses for a training batch."""
if self.loss_functions is None:
raise ValueError("loss_functions must be set in the module")
output = self.forward(batch)
Expand Down
64 changes: 51 additions & 13 deletions lightning_ir/cross_encoder/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
Tokenizer module for cross-encoder models.
This module contains the tokenizer class cross-encoder models.
"""

from typing import Dict, List, Sequence, Tuple, Type

from transformers import BatchEncoding
Expand All @@ -9,11 +15,21 @@
class CrossEncoderTokenizer(LightningIRTokenizer):

config_class: Type[CrossEncoderConfig] = CrossEncoderConfig
"""Configuration class for the tokenizer."""

def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs):
""":class:`.LightningIRTokenizer` for cross-encoder models. Encodes queries and documents jointly and ensures
that the input sequences are of the correct length.
:param query_length: _description_, defaults to 32
:type query_length: int, optional
:param doc_length: _description_, defaults to 512
:type doc_length: int, optional
"""
super().__init__(*args, query_length=query_length, doc_length=doc_length, **kwargs)

def truncate(self, text: Sequence[str], max_length: int) -> List[str]:
def _truncate(self, text: Sequence[str], max_length: int) -> List[str]:
"""Encodes a list of texts, truncates them to a maximum number of tokens and decodes them to strings."""
return self.batch_decode(
self(
text,
Expand All @@ -25,15 +41,17 @@ def truncate(self, text: Sequence[str], max_length: int) -> List[str]:
).input_ids
)

def expand_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> List[str]:
def _repeat_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> List[str]:
"""Repeats queries to match the number of documents."""
return [query for query_idx, query in enumerate(queries) for _ in range(num_docs[query_idx])]

def preprocess(
def _preprocess(
self,
queries: str | Sequence[str] | None,
docs: str | Sequence[str] | None,
num_docs: Sequence[int] | None,
num_docs: Sequence[int] | int | None,
) -> Tuple[str | Sequence[str], str | Sequence[str]]:
"""Preprocesses queries and documents to ensure that they are truncated their respective maximum lengths."""
if queries is None or docs is None:
raise ValueError("Both queries and docs must be provided.")
queries_is_string = isinstance(queries, str)
Expand All @@ -43,27 +61,47 @@ def preprocess(
if queries_is_string and docs_is_string:
queries = [queries]
docs = [docs]
truncated_queries = self.truncate(queries, self.query_length)
truncated_docs = self.truncate(docs, self.doc_length)
truncated_queries = self._truncate(queries, self.query_length)
truncated_docs = self._truncate(docs, self.doc_length)
if not queries_is_string:
if num_docs is None:
num_docs = [len(docs) // len(queries) for _ in range(len(queries))]
expanded_queries = self.expand_queries(truncated_queries, num_docs)
if isinstance(num_docs, int):
num_docs = [num_docs] * len(queries)
else:
if len(docs) % len(queries) != 0:
raise ValueError("Number of documents must be divisible by the number of queries.")
num_docs = [len(docs) // len(queries) for _ in range(len(queries))]
repeated_queries = self._repeat_queries(truncated_queries, num_docs)
docs = truncated_docs
else:
expanded_queries = truncated_queries[0]
repeated_queries = truncated_queries[0]
docs = truncated_docs[0]
return expanded_queries, docs
return repeated_queries, docs

def tokenize(
self,
queries: str | Sequence[str] | None = None,
docs: str | Sequence[str] | None = None,
num_docs: Sequence[int] | None = None,
num_docs: Sequence[int] | int | None = None,
**kwargs,
) -> Dict[str, BatchEncoding]:
expanded_queries, docs = self.preprocess(queries, docs, num_docs)
"""Tokenizes queries and documents into a single sequence of tokens.
:param queries: Queries to tokenize, defaults to None
:type queries: str | Sequence[str] | None, optional
:param docs: Documents to tokenize, defaults to None
:type docs: str | Sequence[str] | None, optional
:param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
sequence contains one value per query specifying the number of documents for that query. If an integer,
assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
the number of documents by the number of queries, defaults to None
:type num_docs: Sequence[int] | int | None, optional
:return: Tokenized query-document sequence
:rtype: Dict[str, BatchEncoding]
"""
repeated_queries, docs = self._preprocess(queries, docs, num_docs)
return_tensors = kwargs.get("return_tensors", None)
if return_tensors is not None:
kwargs["pad_to_multiple_of"] = 8
return {"encoding": self(expanded_queries, docs, **kwargs)}
return {"encoding": self(repeated_queries, docs, **kwargs)}

0 comments on commit 2875f3a

Please sign in to comment.