Skip to content

Commit

Permalink
update doc strings
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Aug 7, 2024
1 parent e17fcf2 commit f83ca07
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 35 deletions.
15 changes: 14 additions & 1 deletion lightning_ir/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def to_tokenizer_dict(self) -> Dict[str, Any]:
return {arg: getattr(self, arg) for arg in self.TOKENIZER_ARGS}

def to_dict(self) -> Dict[str, Any]:
"""Overrides the `to_dict` method to include the added arguments and the backbone model type.
"""Overrides the transformers.PretrainedConfig.to_dict_ method to include the added arguments and the backbone model type.
.. transformers._PretrainedConfig.to_dict: https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig.to_dict
:return: Configuration dictionary
:rtype: Dict[str, Any]
Expand All @@ -64,6 +66,17 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_dict(cls, config_dict: Dict[str, Any], *args, **kwargs) -> "LightningIRConfig":
"""Loads the configuration from a dictionary. Wraps the transformers.PretrainedConfig.from_dict_ method to
return a derived LightningIRConfig class. See :class:`.LightningIRConfigClassFactory` for more details.
.. _transformers.PretrainedConfig.from_dict: https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig.from_dict
:param config_dict: Configuration dictionary
:type config_dict: Dict[str, Any]
:raises ValueError: If the model type does not match the configuration model type
:return: Derived LightningIRConfig class
:rtype: LightningIRConfig
"""
if all(issubclass(base, LightningIRConfig) for base in cls.__bases__) or cls is LightningIRConfig:
if "backbone_model_type" in config_dict:
backbone_model_type = config_dict["backbone_model_type"]
Expand Down
22 changes: 13 additions & 9 deletions lightning_ir/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class LightningIROutput(ModelOutput):


