diff --git a/lightning_ir/base/module.py b/lightning_ir/base/module.py index b19fbef..08e63c9 100644 --- a/lightning_ir/base/module.py +++ b/lightning_ir/base/module.py @@ -2,7 +2,7 @@ from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple import torch from lightning import LightningModule @@ -22,7 +22,7 @@ def __init__( model_name_or_path: str | None = None, config: LightningIRConfig | None = None, model: LightningIRModel | None = None, - loss_functions: Sequence[LossFunction] | Mapping[LossFunction, float] | None = None, + loss_functions: Sequence[LossFunction] | Sequence[Tuple[LossFunction, float]] | None = None, evaluation_metrics: Sequence[str] | None = None, ): super().__init__() @@ -43,9 +43,14 @@ def __init__( self.model: LightningIRModel = model self.config = self.model.config - if loss_functions is not None and not isinstance(loss_functions, dict): - loss_functions = {loss_function: 1.0 for loss_function in loss_functions} - self.loss_functions = loss_functions + self.loss_functions: Dict[LossFunction, float] | None = None + if loss_functions is not None: + self.loss_functions = {} + for loss_function in loss_functions: + if isinstance(loss_function, LossFunction): + self.loss_functions[loss_function] = 1.0 + else: + self.loss_functions[loss_function[0]] = loss_function[1] self.evaluation_metrics = evaluation_metrics self.tokenizer = self.config.__class__.tokenizer_class.from_pretrained( self.config.name_or_path, **self.config.to_tokenizer_dict()