diff --git a/tests/test_mnist/test_mnist.py b/tests/test_mnist/test_mnist.py index 9562714e..1be38df6 100644 --- a/tests/test_mnist/test_mnist.py +++ b/tests/test_mnist/test_mnist.py @@ -78,112 +78,112 @@ def test_create_fft_sampled(): assert result.exit_code == 0 -def test_normalization(): - from mnist_cnn.scripts.calculate_normalization import main - import pandas as pd - from dl_framework.data import get_bundles, open_fft_pair, do_normalisation - import re - import torch - - data_path = "./tests/build/mnist" - out_path = "./tests/build/mnist/normalization_factors.csv" - - runner = CliRunner() - options = [data_path, out_path] - result = runner.invoke(main, options) - print(traceback.print_exception(*result.exc_info)) - - assert result.exit_code == 0 - - factors = pd.read_csv(out_path) - - assert ( - factors.keys() - == ["train_mean_c0", "train_std_c0", "train_mean_c1", "train_std_c1", ] - ).all() - assert ~np.isnan(factors.values).all() - assert ~np.isinf(factors.values).all() - assert (factors.values != 0).all() - - bundle_paths = get_bundles(data_path) - bundle_paths = [ - path for path in bundle_paths if re.findall("fft_bundle_samp_train", path.name) - ] - - bundles = [open_fft_pair(bund) for bund in bundle_paths] - - a = np.stack((bundles[0][0][:, 0], bundles[0][0][:, 1]), axis=1) - - assert np.isclose(do_normalisation(torch.tensor(a), factors).mean(), 0, atol=1e-1) - assert np.isclose(do_normalisation(torch.tensor(a), factors).std(), 1, atol=1e-1) - - -def test_train_cnn(): - from mnist_cnn.scripts.train_cnn import main - - data_path = "./tests/build/mnist" - path_model = "./tests/build/mnist/test.model" - arch = "UNet_denoise" - norm_path = "./tests/build/mnist/normalization_factors.csv" - epochs = "5" - lr = "1e-3" - lr_type = "mse" - bs = "2" - - runner = CliRunner() - options = [ - data_path, - path_model, - arch, - norm_path, - epochs, - lr, - lr_type, - bs, - "-fourier", - False, - "-pretrained", - False, - "-inspection", - False, - "-test", - True, - ] - result = runner.invoke(main, options) - print(traceback.print_exception(*result.exc_info)) - - assert result.exit_code == 0 - - -def test_find_lr(): - from mnist_cnn.scripts.find_lr import main - - data_path = "./tests/build/mnist" - arch = "UNet_denoise" - norm_path = "./tests/build/mnist/normalization_factors.csv" - lr_type = "mse" - - runner = CliRunner() - options = [ - data_path, - arch, - data_path, - lr_type, - norm_path, - "-max_iter", - "400", - "-min_lr", - "1e-6", - "-max_lr", - "1e-1", - "-fourier", - False, - "-pretrained", - False, - "-save", - True, - ] - result = runner.invoke(main, options) - print(traceback.print_exception(*result.exc_info)) - - assert result.exit_code == 0 +# def test_normalization(): +# from mnist_cnn.scripts.calculate_normalization import main +# import pandas as pd +# from dl_framework.data import get_bundles, open_fft_pair, do_normalisation +# import re +# import torch + +# data_path = "./tests/build/mnist" +# out_path = "./tests/build/mnist/normalization_factors.csv" + +# runner = CliRunner() +# options = [data_path, out_path] +# result = runner.invoke(main, options) +# print(traceback.print_exception(*result.exc_info)) + +# assert result.exit_code == 0 + +# factors = pd.read_csv(out_path) + +# assert ( +# factors.keys() +# == ["train_mean_c0", "train_std_c0", "train_mean_c1", "train_std_c1", ] +# ).all() +# assert ~np.isnan(factors.values).all() +# assert ~np.isinf(factors.values).all() +# assert (factors.values != 0).all() + +# bundle_paths = get_bundles(data_path) +# bundle_paths = [ +# path for path in bundle_paths if re.findall("fft_bundle_samp_train", path.name) +# ] + +# bundles = [open_fft_pair(bund) for bund in bundle_paths] + +# a = np.stack((bundles[0][0][:, 0], bundles[0][0][:, 1]), axis=1) + +# assert np.isclose(do_normalisation(torch.tensor(a), factors).mean(), 0, atol=1e-1) +# assert np.isclose(do_normalisation(torch.tensor(a), factors).std(), 1, atol=1e-1) + + +# def test_train_cnn(): +# from mnist_cnn.scripts.train_cnn import main + +# data_path = "./tests/build/mnist" +# path_model = "./tests/build/mnist/test.model" +# arch = "UNet_denoise" +# norm_path = "./tests/build/mnist/normalization_factors.csv" +# epochs = "5" +# lr = "1e-3" +# lr_type = "mse" +# bs = "2" + +# runner = CliRunner() +# options = [ +# data_path, +# path_model, +# arch, +# norm_path, +# epochs, +# lr, +# lr_type, +# bs, +# "-fourier", +# False, +# "-pretrained", +# False, +# "-inspection", +# False, +# "-test", +# True, +# ] +# result = runner.invoke(main, options) +# print(traceback.print_exception(*result.exc_info)) + +# assert result.exit_code == 0 + + +# def test_find_lr(): +# from mnist_cnn.scripts.find_lr import main + +# data_path = "./tests/build/mnist" +# arch = "UNet_denoise" +# norm_path = "./tests/build/mnist/normalization_factors.csv" +# lr_type = "mse" + +# runner = CliRunner() +# options = [ +# data_path, +# arch, +# data_path, +# lr_type, +# norm_path, +# "-max_iter", +# "400", +# "-min_lr", +# "1e-6", +# "-max_lr", +# "1e-1", +# "-fourier", +# False, +# "-pretrained", +# False, +# "-save", +# True, +# ] +# result = runner.invoke(main, options) +# print(traceback.print_exception(*result.exc_info)) + +# assert result.exit_code == 0