From 617b5a7ffadbab0bedde3353ef392aecdf1a3d97 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Wed, 14 Aug 2024 12:48:49 +0900 Subject: [PATCH] Update the default path of downloaded pretrained weights from pytorchcv (#3837) * Disable to download pretrained weights * Revert "Disable to download pretrained weights" This reverts commit c35b13df80c1e062d3c775752cab8da0d389b6f7. * Fix to use `models_cache_root` --- src/otx/algo/common/backbones/pytorchcv_backbones.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/otx/algo/common/backbones/pytorchcv_backbones.py b/src/otx/algo/common/backbones/pytorchcv_backbones.py index 3ed30d8e73d..4b11751080a 100644 --- a/src/otx/algo/common/backbones/pytorchcv_backbones.py +++ b/src/otx/algo/common/backbones/pytorchcv_backbones.py @@ -124,12 +124,10 @@ def _build_pytorchcv_model( **kwargs, ) -> nn.Module: """Build pytorchcv model.""" - models_cache_root = kwargs.get("root", Path.home() / ".cache" / "torch" / "hub" / "checkpoints") - is_pretrained = kwargs.get("pretrained", False) - print( - f"Init model {type}, pretrained={is_pretrained}, models cache {models_cache_root}", - ) - model = _models[type](**kwargs) + models_cache_root = kwargs.pop("root", Path.home() / ".cache" / "torch" / "hub" / "checkpoints") + pretrained = kwargs.pop("pretrained", False) + print(f"Init model {type}, pretrained={pretrained}, models cache {models_cache_root}") + model = _models[type](root=models_cache_root, pretrained=pretrained, **kwargs) if activation_callable: model = replace_activation(model, activation_callable) if norm_cfg: