Skip to content

Commit

Permalink
add label after transform
Browse files Browse the repository at this point in the history
Signed-off-by: Carlos Gomes <[email protected]>
  • Loading branch information
CarlosGomes98 committed Aug 13, 2024
1 parent 88b4231 commit 680963f
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 6 deletions.
4 changes: 2 additions & 2 deletions terratorch/datasets/m_bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
labels_tensor = torch.tensor(labels_vector, dtype=torch.float)

output = {
"image": image,
"label": labels_tensor
"image": image
}

output = self.transform(**output)

output["label"] = labels_tensor
return output

def _validate_bands(self, bands: Sequence[str]) -> None:
Expand Down
4 changes: 3 additions & 1 deletion terratorch/datasets/m_brick_kiln.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))
class_index = attr_dict["label"]

output = {"image": image.astype(np.float32), "label": class_index}
output = {"image": image.astype(np.float32)}

output = self.transform(**output)

output["label"] = class_index

return output

def __len__(self):
Expand Down
4 changes: 3 additions & 1 deletion terratorch/datasets/m_eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
label_class = self.id_to_class[image_id]
label_index = list(self.label_map.keys()).index(label_class)

output = {"image": image.astype(np.float32), "label": label_index}
output = {"image": image.astype(np.float32)}

output = self.transform(**output)

output["label"] = label_index

return output

def _validate_bands(self, bands: Sequence[str]) -> None:
Expand Down
4 changes: 3 additions & 1 deletion terratorch/datasets/m_forestnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))
class_index = attr_dict["label"]

output = {"image": image.astype(np.float32), "label": class_index}
output = {"image": image.astype(np.float32)}

output = self.transform(**output)

output["label"] = class_index

return output

def _validate_bands(self, bands: Sequence[str]) -> None:
Expand Down
4 changes: 3 additions & 1 deletion terratorch/datasets/m_pv4ger.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))
class_index = attr_dict["label"]

output = {"image": image.astype(np.float32), "label": class_index}
output = {"image": image.astype(np.float32)}

output = self.transform(**output)

output["label"] = class_index

return output

def _validate_bands(self, bands: Sequence[str]) -> None:
Expand Down

0 comments on commit 680963f

Please sign in to comment.