diff --git a/terratorch/models/backbones/utils.py b/terratorch/models/backbones/utils.py index 8945fdfd..06e4f462 100644 --- a/terratorch/models/backbones/utils.py +++ b/terratorch/models/backbones/utils.py @@ -14,7 +14,7 @@ def _estimate_in_chans(model_bands: list[HLSBands] | list[str] | tuple[int, int] # Conditional to deal with the different possible choices for the bands # Bands as lists of strings or enum - if all([isinstance(b, str) for b in model_bands]): + if all([isinstance(b, str) or isinstance(b, HLSBands) for b in model_bands]): in_chans = len(model_bands) # Bands as intervals limited by integers elif all([isinstance(b, int) for b in model_bands] or _are_sublists_of_int(model_bands)):