From 5cf587631140d9075b7a53a89d736587ce9fd645 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 9 Aug 2024 16:08:00 -0300 Subject: [PATCH] Allowing the configuration to overwrite default arguments from the model constructors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/prithvi_vit.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index a722d674..3ccaba5e 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -4,6 +4,7 @@ import logging from functools import partial from pathlib import Path +from collections import defaultdict import torch from timm.models import FeatureInfo @@ -140,13 +141,22 @@ def create_prithvi_vit_100( "norm_layer": partial(nn.LayerNorm, eps=1e-6), "num_frames": 1, } + + # It is possible to overwrite default parameters using + # config file + kwargs_ = defaultdict() + kwargs_.update(model_args) + kwargs_.update(kwargs) + kwargs_ = dict(kwargs_) + model = _create_prithvi( model_name, pretrained=pretrained, model_bands=bands, pretrained_bands=pretrained_bands, - **dict(model_args, **kwargs), + **kwargs_, ) + return model