From d020c0602b5eabb79becbea5b0765cd428006834 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 6 Jan 2025 22:56:30 +0100 Subject: [PATCH 1/2] Revert "feat(config): add consistency checks for optimizer params" This reverts commit b62be68263c4bdaa2839bff76957e4f18818ff8c. --- tests/test_config.py | 30 ------------------------------ trainer/config.py | 12 +----------- 2 files changed, 1 insertion(+), 41 deletions(-) delete mode 100644 tests/test_config.py diff --git a/tests/test_config.py b/tests/test_config.py deleted file mode 100644 index 0b0e500..0000000 --- a/tests/test_config.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest - -from trainer.config import TrainerConfig - - -def test_optimizer_params(): - TrainerConfig( - optimizer="optimizer", - grad_clip=0.0, - lr=0.1, - optimizer_params={}, - lr_scheduler="scheduler", - ) - - TrainerConfig( - optimizer=["optimizer1", "optimizer2"], - grad_clip=[0.0, 0.0], - lr=[0.1, 0.01], - optimizer_params=[{}, {}], - lr_scheduler=["scheduler1", "scheduler2"], - ) - - with pytest.raises(TypeError, match="Either none or all of these fields must be a list:"): - TrainerConfig( - optimizer=["optimizer1", "optimizer2"], - grad_clip=0.0, - lr=[0.1, 0.01], - optimizer_params=[{}, {}], - lr_scheduler=["scheduler1", "scheduler2"], - ) diff --git a/trainer/config.py b/trainer/config.py index ddfeb6f..795c14c 100644 --- a/trainer/config.py +++ b/trainer/config.py @@ -1,4 +1,4 @@ -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field from typing import Any from coqpit import Coqpit @@ -227,13 +227,3 @@ class TrainerConfig(Coqpit): default=54321, metadata={"help": "Global seed for torch, random and numpy random number generator. Defaults to 54321"}, ) - - def check_values(self) -> None: - """Check config fields.""" - c = asdict(self) - optimizer_fields = ["optimizer", "grad_clip", "lr", "optimizer_params", "lr_scheduler"] - is_list = [isinstance(c[field], list) for field in optimizer_fields] - consistent = all(is_list) or not any(is_list) - if not consistent: - msg = f"Either none or all of these fields must be a list: {optimizer_fields}" - raise TypeError(msg) From c77a3ae9674da3d6c3d31569fd3fa7a7dfb294b1 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 6 Jan 2025 22:58:49 +0100 Subject: [PATCH 2/2] chore: bump version to 0.2.2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f80ed0a..de5ebd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ build-backend = "hatchling.build" [project] name = "coqui-tts-trainer" -version = "0.2.1" +version = "0.2.2" description = "General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui." readme = "README.md" requires-python = ">=3.10, <3.13"