From 9a47ee861717b90b592971a59177568a72cf1165 Mon Sep 17 00:00:00 2001 From: Jaeguk Hyun Date: Thu, 28 Mar 2024 11:34:15 +0900 Subject: [PATCH] Set model label info first (#3213) * Set model label info first * Add notes --- src/otx/engine/engine.py | 67 ++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 44d757630b5..75899b5cdd1 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -264,13 +264,9 @@ def train( **kwargs, ) fit_kwargs: dict[str, Any] = {} - if resume: - fit_kwargs["ckpt_path"] = self.checkpoint - elif self.checkpoint is not None: - loaded_checkpoint = torch.load(self.checkpoint) - # loaded checkpoint have keys (OTX1.5): model, config, labels, input_size, VERSION - self.model.load_state_dict(loaded_checkpoint) + # NOTE Model's label info should be converted datamodule's label info before ckpt loading + # This is due to smart weight loading check label name as well as number of classes. if self.model.label_info != self.datamodule.label_info: # TODO (vinnamki): Revisit label_info logic to make it cleaner msg = ( @@ -280,6 +276,13 @@ def train( logging.warning(msg) self.model.label_info = self.datamodule.label_info + if resume: + fit_kwargs["ckpt_path"] = self.checkpoint + elif self.checkpoint is not None: + loaded_checkpoint = torch.load(self.checkpoint) + # loaded checkpoint have keys (OTX1.5): model, config, labels, input_size, VERSION + self.model.load_state_dict(loaded_checkpoint) + with override_metric_callable(model=self.model, new_metric_callable=metric) as model: self.trainer.fit( model=model, @@ -335,6 +338,20 @@ def test( otx test --config --checkpoint ``` """ + # NOTE Model's label info should be converted datamodule's label info before ckpt loading + # This is due to smart weight loading check label name as well as number of classes. + if self.model.label_info != self.datamodule.label_info: + # TODO (vinnamki): Revisit label_info logic to make it cleaner + msg = ( + "Model label_info is not equal to the Datamodule label_info. " + f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}" + ) + logging.warning(msg) + self.model.label_info = self.datamodule.label_info + + # TODO (vinnamki): This should be changed to raise an error if not equivalent in case of test + # raise ValueError() + model = self.model checkpoint = checkpoint if checkpoint is not None else self.checkpoint datamodule = datamodule if datamodule is not None else self.datamodule @@ -352,18 +369,6 @@ def test( self._build_trainer(**kwargs) - if self.model.label_info != self.datamodule.label_info: - # TODO (vinnamki): Revisit label_info logic to make it cleaner - msg = ( - "Model label_info is not equal to the Datamodule label_info. " - f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}" - ) - logging.warning(msg) - self.model.label_info = self.datamodule.label_info - - # TODO (vinnamki): This should be changed to raise an error if not equivalent in case of test - # raise ValueError() - with override_metric_callable(model=model, new_metric_callable=metric) as model: self.trainer.test( model=model, @@ -416,6 +421,20 @@ def predict( """ from otx.algo.utils.xai_utils import process_saliency_maps_in_pred_entity + # NOTE Model's label info should be converted datamodule's label info before ckpt loading + # This is due to smart weight loading check label name as well as number of classes. + if self.model.label_info != self.datamodule.label_info: + # TODO (vinnamki): Revisit label_info logic to make it cleaner + msg = ( + "Model label_info is not equal to the Datamodule label_info. " + f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}" + ) + logging.warning(msg) + self.model.label_info = self.datamodule.label_info + + # TODO (vinnamki): This should be changed to raise an error if not equivalent in case of test + # raise ValueError() + model = self.model checkpoint = checkpoint if checkpoint is not None else self.checkpoint @@ -434,18 +453,6 @@ def predict( self._build_trainer(**kwargs) - if self.model.label_info != self.datamodule.label_info: - # TODO (vinnamki): Revisit label_info logic to make it cleaner - msg = ( - "Model label_info is not equal to the Datamodule label_info. " - f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}" - ) - logging.warning(msg) - self.model.label_info = self.datamodule.label_info - - # TODO (vinnamki): This should be changed to raise an error if not equivalent in case of test - # raise ValueError() - predict_result = self.trainer.predict( model=model, dataloaders=datamodule,