Skip to content

Commit

Permalink
Set model label info first (#3213)
Browse files Browse the repository at this point in the history
* Set model label info first

* Add notes
  • Loading branch information
jaegukhyun authored Mar 28, 2024
1 parent a6c4280 commit 9a47ee8
Showing 1 changed file with 37 additions and 30 deletions.
67 changes: 37 additions & 30 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,9 @@ def train(
**kwargs,
)
fit_kwargs: dict[str, Any] = {}
if resume:
fit_kwargs["ckpt_path"] = self.checkpoint
elif self.checkpoint is not None:
loaded_checkpoint = torch.load(self.checkpoint)
# loaded checkpoint have keys (OTX1.5): model, config, labels, input_size, VERSION
self.model.load_state_dict(loaded_checkpoint)

# NOTE Model's label info should be converted datamodule's label info before ckpt loading
# This is due to smart weight loading check label name as well as number of classes.
if self.model.label_info != self.datamodule.label_info:
# TODO (vinnamki): Revisit label_info logic to make it cleaner
msg = (
Expand All @@ -280,6 +276,13 @@ def train(
logging.warning(msg)
self.model.label_info = self.datamodule.label_info

if resume:
fit_kwargs["ckpt_path"] = self.checkpoint
elif self.checkpoint is not None:
loaded_checkpoint = torch.load(self.checkpoint)
# loaded checkpoint have keys (OTX1.5): model, config, labels, input_size, VERSION
self.model.load_state_dict(loaded_checkpoint)

with override_metric_callable(model=self.model, new_metric_callable=metric) as model:
self.trainer.fit(
model=model,
Expand Down Expand Up @@ -335,6 +338,20 @@ def test(
otx test --config <CONFIG_PATH, str> --checkpoint <CKPT_PATH, str>
```
"""
# NOTE Model's label info should be converted datamodule's label info before ckpt loading
# This is due to smart weight loading check label name as well as number of classes.
if self.model.label_info != self.datamodule.label_info:
# TODO (vinnamki): Revisit label_info logic to make it cleaner
msg = (
"Model label_info is not equal to the Datamodule label_info. "
f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}"
)
logging.warning(msg)
self.model.label_info = self.datamodule.label_info

# TODO (vinnamki): This should be changed to raise an error if not equivalent in case of test
# raise ValueError()

model = self.model
checkpoint = checkpoint if checkpoint is not None else self.checkpoint
datamodule = datamodule if datamodule is not None else self.datamodule
Expand All @@ -352,18 +369,6 @@ def test(

self._build_trainer(**kwargs)

if self.model.label_info != self.datamodule.label_info:
# TODO (vinnamki): Revisit label_info logic to make it cleaner
msg = (
"Model label_info is not equal to the Datamodule label_info. "
f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}"
)
logging.warning(msg)
self.model.label_info = self.datamodule.label_info

# TODO (vinnamki): This should be changed to raise an error if not equivalent in case of test
# raise ValueError()

with override_metric_callable(model=model, new_metric_callable=metric) as model:
self.trainer.test(
model=model,
Expand Down Expand Up @@ -416,6 +421,20 @@ def predict(
"""
from otx.algo.utils.xai_utils import process_saliency_maps_in_pred_entity

# NOTE Model's label info should be converted datamodule's label info before ckpt loading
# This is due to smart weight loading check label name as well as number of classes.
if self.model.label_info != self.datamodule.label_info:
# TODO (vinnamki): Revisit label_info logic to make it cleaner
msg = (
"Model label_info is not equal to the Datamodule label_info. "
f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}"
)
logging.warning(msg)
self.model.label_info = self.datamodule.label_info

# TODO (vinnamki): This should be changed to raise an error if not equivalent in case of test
# raise ValueError()

model = self.model

checkpoint = checkpoint if checkpoint is not None else self.checkpoint
Expand All @@ -434,18 +453,6 @@ def predict(

self._build_trainer(**kwargs)

if self.model.label_info != self.datamodule.label_info:
# TODO (vinnamki): Revisit label_info logic to make it cleaner
msg = (
"Model label_info is not equal to the Datamodule label_info. "
f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}"
)
logging.warning(msg)
self.model.label_info = self.datamodule.label_info

# TODO (vinnamki): This should be changed to raise an error if not equivalent in case of test
# raise ValueError()

predict_result = self.trainer.predict(
model=model,
dataloaders=datamodule,
Expand Down

0 comments on commit 9a47ee8

Please sign in to comment.