Skip to content

Commit

Permalink
Update the default path of downloaded pretrained weights from pytorch…
Browse files Browse the repository at this point in the history
…cv (#3837)

* Disable to download pretrained weights

* Revert "Disable to download pretrained weights"

This reverts commit c35b13d.

* Fix to use `models_cache_root`
  • Loading branch information
sungchul2 authored Aug 14, 2024
1 parent 0b5ed3b commit 617b5a7
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/otx/algo/common/backbones/pytorchcv_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,10 @@ def _build_pytorchcv_model(
**kwargs,
) -> nn.Module:
"""Build pytorchcv model."""
models_cache_root = kwargs.get("root", Path.home() / ".cache" / "torch" / "hub" / "checkpoints")
is_pretrained = kwargs.get("pretrained", False)
print(
f"Init model {type}, pretrained={is_pretrained}, models cache {models_cache_root}",
)
model = _models[type](**kwargs)
models_cache_root = kwargs.pop("root", Path.home() / ".cache" / "torch" / "hub" / "checkpoints")
pretrained = kwargs.pop("pretrained", False)
print(f"Init model {type}, pretrained={pretrained}, models cache {models_cache_root}")
model = _models[type](root=models_cache_root, pretrained=pretrained, **kwargs)
if activation_callable:
model = replace_activation(model, activation_callable)
if norm_cfg:
Expand Down

0 comments on commit 617b5a7

Please sign in to comment.