diff --git a/src/otx/algo/common/backbones/pytorchcv_backbones.py b/src/otx/algo/common/backbones/pytorchcv_backbones.py index 3ed30d8e73d..4b11751080a 100644 --- a/src/otx/algo/common/backbones/pytorchcv_backbones.py +++ b/src/otx/algo/common/backbones/pytorchcv_backbones.py @@ -124,12 +124,10 @@ def _build_pytorchcv_model( **kwargs, ) -> nn.Module: """Build pytorchcv model.""" - models_cache_root = kwargs.get("root", Path.home() / ".cache" / "torch" / "hub" / "checkpoints") - is_pretrained = kwargs.get("pretrained", False) - print( - f"Init model {type}, pretrained={is_pretrained}, models cache {models_cache_root}", - ) - model = _models[type](**kwargs) + models_cache_root = kwargs.pop("root", Path.home() / ".cache" / "torch" / "hub" / "checkpoints") + pretrained = kwargs.pop("pretrained", False) + print(f"Init model {type}, pretrained={pretrained}, models cache {models_cache_root}") + model = _models[type](root=models_cache_root, pretrained=pretrained, **kwargs) if activation_callable: model = replace_activation(model, activation_callable) if norm_cfg: