diff --git a/pyproject.toml b/pyproject.toml index f80ed0a..de5ebd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ build-backend = "hatchling.build" [project] name = "coqui-tts-trainer" -version = "0.2.1" +version = "0.2.2" description = "General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui." readme = "README.md" requires-python = ">=3.10, <3.13" diff --git a/tests/test_config.py b/tests/test_config.py deleted file mode 100644 index 0b0e500..0000000 --- a/tests/test_config.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest - -from trainer.config import TrainerConfig - - -def test_optimizer_params(): - TrainerConfig( - optimizer="optimizer", - grad_clip=0.0, - lr=0.1, - optimizer_params={}, - lr_scheduler="scheduler", - ) - - TrainerConfig( - optimizer=["optimizer1", "optimizer2"], - grad_clip=[0.0, 0.0], - lr=[0.1, 0.01], - optimizer_params=[{}, {}], - lr_scheduler=["scheduler1", "scheduler2"], - ) - - with pytest.raises(TypeError, match="Either none or all of these fields must be a list:"): - TrainerConfig( - optimizer=["optimizer1", "optimizer2"], - grad_clip=0.0, - lr=[0.1, 0.01], - optimizer_params=[{}, {}], - lr_scheduler=["scheduler1", "scheduler2"], - ) diff --git a/trainer/config.py b/trainer/config.py index ddfeb6f..795c14c 100644 --- a/trainer/config.py +++ b/trainer/config.py @@ -1,4 +1,4 @@ -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field from typing import Any from coqpit import Coqpit @@ -227,13 +227,3 @@ class TrainerConfig(Coqpit): default=54321, metadata={"help": "Global seed for torch, random and numpy random number generator. Defaults to 54321"}, ) - - def check_values(self) -> None: - """Check config fields.""" - c = asdict(self) - optimizer_fields = ["optimizer", "grad_clip", "lr", "optimizer_params", "lr_scheduler"] - is_list = [isinstance(c[field], list) for field in optimizer_fields] - consistent = all(is_list) or not any(is_list) - if not consistent: - msg = f"Either none or all of these fields must be a list: {optimizer_fields}" - raise TypeError(msg)