From 29b57735399155154fa630a3187edc10a8eb4391 Mon Sep 17 00:00:00 2001 From: Burak <68427259+burakekim@users.noreply.github.com> Date: Mon, 3 Feb 2025 13:19:42 +0100 Subject: [PATCH] EuroCrops: handle Nones in get_label (#2499) * handle Nones in get_label * make ruff hapy * unit test for the win * remove print statement --- tests/datasets/test_eurocrops.py | 5 +++++ torchgeo/datasets/eurocrops.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/tests/datasets/test_eurocrops.py b/tests/datasets/test_eurocrops.py index 3b2d4fc63f7..63896e23090 100644 --- a/tests/datasets/test_eurocrops.py +++ b/tests/datasets/test_eurocrops.py @@ -82,6 +82,11 @@ def test_invalid_query(self, dataset: EuroCrops) -> None: ): dataset[query] + def test_get_label_with_none_hcat_code(self, dataset: EuroCrops) -> None: + mock_feature = {'properties': {dataset.label_name: None}} + label = dataset.get_label(mock_feature) + assert label == 0, "Expected label to be 0 when 'EC_hcat_c' is None." + def test_integrity_error(self, dataset: EuroCrops) -> None: dataset.zenodo_files = (('AA.zip', 'invalid'),) assert not dataset._check_integrity() diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index 5f438143c87..8832905c2d6 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -204,6 +204,9 @@ def get_label(self, feature: 'fiona.model.Feature') -> int: # We go up the class hierarchy until there is a match. # (Parent code is computed by replacing rightmost non-0 character with 0.) hcat_code = feature['properties'][self.label_name] + if hcat_code is None: + return 0 + while True: if hcat_code in self.class_map: return self.class_map[hcat_code]