diff --git a/lightning_ir/base/module.py b/lightning_ir/base/module.py index 2ff542d..801ab44 100644 --- a/lightning_ir/base/module.py +++ b/lightning_ir/base/module.py @@ -269,7 +269,7 @@ def validate( doc_ids: Sequence[Sequence[str]] | None = None, qrels: Sequence[Dict[str, int]] | None = None, targets: torch.Tensor | None = None, - num_docs: Sequence[int] | None = None, + num_docs: Sequence[int] | int | None = None, ) -> Dict[str, float]: """Validates the model output with the evaluation metrics and loss functions. @@ -283,8 +283,11 @@ def validate( :type qrels: Sequence[Dict[str, int]] | None, optional :param targets: Target tensor used during fine-tuning, defaults to None :type targets: torch.Tensor | None, optional - :param num_docs: Number of documents per query, defaults to None - :type num_docs: Sequence[int] | None, optional + :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` + should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the + sequence contains one value per query specifying the number of documents for that query. If an integer, + assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing + the number of documents by the number of queries, defaults to None :raises ValueError: If num_docs can not be parsed and query_ids are not set :raises ValueError: If num_docs can not be parsed and doc_ids are not set :return: _description_ diff --git a/lightning_ir/bi_encoder/config.py b/lightning_ir/bi_encoder/config.py index d3007d9..3259972 100644 --- a/lightning_ir/bi_encoder/config.py +++ b/lightning_ir/bi_encoder/config.py @@ -13,9 +13,8 @@ class BiEncoderConfig(LightningIRConfig): - """The configuration class to instantiate a Bi-Encoder model.""" - - model_type = "bi-encoder" + model_type: str = "bi-encoder" + """Model type for bi-encoder models.""" TOKENIZER_ARGS = LightningIRConfig.TOKENIZER_ARGS.union( { @@ -26,6 +25,7 @@ class BiEncoderConfig(LightningIRConfig): "add_marker_tokens", } ) + """Arguments for the tokenizer.""" ADDED_ARGS = LightningIRConfig.ADDED_ARGS.union( { @@ -41,9 +41,12 @@ class BiEncoderConfig(LightningIRConfig): "projection", } ).union(TOKENIZER_ARGS) + """Arguments added to the configuration.""" def __init__( self, + query_length: int = 32, + doc_length: int = 512, similarity_function: Literal["cosine", "dot"] = "dot", query_expansion: bool = False, attend_to_query_expanded_tokens: bool = False, @@ -61,8 +64,12 @@ def __init__( projection: Literal["linear", "linear_no_bias", "mlm"] | None = "linear", **kwargs, ): - """Initializes a bi-encoder configuration. + """Configuration class for a bi-encoder model. + :param query_length: Maximum query length, defaults to 32 + :type query_length: int, optional + :param doc_length: Maximum document length, defaults to 512 + :type doc_length: int, optional :param similarity_function: Similarity function to compute scores between query and document embeddings, defaults to "dot" :type similarity_function: Literal['cosine', 'dot'], optional @@ -97,7 +104,7 @@ def __init__( :param projection: Whether and how to project the output emeddings, defaults to "linear" :type projection: Literal['linear', 'linear_no_bias', 'mlm'] | None, optional """ - super().__init__(**kwargs) + super().__init__(query_length=query_length, doc_length=doc_length, **kwargs) self.similarity_function = similarity_function self.query_expansion = query_expansion self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens diff --git a/lightning_ir/bi_encoder/model.py b/lightning_ir/bi_encoder/model.py index 021875d..5d82a1b 100644 --- a/lightning_ir/bi_encoder/model.py +++ b/lightning_ir/bi_encoder/model.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from functools import wraps from string import punctuation -from typing import Callable, Iterable, Literal, Sequence, Tuple, overload +from typing import Callable, Iterable, Literal, Sequence, Tuple, Type, overload import torch from transformers import BatchEncoding @@ -124,7 +124,7 @@ class BiEncoderModel(LightningIRModel): _tied_weights_keys = ["projection.decoder.bias", "projection.decoder.weight", "encoder.embed_tokens.weight"] _keys_to_ignore_on_load_unexpected = [r"decoder"] - config_class = BiEncoderConfig + config_class: Type[BiEncoderConfig] = BiEncoderConfig """Configuration class for the bi-encoder model.""" def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None: diff --git a/lightning_ir/bi_encoder/tokenizer.py b/lightning_ir/bi_encoder/tokenizer.py index 4c7aa0d..11941d6 100644 --- a/lightning_ir/bi_encoder/tokenizer.py +++ b/lightning_ir/bi_encoder/tokenizer.py @@ -5,7 +5,7 @@ """ import warnings -from typing import Dict, Sequence +from typing import Dict, Sequence, Type from tokenizers.processors import TemplateProcessing from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast @@ -16,7 +16,7 @@ class BiEncoderTokenizer(LightningIRTokenizer): - config_class = BiEncoderConfig + config_class: Type[BiEncoderConfig] = BiEncoderConfig """Configuration class for the tokenizer.""" QUERY_TOKEN: str = "[QUE]" diff --git a/lightning_ir/cross_encoder/__init__.py b/lightning_ir/cross_encoder/__init__.py index 983393d..609c77d 100644 --- a/lightning_ir/cross_encoder/__init__.py +++ b/lightning_ir/cross_encoder/__init__.py @@ -1,3 +1,8 @@ +"""Base module for cross-encoder models. + +This module provides the main classes and functions for cross-encoder models, including configurations, models, +modules, and tokenizers.""" + from .config import CrossEncoderConfig from .model import CrossEncoderModel, CrossEncoderOutput from .module import CrossEncoderModule diff --git a/lightning_ir/cross_encoder/config.py b/lightning_ir/cross_encoder/config.py index 4ca8c4b..f2bb20d 100644 --- a/lightning_ir/cross_encoder/config.py +++ b/lightning_ir/cross_encoder/config.py @@ -1,12 +1,20 @@ +""" +Configuration module for cross-encoder models. + +This module defines the configuration class used to instantiate cross-encoder models. +""" + from typing import Literal from ..base import LightningIRConfig class CrossEncoderConfig(LightningIRConfig): - model_type = "cross-encoder" + model_type: str = "cross-encoder" + """Model type for cross-encoder models.""" ADDED_ARGS = LightningIRConfig.ADDED_ARGS.union({"pooling_strategy", "linear_bias"}) + """Arguments added to the configuration.""" def __init__( self, @@ -16,6 +24,18 @@ def __init__( linear_bias: bool = False, **kwargs ): + """Configuration class for a cross-encoder model + + :param query_length: Maximum query length, defaults to 32 + :type query_length: int, optional + :param doc_length: Maximum document length, defaults to 512 + :type doc_length: int, optional + :param pooling_strategy: Pooling strategy to aggregate the contextualized embeddings into a single vector for + computing a relevance score, defaults to "first" + :type pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional + :param linear_bias: Whether to use a bias in the prediction linear layer, defaults to False + :type linear_bias: bool, optional + """ super().__init__(query_length=query_length, doc_length=doc_length, **kwargs) self.pooling_strategy = pooling_strategy self.linear_bias = linear_bias diff --git a/lightning_ir/cross_encoder/model.py b/lightning_ir/cross_encoder/model.py index 54fa1aa..e3a21e0 100644 --- a/lightning_ir/cross_encoder/model.py +++ b/lightning_ir/cross_encoder/model.py @@ -1,3 +1,9 @@ +""" +Model module for cross-encoder models. + +This module defines the model class used to implement cross-encoder models. +""" + from dataclasses import dataclass from typing import Type @@ -11,19 +17,37 @@ @dataclass class CrossEncoderOutput(LightningIROutput): + """Dataclass containing the output of a cross-encoder model""" + embeddings: torch.Tensor | None = None + """Joint query-document embeddings""" class CrossEncoderModel(LightningIRModel): config_class: Type[CrossEncoderConfig] = CrossEncoderConfig + """Configuration class for cross-encoder models.""" def __init__(self, config: CrossEncoderConfig, *args, **kwargs): + """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are + aggragated into a single vector and fed to a linear layer which computes a final relevance score. + + :param config: Configuration for the cross-encoder model + :type config: CrossEncoderConfig + """ super().__init__(config, *args, **kwargs) self.config: CrossEncoderConfig self.linear = torch.nn.Linear(config.hidden_size, 1, bias=config.linear_bias) @batch_encoding_wrapper def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput: + """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance + score. + + :param encoding: Tokenizer encoding for the joint query-document input sequence + :type encoding: BatchEncoding + :return: Output of the model + :rtype: CrossEncoderOutput + """ embeddings = self._backbone_forward(**encoding).last_hidden_state embeddings = self._pooling( embeddings, encoding.get("attention_mask", None), pooling_strategy=self.config.pooling_strategy diff --git a/lightning_ir/cross_encoder/module.py b/lightning_ir/cross_encoder/module.py index 04488b2..7d40ab3 100644 --- a/lightning_ir/cross_encoder/module.py +++ b/lightning_ir/cross_encoder/module.py @@ -1,3 +1,9 @@ +""" +Module module for cross-encoder models. + +This module defines the Lightning IR module class used to implement cross-encoder models. +""" + from typing import List, Sequence, Tuple import torch @@ -19,14 +25,38 @@ def __init__( loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, ): + """:class:`.LightningIRModule` for cross-encoder models. It contains a :class:`.CrossEncoderModel` and a + :class:`.CrossEncoderTokenizer` and implements the training, validation, and testing steps for the model. + + :param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None + :type model_name_or_path: str | None, optional + :param config: CrossEncoderConfig to apply when loading from backbone model, defaults to None + :type config: CrossEncoderConfig | None, optional + :param model: Already instantiated CrossEncoderModel, defaults to None + :type model: CrossEncoderModel | None, optional + :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per + loss function, defaults to None + :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional + :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or + testing, defaults to None + """ super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics) self.model: CrossEncoderModel self.config: CrossEncoderConfig self.tokenizer: CrossEncoderTokenizer def forward(self, batch: RankBatch | TrainBatch | SearchBatch) -> CrossEncoderOutput: + """Runs a forward pass of the model on a batch of data and returns the contextualized embeddings from the + backbone model as well as the relevance scores. + + :param batch: Batch of data to run the forward pass on + :type batch: RankBatch | TrainBatch | SearchBatch + :raises ValueError: If the batch is a SearchBatch + :return: Output of the model + :rtype: CrossEncoderOutput + """ if isinstance(batch, SearchBatch): - raise NotImplementedError("Searching is not available for cross-encoders") + raise ValueError("Searching is not available for cross-encoders") queries = batch.queries docs = [d for docs in batch.docs for d in docs] num_docs = [len(docs) for docs in batch.docs] @@ -35,6 +65,7 @@ def forward(self, batch: RankBatch | TrainBatch | SearchBatch) -> CrossEncoderOu return output def _compute_losses(self, batch: TrainBatch, output: CrossEncoderOutput) -> List[torch.Tensor]: + """Computes the losses for a training batch.""" if self.loss_functions is None: raise ValueError("loss_functions must be set in the module") output = self.forward(batch) diff --git a/lightning_ir/cross_encoder/tokenizer.py b/lightning_ir/cross_encoder/tokenizer.py index 991feed..c4b1500 100644 --- a/lightning_ir/cross_encoder/tokenizer.py +++ b/lightning_ir/cross_encoder/tokenizer.py @@ -1,3 +1,9 @@ +""" +Tokenizer module for cross-encoder models. + +This module contains the tokenizer class cross-encoder models. +""" + from typing import Dict, List, Sequence, Tuple, Type from transformers import BatchEncoding @@ -9,11 +15,21 @@ class CrossEncoderTokenizer(LightningIRTokenizer): config_class: Type[CrossEncoderConfig] = CrossEncoderConfig + """Configuration class for the tokenizer.""" def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs): + """:class:`.LightningIRTokenizer` for cross-encoder models. Encodes queries and documents jointly and ensures + that the input sequences are of the correct length. + + :param query_length: _description_, defaults to 32 + :type query_length: int, optional + :param doc_length: _description_, defaults to 512 + :type doc_length: int, optional + """ super().__init__(*args, query_length=query_length, doc_length=doc_length, **kwargs) - def truncate(self, text: Sequence[str], max_length: int) -> List[str]: + def _truncate(self, text: Sequence[str], max_length: int) -> List[str]: + """Encodes a list of texts, truncates them to a maximum number of tokens and decodes them to strings.""" return self.batch_decode( self( text, @@ -25,15 +41,17 @@ def truncate(self, text: Sequence[str], max_length: int) -> List[str]: ).input_ids ) - def expand_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> List[str]: + def _repeat_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> List[str]: + """Repeats queries to match the number of documents.""" return [query for query_idx, query in enumerate(queries) for _ in range(num_docs[query_idx])] - def preprocess( + def _preprocess( self, queries: str | Sequence[str] | None, docs: str | Sequence[str] | None, - num_docs: Sequence[int] | None, + num_docs: Sequence[int] | int | None, ) -> Tuple[str | Sequence[str], str | Sequence[str]]: + """Preprocesses queries and documents to ensure that they are truncated their respective maximum lengths.""" if queries is None or docs is None: raise ValueError("Both queries and docs must be provided.") queries_is_string = isinstance(queries, str) @@ -43,27 +61,47 @@ def preprocess( if queries_is_string and docs_is_string: queries = [queries] docs = [docs] - truncated_queries = self.truncate(queries, self.query_length) - truncated_docs = self.truncate(docs, self.doc_length) + truncated_queries = self._truncate(queries, self.query_length) + truncated_docs = self._truncate(docs, self.doc_length) if not queries_is_string: if num_docs is None: - num_docs = [len(docs) // len(queries) for _ in range(len(queries))] - expanded_queries = self.expand_queries(truncated_queries, num_docs) + if isinstance(num_docs, int): + num_docs = [num_docs] * len(queries) + else: + if len(docs) % len(queries) != 0: + raise ValueError("Number of documents must be divisible by the number of queries.") + num_docs = [len(docs) // len(queries) for _ in range(len(queries))] + repeated_queries = self._repeat_queries(truncated_queries, num_docs) docs = truncated_docs else: - expanded_queries = truncated_queries[0] + repeated_queries = truncated_queries[0] docs = truncated_docs[0] - return expanded_queries, docs + return repeated_queries, docs def tokenize( self, queries: str | Sequence[str] | None = None, docs: str | Sequence[str] | None = None, - num_docs: Sequence[int] | None = None, + num_docs: Sequence[int] | int | None = None, **kwargs, ) -> Dict[str, BatchEncoding]: - expanded_queries, docs = self.preprocess(queries, docs, num_docs) + """Tokenizes queries and documents into a single sequence of tokens. + + :param queries: Queries to tokenize, defaults to None + :type queries: str | Sequence[str] | None, optional + :param docs: Documents to tokenize, defaults to None + :type docs: str | Sequence[str] | None, optional + :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` + should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the + sequence contains one value per query specifying the number of documents for that query. If an integer, + assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing + the number of documents by the number of queries, defaults to None + :type num_docs: Sequence[int] | int | None, optional + :return: Tokenized query-document sequence + :rtype: Dict[str, BatchEncoding] + """ + repeated_queries, docs = self._preprocess(queries, docs, num_docs) return_tensors = kwargs.get("return_tensors", None) if return_tensors is not None: kwargs["pad_to_multiple_of"] = 8 - return {"encoding": self(expanded_queries, docs, **kwargs)} + return {"encoding": self(repeated_queries, docs, **kwargs)}