Skip to content

Commit

Permalink
Added model name to args
Browse files Browse the repository at this point in the history
  • Loading branch information
L1m3D4rkn3ss committed Apr 6, 2022
1 parent d521bc8 commit 908c11b
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def add_argparse_args(parent_parser):
"""Argument parser for model."""
parser = parent_parser.add_argument_group("Classifier")
parser.add_argument("--learning_rate", type=float, default=0.0005)
parser.add_argument("--model_name",type=str,default="mobilenetv3_rw")
return parent_parser

def forward(self, x):
Expand Down Expand Up @@ -156,20 +157,20 @@ def test_dataloader(self):


if __name__ == "__main__":
model_name = "mobilenetv3_rw"

parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = Classifier.add_argparse_args(parser)
args = parser.parse_args()

checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath="lightning_logs",
filename=f"MNIST_classifier_{model_name}" + "{epoch}-{val_loss:.2f}",
filename=f"MNIST_classifier_{args.model_name}" + "{epoch}-{val_loss:.2f}",
monitor="val_acc",
mode="max",
)
dm = MNISTDataModule()
model = Classifier(model_name=model_name)
model = Classifier(learning_rate=args.learning_rate,model_name=args.model_name)
trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback])
trainer.fit(model, dm)
trainer.test(model, dm)

0 comments on commit 908c11b

Please sign in to comment.