Skip to content

Commit

Permalink
Update log method to include start_time parameter (#2381)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Nov 21, 2024
1 parent bdeb117 commit 672c965
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 10 deletions.
12 changes: 10 additions & 2 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import transformers
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from datasets import Dataset
from packaging import version
from torch.utils.data import DataLoader, SequentialSampler
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -1437,13 +1439,15 @@ def evaluation_loop(

return initial_output

def log(self, logs: Dict[str, float]) -> None:
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`float` or `None`, *optional*, defaults to `None`):
Start time of the training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
Expand All @@ -1468,7 +1472,11 @@ def log(self, logs: Dict[str, float]) -> None:
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)

if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super().log(logs, start_time)
else: # transformers<=4.46
return super().log(logs)

def create_model_card(
self,
Expand Down
12 changes: 10 additions & 2 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import transformers
from accelerate import PartialState
from datasets import Dataset
from packaging import version
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -963,21 +965,27 @@ def evaluation_loop(

return initial_output

def log(self, logs: Dict[str, float]) -> None:
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`float` or `None`, *optional*, defaults to `None`):
Start time of the training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)

if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super().log(logs, start_time)
else: # transformers<=4.46
return super().log(logs)

def _shift_right(self, input_ids):
if self.decoder_start_token_id is None:
Expand Down
12 changes: 10 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import transformers
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from datasets import Dataset
from packaging import version
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -1390,21 +1392,27 @@ def evaluation_loop(

return initial_output

def log(self, logs: Dict[str, float]) -> None:
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`float` or `None`, *optional*, defaults to `None`):
Start time of the training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)

if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super().log(logs, start_time)
else: # transformers<=4.46
return super().log(logs)

def create_model_card(
self,
Expand Down
12 changes: 10 additions & 2 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import transformers
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from datasets import Dataset, concatenate_datasets
from packaging import version
from torch.utils.data import DataLoader, SequentialSampler
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -1442,13 +1444,15 @@ def evaluation_loop(

return initial_output

def log(self, logs: Dict[str, float]) -> None:
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`float` or `None`, *optional*, defaults to `None`):
Start time of the training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
Expand All @@ -1473,7 +1477,11 @@ def log(self, logs: Dict[str, float]) -> None:
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)

if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super().log(logs, start_time)
else: # transformers<=4.46
return super().log(logs)

def create_model_card(
self,
Expand Down
12 changes: 10 additions & 2 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
import transformers
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available
from datasets import Dataset
from packaging import version
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -981,21 +983,27 @@ def evaluation_loop(

return initial_output

def log(self, logs: Dict[str, float]) -> None:
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`float` or `None`, *optional*, defaults to `None`):
Start time of the training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)

if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super().log(logs, start_time)
else: # transformers<=4.46
return super().log(logs)

def _shift_right(self, input_ids):
if self.decoder_start_token_id is None:
Expand Down

0 comments on commit 672c965

Please sign in to comment.