Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EO Development (draft pull request) #3

Open
wants to merge 52 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
36007ea
Merge remote-tracking branch 'origin/development' into development-eo
annajungbluth Apr 25, 2024
0c2b26f
started eo training development
annajungbluth Apr 25, 2024
9993839
started testing training pipeline
annajungbluth Apr 25, 2024
3330469
wip - tested training pipeline
annajungbluth Apr 25, 2024
9071b24
made training pipeline run
annajungbluth Apr 26, 2024
d7f2c9d
moved geo dataset and editors into ITI repo
annajungbluth Apr 26, 2024
b6983b2
added new editor and modified training script
annajungbluth Apr 27, 2024
2879551
modified editor
annajungbluth Apr 27, 2024
12ee80e
tested callbacks
annajungbluth Apr 28, 2024
cdc4c7c
added normalisation steps to training script, and started writing a n…
lillif May 13, 2024
92c26d6
training script for miniset set up, mean std normalisation finished
lillif May 14, 2024
a18e68a
normalisation script finished, attempted training
lillif May 17, 2024
9d592bb
started hydra training file
annajungbluth May 19, 2024
6d24344
added normalization and fixed bugs in training script
annajungbluth May 20, 2024
1ddd8d5
merge with master
annajungbluth Oct 3, 2024
d7b31d0
fixed small merge bugs and added autoroot file
annajungbluth Oct 3, 2024
ad54e8e
Added file with dataset information
annajungbluth Oct 4, 2024
72c3f82
fixed goes metrics file
annajungbluth Oct 18, 2024
4533b53
updated summary files
annajungbluth Oct 20, 2024
2e72ef8
added new normalization routine and started first experiment
annajungbluth Oct 20, 2024
7a79a1e
reduced val data
annajungbluth Oct 20, 2024
d3a4ad4
optimized dataloader to reduce memory consumption
annajungbluth Oct 24, 2024
e2c5675
added normalization files for subset of data
annajungbluth Oct 31, 2024
625852a
debugging constant channels
annajungbluth Oct 31, 2024
8770404
started new experiment
annajungbluth Oct 31, 2024
5b98ee1
added seed to training script
annajungbluth Nov 1, 2024
e586a44
added min max normalizer
annajungbluth Nov 18, 2024
3a9c243
updated center weighted cropping routine for better cropping
annajungbluth Nov 22, 2024
94ea2cd
removed file with large missing band
annajungbluth Nov 25, 2024
943fe17
started miniset experiment
annajungbluth Nov 25, 2024
520415b
removed files with pixel strips missing
annajungbluth Nov 25, 2024
69f5f68
removed files with missing or half missing channels
annajungbluth Nov 25, 2024
e61dca0
added files for miniset and removed files with artifacts
annajungbluth Nov 25, 2024
cc5031e
fixed subser summary file
annajungbluth Nov 27, 2024
2c17c29
updated config
annajungbluth Nov 27, 2024
479bb93
fixed mistake in notebook
annajungbluth Nov 27, 2024
eea208d
added script for goes-to-msg translation
annajungbluth Dec 6, 2024
2a17687
updated miniset
annajungbluth Dec 15, 2024
4b3caf0
added separation of A and B patch size
annajungbluth Dec 15, 2024
dad7444
modified config for miniset experiment
annajungbluth Dec 15, 2024
e6af551
experiment modifications
annajungbluth Dec 15, 2024
b4d3b62
merged with changes adding rotation transform
annajungbluth Dec 15, 2024
1fa9544
modified dataloader to handle no cropping
annajungbluth Dec 17, 2024
5eb95f8
updated config for experimentation with miniset
annajungbluth Dec 17, 2024
76dfecb
fixed problem with normalization and fixed experiment
annajungbluth Dec 19, 2024
4e37c53
added storage dataset
annajungbluth Dec 19, 2024
1bc1c8c
started new experiment for all infrared channels
annajungbluth Dec 21, 2024
f0ef3da
started new experiment with visible channels
annajungbluth Dec 21, 2024
399bfa6
started new experiment with visible channels
annajungbluth Dec 21, 2024
6a6641e
deleted testing code
annajungbluth Dec 21, 2024
5a5482c
updated config for new experiment
annajungbluth Dec 23, 2024
2d35d61
merged configs
annajungbluth Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added .project-root
Empty file.
54 changes: 54 additions & 0 deletions config/example-hydra-config/data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
A_data:
A_path: null
A_train_dataset:
_target_: iti.data.geo_datasets.GeoDataset # TODO: make specific msg dataset?
data_dir: null
editors: null # TODO: hard code in dataset?
splits_dict:
train:
years: [2020]
months: [10]
days: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
load_coords: False
load_cloudmask: False
A_val_dataset:
_target_: iti.data.geo_datasets.GeoDataset # TODO: make specific msg dataset?
data_dir: null
editors: null # TODO: hard code in dataset?
splits_dict:
train:
years: [2020]
months: [10]
days: [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
load_coords: False
load_cloudmask: False
A_plot_settings: null

B_data:
B_path: null
B_train_dataset:
_target_: iti.data.geo_datasets.GeoDataset # TODO: make specific goes dataset?
data_dir: null
editors: null # TODO: hard code in dataset?
splits_dict:
train:
years: [2020]
months: [10]
days: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
load_coords: False
load_cloudmask: False
B_val_dataset:
_target_: iti.data.geo_datasets.GeoDataset # TODO: make specific goes dataset?
data_dir: null
editors: null # TODO: hard code in dataset?
splits_dict:
train:
years: [2020]
months: [10]
days: [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
load_coords: False
load_cloudmask: False
B_plot_settings: null

num_workers: 4
iterations_per_epoch: 1000
9 changes: 9 additions & 0 deletions config/example-hydra-config/model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
model:
__target__: null
input_dim_a: 11
input_dim_b: 16
upsampling: 0
discriminator_mode: CHANNELS
lambda_diversity: 0
norm: 'none'
use_batch_statistic: False
1 change: 1 addition & 0 deletions config/example-hydra-config/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
base_dir: /home/freischem/outputs/miniset/
6 changes: 6 additions & 0 deletions config/example-hydra-config/wandb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
experiment_name: null
tags: null
wandb_entity: null
wandb_project: null
wandb_name: null
wandb_id: null
28 changes: 28 additions & 0 deletions config/msg_to_goes.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
base_dir: /home/anna.jungbluth/outputs/msg-to-goes/
data:
A_path: /mnt/disks/eo-data/msg/
converted_A_path: /home/anna.jungbluth/tmp-data/msg/
B_path: /mnt/disks/eo-data/goes/
converted_B_path: /home/anna.jungbluth/tmp-data/goes/
num_workers: 4
iterations_per_epoch: 1000
A_patch_size: (1024, 1024) # Larger patches are saved for accelerated training.
B_patch_size: (1024, 1024) # Patches are further cropped to (256, 256) before training.
A_bands: [6.25, 7.35, 8.7, 9.66, 10.8, 12.0, 13.4] # [0.64, 0.81, 1.64, 3.92, 6.25, 7.35, 8.7, 9.66, 10.8, 12.0, 13.4]
B_bands: [6.17, 6.93, 7.34, 8.44, 9.61, 10.33, 11.19, 12.27, 13.27] # [0.47, 0.64, 0.87, 1.38, 1.61, 2.25, 3.89, 6.17, 6.93, 7.34, 8.44, 9.61, 10.33, 11.19, 12.27, 13.27]
model:
input_dim_a: 7
input_dim_b: 9
upsampling: 0
discriminator_mode: CHANNELS
lambda_diversity: 0
norm: 'in_rs_aff'
use_batch_statistic: False
logging:
wandb_entity: itieo
wandb_project: msg-to-goes
wandb_name: MSG_to_GOES-infrared-7bands (6.25 - 13.4 um)
training:
epochs: 200
limit_train_batches: null
normalization: # TODO: Change to avoid absolute paths
29 changes: 29 additions & 0 deletions config/msg_to_goes_miniset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
base_dir: /home/anna.jungbluth/outputs/msg-to-goes-miniset/
data:
A_path: /home/anna.jungbluth/data-miniset/msg/
B_path: /home/anna.jungbluth/data-miniset/goes/
num_workers: 4
iterations_per_epoch: 500
A_patch_size: null # Already cropped to 150 x 150
B_patch_size: null # Already cropped to 450 x 450
A_bands: [10.8] # [0.64, 0.81, 1.64, 3.92, 6.25, 7.35, 8.7, 9.66, 10.8, 12.0, 13.4]
B_bands: [10.33] # [0.47, 0.64, 0.87, 1.38, 1.61, 2.25, 3.89, 6.17, 6.93, 7.34, 8.44, 9.61, 10.33, 11.19, 12.27, 13.27]
model:
input_dim_a: 1
input_dim_b: 1
upsampling: 1 # one upscaling of x3
discriminator_mode: SINGLE
lambda_diversity: 0
norm: 'in_aff'
use_batch_statistic: False
logging:
wandb_entity: itieo
wandb_project: msg-to-goes
wandb_name: "[miniset]-MSG_to_GOES-infrared-upsampled (10.8 um) ['in_aff' norm]"
training:
epochs: 200
limit_train_batches: null
limit_val_batches: null
normalization: # TODO: Change to avoid absolute paths
A_norm_dir: /home/anna.jungbluth/InstrumentToInstrument/dataset/tmp/msg_2020_miniset.csv
B_norm_dir: /home/anna.jungbluth/InstrumentToInstrument/dataset/tmp/goes_2020_miniset.csv
8,689 changes: 8,689 additions & 0 deletions dataset/goes_2020_hourly.csv

Large diffs are not rendered by default.

4,187 changes: 4,187 additions & 0 deletions dataset/goes_2020_hourly_subset.csv

Large diffs are not rendered by default.

8,707 changes: 8,707 additions & 0 deletions dataset/msg_2020_hourly.csv

Large diffs are not rendered by default.

4,350 changes: 4,350 additions & 0 deletions dataset/msg_2020_hourly_subset.csv

Large diffs are not rendered by default.

67 changes: 67 additions & 0 deletions dataset/tmp/goes_2020_miniset.csv

Large diffs are not rendered by default.

70 changes: 70 additions & 0 deletions dataset/tmp/msg_2020_miniset.csv

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions itipy/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, data, model, plot_settings_A=None, plot_settings_B=None, plot

plot_settings = [*plot_settings_A, *plot_settings_B, *plot_settings_A]

super().__init__(data, model, path, plot_id, plot_settings, **kwargs)
super().__init__(data, model, plot_id, plot_settings, **kwargs)

def predict(self, x):
x_ab, x_aba = self.model.forwardABA(x)
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(self, data, model, plot_settings_A=None, plot_settings_B=None, plot

plot_settings = [*plot_settings_B, *plot_settings_A, *plot_settings_B]

super().__init__(data, model, path, plot_id, plot_settings, **kwargs)
super().__init__(data, model, plot_id, plot_settings, **kwargs)

def predict(self, x):
x_ba, x_bab = self.model.forwardBAB(x)
Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(self, data, model, plot_settings_A=None, plot_settings_B=None, plot

plot_settings = [*plot_settings_A, *plot_settings_B]

super().__init__(data, model, path, plot_id, plot_settings, **kwargs)
super().__init__(data, model, plot_id, plot_settings, **kwargs)

def predict(self, input_data):
x_ab = self.model.forwardAB(input_data)
Expand Down
142 changes: 142 additions & 0 deletions itipy/data/geo_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from __future__ import annotations
import collections
import collections.abc

#hyper needs the four following aliases to be done manually.
collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping

import logging
import time
import torch
import numpy as np
import xarray as xr
from typing import List, Union, Dict
from loguru import logger

from itipy.data.editor import Editor
from itipy.data.geo_editor import CenterWeightedCropDatasetEditor
from itipy.data.dataset import BaseDataset
from itipy.data.geo_utils import get_split, get_list_filenames, _check_any_constant_channels

class GeoDataset(BaseDataset):
def __init__(
self,
data_dir: List[str],
splits_dict: Dict,
editors: List[Editor]=None,
ext: str="nc",
limit: int=None,
fov_radius: float=0.6,
load_coords: bool=True,
load_cloudmask: bool=True,
patch_size: tuple[int, int] = (256, 256),
**kwargs
):
"""
Initialize the GeoDataset class.

Args:
data_dir (List[str]): A list of directories containing the data files.
editors (List[Editor]): A list of editors for data preprocessing.
splits_dict (Dict, optional): A dictionary specifying the splits for the dataset. Defaults to None.
ext (str, optional): The file extension of the data files. Defaults to "nc".
limit (int, optional): The maximum number of files to load. Defaults to None.
fov_radius (float, optional): The radius of the field of view. Defaults to 0.6.
load_coords (bool, optional): Whether to load the coordinates. Defaults to True.
load_cloudmask (bool, optional): Whether to load the cloud mask. Defaults to True.
patch_size (tuple[int, int], optional): The size of the patches to crop. Defaults to (256, 256).
**kwargs: Additional keyword arguments.

"""
self.data_dir = data_dir
self.editors = editors
self.splits_dict = splits_dict
self.ext = ext
self.limit = limit
self.fov_radius = fov_radius
self.load_coords = load_coords
self.load_cloudmask = load_cloudmask
self.patch_size = patch_size

self.files = self.get_files()

self.crop = CenterWeightedCropDatasetEditor(patch_shape=self.patch_size, fov_radius=self.fov_radius)

super().__init__(
data=self.files,
editors=self.editors,
ext=self.ext,
limit=self.limit,
**kwargs
)

def get_files(self):
# Get filenames from data_dir
files = get_list_filenames(data_path=self.data_dir, ext=self.ext)
# split files based on split criteria
files = get_split(files=files, split_dict=self.splits_dict)
return files

def __len__(self):
return len(self.files)

def getIndex(self, data_dict, idx):
# Attempt applying editors
try:
return self.convertData(data_dict)
except Exception as ex:
logging.error('Unable to convert %s: %s' % (self.files[idx], ex))
raise ex

def __getitem__(self, idx):
data_dict = {}

ds: xr.Dataset = xr.load_dataset(self.files[idx], engine="netcdf4")
if self.patch_size is not None:
ds, xmin, ymin = self.crop(ds)
else:
xmin, ymin = 0, 0 # Set to 0 if no cropping is done
data = ds.Rad.compute().to_numpy()

data_dict["data"] = data
del data # Delete data to reduce memory usage
# Extract wavelengths
wavelengths = ds.band_wavelength.compute().to_numpy()
data_dict["wavelengths"] = wavelengths
del wavelengths # Delete data to reduce memory usage

# Extract coordinates
if self.load_coords:
latitude = ds.latitude.compute().to_numpy()
longitude = ds.longitude.compute().to_numpy()
coords = np.stack([latitude, longitude], axis=0)
data_dict["coords"] = coords
del latitude, longitude # Delete data to reduce memory usage
del coords # Delete data to reduce memory usage

# Extract cloud mask
if self.load_cloudmask:
cloud_mask = ds.cloud_mask.compute().to_numpy()
data_dict["cloud_mask"] = cloud_mask
del cloud_mask # Delete data to reduce memory usage

# Delete dataset to reduce memory usage
del ds

if self.editors is not None:
# Apply editors
data, _ = self.getIndex(data_dict, idx)

if np.any(np.nanstd(data, axis=(1, 2)) == 0):
print(f"Constant channel in patch")
print(f"File: {self.files[idx]}")
print(f"Patch x/y: {xmin}/{ymin}")
return data
else:
return data_dict



Loading