forked from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_eval.py
39 lines (28 loc) · 1.28 KB
/
test_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
from pathlib import Path
import pytest
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, open_dict
from src.eval import evaluate
from src.train import train
@pytest.mark.slow
def test_train_eval(tmp_path: Path, cfg_train: DictConfig, cfg_eval: DictConfig) -> None:
"""Tests training and evaluation by training for 1 epoch with `train.py` then evaluating with
`eval.py`.
:param tmp_path: The temporary logging path.
:param cfg_train: A DictConfig containing a valid training configuration.
:param cfg_eval: A DictConfig containing a valid evaluation configuration.
"""
assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir
with open_dict(cfg_train):
cfg_train.trainer.max_epochs = 1
cfg_train.test = True
HydraConfig().set_config(cfg_train)
train_metric_dict, _ = train(cfg_train)
assert "last.ckpt" in os.listdir(tmp_path / "checkpoints")
with open_dict(cfg_eval):
cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
HydraConfig().set_config(cfg_eval)
test_metric_dict, _ = evaluate(cfg_eval)
assert test_metric_dict["test/acc"] > 0.0
assert abs(train_metric_dict["test/acc"].item() - test_metric_dict["test/acc"].item()) < 0.001