From 9c611d3f293476b7335f1456f7709e6d73b166d1 Mon Sep 17 00:00:00 2001 From: annajungbluth Date: Sat, 27 Apr 2024 12:09:55 +0000 Subject: [PATCH] added new editor --- .../multi-sat/2.0-pipeline-dataloader.ipynb | 221 ++++++------------ rs_tools/_src/datamodule/datasets.py | 2 +- rs_tools/_src/datamodule/editor.py | 38 ++- rs_tools/_src/datamodule/utils.py | 1 + 4 files changed, 106 insertions(+), 156 deletions(-) diff --git a/notebooks/dev/multi-sat/2.0-pipeline-dataloader.ipynb b/notebooks/dev/multi-sat/2.0-pipeline-dataloader.ipynb index 84957ea..16f099d 100644 --- a/notebooks/dev/multi-sat/2.0-pipeline-dataloader.ipynb +++ b/notebooks/dev/multi-sat/2.0-pipeline-dataloader.ipynb @@ -9,10 +9,23 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ + "import sys\n", + "# TODO: Fix path\n", + "sys.path.append(\"/home/anna.jungbluth/rs_tools/\")\n", + "sys.path.append(\"/home/anna.jungbluth/InstrumentToInstrument/\")\n", "from typing import Optional, List, Union, Tuple, Dict\n", "from omegaconf import DictConfig\n", "from datetime import datetime\n", @@ -113,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -122,16 +135,16 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ - "filenames = get_list_filenames('/Users/anna.jungbluth/Desktop/git/rs_tools/data/msg/analysis', 'nc')" + "filenames = get_list_filenames('/home/anna.jungbluth/data/msg/analysis', 'nc')" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -140,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -509,7 +522,7 @@ " stroke: currentColor;\n", " fill: currentColor;\n", "}\n", - "
<xarray.DataArray 'cloud_mask' (y: 256, x: 256)>\n",
+       "
<xarray.DataArray 'cloud_mask' (y: 256, x: 256)> Size: 524kB\n",
        "array([[3., 3., 3., ..., 2., 1., 1.],\n",
        "       [3., 3., 3., ..., 2., 2., 1.],\n",
        "       [3., 3., 3., ..., 2., 1., 1.],\n",
@@ -518,27 +531,27 @@
        "       [2., 2., 2., ..., 2., 2., 2.],\n",
        "       [2., 2., 2., ..., 1., 2., 2.]])\n",
        "Coordinates:\n",
-       "  * y           (y) float64 -1.449e+06 -1.446e+06 ... -6.871e+05 -6.841e+05\n",
-       "  * x           (x) float64 -5.137e+06 -5.134e+06 ... -4.375e+06 -4.372e+06\n",
-       "    cloud_mask  (y, x) float64 3.0 3.0 3.0 3.0 3.0 3.0 ... 2.0 1.0 1.0 2.0 2.0\n",
-       "    latitude    (y, x) float64 -15.0 -14.99 -14.98 -14.98 ... -6.611 -6.61 -6.61\n",
-       "    longitude   (y, x) float64 -69.95 -69.79 -69.63 ... -47.28 -47.23 -47.19\n",
+       "  * y           (y) float64 2kB -1.449e+06 -1.446e+06 ... -6.871e+05 -6.841e+05\n",
+       "  * x           (x) float64 2kB -5.137e+06 -5.134e+06 ... -4.375e+06 -4.372e+06\n",
+       "    cloud_mask  (y, x) float64 524kB 3.0 3.0 3.0 3.0 3.0 ... 2.0 1.0 1.0 2.0 2.0\n",
+       "    latitude    (y, x) float64 524kB -15.0 -14.99 -14.98 ... -6.611 -6.61 -6.61\n",
+       "    longitude   (y, x) float64 524kB -69.95 -69.79 -69.63 ... -47.23 -47.19\n",
        "Attributes:\n",
