diff --git a/test/test_discriminador.py b/test/test_discriminador.py index e890a1a..570f1f0 100644 --- a/test/test_discriminador.py +++ b/test/test_discriminador.py @@ -1,5 +1,6 @@ import unittest -import numpy as np +from unittest.mock import patch +from io import StringIO from keras.datasets import cifar100 from src.gan.discriminador import Discriminator from src.prueba.entrenamiento import Training @@ -10,6 +11,7 @@ def setUp(self): self.input_shape = (32, 32, 3) self.discriminador = Discriminator(self.input_shape) self.dataset = Training.load_images(cifar100) + self.stdout = StringIO() def test_model_structure(self): model = self.discriminador.model @@ -31,6 +33,23 @@ def test_fake_training(self): loss, accuracy = self.discriminador.evaluate(dataset, labels) self.assertTrue(accuracy > initial_accuracy) + def test_build_model(self): + model = self.discriminador.build_model() + self.assertIsNotNone(model) + + def test_summary(self): + with patch("sys.stdout", self.stdout): + self.discriminador.summary() + printed_output = self.stdout.getvalue().strip() + + self.assertNotEqual(printed_output, "") + + def test_evaluate(self): + dataset, labels = Training.load_real_data(self.dataset, 100) + result = self.discriminador.evaluate(dataset, labels) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) -if __name__ == "__main__": - unittest.main() + def test_trainable(self): + trainable = self.discriminador.trainable + self.assertIsNotNone(trainable) diff --git a/test/test_gan.py b/test/test_gan.py index cf74f12..7254763 100644 --- a/test/test_gan.py +++ b/test/test_gan.py @@ -13,7 +13,3 @@ def setUp(self): def test_gan_structure(self): model = self.gan self.assertIsNotNone(model) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_generador.py b/test/test_generador.py index 4b82d82..eb1d332 100644 --- a/test/test_generador.py +++ b/test/test_generador.py @@ -1,13 +1,15 @@ import unittest +from unittest.mock import patch import numpy as np from src.gan.generador import Generator - +from io import StringIO class TestGenerador(unittest.TestCase): def setUp(self): self.latent_dim = 100 self.output_shape = (32, 32, 3) self.gen = Generator(self.latent_dim, self.output_shape) + self.stdout = StringIO() def test_model_structure(self): model = self.gen.model @@ -31,6 +33,9 @@ def test_generalization(self): # Verifica la calidad de las imágenes generadas y su similitud con las imágenes reales return None + def test_summary(self): + with patch('sys.stdout', self.stdout): + self.gen.summary() + printed_output = self.stdout.getvalue().strip() -if __name__ == "__main__": - unittest.main() + self.assertNotEqual(printed_output, "") diff --git a/test/test_servidor.py b/test/test_servidor.py index 56775e9..b71bae8 100644 --- a/test/test_servidor.py +++ b/test/test_servidor.py @@ -38,7 +38,3 @@ def test_client_handler(self, mock_handler): mock_handler.assert_called_once_with(mock_client_socket) mock_handler.return_value.handle.assert_called_once() - - -if __name__ == "__main__": - unittest.main()