Skip to content

Commit

Permalink
Merge pull request #107 from IBM/overwrite_default_prithvi
Browse files Browse the repository at this point in the history
Overwrite default prithvi
  • Loading branch information
Joao-L-S-Almeida authored Aug 20, 2024
2 parents ec8125e + 5cf5876 commit cf963a8
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from functools import partial
from pathlib import Path
from collections import defaultdict

import torch
from timm.models import FeatureInfo
Expand Down Expand Up @@ -140,13 +141,22 @@ def create_prithvi_vit_100(
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
"num_frames": 1,
}

# It is possible to overwrite default parameters using
# config file
kwargs_ = defaultdict()
kwargs_.update(model_args)
kwargs_.update(kwargs)
kwargs_ = dict(kwargs_)

model = _create_prithvi(
model_name,
pretrained=pretrained,
model_bands=bands,
pretrained_bands=pretrained_bands,
**dict(model_args, **kwargs),
**kwargs_,
)

return model


Expand Down

0 comments on commit cf963a8

Please sign in to comment.