-       "    grid_mapping:  msg_seviri_fes_3km
  • grid_mapping :
    msg_seviri_fes_3km
  • " ], "text/plain": [ - "\n", + " Size: 524kB\n", "array([[3., 3., 3., ..., 2., 1., 1.],\n", " [3., 3., 3., ..., 2., 2., 1.],\n", " [3., 3., 3., ..., 2., 1., 1.],\n", @@ -592,16 +605,16 @@ " [2., 2., 2., ..., 2., 2., 2.],\n", " [2., 2., 2., ..., 1., 2., 2.]])\n", "Coordinates:\n", - " * y (y) float64 -1.449e+06 -1.446e+06 ... -6.871e+05 -6.841e+05\n", - " * x (x) float64 -5.137e+06 -5.134e+06 ... -4.375e+06 -4.372e+06\n", - " cloud_mask (y, x) float64 3.0 3.0 3.0 3.0 3.0 3.0 ... 2.0 1.0 1.0 2.0 2.0\n", - " latitude (y, x) float64 -15.0 -14.99 -14.98 -14.98 ... -6.611 -6.61 -6.61\n", - " longitude (y, x) float64 -69.95 -69.79 -69.63 ... -47.28 -47.23 -47.19\n", + " * y (y) float64 2kB -1.449e+06 -1.446e+06 ... -6.871e+05 -6.841e+05\n", + " * x (x) float64 2kB -5.137e+06 -5.134e+06 ... -4.375e+06 -4.372e+06\n", + " cloud_mask (y, x) float64 524kB 3.0 3.0 3.0 3.0 3.0 ... 2.0 1.0 1.0 2.0 2.0\n", + " latitude (y, x) float64 524kB -15.0 -14.99 -14.98 ... -6.611 -6.61 -6.61\n", + " longitude (y, x) float64 524kB -69.95 -69.79 -69.63 ... -47.23 -47.19\n", "Attributes:\n", " grid_mapping: msg_seviri_fes_3km" ] }, - "execution_count": 40, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -612,7 +625,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -621,7 +634,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -631,7 +644,7 @@ " 6.25, 7.35])" ] }, - "execution_count": 33, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -650,7 +663,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -659,7 +672,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -671,13 +684,13 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "geo_train = GeoDataset(\n", - " data_dir='/Users/anna.jungbluth/Desktop/git/rs_tools/data/msg/analysis',\n", - " editors=None,\n", + " data_dir='/home/anna.jungbluth/data/msg/analysis',\n", + " editors=[NanDictEditor(key=\"data\")],\n", " splits_dict=splits_dict['train'],\n", " load_coords=True,\n", " load_cloudmask=True,\n", @@ -686,7 +699,27 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "108" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(geo_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, "metadata": {}, "outputs": [ { @@ -696,7 +729,7 @@ " 6.25, 7.35])" ] }, - "execution_count": 42, + "execution_count": 47, "metadata": {}, "output_type": "execute_result" } @@ -714,12 +747,12 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "from rs_tools._src.datamodule.editor import BandSelectionEditor, NanMaskEditor, CoordNormEditor, NanDictEditor, RadUnitEditor, ToTensorEditor, StackDictEditor\n", - "from iti.data.editor import NanEditor, RandomPatchEditor, BrightestPixelPatchEditor\n", + "from iti.data.editor import NanEditor, RandomPatchEditor, BrightestPixelPatchEditor, Editor\n", "from rs_tools._src.geoprocessing.units import convert_units\n", "from rs_tools import MODIS_WAVELENGTHS, GOES_WAVELENGTHS, MSG_WAVELENGTHS" ] @@ -758,110 +791,6 @@ ")" ] }, - { - "cell_type": "code", - "execution_count": 75, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[ 3.9061e+01, 3.9061e+01, 3.9061e+01, ..., 7.3757e+02,\n", - " 6.2728e+02, 6.6404e+02],\n", - " [ 3.9061e+01, 3.9061e+01, 3.9061e+01, ..., 6.6404e+02,\n", - " 6.6404e+02, 6.6404e+02],\n", - " [ 3.9061e+01, 3.9061e+01, 3.9061e+01, ..., 6.6404e+02,\n", - " 6.6404e+02, 6.2728e+02],\n", - " ...,\n", - " [ 7.5825e+01, 7.5825e+01, 7.5825e+01, ..., 9.2139e+02,\n", - " 8.4786e+02, 8.1110e+02],\n", - " [ 7.5825e+01, 7.5825e+01, 7.5825e+01, ..., 8.1110e+02,\n", - " 8.4786e+02, 8.8462e+02],\n", - " [ 7.5825e+01, 3.9061e+01, 7.5825e+01, ..., 8.8462e+02,\n", - " 9.2139e+02, 9.9491e+02]],\n", - "\n", - " [[ 2.1066e+01, 1.2392e+00, 1.2392e+00, ..., 1.7968e+02,\n", - " 1.5985e+02, 1.5985e+02],\n", - " [ 2.1066e+01, 2.1066e+01, 1.2392e+00, ..., 1.5985e+02,\n", - " 1.5985e+02, 1.5985e+02],\n", - " [ 2.1066e+01, 1.2392e+00, 1.2392e+00, ..., 1.4003e+02,\n", - " 1.7968e+02, 1.4003e+02],\n", - " ...,\n", - " [ 2.1066e+01, 2.1066e+01, 2.1066e+01, ..., 2.5899e+02,\n", - " 1.5985e+02, 1.7968e+02],\n", - " [ 2.1066e+01, 4.0893e+01, 2.1066e+01, ..., 1.5985e+02,\n", - " 1.7968e+02, 1.9951e+02],\n", - " [ 2.1066e+01, 2.1066e+01, 1.2392e+00, ..., 1.9951e+02,\n", - " 2.1933e+02, 2.5899e+02]],\n", - "\n", - " [[ 6.7058e-01, 6.7058e-01, 6.7058e-01, ..., 6.5047e+01,\n", - " 4.3588e+01, 6.5047e+01],\n", - " [ 6.7058e-01, 6.7058e-01, 6.7058e-01, ..., 5.4317e+01,\n", - " 5.4317e+01, 4.3588e+01],\n", - " [ 6.7058e-01, 6.7058e-01, 6.7058e-01, ..., 3.2859e+01,\n", - " 4.3588e+01, 3.2859e+01],\n", - " ...,\n", - " [ 1.1400e+01, 1.1400e+01, 1.1400e+01, ..., 7.5776e+01,\n", - " 4.3588e+01, 3.2859e+01],\n", - " [ 6.7058e-01, 6.7058e-01, 6.7058e-01, ..., 3.2859e+01,\n", - " 4.3588e+01, 4.3588e+01],\n", - " [ 1.1400e+01, 6.7058e-01, 6.7058e-01, ..., 3.2859e+01,\n", - " 8.6506e+01, 9.7235e+01]],\n", - "\n", - " ...,\n", - "\n", - " [[ 1.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [ 1.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [ 1.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [ 1.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [ 1.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [ 1.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[-1.6655e-01, -1.6654e-01, -1.6654e-01, ..., -1.6461e-01,\n", - " -1.6461e-01, -1.6460e-01],\n", - " [-1.6644e-01, -1.6643e-01, -1.6642e-01, ..., -1.6450e-01,\n", - " -1.6449e-01, -1.6449e-01],\n", - " [-1.6632e-01, -1.6631e-01, -1.6631e-01, ..., -1.6439e-01,\n", - " -1.6438e-01, -1.6437e-01],\n", - " ...,\n", - " [-1.3769e-01, -1.3769e-01, -1.3768e-01, ..., -1.3618e-01,\n", - " -1.3618e-01, -1.3617e-01],\n", - " [-1.3758e-01, -1.3757e-01, -1.3757e-01, ..., -1.3607e-01,\n", - " -1.3607e-01, -1.3606e-01],\n", - " [-1.3747e-01, -1.3746e-01, -1.3745e-01, ..., -1.3596e-01,\n", - " -1.3596e-01, -1.3595e-01]],\n", - "\n", - " [[-7.2208e-01, -7.2196e-01, -7.2184e-01, ..., -6.9359e-01,\n", - " -6.9348e-01, -6.9338e-01],\n", - " [-7.2205e-01, -7.2193e-01, -7.2181e-01, ..., -6.9356e-01,\n", - " -6.9346e-01, -6.9336e-01],\n", - " [-7.2202e-01, -7.2190e-01, -7.2178e-01, ..., -6.9354e-01,\n", - " -6.9344e-01, -6.9333e-01],\n", - " ...,\n", - " [-7.1541e-01, -7.1530e-01, -7.1518e-01, ..., -6.8835e-01,\n", - " -6.8825e-01, -6.8815e-01],\n", - " [-7.1539e-01, -7.1527e-01, -7.1516e-01, ..., -6.8833e-01,\n", - " -6.8823e-01, -6.8813e-01],\n", - " [-7.1537e-01, -7.1525e-01, -7.1514e-01, ..., -6.8831e-01,\n", - " -6.8821e-01, -6.8811e-01]]])" - ] - }, - "execution_count": 75, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "geo_train[0]" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -946,7 +875,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/rs_tools/_src/datamodule/datasets.py b/rs_tools/_src/datamodule/datasets.py index 277a61a..b739c04 100644 --- a/rs_tools/_src/datamodule/datasets.py +++ b/rs_tools/_src/datamodule/datasets.py @@ -19,7 +19,7 @@ from rs_tools._src.utils.io import get_list_filenames from rs_tools._src.datamodule.utils import get_split -# TODO: To be moved into ITI repo +# NOTE: Code already moved to ITI repo class GeoDataset(BaseDataset): def __init__( self, diff --git a/rs_tools/_src/datamodule/editor.py b/rs_tools/_src/datamodule/editor.py index de4eb0e..771f424 100644 --- a/rs_tools/_src/datamodule/editor.py +++ b/rs_tools/_src/datamodule/editor.py @@ -15,15 +15,36 @@ ToTensor, ) -# Editors that already exist (but not for dictionaries) -# - NormalizeEditor / ImageNormalizeEditor] +# NOTE: Code already moved to ITI repo -# TODO: Check if this is still needed class BandOrderEditor(Editor): - def call(self, data, **kwargs): - raise NotImplementedError + """ + Reorders bands in data dictionary. + """ + + def __init__(self, target_order, key="data"): + """ + Args: + target_order (list): Order of bands + key (str): Key in dictionary to apply transformation + """ + self.target_order = target_order + self.key = key + + def call(self, data_dict, **kwargs): + source_order = data_dict["wavelengths"] + assert len(source_order) == len(self.target_order), "Length of source and target wavelengths must match." + # Get indexes of bands to select + indexes = [np.where(source_order == wvl)[0][0] for wvl in self.target_order] + # Extract data + data = data_dict[self.key] + # Subselect bands + data = data[indexes] + # Update dictionary + data_dict[self.key] = data + data_dict["wavelengths"] = np.array(self.target_order) + return data_dict -# TODO: Allow selecting by band name, rather than center wvl? class BandSelectionEditor(Editor): """ Selects a subset of available bands from data dictionary @@ -66,11 +87,10 @@ def call(self, data_dict, **kwargs): data_dict["nan_mask"] = mask return data_dict -# NOTE: Already exists in ITI repo for numpy arrays -# NOTE: Can also be used to replace NaN values for coordinates to remove off limb data class NanDictEditor(Editor): """ - Removes NaN values from data dictionary + Removes NaN values from data dictionary. + Can also be used to replace NaN values of coordinates to remove off limb data. """ def __init__(self, key="data", fill_value=0): self.key = key diff --git a/rs_tools/_src/datamodule/utils.py b/rs_tools/_src/datamodule/utils.py index a4f8985..9aefd7b 100644 --- a/rs_tools/_src/datamodule/utils.py +++ b/rs_tools/_src/datamodule/utils.py @@ -4,6 +4,7 @@ import pandas as pd from loguru import logger +# NOTE: Code already moved to ITI repo def split_train_val(files: List, split_spec: DictConfig) -> Tuple[List, List]: """ Split files into training and validation sets based on dataset specification.