Skip to content

Commit

Permalink
Removed all local file paths; linted and formatted
Browse files Browse the repository at this point in the history
  • Loading branch information
jdollinger-bit committed May 28, 2024
1 parent f8c5b1a commit 47f539f
Show file tree
Hide file tree
Showing 13 changed files with 612 additions and 306 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ __pychache__
.ipynb_checkpoints
.gitignore
venv
glc23_data
12 changes: 6 additions & 6 deletions config/local.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
sent_data_path: "/data/jdolli/glc23_data/SatelliteImages/"
bioclim_path: "/shares/wegner.ics.uzh/glc23_data/bioclim+elev/bioclim_elevation_scaled_europe.npy"
dataset_file_path: "/shares/wegner.ics.uzh/glc23_data/Pot_10_to_1000.csv"
cp_dir_path: "/scratch/jdolli/sent-sinr/checkpoints"
logs_dir_path: "/scratch/jdolli/sent-sinr/"
test_data_path: "/shares/wegner.ics.uzh/glc23_data/Presence_Absence_surveys/Presences_Absences_train.csv"
sent_data_path: "glc23_data/SatelliteImages/"
bioclim_path: "glc23_data/bioclim+elev/bioclim_elevation_scaled_europe.npy"
dataset_file_path: "glc23_data/Pot_10_to_1000.csv"
cp_dir_path: "cps/"
logs_dir_path: "logs/"
test_data_path: "glc23_data/Presence_Absence_surveys/Presences_Absences_train.csv"
gpu: True
106 changes: 83 additions & 23 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,21 @@ def __init__(self, params, dataset_file, predictors, bioclim_path, sent_data_pat
# test_data is not used by the dataset itself, but the model needs this object
with open(params.local.test_data_path, "r") as f:
data_test = pd.read_csv(f, sep=";", header="infer", low_memory=False)
self.test_data = data_test.groupby(["patchID", "dayOfYear", "lon", "lat"]).agg(
{"speciesId": lambda x: list(x)}).reset_index()
self.test_data = {str(entry["lon"]) + "/" + str(entry["lat"]) + "/" + str(entry["dayOfYear"]) + "/" + str(
entry["patchID"]): entry["speciesId"] for idx, entry in self.test_data.iterrows()}
self.test_data = (
data_test.groupby(["patchID", "dayOfYear", "lon", "lat"])
.agg({"speciesId": lambda x: list(x)})
.reset_index()
)
self.test_data = {
str(entry["lon"])
+ "/"
+ str(entry["lat"])
+ "/"
+ str(entry["dayOfYear"])
+ "/"
+ str(entry["patchID"]): entry["speciesId"]
for idx, entry in self.test_data.iterrows()
}

self.predictors = predictors
if "sent2" in predictors:
Expand All @@ -28,14 +39,18 @@ def __init__(self, params, dataset_file, predictors, bioclim_path, sent_data_pat
# The raster we are loading is already cropped to Europe and normalized
context_feats = np.load(bioclim_path).astype(np.float32)
self.raster = torch.from_numpy(context_feats)
self.raster[torch.isnan(self.raster)] = 0.0 # replace with mean value (0 is mean post-normalization)
self.raster[torch.isnan(self.raster)] = (
0.0 # replace with mean value (0 is mean post-normalization)
)

self.sent_data_path = sent_data_path

self.transforms = v2.Compose([
v2.RandomHorizontalFlip(p=0.5),
v2.RandomVerticalFlip(p=0.5),
])
self.transforms = v2.Compose(
[
v2.RandomHorizontalFlip(p=0.5),
v2.RandomVerticalFlip(p=0.5),
]
)

def __len__(self):
return len(self.data)
Expand All @@ -50,7 +65,12 @@ def _normalize_loc_to_uniform(self, lon, lat):

def _encode_loc(self, lon, lat):
"""Expects lon and lat to be scale between [-1,1]"""
features = [np.sin(np.pi * lon), np.cos(np.pi * lon), np.sin(np.pi * lat), np.cos(np.pi * lat)]
features = [
np.sin(np.pi * lon),
np.cos(np.pi * lon),
np.sin(np.pi * lat),
np.cos(np.pi * lat),
]
return np.stack(features, axis=-1)

def sample_encoded_locs(self, size):
Expand All @@ -61,7 +81,9 @@ def sample_encoded_locs(self, size):
lat = lat * 2 - 1
loc_enc = torch.tensor(self._encode_loc(lon, lat), dtype=torch.float32)
if "env" in self.predictors:
env_enc = bilinear_interpolate(torch.stack([torch.tensor(lon), torch.tensor(lat)], dim=1), self.raster)
env_enc = bilinear_interpolate(
torch.stack([torch.tensor(lon), torch.tensor(lat)], dim=1), self.raster
)
if "loc" in self.predictors:
return torch.cat([loc_enc, env_enc], dim=1).type("torch.FloatTensor")
else:
Expand All @@ -82,7 +104,9 @@ def get_env_raster(self, lon, lat):
def get_loc_env(self, lon, lat):
"""Given lon and lat, create the location and environmental embedding."""
lon_norm, lat_norm = self._normalize_loc_to_uniform(lon, lat)
loc_enc = torch.tensor(self._encode_loc(lon_norm, lat_norm), dtype=torch.float32)
loc_enc = torch.tensor(
self._encode_loc(lon_norm, lat_norm), dtype=torch.float32
)
env_enc = self.get_env_raster(lon, lat).type("torch.FloatTensor")
return torch.cat((loc_enc, env_enc.view(20)))

Expand All @@ -108,8 +132,26 @@ def encode(self, lon, lat):

def get_gbif_sent2(self, pid):
"""Get Sentinel-2 image for patch_id."""
rgb_path = self.sent_data_path + "rgb/" + str(pid)[-2:] + "/" + str(pid)[-4:-2] + "/" + str(pid) + ".jpeg"
nir_path = self.sent_data_path + "nir/" + str(pid)[-2:] + "/" + str(pid)[-4:-2] + "/" + str(pid) + ".jpeg"
rgb_path = (
self.sent_data_path
+ "rgb/"
+ str(pid)[-2:]
+ "/"
+ str(pid)[-4:-2]
+ "/"
+ str(pid)
+ ".jpeg"
)
nir_path = (
self.sent_data_path
+ "nir/"
+ str(pid)[-2:]
+ "/"
+ str(pid)[-4:-2]
+ "/"
+ str(pid)
+ ".jpeg"
)
rgb = Image.open(rgb_path)
nir = Image.open(nir_path)
img = torch.concat([self.to_tensor(rgb), self.to_tensor(nir)], dim=0) / 255
Expand All @@ -121,21 +163,39 @@ def __getitem__(self, idx):
data_dict = self.data.iloc[idx]
lon, lat = tuple(data_dict[["lon", "lat"]].to_numpy())
if "sent2" in self.predictors:
return self.encode(lon, lat), self.get_gbif_sent2(data_dict["patchID"]), torch.tensor(
data_dict["speciesId"])
return (
self.encode(lon, lat),
self.get_gbif_sent2(data_dict["patchID"]),
torch.tensor(data_dict["speciesId"]),
)
else:
return self.encode(lon, lat), torch.tensor(data_dict["speciesId"])


def create_datasets(params):
"""Creates dataset and dataloaders from the various files"""
dataset_file = pd.read_csv(params.local.dataset_file_path, sep=";", header='infer', low_memory=False)
dataset_file = pd.read_csv(
params.local.dataset_file_path, sep=";", header="infer", low_memory=False
)
bioclim_path = params.local.bioclim_path
dataset = SINR_DS(params, dataset_file, params.dataset.predictors, sent_data_path=params.local.sent_data_path,
bioclim_path=bioclim_path)
dataset = SINR_DS(
params,
dataset_file,
params.dataset.predictors,
sent_data_path=params.local.sent_data_path,
bioclim_path=bioclim_path,
)
ds_train, ds_val = torch.utils.data.random_split(dataset, [0.9, 0.1])
train_loader = torch.utils.data.DataLoader(ds_train, shuffle=True, batch_size=params.dataset.batchsize,
num_workers=params.dataset.num_workers)
val_loader = torch.utils.data.DataLoader(ds_val, shuffle=False, batch_size=params.dataset.batchsize,
num_workers=params.dataset.num_workers)
train_loader = torch.utils.data.DataLoader(
ds_train,
shuffle=True,
batch_size=params.dataset.batchsize,
num_workers=params.dataset.num_workers,
)
val_loader = torch.utils.data.DataLoader(
ds_val,
shuffle=False,
batch_size=params.dataset.batchsize,
num_workers=params.dataset.num_workers,
)
return dataset, train_loader, val_loader
5 changes: 2 additions & 3 deletions embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def __init__(self, layer_removed=1, hidden_dim=128):
super().__init__()
self.center_crop = torchvision.transforms.functional.center_crop
self.layer_removed = layer_removed
layers = [torch.nn.Conv2d(4, 32, 4, 2, 1),
torch.nn.ReLU()]
layers = [torch.nn.Conv2d(4, 32, 4, 2, 1), torch.nn.ReLU()]
for i in range(layer_removed):
layers.append(torch.nn.Conv2d(32, 32, 3, 1, 1))
layers.append(torch.nn.ReLU())
Expand Down Expand Up @@ -71,7 +70,7 @@ def forward(self, tensor):


def get_embedder(params):
if params.embedder == "ae_default":
if params.embedder == "cnn_default":
return AE_DEFAULT()
elif params.embedder.startswith("cnn_si"):
return CNN_SMALLERINPUT(int(params.embedder[-1]))
Expand Down
63 changes: 47 additions & 16 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,30 @@ def get_logger(params, tag=""):
name += " val"
name += " " + tag

logger = hydra.utils.instantiate({"_target_": "pytorch_lightning.loggers.WandbLogger",
"name": name,
"save_dir": params.local.logs_dir_path,
"project": "sinr_on_glc23"})
logger = hydra.utils.instantiate(
{
"_target_": "pytorch_lightning.loggers.WandbLogger",
"name": name,
"save_dir": params.local.logs_dir_path,
"project": "sinr_on_glc23",
}
)
return logger


def train_model(params, dataset, train_loader, val_loader, provide_model=None, logger=None, validate=False):
def train_model(
params,
dataset,
train_loader,
val_loader,
provide_model=None,
logger=None,
validate=False,
):
"""
Instantiates model, defines which epoch to save as checkpoint, and trains
"""
torch.set_float32_matmul_precision('medium')
torch.set_float32_matmul_precision("medium")

if not provide_model:
if params.model == "sinr" or params.model == "log_reg":
Expand All @@ -54,11 +66,17 @@ def train_model(params, dataset, train_loader, val_loader, provide_model=None, l
monitor="val_loss",
mode="min",
dirpath=params.local.cp_dir_path,
filename=logger._name + "{val_loss:.4f}"
filename=logger._name + "{val_loss:.4f}",
)
trainer = pl.Trainer(
max_epochs=params.epochs,
accelerator=("gpu" if params.local.gpu else "cpu"),
devices=1,
precision="16-mixed",
logger=logger,
log_every_n_steps=50,
callbacks=[checkpoint_callback],
)
trainer = pl.Trainer(max_epochs=params.epochs, accelerator=("gpu" if params.local.gpu else "cpu"), devices=1,
precision="16-mixed", logger=logger, log_every_n_steps=50,
callbacks=[checkpoint_callback])
if validate:
trainer.validate(model=model, dataloaders=[val_loader])
else:
Expand All @@ -68,22 +86,35 @@ def train_model(params, dataset, train_loader, val_loader, provide_model=None, l
def load_cp(params, dataset):
"""Loads checkpoint."""
if params.model == "sinr" or params.model == "log_reg":
model = SINR.load_from_checkpoint(params.checkpoint, params=params, dataset=dataset)
model = SINR.load_from_checkpoint(
params.checkpoint, params=params, dataset=dataset
)
elif "sat" in params.model:
model = SAT_SINR.load_from_checkpoint(params.checkpoint, params=params, dataset=dataset,
sent2_net=get_embedder(params))
model = SAT_SINR.load_from_checkpoint(
params.checkpoint,
params=params,
dataset=dataset,
sent2_net=get_embedder(params),
)
return model


@hydra.main(version_base=None, config_path='config', config_name='base_config.yaml')
@hydra.main(version_base=None, config_path="config", config_name="base_config.yaml")
def main(params):
"""main funct."""
dataset, train_loader, val_loader = create_datasets(params)
logger = get_logger(params, tag=params.tag)
if params.checkpoint != "None":
model = load_cp(params, dataset)
train_model(params, dataset, train_loader, val_loader, provide_model=model, logger=logger,
validate=params.validate)
train_model(
params,
dataset,
train_loader,
val_loader,
provide_model=model,
logger=logger,
validate=params.validate,
)
else:
train_model(params, dataset, train_loader, val_loader, logger=logger)
wandb.finish()
Expand Down
Loading

0 comments on commit 47f539f

Please sign in to comment.