From 908c11b1a21cd33e9905811711ca64f4f828b14b Mon Sep 17 00:00:00 2001 From: EmilEOGG Date: Wed, 6 Apr 2022 15:39:44 +0200 Subject: [PATCH] Added model name to args --- training.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/training.py b/training.py index b4a2e14..6c3cb17 100644 --- a/training.py +++ b/training.py @@ -38,6 +38,7 @@ def add_argparse_args(parent_parser): """Argument parser for model.""" parser = parent_parser.add_argument_group("Classifier") parser.add_argument("--learning_rate", type=float, default=0.0005) + parser.add_argument("--model_name",type=str,default="mobilenetv3_rw") return parent_parser def forward(self, x): @@ -156,7 +157,7 @@ def test_dataloader(self): if __name__ == "__main__": - model_name = "mobilenetv3_rw" + parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = Classifier.add_argparse_args(parser) @@ -164,12 +165,12 @@ def test_dataloader(self): checkpoint_callback = pl.callbacks.ModelCheckpoint( dirpath="lightning_logs", - filename=f"MNIST_classifier_{model_name}" + "{epoch}-{val_loss:.2f}", + filename=f"MNIST_classifier_{args.model_name}" + "{epoch}-{val_loss:.2f}", monitor="val_acc", mode="max", ) dm = MNISTDataModule() - model = Classifier(model_name=model_name) + model = Classifier(learning_rate=args.learning_rate,model_name=args.model_name) trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback]) trainer.fit(model, dm) trainer.test(model, dm)