diff --git a/python/text_utils/api/trainer.py b/python/text_utils/api/trainer.py index 3b6def9..044d4fa 100644 --- a/python/text_utils/api/trainer.py +++ b/python/text_utils/api/trainer.py @@ -1044,8 +1044,8 @@ def _train_one_epoch(self): 1, dtype=torch.long, device=self.info.device) metrics = [] - for name, cfg in self.cfg["train"].get("metrics", {}).items(): - metric = self._metric_from_config(name, cfg, "train") + for metric_cfg in self.cfg["train"].get("metrics", []): + metric = self._metric_from_config(metric_cfg, "train") metrics.append(metric) start_items = self.epoch_items @@ -1333,10 +1333,9 @@ def _evaluate_and_checkpoint(self): self.loss_fn = self.loss_fn.eval() metrics = [] - for name, cfg in self.cfg["val"].get("metrics", {}).items(): + for metric_cfg in self.cfg["val"].get("metrics", []): metric = self._metric_from_config( - name, - cfg, + metric_cfg, "val" ) metrics.append(metric)