Skip to content

Commit

Permalink
use tuples for weighting loss functions to support yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Jul 30, 2024
1 parent 6d3b06a commit e39ec5f
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions lightning_ir/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple

import torch
from lightning import LightningModule
Expand All @@ -22,7 +22,7 @@ def __init__(
model_name_or_path: str | None = None,
config: LightningIRConfig | None = None,
model: LightningIRModel | None = None,
loss_functions: Sequence[LossFunction] | Mapping[LossFunction, float] | None = None,
loss_functions: Sequence[LossFunction] | Sequence[Tuple[LossFunction, float]] | None = None,
evaluation_metrics: Sequence[str] | None = None,
):
super().__init__()
Expand All @@ -43,9 +43,14 @@ def __init__(

self.model: LightningIRModel = model
self.config = self.model.config
if loss_functions is not None and not isinstance(loss_functions, dict):
loss_functions = {loss_function: 1.0 for loss_function in loss_functions}
self.loss_functions = loss_functions
self.loss_functions: Dict[LossFunction, float] | None = None
if loss_functions is not None:
self.loss_functions = {}
for loss_function in loss_functions:
if isinstance(loss_function, LossFunction):
self.loss_functions[loss_function] = 1.0
else:
self.loss_functions[loss_function[0]] = loss_function[1]
self.evaluation_metrics = evaluation_metrics
self.tokenizer = self.config.__class__.tokenizer_class.from_pretrained(
self.config.name_or_path, **self.config.to_tokenizer_dict()
Expand Down

0 comments on commit e39ec5f

Please sign in to comment.