class LightningIRModel:
"""Base class for the LightningIR models. Derived classes implement the forward functionality for handling query
"""Base class for LightningIR models. Derived classes implement the forward method for handling query
and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model.
.. _transformers.PreTrainedModel: https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel
Expand Down Expand Up @@ -122,24 +122,28 @@ def _pooling(

@classmethod
def from_pretrained(cls, model_name_or_path: str, *args, **kwargs) -> "LightningIRModel":
"""Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method and returns a
derived LightningIRModel. See :func:`LightningIRModelClassFactory` for more details.
"""Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method and to return a
derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details.
.. _transformers.PreTrainedModel.from_pretrained: https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
.. highlight:: python
.. code-block:: python
>>> # Loading using model class and backbone checkpoint
>>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
...
<class 'lightning_ir.base.model.CrossEncoderBertModel'>
>>> type(ColModel.from_pretrained("bert-base-uncased"))
<class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
>>> # Loading using base class and backbone checkpoint
>>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
...
<class 'lightning_ir.base.model.ColBertModel'>
<class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
:raises ValueError: If called on the abstract class :class:`LightningIRModel`.
:raises ValueError: If the backbone model is not found.
:return: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin.
:param model_name_or_path: Name or path of the pretrained model
:type model_name_or_path: str
:raises ValueError: If called on the abstract class :class:`LightningIRModel` and no config is passed
:return: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin
:rtype: LightningIRModel
"""
# provides AutoModel.from_pretrained support
Expand Down
184 changes: 160 additions & 24 deletions lightning_ir/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@


class LightningIRModule(LightningModule):
"""LightningIRModule base class. LightningIRModules contain a LightningIRModel and a LightningIRTokenizer and
implements the training, validation, and testing steps for the model. Derived classes must implement the forward
method for the model.
"""

def __init__(
self,
model_name_or_path: str | None = None,
Expand All @@ -27,6 +32,25 @@ def __init__(
loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None,
evaluation_metrics: Sequence[str] | None = None,
):
"""Initializes the LightningIRModule.
.. _ir-measures: https://ir-measur.es/en/latest/index.html
:param model_name_or_path: Name or path of backbone model or fine-tuned LightningIR model, defaults to None
:type model_name_or_path: str | None, optional
:param config: LightningIRConfig to apply when loading from backbone model, defaults to None
:type config: LightningIRConfig | None, optional
:param model: Already instantiated LightningIR model, defaults to None
:type model: LightningIRModel | None, optional
:param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per
loss function, defaults to None
:type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional
:param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or
testing, defaults to None
:type evaluation_metrics: Sequence[str] | None, optional
:raises ValueError: If both model and model_name_or_path are provided
:raises ValueError: If neither model nor model_name_or_path are provided
"""
super().__init__()
if model is not None and model_name_or_path is not None:
raise ValueError("Only one of model or model_name_or_path must be provided.")
Expand All @@ -49,36 +73,73 @@ def __init__(
self.tokenizer = LightningIRTokenizer.from_pretrained(self.config.name_or_path, config=config)

def on_fit_start(self) -> None:
"""Called at the very beginning of fit.
If on DDP it is called on every process
"""
# NOTE huggingface models are in eval mode by default
self.train()
return super().on_fit_start()

def forward(self, batch: TrainBatch | RankBatch) -> LightningIROutput:
"""Handles the forward pass of the model.
:param batch: Batch of training or ranking data
:type batch: TrainBatch | RankBatch
:raises NotImplementedError: Must be implemented by derived class
:return: Model output
:rtype: LightningIROutput
"""
raise NotImplementedError

def prepare_input(
self,
queries: Sequence[str] | None,
docs: Sequence[str] | None,
num_docs: Sequence[int] | int | None,
self, queries: Sequence[str] | None, docs: Sequence[str] | None, num_docs: Sequence[int] | int | None
) -> Dict[str, BatchEncoding]:
"""Tokenizes queries and documents and returns the tokenized BatchEncoding_.
:: _BatchEncoding: https://huggingface.co/transformers/main_classes/tokenizer#transformers.BatchEncoding
:param queries: Queries to tokenize
:type queries: Sequence[str] | None
:param docs: Documents to tokenize
:type docs: Sequence[str] | None
:param num_docs: Number of documents per query, if None num_docs is inferred by `len(docs) // len(queries)`,
defaults to None
:type num_docs: Sequence[int] | int | None
:return: Tokenized queries and documents, format depends on the tokenizer
:rtype: Dict[str, BatchEncoding]
"""
encodings = self.tokenizer.tokenize(
queries,
docs,
return_tensors="pt",
padding=True,
truncation=True,
num_docs=num_docs,
queries, docs, return_tensors="pt", padding=True, truncation=True, num_docs=num_docs
)
for key in encodings:
encodings[key] = encodings[key].to(self.device)
return encodings

def compute_losses(self, batch: TrainBatch) -> List[torch.Tensor]:
"""Computes the losses for the batch.
:param batch: Batch of training data
:type batch: TrainBatch
:raises NotImplementedError: Must be implemented by derived class
:return: List of losses, one for each loss function
:rtype: List[torch.Tensor]
"""
raise NotImplementedError

def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor:
"""Handles the training step for the model.
:param batch: Batch of training data
:type batch: TrainBatch
:param batch_idx: Index of the batch
:type batch_idx: int
:raises ValueError: If no loss functions are set
:return: Sum of the losses weighted by the loss weights
:rtype: torch.Tensor
"""
if self.loss_functions is None:
raise ValueError("Loss function is not set")
raise ValueError("Loss functions are not set")
losses = self.compute_losses(batch)
total_loss = torch.tensor(0)
assert len(losses) == len(self.loss_functions)
Expand All @@ -89,11 +150,19 @@ def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor:
return total_loss

def validation_step(
self,
batch: TrainBatch | RankBatch,
batch_idx: int,
dataloader_idx: int = 0,
self, batch: TrainBatch | RankBatch, batch_idx: int, dataloader_idx: int = 0
) -> LightningIROutput:
"""Handles the validation step for the model.
:param batch: Batch of validation or testing data
:type batch: TrainBatch | RankBatch
:param batch_idx: Index of the batch
:type batch_idx: int
:param dataloader_idx: Index of the dataloader, defaults to 0
:type dataloader_idx: int, optional
:return: Model output
:rtype: LightningIROutput
"""
output = self.forward(batch)

if self.evaluation_metrics is None:
Expand All @@ -118,9 +187,29 @@ def test_step(
batch_idx: int,
dataloader_idx: int = 0,
) -> LightningIROutput:
"""Handles the testing step for the model. Passes the batch to the validation step.
:param batch: Batch of testing data
:type batch: TrainBatch | RankBatch
:param batch_idx: Index of the batch
:type batch_idx: int
:param dataloader_idx: Index of the dataloader, defaults to 0
:type dataloader_idx: int, optional
:return: Model output
:rtype: LightningIROutput
"""
return self.validation_step(batch, batch_idx, dataloader_idx)

def get_dataset_id(self, dataloader_idx: int) -> str:
"""Gets the dataset id from the dataloader index for logging.
.. _ir-datasets: https://ir-datasets.com/
:param dataloader_idx: Index of the dataloader
:type dataloader_idx: int
:return: ir-datasets_ dataset id or dataloader index
:rtype: str
"""
dataset_id = str(dataloader_idx)
datamodule = None
try:
Expand All @@ -139,7 +228,26 @@ def validate(
targets: torch.Tensor | None = None,
num_docs: Sequence[int] | None = None,
) -> Dict[str, float]:
metrics = {}
"""Validates the model output with the evaluation metrics and loss functions.
:param scores: Model output scores, defaults to None
:type scores: torch.Tensor | None, optional
:param query_ids: ids of the queries, defaults to None
:type query_ids: Sequence[str] | None, optional
:param doc_ids: ids of the documents, defaults to None
:type doc_ids: Sequence[Sequence[str]] | None, optional
:param qrels: Mappings of doc_id -> relevance for each query, defaults to None
:type qrels: Sequence[Dict[str, int]] | None, optional
:param targets: Target tensor used during fine-tuning, defaults to None
:type targets: torch.Tensor | None, optional
:param num_docs: Number of documents per query, defaults to None
:type num_docs: Sequence[int] | None, optional
:raises ValueError: If num_docs can not be parsed and query_ids are not set
:raises ValueError: If num_docs can not be parsed and doc_ids are not set
:return: _description_
:rtype: Dict[str, float]
"""
metrics: Dict[str, float] = {}
if self.evaluation_metrics is None or scores is None:
return metrics
if query_ids is None:
Expand All @@ -151,7 +259,7 @@ def validate(
raise ValueError("num_docs must be set if doc_ids is not set")
doc_ids = tuple(tuple(f"{i}-{j}" for j in range(docs)) for i, docs in enumerate(num_docs))
metrics.update(self.validate_metrics(scores, query_ids, doc_ids, qrels))
metrics.update(self.validate_loss(scores, query_ids, doc_ids, targets))
metrics.update(self.validate_loss(scores, query_ids, targets))
return metrics

def validate_metrics(
Expand All @@ -161,7 +269,20 @@ def validate_metrics(
doc_ids: Sequence[Sequence[str]],
qrels: Sequence[Dict[str, int]] | None,
) -> Dict[str, float]:
metrics = {}
"""Validates the model output with the evaluation metrics.
:param scores: Model output scores
:type scores: torch.Tensor
:param query_ids: ids of the queries
:type query_ids: Sequence[str]
:param doc_ids: ids of the documents
:type doc_ids: Sequence[Sequence[str]]
:param qrels: Mappings of doc_id -> relevance for each query, defaults to None
:type qrels: Sequence[Dict[str, int]] | None
:return: Evaluation metrics
:rtype: Dict[str, float]
"""
metrics: Dict[str, float] = {}
if self.evaluation_metrics is None or qrels is None:
return metrics
evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"]
Expand All @@ -172,13 +293,20 @@ def validate_metrics(
return metrics

def validate_loss(
self,
scores: torch.Tensor,
query_ids: Sequence[str],
doc_ids: Sequence[Sequence[str]],
targets: torch.Tensor | None,
self, scores: torch.Tensor, query_ids: Sequence[str], targets: torch.Tensor | None
) -> Dict[str, float]:
metrics = {}
"""Validates the model output with the loss functions.
:param scores: Model output scores
:type scores: torch.Tensor
:param query_ids: ids of the queries
:type query_ids: Sequence[str]
:param targets: Target tensor used during fine-tuning
:type targets: torch.Tensor | None
:return: Loss metrics
:rtype: Dict[str, float]
"""
metrics: Dict[str, float] = {}
if (
self.evaluation_metrics is None
or "loss" not in self.evaluation_metrics
Expand All @@ -197,6 +325,7 @@ def validate_loss(
return metrics

def on_validation_epoch_end(self) -> None:
"""Logs the accumulated metrics for each dataloader."""
try:
trainer = self.trainer
except RuntimeError:
Expand All @@ -212,13 +341,20 @@ def on_validation_epoch_end(self) -> None:
self.log(key, torch.stack(value).mean(), logger=False)

def on_test_epoch_end(self) -> None:
"""Logs the accumulated metrics for each dataloader."""
self.on_validation_epoch_end()

def save_pretrained(self, save_path: str | Path) -> None:
"""Saves the model and tokenizer to the save path.
:param save_path: Path to save the model and tokenizer
:type save_path: str | Path
"""
self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Saves the model and tokenizer to the trainer's log directory."""
if self.trainer is not None and self.trainer.log_dir is not None:
if self.trainer.global_rank != 0:
return
Expand Down
Loading

0 comments on commit f83ca07

Please sign in to comment.