diff --git a/tools/forecast/2021_09_03_1300_icenet_demo.json b/tools/forecast/2021_09_03_1300_icenet_demo.json
new file mode 100644
index 0000000..47691d7
--- /dev/null
+++ b/tools/forecast/2021_09_03_1300_icenet_demo.json
@@ -0,0 +1,207 @@
+{
+ "dataloader_name": "icenet_demo",
+ "dataset_name": "dataset1",
+ "input_data": {
+ "siconca": {
+ "abs": {
+ "include": true,
+ "max_lag": 12
+ },
+ "anom": {
+ "include": false,
+ "max_lag": 3
+ },
+ "linear_trend": {
+ "include": true
+ }
+ },
+ "tas": {
+ "abs": {
+ "include": false,
+ "max_lag": 3
+ },
+ "anom": {
+ "include": true,
+ "max_lag": 3
+ }
+ },
+ "ta500": {
+ "abs": {
+ "include": false,
+ "max_lag": 3
+ },
+ "anom": {
+ "include": true,
+ "max_lag": 3
+ }
+ },
+ "tos": {
+ "abs": {
+ "include": false,
+ "max_lag": 3
+ },
+ "anom": {
+ "include": true,
+ "max_lag": 3
+ }
+ },
+ "rsds": {
+ "abs": {
+ "include": false,
+ "max_lag": 3
+ },
+ "anom": {
+ "include": true,
+ "max_lag": 3
+ }
+ },
+ "rsus": {
+ "abs": {
+ "include": false,
+ "max_lag": 3
+ },
+ "anom": {
+ "include": true,
+ "max_lag": 3
+ }
+ },
+ "psl": {
+ "abs": {
+ "include": false,
+ "max_lag": 3
+ },
+ "anom": {
+ "include": true,
+ "max_lag": 3
+ }
+ },
+ "zg500": {
+ "abs": {
+ "include": false,
+ "max_lag": 3
+ },
+ "anom": {
+ "include": true,
+ "max_lag": 3
+ }
+ },
+ "zg250": {
+ "abs": {
+ "include": false,
+ "max_lag": 3
+ },
+ "anom": {
+ "include": true,
+ "max_lag": 3
+ }
+ },
+ "ua10": {
+ "abs": {
+ "include": true,
+ "max_lag": 3
+ },
+ "anom": {
+ "include": false,
+ "max_lag": 3
+ }
+ },
+ "uas": {
+ "abs": {
+ "include": true,
+ "max_lag": 1
+ },
+ "anom": {
+ "include": false,
+ "max_lag": 1
+ }
+ },
+ "vas": {
+ "abs": {
+ "include": true,
+ "max_lag": 1
+ },
+ "anom": {
+ "include": false,
+ "max_lag": 1
+ }
+ },
+ "land": {
+ "metadata": true,
+ "include": true
+ },
+ "circmonth": {
+ "metadata": true,
+ "include": true
+ }
+ },
+ "batch_size": 2,
+ "shuffle": true,
+ "n_forecast_months": 6,
+ "sample_IDs": {
+ "obs_train_dates": [
+ "1980-1-1",
+ "2011-6-1"
+ ],
+ "obs_val_dates": [
+ "2012-1-1",
+ "2017-6-1"
+ ],
+ "obs_test_dates": [
+ "2018-1-1",
+ "2019-6-1"
+ ]
+ },
+ "cmip6_run_dict": {
+ "EC-Earth3": {
+ "r2i1p1f1": [
+ "1851-1-1",
+ "2099-6-1"
+ ],
+ "r7i1p1f1": [
+ "1851-1-1",
+ "2099-6-1"
+ ],
+ "r10i1p1f1": [
+ "1851-1-1",
+ "2099-6-1"
+ ],
+ "r12i1p1f1": [
+ "1851-1-1",
+ "2099-6-1"
+ ],
+ "r14i1p1f1": [
+ "1851-1-1",
+ "2099-6-1"
+ ]
+ },
+ "MRI-ESM2-0": {
+ "r1i1p1f1": [
+ "1851-1-1",
+ "2099-6-1"
+ ],
+ "r2i1p1f1": [
+ "1851-1-1",
+ "2029-6-1"
+ ],
+ "r3i1p1f1": [
+ "1851-1-1",
+ "2029-6-1"
+ ],
+ "r4i1p1f1": [
+ "1851-1-1",
+ "2029-6-1"
+ ],
+ "r5i1p1f1": [
+ "1851-1-1",
+ "2029-6-1"
+ ]
+ }
+ },
+ "raw_data_shape": [
+ 432,
+ 432
+ ],
+ "default_seed": 42,
+ "loss_weight_months": true,
+ "verbose_level": 0
+}
diff --git a/tools/forecast/config.py b/tools/forecast/config.py
new file mode 100644
index 0000000..9fa6f12
--- /dev/null
+++ b/tools/forecast/config.py
@@ -0,0 +1,89 @@
+"""
+Code taken from https://github.com/tom-andersson/icenet-paper and slightly adjusted
+to fit the galaxy interface.
+"""
+
+import os
+import pandas as pd
+'''
+Defines globals used throughout the codebase.
+'''
+
+###############################################################################
+# Folder structure naming system
+###############################################################################
+
+data_folder = 'data'
+obs_data_folder = os.path.join(data_folder, 'obs')
+cmip6_data_folder = os.path.join(data_folder, 'cmip6')
+mask_data_folder = os.path.join(data_folder, 'masks')
+forecast_data_folder = os.path.join(data_folder, 'forecasts')
+network_dataset_folder = os.path.join(data_folder, 'network_datasets')
+
+dataloader_config_folder = 'dataloader_configs'
+
+networks_folder = 'trained_networks'
+
+results_folder = 'results'
+forecast_results_folder = os.path.join(results_folder, 'forecast_results')
+permute_and_predict_results_folder = os.path.join(results_folder, 'permute_and_predict_results')
+uncertainty_results_folder = os.path.join(results_folder, 'uncertainty_results')
+
+figure_folder = 'figures'
+
+video_folder = 'videos'
+
+active_grid_cell_file_format = 'active_grid_cell_mask_{}.npy'
+land_mask_filename = 'land_mask.npy'
+region_mask_filename = 'region_mask.npy'
+
+###############################################################################
+# Polar hole/missing months
+###############################################################################
+
+# Pre-defined polar hole radii (in number of 25km x 25km grid cells)
+# The polar hole radii were determined from Sections 2.1, 2.2, and 2.3 of
+# http://osisaf.met.no/docs/osisaf_cdop3_ss2_pum_sea-ice-conc-climate-data-record_v2p0.pdf
+polarhole1_radius = 28
+polarhole2_radius = 11
+polarhole3_radius = 3
+
+# Whether or not to mask out the 3rd polar hole mask from
+# Nov 2005 to Dec 2015 with a radius of only 3 grid cells. Including it creates
+# some complications when analysing performance on a validation set that
+# overlaps with the 3rd polar hole period.
+use_polarhole3 = False
+
+polarhole1_fname = 'polarhole1_mask.npy'
+polarhole2_fname = 'polarhole2_mask.npy'
+polarhole3_fname = 'polarhole3_mask.npy'
+
+# Final month that each of the polar holes apply
+# NOTE: 1st of the month chosen arbitrarily throughout as always working wit
+# monthly averages
+polarhole1_final_date = pd.Timestamp('1987-06-01') # 1987 June
+polarhole2_final_date = pd.Timestamp('2005-10-01') # 2005 Oct
+polarhole3_final_date = pd.Timestamp('2015-12-01') # 2015 Dec
+
+missing_dates = [pd.Timestamp('1986-4-1'), pd.Timestamp('1986-5-1'),
+ pd.Timestamp('1986-6-1'), pd.Timestamp('1987-12-1')]
+
+###############################################################################
+# Weights and biases config (https://docs.wandb.ai/guides/track/advanced/environment-variables)
+###############################################################################
+
+# Get API key from https://wandb.ai/authorize
+WANDB_API_KEY = 'YOUR-KEY-HERE'
+# Absolute path to store wandb generated files (folder must exist)
+# Note: user must have write access
+WANDB_DIR = '/path/to/wandb/dir'
+# Absolute path to wandb config dir (
+WANDB_CONFIG_DIR = '/path/to/wandb/config/dir'
+WANDB_CACHE_DIR = '/path/to/wandb/cache/dir'
+
+###############################################################################
+# ECMWF details
+###############################################################################
+
+ECMWF_API_KEY = 'YOUR-KEY-HERE'
+ECMWF_API_EMAIL = 'YOUR-KEY-HERE'
diff --git a/tools/forecast/forecast.py b/tools/forecast/forecast.py
new file mode 100644
index 0000000..16036a1
--- /dev/null
+++ b/tools/forecast/forecast.py
@@ -0,0 +1,145 @@
+"""
+Code taken from https://github.com/tom-andersson/icenet-paper and slightly adjusted
+to fit the galaxy interface.
+"""
+import os
+import sys
+import argparse
+from utils import IceNetDataLoader
+import pandas as pd
+import xarray as xr
+import numpy as np
+from tqdm import tqdm
+import re
+from tensorflow.keras.models import load_model
+import config
+sys.path.insert(0, os.path.join(os.getcwd(), 'icenet'))
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--config", type=str, help="config file")
+parser.add_argument("--models", type=str, help="network models")
+parser.add_argument("--siconca", type=str, help="siconca netcdf file")
+parser.add_argument("--forecast_start", type=str, help="forecast start date")
+parser.add_argument("--forecast_end", type=str, help="forecast end date")
+args = parser.parse_args()
+
+# Load dataloader
+dataloader_ID = '2021_09_03_1300_icenet_demo'
+dataloader_config_fpath = args.config
+
+# Data loader
+# print("\nSetting up the data loader with config file: {}\n\n".format(dataloader_ID))
+dataloader = IceNetDataLoader(dataloader_config_fpath)
+print('\n\nDone.\n')
+
+# load networks
+network_regex = re.compile('^network_tempscaled_([0-9]*).h5$')
+
+network_fpaths = args.models.split(",")
+
+# ensemble_seeds = [36, 42, 53]
+ensemble_seeds = [network_regex.match(f)[1] for f in
+ ["network_tempscaled_36.h5", "network_tempscaled_42.h5", "network_tempscaled_53.h5"] if network_regex.match(f)]
+print(ensemble_seeds)
+networks = []
+for network_fpath in network_fpaths:
+ print('Loading model from {}... '.format(network_fpath), end='', flush=True)
+ networks.append(load_model(network_fpath, compile=False))
+ print('Done.')
+
+model = 'IceNet'
+
+forecast_start = pd.Timestamp(args.forecast_start)
+forecast_end = pd.Timestamp(args.forecast_end)
+
+n_forecast_months = dataloader.config['n_forecast_months']
+
+
+forecast_folder = os.path.join(config.forecast_data_folder, 'icenet', dataloader_ID, model)
+
+if not os.path.exists(forecast_folder):
+ os.makedirs(forecast_folder)
+
+# load ground truth
+print('Loading ground truth SIC... ', end='', flush=True)
+true_sic_fpath = args.siconca
+true_sic_da = xr.open_dataarray(true_sic_fpath)
+print('Done.')
+
+
+# set up forecast folder
+
+# define list of lead times
+leadtimes = np.arange(1, n_forecast_months + 1)
+
+# add ensemble to the list of models
+ensemble_seeds_and_mean = ensemble_seeds.copy()
+ensemble_seeds_and_mean.append('ensemble')
+
+all_target_dates = pd.date_range(
+ start=forecast_start,
+ end=forecast_end,
+ freq='MS'
+)
+
+all_start_dates = pd.date_range(
+ start=forecast_start - pd.DateOffset(months=n_forecast_months - 1),
+ end=forecast_end,
+ freq='MS'
+)
+
+shape = (len(all_target_dates),
+ *dataloader.config['raw_data_shape'],
+ n_forecast_months)
+
+coords = {
+ 'time': all_target_dates, # To be sliced to target dates
+ 'yc': true_sic_da.coords['yc'],
+ 'xc': true_sic_da.coords['xc'],
+ 'lon': true_sic_da.isel(time=0).coords['lon'],
+ 'lat': true_sic_da.isel(time=0).coords['lat'],
+ 'leadtime': leadtimes,
+ 'seed': ensemble_seeds_and_mean,
+ 'ice_class': ['no_ice', 'marginal_ice', 'full_ice']
+}
+
+# Probabilistic SIC class forecasts
+dims = ('seed', 'time', 'yc', 'xc', 'leadtime', 'ice_class')
+shape = (len(ensemble_seeds_and_mean), *shape, 3)
+print(dims)
+print(shape)
+model_forecast = xr.DataArray(
+ data=np.zeros(shape, dtype=np.float32),
+ coords=coords,
+ dims=dims
+)
+
+for start_date in tqdm(all_start_dates):
+
+ # Target forecast dates for the forecast beginning at this `start_date`
+ target_dates = pd.date_range(
+ start=start_date,
+ end=start_date + pd.DateOffset(months=n_forecast_months - 1),
+ freq='MS'
+ )
+
+ X, y, sample_weights = dataloader.data_generation([start_date])
+ mask = sample_weights > 0
+ pred = np.array([network.predict(X)[0] for network in networks])
+ pred *= mask # mask outside active grid cell region to zero
+ # concat ensemble mean to the set of network predictions
+ ensemble_mean_pred = pred.mean(axis=0, keepdims=True)
+ pred = np.concatenate([pred, ensemble_mean_pred], axis=0)
+
+ for i, (target_date, leadtime) in enumerate(zip(target_dates, leadtimes)):
+ if target_date in all_target_dates:
+ model_forecast.\
+ loc[:, target_date, :, :, leadtime] = pred[..., i]
+
+print('Saving forecast NetCDF for {}... '.format(model), end='', flush=True)
+
+forecast_fpath = os.path.join(forecast_folder, f'{model.lower()}_forecasts.nc'.format(model.lower()))
+model_forecast.to_netcdf(forecast_fpath) # export file as Net
+
+print('Done.')
diff --git a/tools/forecast/forecast.xml b/tools/forecast/forecast.xml
new file mode 100644
index 0000000..1746b7a
--- /dev/null
+++ b/tools/forecast/forecast.xml
@@ -0,0 +1,419 @@
+
+ for forecasting sea ice concentration with IceNet
+
+ python
+ xarray
+ numpy
+ pandas
+ scipy
+ iris
+ netcdf4
+ imageio
+ matplotlib
+ tqdm
+ cdsapi
+ tensorflow
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 10.5281/zenodo.5176573
+
+
+
diff --git a/tools/forecast/masks/.listing b/tools/forecast/masks/.listing
new file mode 100644
index 0000000..738cb2b
--- /dev/null
+++ b/tools/forecast/masks/.listing
@@ -0,0 +1,32 @@
+drwxr-xr-x 2 ftp ftp 0 May 09 2017 .
+drwxr-xr-x 2 ftp ftp 0 May 09 2017 ..
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901021200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901041200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901061200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901081200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901101200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901121200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901141200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901161200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901181200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901201200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901221200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901241200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901261200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901281200.nc
+-rwxr-xr-x 1 ftp ftp 9856120 May 09 2017 ice_conc_nh_ease2-250_cdr-v2p0_197901301200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901021200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901041200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901061200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901081200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901101200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901121200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901141200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901161200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901181200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901201200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901221200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901241200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901261200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901281200.nc
+-rwxr-xr-x 1 ftp ftp 9856141 Jun 08 2022 ice_conc_sh_ease2-250_cdr-v2p0_197901301200.nc
diff --git a/tools/forecast/masks/active_grid_cell_mask_01.npy b/tools/forecast/masks/active_grid_cell_mask_01.npy
new file mode 100644
index 0000000..75d5412
Binary files /dev/null and b/tools/forecast/masks/active_grid_cell_mask_01.npy differ
diff --git a/tools/forecast/masks/active_grid_cell_mask_02.npy b/tools/forecast/masks/active_grid_cell_mask_02.npy
new file mode 100644
index 0000000..a5b63e8
Binary files /dev/null and b/tools/forecast/masks/active_grid_cell_mask_02.npy differ
diff --git a/tools/forecast/masks/active_grid_cell_mask_03.npy b/tools/forecast/masks/active_grid_cell_mask_03.npy
new file mode 100644
index 0000000..7b1f255
Binary files /dev/null and b/tools/forecast/masks/active_grid_cell_mask_03.npy differ
diff --git a/tools/forecast/masks/active_grid_cell_mask_04.npy b/tools/forecast/masks/active_grid_cell_mask_04.npy
new file mode 100644
index 0000000..c233eb6
Binary files /dev/null and b/tools/forecast/masks/active_grid_cell_mask_04.npy differ
diff --git a/tools/forecast/masks/active_grid_cell_mask_05.npy b/tools/forecast/masks/active_grid_cell_mask_05.npy
new file mode 100644
index 0000000..afb4ceb
Binary files /dev/null and b/tools/forecast/masks/active_grid_cell_mask_05.npy differ
diff --git a/tools/forecast/masks/active_grid_cell_mask_06.npy b/tools/forecast/masks/active_grid_cell_mask_06.npy
new file mode 100644
index 0000000..78c5641
Binary files /dev/null and b/tools/forecast/masks/active_grid_cell_mask_06.npy differ
diff --git a/tools/forecast/masks/active_grid_cell_mask_07.npy b/tools/forecast/masks/active_grid_cell_mask_07.npy
new file mode 100644
index 0000000..7f322d3
Binary files /dev/null and b/tools/forecast/masks/active_grid_cell_mask_07.npy differ
diff --git a/tools/forecast/masks/active_grid_cell_mask_08.npy b/tools/forecast/masks/active_grid_cell_mask_08.npy
new file mode 100644
index 0000000..137838e
Binary files /dev/null and b/tools/forecast/masks/active_grid_cell_mask_08.npy differ
diff --git a/tools/forecast/masks/active_grid_cell_mask_09.npy b/tools/forecast/masks/active_grid_cell_mask_09.npy
new file mode 100644
index 0000000..621c997
Binary files /dev/null and b/tools/forecast/masks/active_grid_cell_mask_09.npy differ
diff --git a/tools/forecast/masks/active_grid_cell_mask_10.npy b/tools/forecast/masks/active_grid_cell_mask_10.npy
new file mode 100644
index 0000000..c55090f
Binary files /dev/null and b/tools/forecast/masks/active_grid_cell_mask_10.npy differ
diff --git a/tools/forecast/masks/active_grid_cell_mask_11.npy b/tools/forecast/masks/active_grid_cell_mask_11.npy
new file mode 100644
index 0000000..ecd33e5
Binary files /dev/null and b/tools/forecast/masks/active_grid_cell_mask_11.npy differ
diff --git a/tools/forecast/masks/active_grid_cell_mask_12.npy b/tools/forecast/masks/active_grid_cell_mask_12.npy
new file mode 100644
index 0000000..193ff1b
Binary files /dev/null and b/tools/forecast/masks/active_grid_cell_mask_12.npy differ
diff --git a/tools/forecast/masks/land_mask.npy b/tools/forecast/masks/land_mask.npy
new file mode 100644
index 0000000..a79ad9b
Binary files /dev/null and b/tools/forecast/masks/land_mask.npy differ
diff --git a/tools/forecast/masks/polarhole1_mask.npy b/tools/forecast/masks/polarhole1_mask.npy
new file mode 100644
index 0000000..e0f9251
Binary files /dev/null and b/tools/forecast/masks/polarhole1_mask.npy differ
diff --git a/tools/forecast/masks/polarhole2_mask.npy b/tools/forecast/masks/polarhole2_mask.npy
new file mode 100644
index 0000000..f5db999
Binary files /dev/null and b/tools/forecast/masks/polarhole2_mask.npy differ
diff --git a/tools/forecast/masks/polarhole3_mask.npy b/tools/forecast/masks/polarhole3_mask.npy
new file mode 100644
index 0000000..3322af1
Binary files /dev/null and b/tools/forecast/masks/polarhole3_mask.npy differ
diff --git a/tools/forecast/masks/region_mask.npy b/tools/forecast/masks/region_mask.npy
new file mode 100644
index 0000000..e3d1d46
Binary files /dev/null and b/tools/forecast/masks/region_mask.npy differ
diff --git a/tools/forecast/meta/cos_month_01.npy b/tools/forecast/meta/cos_month_01.npy
new file mode 100644
index 0000000..17db7aa
Binary files /dev/null and b/tools/forecast/meta/cos_month_01.npy differ
diff --git a/tools/forecast/meta/cos_month_02.npy b/tools/forecast/meta/cos_month_02.npy
new file mode 100644
index 0000000..910d0ac
Binary files /dev/null and b/tools/forecast/meta/cos_month_02.npy differ
diff --git a/tools/forecast/meta/cos_month_03.npy b/tools/forecast/meta/cos_month_03.npy
new file mode 100644
index 0000000..da3cfc2
Binary files /dev/null and b/tools/forecast/meta/cos_month_03.npy differ
diff --git a/tools/forecast/meta/cos_month_04.npy b/tools/forecast/meta/cos_month_04.npy
new file mode 100644
index 0000000..3cf8d50
Binary files /dev/null and b/tools/forecast/meta/cos_month_04.npy differ
diff --git a/tools/forecast/meta/cos_month_05.npy b/tools/forecast/meta/cos_month_05.npy
new file mode 100644
index 0000000..4bd371c
Binary files /dev/null and b/tools/forecast/meta/cos_month_05.npy differ
diff --git a/tools/forecast/meta/cos_month_06.npy b/tools/forecast/meta/cos_month_06.npy
new file mode 100644
index 0000000..3e502be
Binary files /dev/null and b/tools/forecast/meta/cos_month_06.npy differ
diff --git a/tools/forecast/meta/cos_month_07.npy b/tools/forecast/meta/cos_month_07.npy
new file mode 100644
index 0000000..4bd371c
Binary files /dev/null and b/tools/forecast/meta/cos_month_07.npy differ
diff --git a/tools/forecast/meta/cos_month_08.npy b/tools/forecast/meta/cos_month_08.npy
new file mode 100644
index 0000000..babc209
Binary files /dev/null and b/tools/forecast/meta/cos_month_08.npy differ
diff --git a/tools/forecast/meta/cos_month_09.npy b/tools/forecast/meta/cos_month_09.npy
new file mode 100644
index 0000000..6eb0ef0
Binary files /dev/null and b/tools/forecast/meta/cos_month_09.npy differ
diff --git a/tools/forecast/meta/cos_month_10.npy b/tools/forecast/meta/cos_month_10.npy
new file mode 100644
index 0000000..7e1463f
Binary files /dev/null and b/tools/forecast/meta/cos_month_10.npy differ
diff --git a/tools/forecast/meta/cos_month_11.npy b/tools/forecast/meta/cos_month_11.npy
new file mode 100644
index 0000000..ef78f85
Binary files /dev/null and b/tools/forecast/meta/cos_month_11.npy differ
diff --git a/tools/forecast/meta/cos_month_12.npy b/tools/forecast/meta/cos_month_12.npy
new file mode 100644
index 0000000..f9688e6
Binary files /dev/null and b/tools/forecast/meta/cos_month_12.npy differ
diff --git a/tools/forecast/meta/land.npy b/tools/forecast/meta/land.npy
new file mode 100644
index 0000000..580dd88
Binary files /dev/null and b/tools/forecast/meta/land.npy differ
diff --git a/tools/forecast/meta/sin_month_01.npy b/tools/forecast/meta/sin_month_01.npy
new file mode 100644
index 0000000..ca736b5
Binary files /dev/null and b/tools/forecast/meta/sin_month_01.npy differ
diff --git a/tools/forecast/meta/sin_month_02.npy b/tools/forecast/meta/sin_month_02.npy
new file mode 100644
index 0000000..5472b52
Binary files /dev/null and b/tools/forecast/meta/sin_month_02.npy differ
diff --git a/tools/forecast/meta/sin_month_03.npy b/tools/forecast/meta/sin_month_03.npy
new file mode 100644
index 0000000..f9688e6
Binary files /dev/null and b/tools/forecast/meta/sin_month_03.npy differ
diff --git a/tools/forecast/meta/sin_month_04.npy b/tools/forecast/meta/sin_month_04.npy
new file mode 100644
index 0000000..17db7aa
Binary files /dev/null and b/tools/forecast/meta/sin_month_04.npy differ
diff --git a/tools/forecast/meta/sin_month_05.npy b/tools/forecast/meta/sin_month_05.npy
new file mode 100644
index 0000000..5197c70
Binary files /dev/null and b/tools/forecast/meta/sin_month_05.npy differ
diff --git a/tools/forecast/meta/sin_month_06.npy b/tools/forecast/meta/sin_month_06.npy
new file mode 100644
index 0000000..1315ee1
Binary files /dev/null and b/tools/forecast/meta/sin_month_06.npy differ
diff --git a/tools/forecast/meta/sin_month_07.npy b/tools/forecast/meta/sin_month_07.npy
new file mode 100644
index 0000000..4ac3b1c
Binary files /dev/null and b/tools/forecast/meta/sin_month_07.npy differ
diff --git a/tools/forecast/meta/sin_month_08.npy b/tools/forecast/meta/sin_month_08.npy
new file mode 100644
index 0000000..37bf72b
Binary files /dev/null and b/tools/forecast/meta/sin_month_08.npy differ
diff --git a/tools/forecast/meta/sin_month_09.npy b/tools/forecast/meta/sin_month_09.npy
new file mode 100644
index 0000000..3e502be
Binary files /dev/null and b/tools/forecast/meta/sin_month_09.npy differ
diff --git a/tools/forecast/meta/sin_month_10.npy b/tools/forecast/meta/sin_month_10.npy
new file mode 100644
index 0000000..37bf72b
Binary files /dev/null and b/tools/forecast/meta/sin_month_10.npy differ
diff --git a/tools/forecast/meta/sin_month_11.npy b/tools/forecast/meta/sin_month_11.npy
new file mode 100644
index 0000000..e7cc0f5
Binary files /dev/null and b/tools/forecast/meta/sin_month_11.npy differ
diff --git a/tools/forecast/meta/sin_month_12.npy b/tools/forecast/meta/sin_month_12.npy
new file mode 100644
index 0000000..45bbaf0
Binary files /dev/null and b/tools/forecast/meta/sin_month_12.npy differ
diff --git a/tools/forecast/models.py b/tools/forecast/models.py
new file mode 100644
index 0000000..a3c59c8
--- /dev/null
+++ b/tools/forecast/models.py
@@ -0,0 +1,190 @@
+"""
+Code taken from https://github.com/tom-andersson/icenet-paper and slightly adjusted
+to fit the galaxy interface.
+"""
+import sys
+import os
+import config
+import numpy as np
+import pandas as pd
+import xarray as xr
+import tensorflow as tf
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import Conv2D, BatchNormalization, UpSampling2D, \
+ concatenate, MaxPooling2D, Input
+from tensorflow.keras.optimizers import Adam
+sys.path.insert(0, os.path.join(os.getcwd(), 'icenet')) # if using jupyter kernel
+'''
+Defines the Python-based sea ice forecasting models, such as the IceNet architecture
+and the linear trend extrapolation model.
+'''
+
+# Custom layers:
+# --------------------------------------------------------------------
+
+
+@tf.keras.utils.register_keras_serializable()
+class TemperatureScale(tf.keras.layers.Layer):
+ '''
+ Implements the temperature scaling layer for probability calibration,
+ as introduced in Guo 2017 (http://proceedings.mlr.press/v70/guo17a.html).
+ '''
+ def __init__(self, **kwargs):
+ super(TemperatureScale, self).__init__()
+ self.temp = tf.Variable(initial_value=1.0, trainable=False,
+ dtype=tf.float32, name='temp')
+
+ def call(self, inputs):
+ ''' Divide the input logits by the T value. '''
+ return tf.divide(inputs, self.temp)
+
+ def get_config(self):
+ ''' For saving and loading networks with this custom layer. '''
+ return {'temp': self.temp.numpy()}
+
+
+# Network architectures:
+# --------------------------------------------------------------------
+
+def unet_batchnorm(input_shape, loss, weighted_metrics, learning_rate=1e-4, filter_size=3,
+ n_filters_factor=1, n_forecast_months=1, use_temp_scaling=False,
+ n_output_classes=3,
+ **kwargs):
+ inputs = Input(shape=input_shape)
+
+ conv1 = Conv2D(np.int(64 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
+ conv1 = Conv2D(np.int(64 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
+ bn1 = BatchNormalization(axis=-1)(conv1)
+ pool1 = MaxPooling2D(pool_size=(2, 2))(bn1)
+
+ conv2 = Conv2D(np.int(128 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
+ conv2 = Conv2D(np.int(128 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
+ bn2 = BatchNormalization(axis=-1)(conv2)
+ pool2 = MaxPooling2D(pool_size=(2, 2))(bn2)
+
+ conv3 = Conv2D(np.int(256 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
+ conv3 = Conv2D(np.int(256 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
+ bn3 = BatchNormalization(axis=-1)(conv3)
+ pool3 = MaxPooling2D(pool_size=(2, 2))(bn3)
+
+ conv4 = Conv2D(np.int(256 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
+ conv4 = Conv2D(np.int(256 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
+ bn4 = BatchNormalization(axis=-1)(conv4)
+ pool4 = MaxPooling2D(pool_size=(2, 2))(bn4)
+
+ conv5 = Conv2D(np.int(512 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
+ conv5 = Conv2D(np.int(512 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
+ bn5 = BatchNormalization(axis=-1)(conv5)
+
+ up6 = Conv2D(np.int(256 * n_filters_factor), 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2), interpolation='nearest')(bn5))
+ merge6 = concatenate([bn4, up6], axis=3)
+ conv6 = Conv2D(np.int(256 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
+ conv6 = Conv2D(np.int(256 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
+ bn6 = BatchNormalization(axis=-1)(conv6)
+
+ up7 = Conv2D(np.int(256 * n_filters_factor), 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2), interpolation='nearest')(bn6))
+ merge7 = concatenate([bn3, up7], axis=3)
+ conv7 = Conv2D(np.int(256 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
+ conv7 = Conv2D(np.int(256 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
+ bn7 = BatchNormalization(axis=-1)(conv7)
+
+ up8 = Conv2D(np.int(128 * n_filters_factor), 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2), interpolation='nearest')(bn7))
+ merge8 = concatenate([bn2, up8], axis=3)
+ conv8 = Conv2D(np.int(128 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
+ conv8 = Conv2D(np.int(128 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)
+ bn8 = BatchNormalization(axis=-1)(conv8)
+
+ up9 = Conv2D(np.int(64 * n_filters_factor), 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2), interpolation='nearest')(bn8))
+ merge9 = concatenate([conv1, up9], axis=3)
+ conv9 = Conv2D(np.int(64 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
+ conv9 = Conv2D(np.int(64 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
+ conv9 = Conv2D(np.int(64 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
+
+ final_layer_logits = [(Conv2D(n_output_classes, 1, activation='linear')(conv9)) for i in range(n_forecast_months)]
+ final_layer_logits = tf.stack(final_layer_logits, axis=-1)
+
+ if use_temp_scaling:
+ # Temperature scaling of the logits
+ final_layer_logits_scaled = TemperatureScale()(final_layer_logits)
+ final_layer = tf.nn.softmax(final_layer_logits_scaled, axis=-2)
+ else:
+ final_layer = tf.nn.softmax(final_layer_logits, axis=-2)
+
+ model = Model(inputs, final_layer)
+
+ model.compile(optimizer=Adam(lr=learning_rate), loss=loss, weighted_metrics=weighted_metrics)
+
+ return model
+
+
+# Benchmark models:
+# --------------------------------------------------------------------
+
+
+def linear_trend_forecast(forecast_month, n_linear_years='all', da=None, dataset='obs'):
+ '''
+ Returns a simple sea ice forecast based on a gridcell-wise linear extrapolation.
+
+ Parameters:
+ forecast_month (datetime.datetime): The month to forecast
+
+ n_linear_years (int or str): Number of past years to use for linear trend
+ extrapolation.
+
+ da (xr.DataArray): xarray data array to use instead of observational
+ data (used for setting up CMIP6 pre-training linear trend inputs in IceUNetDataPreProcessor).
+
+ dataset (str): 'obs' or 'cmip6'. If 'obs', missing observational SIC months
+ will be skipped
+
+ Returns:
+ output_map (np.ndarray): The output SIC map predicted
+ by fitting a least squares linear trend to the past n_linear_years
+ for the month being predicted.
+
+ sie (np.float): The predicted sea ice extend (SIE).
+ '''
+
+ if da is None:
+ with xr.open_dataset('data/obs/siconca_EASE.nc') as ds:
+ da = next(iter(ds.data_vars.values()))
+
+ valid_dates = [pd.Timestamp(date) for date in da.time.values]
+
+ input_dates = [forecast_month - pd.DateOffset(years=1 + lag) for lag in range(n_linear_years)]
+ input_dates
+
+ # Do not use missing months in the linear trend projection
+ input_dates = [date for date in input_dates if date not in config.missing_dates]
+
+ # Chop off input date from before data start
+ input_dates = [date for date in input_dates if date in valid_dates]
+
+ input_dates = sorted(input_dates)
+
+ # The actual number of past years used
+ actual_n_linear_years = len(input_dates)
+
+ da = da.sel(time=input_dates)
+
+ input_maps = np.array(da.data)
+
+ x = np.arange(actual_n_linear_years)
+ y = input_maps.reshape(actual_n_linear_years, -1)
+
+ # Fit the least squares linear coefficients
+ r = np.linalg.lstsq(np.c_[x, np.ones_like(x)], y, rcond=None)[0]
+
+ # y = mx + c
+ output_map = np.matmul(np.array([actual_n_linear_years, 1]), r).reshape(432, 432)
+
+ land_mask_path = os.path.join(config.mask_data_folder, config.land_mask_filename)
+ land_mask = np.load(land_mask_path)
+ output_map[land_mask] = 0.
+
+ output_map[output_map < 0] = 0.
+ output_map[output_map > 1] = 1.
+
+ sie = np.sum(output_map > 0.15) * 25**2
+
+ return output_map, sie
diff --git a/tools/forecast/utils.py b/tools/forecast/utils.py
new file mode 100644
index 0000000..6787625
--- /dev/null
+++ b/tools/forecast/utils.py
@@ -0,0 +1,1955 @@
+"""
+Code taken from https://github.com/tom-andersson/icenet-paper and slightly adjusted
+to fit the galaxy interface.
+"""
+import os
+import sys
+import numpy as np
+import tensorflow as tf
+from models import linear_trend_forecast
+import config
+import itertools
+import requests
+import json
+import time
+import re
+import xarray as xr
+import pandas as pd
+from dateutil.relativedelta import relativedelta
+import iris
+import cartopy.crs as ccrs
+import matplotlib.pyplot as plt
+from mpl_toolkits.axes_grid1 import make_axes_locatable
+import imageio
+from tqdm import tqdm
+sys.path.insert(0, os.path.join(os.getcwd(), 'icenet')) # if using jupyter kernel
+
+
+###############################################################################
+# DATA PROCESSING & LOADING
+###############################################################################
+
+
+class IceNetDataPreProcessor(object):
+ """
+ Normalises IceNet input data and saves the normalised monthly averages
+ as .npy files. If preprocessing climate model data for transfer learning,
+ the observational normalisation is repeated for the climate model data in order
+ to preserve the mapping from raw values to normalised values.
+
+ Data is stored in the following form with observations separated from climate
+ model transfer learning data:
+ - data/network_datasets//obs/tas/2006_04.npy
+ - data/network_datasets//transfer/MRI-ESM2-0/r1i1p1f1/tas/2056_09.npy
+
+ Normalisation parameters computed over the observational training data are
+ stored in a JSON file at data/network_datasets//norm_params.json
+ so that they are only computed once. Similarly, monthly climatology fields
+ used for computing anomaly fields are saved next to the raw NetCDF files so that
+ climatologies are only computed once for each variable.
+ """
+
+ def __init__(self, dataloader_config_fpath, preproc_vars,
+ n_linear_years, minmax, verbose_level,
+ preproc_obs_data=True,
+ preproc_transfer_data=False, cmip_transfer_data={}):
+ """
+ Parameters:
+
+ dataloader_config_fpath (str): Path to the data loader configuration
+ settings JSON file, defining IceNet's input-output data configuration.
+ This also defines the dataset name, used as the folder name to
+ store the preprocessed data within data/network_datasets/.
+
+ preproc_vars (dict): Which variables to preprocess. Example:
+
+ preproc_vars = {
+ 'siconca': {'anom': True, 'abs': True},
+ 'tas': {'anom': True, 'abs': False},
+ 'tos': {'anom': True, 'abs': False},
+ 'rsds': {'anom': True, 'abs': False},
+ 'rsus': {'anom': True, 'abs': False},
+ 'psl': {'anom': False, 'abs': True},
+ 'zg500': {'anom': False, 'abs': True},
+ 'zg250': {'anom': False, 'abs': True},
+ 'ua10': {'anom': False, 'abs': True},
+ 'uas': {'anom': False, 'abs': True},
+ 'vas': {'anom': False, 'abs': True},
+ 'sfcWind': {'anom': False, 'abs': True},
+ 'land': {'metadata': True, 'include': True},
+ 'circmonth': {'metadata': True, 'include': True}
+ }
+
+ n_linear_years (int): Number of past years to used in the linear trend
+ projections.
+
+ minmax (bool): Whether to use min-max normalisation to (-1, 1) or normalise
+ the mean and standard deviation to 0 and 1.
+
+ verbose_level (int): Controls how much to print. 0: Print nothing.
+ 1: Print key set-up stages. 2: Print debugging info.
+
+ preproc_obs_data (bool): Whether to preprocess observational data
+ (default True).
+
+ preproc_transfer_data (bool): Whether to also preprocess CMIP6 data for each variable
+ (default False).
+
+ cmip_transfer_data (dict): Which CMIP6 models & model runs to
+ preprocess for transfer learning. Example:
+
+ cmip_transfer_data = {
+ 'MRI-ESM2-0': ('r1i1p1f1', 'r2i1p1f1', 'r3i1p1f1',
+ 'r4i1p1f1', 'r5i1p1f1')
+ }
+
+ """
+
+ with open(dataloader_config_fpath, 'r') as readfile:
+ self.config = json.load(readfile)
+
+ self.preproc_vars = preproc_vars
+ self.n_linear_years = n_linear_years
+ self.minmax = minmax
+ self.verbose_level = verbose_level
+ self.preproc_obs_data = preproc_obs_data
+ self.preproc_transfer_data = preproc_transfer_data
+ self.cmip_transfer_data = cmip_transfer_data
+
+ self.load_or_instantiate_norm_params_dict()
+ self.set_obs_train_dates()
+ self.set_up_folder_hierarchy()
+
+ if self.verbose_level >= 1:
+ print("Loading and normalising the raw input maps.\n")
+ tic = time.time()
+
+ self.preproc_and_save_icenet_data()
+
+ if self.verbose_level >= 1:
+ print("\nPreprocessing completed in {:.0f}s.\n".format(time.time() - tic))
+
+ def load_or_instantiate_norm_params_dict(self):
+
+ # Path to JSON file storing normalisation parameters for each variable
+ self.norm_params_fpath = os.path.join(
+ config.network_dataset_folder, self.config['dataset_name'], 'norm_params.json')
+
+ if not os.path.exists(self.norm_params_fpath):
+ self.norm_params = {}
+
+ else:
+ with open(self.norm_params_fpath, 'r') as readfile:
+ self.norm_params = json.load(readfile)
+
+ def set_obs_train_dates(self):
+
+ forecast_start_date_ends = self.config['sample_IDs']['obs_train_dates']
+
+ if forecast_start_date_ends is not None:
+
+ # Convert to Pandas Timestamps
+ forecast_start_date_ends = [
+ pd.Timestamp(date).to_pydatetime() for date in forecast_start_date_ends
+ ]
+
+ self.obs_train_dates = list(pd.date_range(
+ forecast_start_date_ends[0],
+ forecast_start_date_ends[1],
+ freq='MS',
+ closed='right',
+ ))
+
+ def set_up_folder_hierarchy(self):
+
+ """
+ Initialise the folders to store the datasets.
+ """
+
+ if self.verbose_level >= 1:
+ print('Setting up the folder hierarchy for {}... '.format(self.config['dataset_name']),
+ end='', flush=True)
+
+ # Parent folder for this dataset
+ self.dataset_path = os.path.join(config.data_folder, 'network_datasets', self.config['dataset_name'])
+
+ # Dictionary data structure to store folder paths
+ self.paths = {}
+
+ # Set up the folder hierarchy
+ self.paths['obs'] = {}
+
+ for varname, vardict in self.preproc_vars.items():
+
+ if 'metadata' not in vardict.keys():
+ self.paths['obs'][varname] = {}
+
+ for data_format in vardict.keys():
+
+ if vardict[data_format] is True:
+ path = os.path.join(self.dataset_path, 'obs',
+ varname, data_format)
+
+ self.paths['obs'][varname][data_format] = path
+
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ self.paths['transfer'] = {}
+
+ for model_name, member_ids in self.cmip_transfer_data.items():
+ self.paths['transfer'][model_name] = {}
+ for member_id in member_ids:
+ self.paths['transfer'][model_name][member_id] = {}
+
+ for varname, vardict in self.preproc_vars.items():
+
+ if 'metadata' not in vardict.keys():
+ self.paths['transfer'][model_name][member_id][varname] = {}
+
+ for data_format in vardict.keys():
+
+ if vardict[data_format] is True:
+ path = os.path.join(self.dataset_path, 'transfer',
+ model_name, member_id,
+ varname, data_format)
+
+ self.paths['transfer'][model_name][member_id][varname][data_format] = path
+
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ for varname, vardict in self.preproc_vars.items():
+ if 'metadata' in vardict.keys():
+
+ if vardict['include'] is True:
+ path = os.path.join(self.dataset_path, 'meta')
+
+ self.paths['meta'] = path
+
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ if self.verbose_level >= 1:
+ print('Done.')
+
+ @staticmethod
+ def standardise_cmip6_time_coord(da):
+
+ """
+ Convert the cmip6 xarray time dimension to use day=1, hour=0 convention
+ used in the rest of the project.
+ """
+
+ standardised_dates = []
+ for datetime64 in da.time.values:
+ date = pd.Timestamp(datetime64, unit='s')
+ date = date.replace(day=1, hour=0)
+ standardised_dates.append(date)
+ da = da.assign_coords({'time': standardised_dates})
+
+ return da
+
+ @staticmethod
+ def mean_and_std(list, verbose_level=2):
+
+ # Must use float64s to be JSON serialisable
+ mean = np.nanmean(list, dtype=np.float64)
+ std = np.nanstd(list, dtype=np.float64)
+
+ return mean, std
+
+ def normalise_array_using_all_training_months(self, da, minmax=False,
+ mean=None, std=None,
+ min=None, max=None):
+
+ """
+ Using the *training* months only, compute the mean and
+ standard deviation of the input raw satellite DataArray (`da`)
+ and return a normalised version. If minmax=True,
+ instead normalise to lie between min and max of the elements of `array`.
+
+ If min, max, mean, or std are given values other than None,
+ those values are used rather than being computed from the training months.
+
+ Returns:
+ new_da (xarray.DataArray): Normalised array.
+
+ mean, std (float): Pre-computed mean and standard deviation for the
+ normalisation.
+
+ min, max (float): Pre-computed min and max for the normalisation.
+ """
+
+ if (min is not None and max is not None) or (mean is not None and std is not None):
+ # Function has been passed precomputed normalisation parameters
+ pass
+ else:
+ # Function will be computing new normalisation parameters
+ training_samples = da.sel(time=self.obs_train_dates).data
+ training_samples = training_samples.ravel()
+
+ if not minmax:
+ if mean is None and std is None:
+ # Compute mean and std
+ mean, std = IceNetDataPreProcessor.mean_and_std(
+ training_samples, self.verbose_level)
+ elif mean is not None and std is None:
+ # Compute std only
+ _, std = IceNetDataPreProcessor.mean_and_std(
+ training_samples, self.verbose_level)
+ elif mean is None and std is not None:
+ # Compute mean only
+ mean, _ = IceNetDataPreProcessor.mean_and_std(
+ training_samples, self.verbose_level)
+
+ new_da = (da - mean) / std
+
+ elif minmax:
+ if min is None:
+ # Compute min
+ min = np.nanmin(training_samples).astype(np.float64)
+ if max is None:
+ # Compute max
+ max = np.nanmax(training_samples).astype(np.float64)
+
+ new_da = (da - min) / (max - min)
+
+ if minmax:
+ return new_da, min, max
+ elif not minmax:
+ return new_da, mean, std
+
+ def save_xarray_in_monthly_averages(self, da, dataset_type, varname, data_format,
+ model_name=None, member_id=None):
+
+ """
+ Saves an xarray DataArray as monthly averaged .npy files using the
+ self.paths data structure.
+
+ Parameters:
+ da (xarray.DataArray): The DataArray to save.
+
+ dataset_type (str): Either 'obs' or 'transfer' (for CMIP6 data) - the type
+ of dataset being saved.
+
+ varname (str): Variable name being saved.
+
+ data_format (str): Either 'abs' or 'anom' - the format of the data
+ being saved.
+ """
+
+ if self.verbose_level >= 2:
+ print('Saving {} {} monthly averages... '.format(data_format, varname), end='', flush=True)
+
+ # Allow for datasets without a time dimension (a single time slice)
+ dates = da.time.values
+ if hasattr(dates, '__iter__'):
+ pass # Dataset has 'time' dimension; dates already iterable
+ else:
+ dates = [dates] # Convert single time value to iterable
+ da = da.expand_dims({'time': dates})
+
+ for date in dates:
+ slice = da.sel(time=date).data
+ date = pd.Timestamp(date)
+ year_str = '{:04d}'.format(date.year)
+ month_str = '{:02d}'.format(date.month)
+ fname = '{}_{}.npy'.format(year_str, month_str)
+
+ if dataset_type == 'obs':
+ np.save(os.path.join(self.paths[dataset_type][varname][data_format], fname),
+ slice)
+
+ if dataset_type == 'transfer':
+ np.save(os.path.join(self.paths[dataset_type][model_name][member_id][varname][data_format], fname),
+ slice)
+
+ if self.verbose_level >= 2:
+ print('Done.')
+
+ def build_linear_trend_da(self, input_da, dataset):
+
+ """
+ Construct a DataArray `linea_trend_da` containing the linear trend SIC
+ forecasts based on the input DataArray `input_da`.
+
+ `linear_trend_da` will be saved in monthly averages using
+ the `save_xarray_in_monthly_averages` method.
+
+ Parameters:
+ `input_da` (xarray.DataArray): Input DataArray to produce linear SIC
+ forecasts for.
+
+ `dataset` (str): 'obs' or 'cmip6' (dictates whether to skip missing
+ observational months in the linear trend extrapolation)
+
+ Returns:
+ `linear_trend_da` (xarray.DataArray): DataArray whose time slices
+ correspond to the linear trend SIC projection for that month.
+ """
+
+ linear_trend_da = input_da.copy(data=np.zeros(input_da.shape, dtype=np.float32))
+
+ # No prediction possible for the first year of data
+ forecast_dates = input_da.time.values[12:]
+
+ # Convert from datetime64 to pd.Timestamp
+ forecast_dates = [pd.Timestamp(date) for date in forecast_dates]
+
+ # Add on the future year
+ last_year = forecast_dates[-12:]
+ forecast_dates.extend([date + pd.DateOffset(years=1) for date in last_year])
+
+ linear_trend_da = linear_trend_da.assign_coords({'time': forecast_dates})
+
+ for forecast_date in forecast_dates:
+ linear_trend_da.loc[dict(time=forecast_date)] = \
+ linear_trend_forecast(forecast_date, self.n_linear_years, da=input_da, dataset=dataset)[0]
+
+ return linear_trend_da
+
+ def check_if_params_precomputed(self, varname, data_format):
+ ''' Searches self.norm_params for normalisation parameters
+ for a given variable name and data format. '''
+
+ if varname == 'siconca':
+ # No normalisation for SIC
+ return True
+
+ # Grab existing parameters if stored in norm_params JSON file
+ precomputed_params_exists = False
+ if varname in self.norm_params.keys():
+ if data_format in self.norm_params[varname].keys():
+ params = self.norm_params[varname][data_format]
+ if self.minmax:
+ if 'min' in params.keys() and 'max' in params.keys():
+ precomputed_params_exists = True
+ elif not self.minmax:
+ if 'mean' in params.keys() and 'std' in params.keys():
+ precomputed_params_exists = True
+
+ return precomputed_params_exists
+
+ def save_variable(self, varname, data_format, dates=None):
+
+ """
+ Save a normalised 3-dimensional satellite/reanalysis dataset as monthly
+ averages (either the absolute values or the monthly anomalies
+ computed with xarray).
+
+ This method assumes there is only one variable stored in the NetCDF files.
+
+ Parameters:
+ varname (str): Name of the variable to load & save
+
+ data_format (str): 'abs' for absolute values, or 'anom' to compute the
+ anomalies, or 'linear_trend' for SIC linear trend projections.
+
+ dates (list of dates): Months to use to compute the monthly
+ climatologies (defaults to the months used for training).
+ """
+
+ if data_format == 'anom':
+ if dates is None:
+ dates = self.obs_train_dates
+
+ ########################################################################
+ # Observational variable
+ ########################################################################
+
+ if self.preproc_obs_data:
+ if self.verbose_level >= 2:
+ print("Preprocessing {} data for {}... ".format(data_format, varname), end='', flush=True)
+ tic = time.time()
+
+ fpath = os.path.join(config.obs_data_folder, '{}_EASE.nc'.format(varname))
+ with xr.open_dataset(fpath) as ds:
+ da = next(iter(ds.data_vars.values()))
+
+ if data_format == 'anom':
+
+ # Check if climatology already computed
+ train_start = self.obs_train_dates[0].strftime('%Y')
+ train_end = self.obs_train_dates[-1].strftime('%Y')
+
+ climatology_fpath = os.path.join(
+ config.obs_data_folder,
+ '{}_climatology_{}_{}.nc'.format(varname, train_start, train_end))
+
+ if os.path.exists(climatology_fpath):
+ with xr.open_dataset(climatology_fpath) as ds:
+ climatology = next(iter(ds.data_vars.values()))
+ else:
+ climatology = da.sel(time=dates). \
+ groupby("time.month", restore_coord_dims=True).mean("time")
+ climatology.to_netcdf(climatology_fpath)
+
+ da = da.groupby("time.month", restore_coord_dims=True) - climatology
+
+ elif data_format == 'linear_trend':
+ da = self.build_linear_trend_da(da, dataset='obs')
+
+ # Realise the array
+ da.data = np.asarray(da.data, dtype=np.float32)
+
+ # Normalise the array
+ if varname == 'siconca':
+ # Don't normalise SIC - already betw 0 and 1
+ mean, std = None, None
+ min, max = None, None
+
+ elif varname != 'siconca':
+ precomputed_params_exists = self.check_if_params_precomputed(varname, data_format)
+
+ if precomputed_params_exists:
+ if self.minmax:
+ min = self.norm_params[varname][data_format]['min']
+ max = self.norm_params[varname][data_format]['max']
+ if self.verbose_level >= 2:
+ print("Using precomputed min/max: {}/{}... ".format(min, max),
+ end='', flush=True)
+ elif not self.minmax:
+ mean = self.norm_params[varname][data_format]['mean']
+ std = self.norm_params[varname][data_format]['std']
+ if self.verbose_level >= 2:
+ print("Using precomputed mean/std: {}/{}... ".format(mean, std),
+ end='', flush=True)
+ elif not precomputed_params_exists:
+ mean, std = None, None
+ min, max = None, None
+ self.norm_params[varname] = {}
+ self.norm_params[varname][data_format] = {}
+
+ if self.minmax:
+ da, min, max = self.normalise_array_using_all_training_months(
+ da, self.minmax, min=min, max=max)
+ if not precomputed_params_exists:
+ if self.verbose_level >= 2:
+ print("Newly computed min/max: {}/{}... ".format(min, max),
+ end='', flush=True)
+ self.norm_params[varname][data_format]['min'] = min
+ self.norm_params[varname][data_format]['max'] = max
+ elif not self.minmax:
+ da, mean, std = self.normalise_array_using_all_training_months(
+ da, self.minmax, mean=mean, std=std)
+ if not precomputed_params_exists:
+ if self.verbose_level >= 2:
+ print("Newly computed mean/std: {}/{}... ".format(mean, std),
+ end='', flush=True)
+ self.norm_params[varname][data_format]['mean'] = mean
+ self.norm_params[varname][data_format]['std'] = std
+
+ da.data[np.isnan(da.data)] = 0. # Convert any NaNs to zeros
+
+ self.save_xarray_in_monthly_averages(da, 'obs', varname, data_format)
+
+ if self.verbose_level >= 2:
+ print("Done in {:.0f}s.\n".format(time.time() - tic))
+
+ ########################################################################
+ # Transfer variable
+ ########################################################################
+
+ if self.preproc_transfer_data:
+ if self.verbose_level >= 2:
+ print("Preprocessing CMIP6 {} data for {}... ".format(data_format, varname), end='', flush=True)
+ tic = time.time()
+
+ if not self.check_if_params_precomputed(varname, data_format):
+ raise ValueError('Normalisation parameters must be computed '
+ 'from observational data before preprocessing '
+ 'CMIP6 data.')
+
+ elif varname != 'siconca' and self.minmax:
+ min = self.norm_params[varname][data_format]['min']
+ max = self.norm_params[varname][data_format]['max']
+ if self.verbose_level >= 2:
+ print("Using precomputed min/max: {}/{}... ".format(min, max),
+ end='', flush=True)
+
+ elif varname != 'siconca' and not self.minmax:
+ mean = self.norm_params[varname][data_format]['mean']
+ std = self.norm_params[varname][data_format]['std']
+ if self.verbose_level >= 2:
+ print("Using precomputed mean/std: {}/{}... ".format(mean, std),
+ end='', flush=True)
+
+ for model_name, member_ids in self.cmip_transfer_data.items():
+ print('{}: '.format(model_name), end='', flush=True)
+
+ for member_id in member_ids:
+ print('{}, '.format(member_id), end='', flush=True)
+
+ fname = '{}_EASE_cmpr.nc'.format(varname)
+ fpath = os.path.join(config.cmip6_data_folder, model_name, member_id, fname)
+
+ with xr.open_dataset(fpath) as ds:
+ da = next(iter(ds.data_vars.values()))
+
+ # Convert to my month convention of day=1 and time=00:00
+ da = IceNetDataPreProcessor.standardise_cmip6_time_coord(da)
+
+ # Realise the array
+ da.data = np.asarray(da.data, dtype=np.float32)
+
+ if data_format == 'anom':
+
+ climatology = da.sel(time=dates). \
+ groupby("time.month", restore_coord_dims=True).mean("time")
+ da = da.groupby("time.month", restore_coord_dims=True) - climatology
+
+ elif data_format == 'linear_trend':
+ da = self.build_linear_trend_da(da, dataset='cmip6')
+
+ # Normalise the array
+ if varname != 'siconca':
+ if self.minmax:
+ da, _, _ = self.normalise_array_using_all_training_months(
+ da, self.minmax, min=min, max=max)
+ elif not self.minmax:
+ da, _, _ = self.normalise_array_using_all_training_months(
+ da, self.minmax, mean=mean, std=std)
+
+ self.save_xarray_in_monthly_averages(da, 'transfer', varname, data_format,
+ model_name, member_id)
+
+ if self.verbose_level >= 2:
+ print("Done in {:.0f}s.\n".format(time.time() - tic))
+
+ def preproc_and_save_icenet_data(self):
+
+ '''
+ Loop through each variable, preprocessing and saving.
+ '''
+
+ for varname, vardict in self.preproc_vars.items():
+
+ if 'metadata' not in vardict.keys():
+
+ for data_format in vardict.keys():
+
+ if vardict[data_format] is True:
+
+ self.save_variable(varname, data_format)
+
+ elif 'metadata' in vardict.keys():
+
+ if vardict['include']:
+ if varname == 'land':
+ if self.verbose_level >= 2:
+ print("Setting up the land map: ", end='', flush=True)
+
+ land_mask = np.load(os.path.join(config.mask_data_folder, config.land_mask_filename))
+ land_map = np.ones(self.config['raw_data_shape'], np.float32)
+ land_map[~land_mask] = -1.
+
+ np.save(os.path.join(self.paths['meta'], 'land.npy'), land_map)
+
+ print('\n')
+
+ elif varname == 'circmonth':
+ if self.verbose_level >= 2:
+ print("Computing circular month values... ", end='', flush=True)
+ tic = time.time()
+
+ for month in np.arange(1, 13):
+ cos_month = np.cos(2 * np.pi * month / 12, dtype='float32')
+ sin_month = np.sin(2 * np.pi * month / 12, dtype='float32')
+
+ np.save(os.path.join(self.paths['meta'], 'cos_month_{:02d}.npy'.format(month)), cos_month)
+ np.save(os.path.join(self.paths['meta'], 'sin_month_{:02d}.npy'.format(month)), sin_month)
+
+ if self.verbose_level >= 2:
+ print("Done in {:.0f}s.\n".format(time.time() - tic))
+
+ with open(self.norm_params_fpath, 'w') as outfile:
+ json.dump(self.norm_params, outfile)
+
+
+class IceNetDataLoader(tf.keras.utils.Sequence):
+ """
+ Custom data loader class for generating batches of input-output tensors for
+ training IceNet. Inherits from keras.utils.Sequence, which ensures each the
+ network trains once on each sample per epoch. Must implement a __len__
+ method that returns the number of batches and a __getitem__ method that
+ returns a batch of data. The on_epoch_end method is called after each
+ epoch.
+ See: https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence
+
+ """
+
+ def __init__(self, dataloader_config_fpath, seed=None):
+
+ '''
+ Params:
+ dataloader_config_fpath (str): Path to the data loader configuration
+ settings JSON file, defining IceNet's input-output data configuration.
+
+ seed (int): Random seed used for shuffling the training samples before
+ each epoch.
+ '''
+
+ with open(dataloader_config_fpath, 'r') as readfile:
+ self.config = json.load(readfile)
+
+ if seed is None:
+ self.set_seed(self.config['default_seed'])
+ else:
+ self.set_seed(seed)
+
+ self.do_transfer_learning = False
+
+ self.set_obs_forecast_IDs(dataset='train')
+ self.set_transfer_forecast_IDs()
+ self.all_forecast_IDs = self.obs_forecast_IDs
+ self.remove_missing_dates()
+ self.set_variable_path_formats()
+ self.set_number_of_input_channels_for_each_input_variable()
+ self.load_polarholes()
+ self.determine_tot_num_channels()
+ self.on_epoch_end()
+
+ if self.config['verbose_level'] >= 1:
+ print("Setup complete.\n")
+
+ def set_obs_forecast_IDs(self, dataset='train'):
+ """
+ Build up a list of forecast initialisation dates for the train, val, or
+ test sets based on the configuration JSON file start & end points for
+ each dataset.
+ """
+
+ forecast_start_date_ends = self.config['sample_IDs']['obs_{}_dates'.format(dataset)]
+
+ if forecast_start_date_ends is not None:
+
+ # Convert to Pandas Timestamps
+ forecast_start_date_ends = [
+ pd.Timestamp(date).to_pydatetime() for date in forecast_start_date_ends
+ ]
+
+ self.obs_forecast_IDs = list(pd.date_range(
+ forecast_start_date_ends[0],
+ forecast_start_date_ends[1],
+ freq='MS',
+ closed='right',
+ ))
+
+ def set_transfer_forecast_IDs(self):
+
+ '''
+ Use self.cmip6_transfer_train_dict to set up a list array of
+ 3-tuples of the form:
+ (cmip6_model_name, member_id, forecast_start_date)
+
+ This list is used as IDs into the transfer data hierarchy
+ to train on all cmip6 models and and their runs simultaneously.
+ '''
+
+ self.transfer_forecast_IDs = []
+ for cmip6_model_name, member_id_dict in self.config['cmip6_run_dict'].items():
+ for member_id, (start_date, end_date) in member_id_dict.items():
+
+ member_id_dates = list(pd.date_range(
+ start_date,
+ end_date,
+ freq='MS',
+ closed='right',
+ ))
+
+ self.transfer_forecast_IDs.extend(
+ itertools.product([cmip6_model_name], [member_id], member_id_dates)
+ )
+
+ def set_variable_path_formats(self):
+
+ """
+ Initialise the paths to the .npy files of each variable based on
+ `self.config['input_data']`.
+ """
+
+ if self.config['verbose_level'] >= 1:
+ print('Setting up the variable paths for {}... '.format(self.config['dataset_name']),
+ end='', flush=True)
+
+ # Parent folder for this dataset
+ self.dataset_path = os.path.join(config.network_dataset_folder, self.config['dataset_name'])
+
+ # Dictionary data structure to store image variable paths
+ self.variable_paths = {}
+
+ for varname, vardict in self.config['input_data'].items():
+
+ if 'metadata' not in vardict.keys():
+ self.variable_paths[varname] = {}
+
+ for data_format in vardict.keys():
+
+ if vardict[data_format]['include'] is True:
+
+ if not self.do_transfer_learning:
+ path = os.path.join(
+ self.dataset_path, 'obs',
+ varname, data_format, '{:04d}_{:02d}.npy'
+ )
+ elif self.do_transfer_learning:
+ path = os.path.join(
+ self.dataset_path, 'transfer', '{}', '{}',
+ varname, data_format, '{:04d}_{:02d}.npy'
+ )
+
+ self.variable_paths[varname][data_format] = path
+
+ elif 'metadata' in vardict.keys():
+
+ if vardict['include'] is True:
+
+ if varname == 'land':
+ path = os.path.join(self.dataset_path, 'meta', 'land.npy')
+ self.variable_paths['land'] = path
+
+ elif varname == 'circmonth':
+ path = os.path.join(self.dataset_path, 'meta',
+ '{}_month_{:02d}.npy')
+ self.variable_paths['circmonth'] = path
+
+ if self.config['verbose_level'] >= 1:
+ print('Done.')
+
+ def set_seed(self, seed):
+ """
+ Set the seed used by the random generator (used to randomly shuffle
+ the ordering of training samples after each epoch).
+ """
+ if self.config['verbose_level'] >= 1:
+ print("Setting the data generator's random seed to {}".format(seed))
+ self.rng = np.random.default_rng(seed)
+
+ def determine_variable_names(self):
+ """
+ Set up a list of strings for the names of each input variable (in the
+ correct order) by looping over the `input_data` dictionary.
+ """
+ variable_names = []
+
+ for varname, vardict in self.config['input_data'].items():
+ # Input variables that span time
+ if 'metadata' not in vardict.keys():
+ for data_format in vardict.keys():
+ if vardict[data_format]['include']:
+ if data_format != 'linear_trend':
+ for lag in np.arange(1, vardict[data_format]['max_lag'] + 1):
+ variable_names.append(varname + '_{}_{}'.format(data_format, lag))
+ elif data_format == 'linear_trend':
+ for leadtime in np.arange(1, self.config['n_forecast_months'] + 1):
+ variable_names.append(varname + '_{}_{}'.format(data_format, leadtime))
+
+ # Metadata input variables that don't span time
+ elif 'metadata' in vardict.keys() and vardict['include']:
+ if varname == 'land':
+ variable_names.append(varname)
+
+ elif varname == 'circmonth':
+ variable_names.append('cos(month)')
+ variable_names.append('sin(month)')
+
+ return variable_names
+
+ def set_number_of_input_channels_for_each_input_variable(self):
+ """
+ Build up the dict `self.num_input_channels_dict` to store the number of input
+ channels spanned by each input variable.
+ """
+
+ if self.config['verbose_level'] >= 1:
+ print("Setting the number of input months for each input variable.")
+
+ self.num_input_channels_dict = {}
+
+ for varname, vardict in self.config['input_data'].items():
+ if 'metadata' not in vardict.keys():
+ # Variables that span time
+ for data_format in vardict.keys():
+ if vardict[data_format]['include']:
+ varname_format = varname + '_{}'.format(data_format)
+ if data_format != 'linear_trend':
+ self.num_input_channels_dict[varname_format] = vardict[data_format]['max_lag']
+ elif data_format == 'linear_trend':
+ self.num_input_channels_dict[varname_format] = self.config['n_forecast_months']
+
+ # Metadata input variables that don't span time
+ elif 'metadata' in vardict.keys() and vardict['include']:
+ if varname == 'land':
+ self.num_input_channels_dict[varname] = 1
+
+ if varname == 'circmonth':
+ self.num_input_channels_dict[varname] = 2
+
+ def determine_tot_num_channels(self):
+ """
+ Determine the number of channels for the input 3D volumes.
+ """
+
+ self.tot_num_channels = 0
+ for varname, num_channels in self.num_input_channels_dict.items():
+ self.tot_num_channels += num_channels
+
+ def all_sic_input_dates_from_forecast_start_date(self, forecast_start_date):
+ """
+ Return a list of all the SIC dates used as input for a particular forecast
+ date, based on the "max_lag" options of self.config['input_data'].
+ """
+
+ # Find all SIC lags
+ max_lags = []
+ if self.config['input_data']['siconca']['abs']['include']:
+ max_lags.append(self.config['input_data']['siconca']['abs']['max_lag'])
+ if self.config['input_data']['siconca']['anom']['include']:
+ max_lags.append(self.config['input_data']['siconca']['anom']['max_lag'])
+ max_lag = np.max(max_lags)
+
+ input_dates = [
+ forecast_start_date - pd.DateOffset(months=int(lag)) for lag in np.arange(1, max_lag + 1)
+ ]
+
+ return input_dates
+
+ def check_for_missing_date_dependence(self, forecast_start_date):
+ """
+ Check a forecast ID and return a bool for whether any of the input SIC maps
+ are missing. Used to remove forecast IDs that depend on missing SIC data.
+
+ Note: If one of the _forecast_ dates are missing but not _input_ dates,
+ the sample weight matrix for that date will be all zeroes so that the
+ samples for that date do not appear in the loss function.
+ """
+ contains_missing_date = False
+
+ # Check SIC input dates
+ input_dates = self.all_sic_input_dates_from_forecast_start_date(forecast_start_date)
+
+ for input_date in input_dates:
+ if any([input_date == missing_date for missing_date in config.missing_dates]):
+ contains_missing_date = True
+ break
+
+ return contains_missing_date
+
+ def remove_missing_dates(self):
+
+ '''
+ Remove dates from self.obs_forecast_IDs that depend on a missing
+ observation of SIC.
+ '''
+
+ if self.config['verbose_level'] >= 2:
+ print('Checking forecast start dates for missing SIC dates... ', end='', flush=True)
+
+ new_obs_forecast_IDs = []
+ for forecast_start_date in self.obs_forecast_IDs:
+ if self.check_for_missing_date_dependence(forecast_start_date):
+ if self.config['verbose_level'] >= 3:
+ print('Removing {}, '.format(
+ forecast_start_date.strftime('%Y_%m_%d')), end='', flush=True)
+
+ else:
+ new_obs_forecast_IDs.append(forecast_start_date)
+
+ self.obs_forecast_IDs = new_obs_forecast_IDs
+
+ def load_polarholes(self):
+ """
+ Loads each of the polar holes.
+ """
+
+ if self.config['verbose_level'] >= 1:
+ tic = time.time()
+ print("Loading and augmenting the polar holes... ", end='', flush=True)
+
+ polarhole_path = os.path.join(config.mask_data_folder, config.polarhole1_fname)
+ self.polarhole1_mask = np.load(polarhole_path)
+
+ polarhole_path = os.path.join(config.mask_data_folder, config.polarhole2_fname)
+ self.polarhole2_mask = np.load(polarhole_path)
+
+ if config.use_polarhole3:
+ polarhole_path = os.path.join(config.mask_data_folder, config.polarhole3_fname)
+ self.polarhole3_mask = np.load(polarhole_path)
+
+ self.nopolarhole_mask = np.full((432, 432), False)
+
+ if self.config['verbose_level'] >= 1:
+ print("Done in {:.0f}s.\n".format(time.time() - tic))
+
+ def determine_polar_hole_mask(self, forecast_start_date):
+ """
+ Determine which polar hole mask to use (if any) by finding the oldest SIC
+ input month based on the current output month. The polar hole active for
+ the oldest input month is used (because the polar hole size decreases
+ monotonically over time, and we wish to use the largest polar hole for
+ the input data).
+
+ Parameters:
+ forecast_start_date (pd.Timestamp): Timepoint for the forecast initialialisation.
+
+ Returns:
+ polarhole_mask: Mask array with NaNs on polar hole grid cells and 1s
+ elsewhere.
+ """
+
+ oldest_input_date = min(self.all_sic_input_dates_from_forecast_start_date(forecast_start_date))
+
+ if oldest_input_date <= config.polarhole1_final_date:
+ polarhole_mask = self.polarhole1_mask
+ if self.config['verbose_level'] >= 3:
+ print("Forecast start date: {}, polar hole: {}".format(
+ forecast_start_date.strftime("%Y_%m"), 1))
+
+ elif oldest_input_date <= config.polarhole2_final_date:
+ polarhole_mask = self.polarhole2_mask
+ if self.config['verbose_level'] >= 3:
+ print("Forecast start date: {}, polar hole: {}".format(
+ forecast_start_date.strftime("%Y_%m"), 2))
+
+ else:
+ polarhole_mask = self.nopolarhole_mask
+ if self.config['verbose_level'] >= 3:
+ print("Forecast start date: {}, polar hole: {}".format(
+ forecast_start_date.strftime("%Y_%m"), "none"))
+
+ return polarhole_mask
+
+ def determine_active_grid_cell_mask(self, forecast_date):
+ """
+ Determine which active grid cell mask to use (a boolean array with
+ True on active cells and False on inactive cells). The cells with 'True'
+ are where predictions are to be made with IceNet. The active grid cell
+ mask for a particular month is determined by the sum of the land cells,
+ the ocean cells (for that month), and the missing polar hole.
+
+ The mask is used for removing 'inactive' cells (such as land or polar
+ hole cells) from the loss function in self.data_generation.
+ """
+
+ output_month_str = '{:02d}'.format(forecast_date.month)
+ output_active_grid_cell_mask_fname = config.active_grid_cell_file_format. \
+ format(output_month_str)
+ output_active_grid_cell_mask_path = os.path.join(
+ config.mask_data_folder, output_active_grid_cell_mask_fname)
+ output_active_grid_cell_mask = np.load(output_active_grid_cell_mask_path)
+
+ # Only use the polar hole mask if predicting observational data
+ if not self.do_transfer_learning:
+ polarhole_mask = self.determine_polar_hole_mask(forecast_date)
+
+ # Add the polar hole mask to that land/ocean mask for the current month
+ output_active_grid_cell_mask[polarhole_mask] = False
+
+ return output_active_grid_cell_mask
+
+ def turn_on_transfer_learning(self):
+
+ '''
+ Converts the data loader to use CMIP6 pre-training data
+ for transfer learning.
+ '''
+
+ self.do_transfer_learning = True
+ self.all_forecast_IDs = self.transfer_forecast_IDs
+ self.on_epoch_end() # Shuffle transfer training indexes
+ self.set_variable_path_formats()
+
+ def turn_off_transfer_learning(self):
+
+ '''
+ Converts the data loader back to using ERA5/OSI-SAF observational
+ training data.
+ '''
+
+ self.do_transfer_learning = False
+ self.all_forecast_IDs = self.obs_forecast_IDs
+ self.on_epoch_end() # Shuffle transfer training indexes
+ self.set_variable_path_formats()
+
+ def convert_to_validation_data_loader(self):
+
+ """
+ Resets the `all_forecast_IDs` array to correspond to the validation
+ months defined by the data loader configuration file.
+ """
+
+ self.set_obs_forecast_IDs(dataset='val')
+ self.remove_missing_dates()
+ self.all_forecast_IDs = self.obs_forecast_IDs
+
+ def convert_to_test_data_loader(self):
+
+ """
+ As above but for the testing months.
+ """
+
+ self.set_obs_forecast_IDs(dataset='test')
+ self.remove_missing_dates()
+ self.all_forecast_IDs = self.obs_forecast_IDs
+
+ def data_generation(self, forecast_IDs):
+ """
+ Generate input-output data for IceNet for a given forecast ID.
+
+ Parameters:
+ forecast_IDs (list):
+ If self.do_transfer_learning is False, a list of pd.Timestamp objects
+ corresponding to the forecast initialisation dates (first month
+ being forecast) for the batch of X-y data to load.
+
+ If self.do_transfer_learning is True, a list of tuples
+ of the form (cmip6_model_name, member_id, forecast_start_date).
+
+ Returns:
+ X (ndarray): Batch of input 3D volumes.
+
+ y (ndarray): Batch of ground truth output SIC class maps
+
+ sample_weight (ndarray): Batch of pixelwise weights for weighting the
+ loss function (masking outside the active grid cell region and
+ up-weighting summer months).
+ """
+
+ # Allow non-list input for single forecasts
+ forecast_IDs = pd.Timestamp(forecast_IDs) if isinstance(forecast_IDs, str) else forecast_IDs
+ forecast_IDs = [forecast_IDs] if not isinstance(forecast_IDs, list) else forecast_IDs
+
+ current_batch_size = len(forecast_IDs)
+
+ if self.do_transfer_learning:
+ cmip6_model_names = [forecast_ID[0] for forecast_ID in forecast_IDs]
+ cmip6_member_ids = [forecast_ID[1] for forecast_ID in forecast_IDs]
+ forecast_start_dates = [forecast_ID[2] for forecast_ID in forecast_IDs]
+ else:
+ forecast_start_dates = forecast_IDs
+
+ ########################################################################
+ # OUTPUT LABELS
+ ########################################################################
+
+ # Build up the set of N_samps output SIC time-series
+ # (each n_forecast_months long in the time dimension)
+
+ # To become array of shape (N_samps, *raw_data_shape, n_forecast_months)
+ batch_sic_list = []
+
+ # True = forecasts months corresponding to no data
+ missing_month_dict = {}
+
+ for sample_idx, forecast_date in enumerate(forecast_start_dates):
+
+ # To become array of shape (*raw_data_shape, n_forecast_months)
+ sample_sic_list = []
+
+ # List of forecast indexes with missing data
+ missing_month_dict[sample_idx] = []
+
+ for forecast_leadtime_idx in range(self.config['n_forecast_months']):
+
+ forecast_date = forecast_start_dates[sample_idx] + pd.DateOffset(months=forecast_leadtime_idx)
+
+ if self.do_transfer_learning:
+ sample_sic_list.append(
+ np.load(self.variable_paths['siconca']['abs'].format(
+ cmip6_model_names[sample_idx], cmip6_member_ids[sample_idx],
+ forecast_date.year, forecast_date.month))
+ )
+
+ elif not self.do_transfer_learning:
+ if any([forecast_date == missing_date for missing_date in config.missing_dates]):
+ # Output file does not exist
+ sample_sic_list.append(np.zeros(self.config['raw_data_shape']))
+
+ else:
+ fpath = self.variable_paths['siconca']['abs'].format(
+ forecast_date.year, forecast_date.month)
+ if os.path.exists(fpath):
+ sample_sic_list.append(np.load(fpath))
+ else:
+ # Ground truth data doesn't exist: fill with NaNs
+ sample_sic_list.append(
+ np.full(self.config['raw_data_shape'], np.nan, dtype=np.float32))
+
+ batch_sic_list.append(np.stack(sample_sic_list, axis=2))
+
+ batch_sic = np.stack(batch_sic_list, axis=0)
+
+ no_ice_gridcells = batch_sic <= 0.15
+ ice_gridcells = batch_sic >= 0.80
+ marginal_ice_gridcells = ~((no_ice_gridcells) | (ice_gridcells))
+
+ # Categorical representation with channel dimension for class probs
+ y = np.zeros((
+ current_batch_size,
+ *self.config['raw_data_shape'],
+ self.config['n_forecast_months'],
+ 3
+ ), dtype=np.float32)
+
+ y[no_ice_gridcells, 0] = 1
+ y[marginal_ice_gridcells, 1] = 1
+ y[ice_gridcells, 2] = 1
+
+ # Move lead time to final axis
+ y = np.moveaxis(y, source=3, destination=4)
+
+ # Missing months
+ for sample_idx, forecast_leadtime_idx_list in missing_month_dict.items():
+ if len(forecast_leadtime_idx_list) > 0:
+ y[sample_idx, :, :, :, forecast_leadtime_idx_list] = 0
+
+ ########################################################################
+ # PIXELWISE LOSS FUNCTION WEIGHTING
+ ########################################################################
+
+ sample_weight = np.zeros((
+ current_batch_size,
+ *self.config['raw_data_shape'],
+ 1, # Broadcastable class dimension
+ self.config['n_forecast_months']
+ ), dtype=np.float32)
+ for sample_idx, forecast_date in enumerate(forecast_start_dates):
+
+ for forecast_leadtime_idx in range(self.config['n_forecast_months']):
+
+ forecast_date = forecast_start_dates[sample_idx] + pd.DateOffset(months=forecast_leadtime_idx)
+
+ if any([forecast_date == missing_date for missing_date in config.missing_dates]):
+ # Leave sample weighting as all-zeros
+ pass
+
+ else:
+ # Zero loss outside of 'active grid cells'
+ sample_weight_ij = self.determine_active_grid_cell_mask(forecast_date)
+ sample_weight_ij = sample_weight_ij.astype(np.float32)
+
+ # Scale the loss for each month s.t. March is
+ # scaled by 1 and Sept is scaled by 1.77
+ if self.config['loss_weight_months']:
+ sample_weight_ij *= 33928. / np.sum(sample_weight_ij)
+
+ sample_weight[sample_idx, :, :, 0, forecast_leadtime_idx] = \
+ sample_weight_ij
+
+ ########################################################################
+ # INPUT FEATURES
+ ########################################################################
+
+ # Batch tensor
+ X = np.zeros((
+ current_batch_size,
+ *self.config['raw_data_shape'],
+ self.tot_num_channels
+ ), dtype=np.float32)
+
+ # Build up the batch of inputs
+ for sample_idx, forecast_start_date in enumerate(forecast_start_dates):
+
+ present_date = forecast_start_date - relativedelta(months=1)
+
+ # Initialise variable indexes used to fill the input tensor `X`
+ variable_idx1 = 0
+ variable_idx2 = 0
+
+ for varname, vardict in self.config['input_data'].items():
+
+ if 'metadata' not in vardict.keys():
+
+ for data_format in vardict.keys():
+
+ if vardict[data_format]['include']:
+
+ varname_format = '{}_{}'.format(varname, data_format)
+
+ if data_format != 'linear_trend':
+ lbs = range(vardict[data_format]['max_lag'])
+ input_months = [present_date - relativedelta(months=lb) for lb in lbs]
+ elif data_format == 'linear_trend':
+ input_months = [present_date + relativedelta(months=forecast_leadtime)
+ for forecast_leadtime in np.arange(1, self.config['n_forecast_months'] + 1)]
+
+ variable_idx2 += self.num_input_channels_dict[varname_format]
+
+ if not self.do_transfer_learning:
+ X[sample_idx, :, :, variable_idx1:variable_idx2] = \
+ np.stack([np.load(self.variable_paths[varname][data_format].format(
+ date.year, date.month))
+ for date in input_months], axis=-1)
+ elif self.do_transfer_learning:
+ cmip6_model_name = cmip6_model_names[sample_idx]
+ cmip6_member_id = cmip6_member_ids[sample_idx]
+
+ X[sample_idx, :, :, variable_idx1:variable_idx2] = \
+ np.stack([np.load(self.variable_paths[varname][data_format].format(
+ cmip6_model_name, cmip6_member_id, date.year, date.month))
+ for date in input_months], axis=-1)
+
+ variable_idx1 += self.num_input_channels_dict[varname_format]
+
+ elif 'metadata' in vardict.keys() and vardict['include']:
+
+ variable_idx2 += self.num_input_channels_dict[varname]
+
+ if varname == 'land':
+ X[sample_idx, :, :, variable_idx1] = np.load(self.variable_paths['land'])
+
+ elif varname == 'circmonth':
+ X[sample_idx, :, :, variable_idx1] = \
+ np.load(self.variable_paths['circmonth'].format('cos', forecast_start_date.month))
+ X[sample_idx, :, :, variable_idx1 + 1] = \
+ np.load(self.variable_paths['circmonth'].format('sin', forecast_start_date.month))
+
+ variable_idx1 += self.num_input_channels_dict[varname]
+
+ return X, y, sample_weight
+
+ def __getitem__(self, batch_idx):
+ '''
+ Generate one batch of data of size `batch_size` at batch index `batch_idx`
+ into the set of batches in the epoch.
+ '''
+
+ batch_start = batch_idx * self.config['batch_size']
+ batch_end = np.min([(batch_idx + 1) * self.config['batch_size'], len(self.all_forecast_IDs)])
+
+ sample_idxs = np.arange(batch_start, batch_end)
+ batch_IDs = [self.all_forecast_IDs[sample_idx] for sample_idx in sample_idxs]
+
+ return self.data_generation(batch_IDs)
+
+ def __len__(self):
+ ''' Returns the number of batches per training epoch. '''
+ return int(np.ceil(len(self.all_forecast_IDs) / self.config['batch_size']))
+
+ def on_epoch_end(self):
+ """ Randomly shuffles training samples after each epoch. """
+
+ if self.config['verbose_level'] >= 2:
+ print("on_epoch_end called")
+
+ # Randomly shuffle the forecast IDs in-place
+ self.rng.shuffle(self.all_forecast_IDs)
+
+
+# MISC FUNCTIONS
+################################################################################
+
+
+def create_results_dataset_index(model_compute_list, leadtimes,
+ all_target_dates, icenet_ID,
+ icenet_seeds):
+
+ '''
+ Returns a pandas.MultiIndex object of results dataset indexes for a
+ given list of models to compute metrics for. For IceNet, the 'Ensemble
+ member' column delineates the performance of each IceNet ensemble
+ member (identified by the integer random seed value it was trained
+ with) and the ensemble mean models ('ensemble' or 'ensemble_tempscaled').
+ '''
+
+ multi_index = pd.MultiIndex.from_product(
+ [model_compute_list, leadtimes, all_target_dates])
+
+ idxs = []
+ for row in multi_index:
+ model = row[0]
+ row = [[item] for item in row]
+ if model == icenet_ID:
+ idxs.extend(list(itertools.product(*row, icenet_seeds)))
+ else:
+ idxs.extend(list(itertools.product(*row, ['NA'])))
+
+ multi_index = pd.MultiIndex.from_tuples(
+ idxs, names=['Model', 'Leadtime', 'Forecast date', 'Ensemble member']).\
+ reorder_levels(['Model', 'Ensemble member', 'Leadtime', 'Forecast date'])
+
+ return multi_index
+
+
+def make_varname_verbose(varname, leadtime, fc_month_idx):
+
+ '''
+ Takes IceNet short variable name (e.g. siconca_abs_3) and converts it to a
+ long name for a given forecast calendar month and lead time (e.g.
+ 'Feb SIC').
+
+ Inputs:
+ varname: Short variable name.
+ leadtime: Lead time of the forecast.
+ fc_month_index: Mod-12 calendar month index for the month being forecast
+ (e.g. 8 for September)
+
+ Returns:
+ verbose_varname: Long variable name.
+ '''
+
+ month_names = np.array(['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sept', 'Oct', 'Nov', 'Dec'])
+
+ varname_regex = re.compile('^(.*)_(abs|anom|linear_trend)_([0-9]+)$')
+
+ var_lookup_table = {
+ 'siconca': 'SIC',
+ 'tas': '2m air temperature',
+ 'ta500': '500 hPa air temperature',
+ 'tos': 'sea surface temperature',
+ 'rsds': 'downwelling solar radiation',
+ 'rsus': 'upwelling solar radiation',
+ 'psl': 'sea level pressure',
+ 'zg500': '500 hPa geopotential height',
+ 'zg250': '250 hPa geopotential height',
+ 'ua10': '10 hPa zonal wind speed',
+ 'uas': 'x-direction wind',
+ 'vas': 'y-direction wind'
+ }
+
+ initialisation_month_idx = (fc_month_idx - leadtime) % 12
+
+ varname_match = varname_regex.match(varname)
+
+ field = varname_match[1]
+ data_format = varname_match[2]
+ lead_or_lag = int(varname_match[3])
+
+ verbose_varname = ''
+
+ month_suffix = ' '
+ month_prefix = ''
+ if data_format != 'linear_trend':
+ # Read back from initialisation month to get input lag month
+ lag = lead_or_lag # In no of months
+ input_month_name = month_names[(initialisation_month_idx - lag + 1) % 12]
+
+ if (initialisation_month_idx - lag + 1) // 12 == -1:
+ # Previous calendar year
+ month_prefix = 'Previous '
+
+ elif data_format == 'linear_trend':
+ # Read forward from initialisation month to get linear trend forecast month
+ lead = lead_or_lag # In no of months
+ input_month_name = month_names[(initialisation_month_idx + lead) % 12]
+
+ if (initialisation_month_idx + lead) // 12 == 1:
+ # Next calendar year
+ month_prefix = 'Next '
+
+ # Month the input corresponds to
+ verbose_varname += month_prefix + input_month_name + month_suffix
+
+ # verbose variable name
+ if data_format != 'linear_trend':
+ verbose_varname += var_lookup_table[field]
+ if data_format == 'anom':
+ verbose_varname += ' anomaly'
+ elif data_format == 'linear_trend':
+ verbose_varname += 'linear trend SIC forecast'
+
+ return verbose_varname
+
+
+def make_varname_verbose_any_leadtime(varname):
+
+ ''' As above, but agnostic to what the target month or lead time is. E.g.
+ "SIC (1)" for sea ice concentration at a lag of 1 month. '''
+
+ varname_regex = re.compile('^(.*)_(abs|anom|linear_trend)_([0-9]+)$')
+
+ var_lookup_table = {
+ 'siconca': 'SIC',
+ 'tas': '2m air temperature',
+ 'ta500': '500 hPa air temperature',
+ 'tos': 'sea surface temperature',
+ 'rsds': 'downwelling solar radiation',
+ 'rsus': 'upwelling solar radiation',
+ 'psl': 'sea level pressure',
+ 'zg500': '500 hPa geopotential height',
+ 'zg250': '250 hPa geopotential height',
+ 'ua10': '10 hPa zonal wind speed',
+ 'uas': 'x-direction wind',
+ 'vas': 'y-direction wind',
+ 'land': 'land mask',
+ 'cos(month)': 'cos(init month)',
+ 'sin(month)': 'sin(init month)',
+ }
+
+ exception_vars = ['cos(month)', 'sin(month)', 'land']
+
+ if varname in exception_vars:
+ return var_lookup_table[varname]
+ else:
+ varname_match = varname_regex.match(varname)
+
+ field = varname_match[1]
+ data_format = varname_match[2]
+ lead_or_lag = int(varname_match[3])
+
+ # verbose variable name
+ if data_format != 'linear_trend':
+ verbose_varname = var_lookup_table[field]
+ if data_format == 'anom':
+ verbose_varname += ' anomaly'
+ elif data_format == 'linear_trend':
+ verbose_varname = 'Linear trend SIC forecast'
+
+ verbose_varname += ' ({:.0f})'.format(lead_or_lag)
+
+ return verbose_varname
+
+
+################################################################################
+# FUNCTIONS
+################################################################################
+
+
+def assignLatLonCoordSystem(cube):
+ ''' Assign coordinate system to iris cube to allow regridding. '''
+
+ cube.coord('latitude').coord_system = iris.coord_systems.GeogCS(6367470.0)
+ cube.coord('longitude').coord_system = iris.coord_systems.GeogCS(6367470.0)
+
+ return cube
+
+
+def fix_near_real_time_era5_func(latlon_path):
+
+ '''
+ Near-real-time ERA5 data is classed as a different dataset called 'ERA5T'.
+ This results in a spurious 'expver' dimension. This method detects
+ whether that dim is present and removes it, concatenating into one array
+ '''
+
+ ds = xr.open_dataarray(latlon_path)
+
+ if len(ds.data.shape) == 4:
+ print('Fixing spurious ERA5 "expver dimension for {}.'.format(latlon_path))
+
+ arr = xr.open_dataarray(latlon_path).data
+ arr = ds.data
+ # Expver 1 (ERA5)
+ era5_months = ~np.isnan(arr[:, 0, :, :]).all(axis=(1, 2))
+
+ # Expver 2 (ERA5T - near real time)
+ era5t_months = ~np.isnan(arr[:, 1, :, :]).all(axis=(1, 2))
+
+ ds = xr.concat((ds[era5_months, 0, :], ds[era5t_months, 1, :]), dim='time')
+
+ ds = ds.reset_coords('expver', drop=True)
+
+ os.remove(latlon_path)
+ ds.load().to_netcdf(latlon_path)
+
+
+###############################################################################
+# LEARNING RATE SCHEDULER
+###############################################################################
+
+
+def make_exp_decay_lr_schedule(rate, start_epoch=1, end_epoch=np.inf, verbose=False):
+
+ ''' Returns an exponential learning rate function that multiplies by
+ exp(-rate) each epoch after `start_epoch`. '''
+
+ def lr_scheduler_exp_decay(epoch, lr):
+ ''' Learning rate scheduler for fine tuning.
+ Exponential decrease after start_epoch until end_epoch. '''
+
+ if epoch >= start_epoch and epoch < end_epoch:
+ lr = lr * np.math.exp(-rate)
+
+ if verbose:
+ print('\nSetting learning rate to: {}\n'.format(lr))
+
+ return lr
+
+ return lr_scheduler_exp_decay
+
+
+###############################################################################
+# REGRIDDING VECTOR DATA
+###############################################################################
+
+
+def rotate_grid_vectors(u_cube, v_cube, angles):
+ """
+ Author: Tony Phillips (BAS)
+
+ Wrapper for :func:`~iris.analysis.cartography.rotate_grid_vectors`
+ that can rotate multiple masked spatial fields in one go by iterating
+ over the horizontal spatial axes in slices
+ """
+ # lists to hold slices of rotated vectors
+ u_r_all = iris.cube.CubeList()
+ v_r_all = iris.cube.CubeList()
+
+ # get the X and Y dimension coordinates for each source cube
+ u_xy_coords = [u_cube.coord(axis='x', dim_coords=True),
+ u_cube.coord(axis='y', dim_coords=True)]
+ v_xy_coords = [v_cube.coord(axis='x', dim_coords=True),
+ v_cube.coord(axis='y', dim_coords=True)]
+
+ # iterate over X, Y slices of the source cubes, rotating each in turn
+ for u, v in zip(u_cube.slices(u_xy_coords, ordered=False),
+ v_cube.slices(v_xy_coords, ordered=False)):
+ u_r, v_r = iris.analysis.cartography.rotate_grid_vectors(u, v, angles)
+ u_r_all.append(u_r)
+ v_r_all.append(v_r)
+
+ # return the slices, merged back together into a pair of cubes
+ return (u_r_all.merge_cube(), v_r_all.merge_cube())
+
+
+def gridcell_angles_from_dim_coords(cube):
+ """
+ Author: Tony Phillips (BAS)
+
+ Wrapper for :func:`~iris.analysis.cartography.gridcell_angles`
+ that derives the 2D X and Y lon/lat coordinates from 1D X and Y
+ coordinates identifiable as 'x' and 'y' axes
+
+ The provided cube must have a coordinate system so that its
+ X and Y coordinate bounds (which are derived if necessary)
+ can be converted to lons and lats
+ """
+
+ # get the X and Y dimension coordinates for the cube
+ x_coord = cube.coord(axis='x', dim_coords=True)
+ y_coord = cube.coord(axis='y', dim_coords=True)
+
+ # add bounds if necessary
+ if not x_coord.has_bounds():
+ x_coord = x_coord.copy()
+ x_coord.guess_bounds()
+ if not y_coord.has_bounds():
+ y_coord = y_coord.copy()
+ y_coord.guess_bounds()
+
+ # get the grid cell bounds
+ x_bounds = x_coord.bounds
+ y_bounds = y_coord.bounds
+ nx = x_bounds.shape[0]
+ ny = y_bounds.shape[0]
+
+ # make arrays to hold the ordered X and Y bound coordinates
+ x = np.zeros((ny, nx, 4))
+ y = np.zeros((ny, nx, 4))
+
+ # iterate over the bounds (in order BL, BR, TL, TR), mesh them and
+ # put them into the X and Y bound coordinates (in order BL, BR, TR, TL)
+ c = [0, 1, 3, 2]
+ cind = 0
+ for yi in [0, 1]:
+ for xi in [0, 1]:
+ xy = np.meshgrid(x_bounds[:, xi], y_bounds[:, yi])
+ x[:, :, c[cind]] = xy[0]
+ y[:, :, c[cind]] = xy[1]
+ cind += 1
+
+ # convert the X and Y coordinates to longitudes and latitudes
+ source_crs = cube.coord_system().as_cartopy_crs()
+ target_crs = ccrs.PlateCarree()
+ pts = target_crs.transform_points(source_crs, x.flatten(), y.flatten())
+ lons = pts[:, 0].reshape(x.shape)
+ lats = pts[:, 1].reshape(x.shape)
+
+ # get the angles
+ angles = iris.analysis.cartography.gridcell_angles(lons, lats)
+
+ # add the X and Y dimension coordinates from the cube to the angles cube
+ angles.add_dim_coord(y_coord, 0)
+ angles.add_dim_coord(x_coord, 1)
+
+ # if the cube's X dimension preceeds its Y dimension
+ # transpose the angles to match
+ if cube.coord_dims(x_coord)[0] < cube.coord_dims(y_coord)[0]:
+ angles.transpose()
+
+ return angles
+
+
+def invert_gridcell_angles(angles):
+ """
+ Author: Tony Phillips (BAS)
+
+ Negate a cube of gridcell angles in place, transforming
+ gridcell_angle_from_true_east <--> true_east_from_gridcell_angle
+ """
+ angles.data *= -1
+
+ names = ['true_east_from_gridcell_angle', 'gridcell_angle_from_true_east']
+ name = angles.name()
+ if name in names:
+ angles.rename(names[1 - names.index(name)])
+
+
+###############################################################################
+# CMIP6
+###############################################################################
+
+
+# Below taken from https://hub.binder.pangeo.io/user/pangeo-data-pan--cmip6-examples-ro965nih/lab
+def esgf_search(server="https://esgf-node.llnl.gov/esg-search/search",
+ files_type="OPENDAP", local_node=False, latest=True, project="CMIP6",
+ verbose1=False, verbose2=False, format="application%2Fsolr%2Bjson",
+ use_csrf=False, **search):
+ client = requests.session()
+ payload = search
+ payload["project"] = project
+ payload["type"] = "File"
+ if latest:
+ payload["latest"] = "true"
+ if local_node:
+ payload["distrib"] = "false"
+ if use_csrf:
+ client.get(server)
+ if 'csrftoken' in client.cookies:
+ # Django 1.6 and up
+ csrftoken = client.cookies['csrftoken']
+ else:
+ # older versions
+ csrftoken = client.cookies['csrf']
+ payload["csrfmiddlewaretoken"] = csrftoken
+
+ payload["format"] = format
+
+ offset = 0
+ numFound = 10000
+ all_files = []
+ files_type = files_type.upper()
+ while offset < numFound:
+ payload["offset"] = offset
+ url_keys = []
+ for k in payload:
+ url_keys += ["{}={}".format(k, payload[k])]
+
+ url = "{}/?{}".format(server, "&".join(url_keys))
+ if verbose1:
+ print(url)
+ r = client.get(url)
+ r.raise_for_status()
+ resp = r.json()["response"]
+ numFound = int(resp["numFound"])
+ resp = resp["docs"]
+ offset += len(resp)
+ for d in resp:
+ if verbose2:
+ for k in d:
+ print("{}: {}".format(k, d[k]))
+ url = d["url"]
+ for f in d["url"]:
+ sp = f.split("|")
+ if sp[-1] == files_type:
+ all_files.append(sp[0].split(".html")[0])
+ return sorted(all_files)
+
+
+def regrid_cmip6(cmip6_cube, grid_cube, verbose=False):
+
+ if verbose:
+ tic = time.time()
+ print("regridding... ", end='', flush=True)
+
+ cs = grid_cube.coord_system().ellipsoid
+
+ for coord in ['longitude', 'latitude']:
+ cmip6_cube.coord(coord).coord_system = cs
+
+ cmip6_ease = cmip6_cube.regrid(grid_cube, iris.analysis.Linear())
+
+ if verbose:
+ dur = time.time() - tic
+ print("done in {}m:{:.0f}s... ".format(np.floor(dur / 60), dur % 60), end='', flush=True)
+
+ return cmip6_ease
+
+
+def save_cmip6(cmip6_ease, fpath, compress=True, verbose=False):
+ tic = time.time()
+
+ if compress:
+ if verbose:
+ print('compressing & saving... ', end='', flush=True)
+ iris.fileformats.netcdf.save(cmip6_ease, fpath, complevel=7, zlib=True)
+ else:
+ if verbose:
+ print('saving uncompressed... ', end='', flush=True)
+ iris.save(cmip6_ease, fpath)
+
+ if verbose:
+ dur = time.time() - tic
+ print("done in {}m:{:.0f}s... ".format(np.floor(dur / 60), dur % 60), end='', flush=True)
+
+
+###############################################################################
+# PLOTTING
+###############################################################################
+
+
+def compute_heatmap(results_df, model, seed='NA', metric='Binary accuracy'):
+ '''
+ Returns a binary accuracy heatmap of lead time vs. calendar month
+ for a given model.
+ '''
+
+ month_names = np.array(['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
+ 'Jul', 'Aug', 'Sept', 'Oct', 'Nov', 'Dec'])
+
+ # Mean over calendar month
+ mean_df = results_df.loc[model, seed].reset_index().\
+ groupby(['Calendar month', 'Leadtime']).mean()
+
+ # Pivot
+ heatmap_df = mean_df.reset_index().\
+ pivot('Calendar month', 'Leadtime', metric).reindex(month_names)
+
+ return heatmap_df
+
+
+def arr_to_ice_edge_arr(arr, thresh, land_mask, region_mask):
+
+ '''
+ Compute a boolean mask with True over ice edge contour grid cells using
+ matplotlib.pyplot.contour and an input threshold to define the ice edge
+ (e.g. 0.15 for the 15% SIC ice edge or 0.5 for SIP forecasts). The contour
+ along the coastline is removed using the region mask.
+ '''
+
+ X, Y = np.meshgrid(np.arange(arr.shape[0]), np.arange(arr.shape[1]))
+ X = X.T
+ Y = Y.T
+
+ cs = plt.contour(X, Y, arr, [thresh], alpha=0) # Do not plot on any axes
+ x = []
+ y = []
+ for p in cs.collections[0].get_paths():
+ x_i, y_i = p.vertices.T
+ x.extend(np.round(x_i))
+ y.extend(np.round(y_i))
+ x = np.array(x, int)
+ y = np.array(y, int)
+ ice_edge_arr = np.zeros(arr.shape, dtype=bool)
+ ice_edge_arr[x, y] = True
+ # Mask out ice edge contour that hugs the coastline
+ ice_edge_arr[land_mask] = False
+ ice_edge_arr[region_mask == 13] = False
+
+ return ice_edge_arr
+
+
+def arr_to_ice_edge_rgba_arr(arr, thresh, land_mask, region_mask, rgb):
+
+ ice_edge_arr = arr_to_ice_edge_arr(arr, thresh, land_mask, region_mask)
+
+ # Contour pixels -> alpha=1, alpha=0 elsewhere
+ ice_edge_rgba_arr = np.zeros((*arr.shape, 4))
+ ice_edge_rgba_arr[:, :, 3] = ice_edge_arr
+ ice_edge_rgba_arr[:, :, :3] = rgb
+
+ return ice_edge_rgba_arr
+
+
+###############################################################################
+# VIDEOS
+###############################################################################
+
+
+def xarray_to_video(da, video_path, fps, mask=None, mask_type='contour', clim=None,
+ crop=None, data_type='abs', video_dates=None, cmap='viridis',
+ figsize=15, dpi=300):
+
+ '''
+ Generate video of an xarray.DataArray. Optionally input a list of
+ `video_dates` to show, otherwise the full set of time coordiantes
+ of the dataset is used.
+
+ Parameters:
+ da (xr.DataArray): Dataset to create video of.
+
+ video_path (str): Path to save the video to.
+
+ fps (int): Frames per second of the video.
+
+ mask (np.ndarray): Boolean mask with True over masked elements to overlay
+ as a contour or filled contour. Defaults to None (no mask plotting).
+
+ mask_type (str): 'contour' or 'contourf' dictating whether the mask is overlaid
+ as a contour line or a filled contour.
+
+ data_type (str): 'abs' or 'anom' describing whether the data is in absolute
+ or anomaly format. If anomaly, the colorbar is centred on 0.
+
+ video_dates (list): List of Pandas Timestamps or datetime.datetime objects
+ to plot video from the dataset.
+
+ crop (list): [(a, b), (c, d)] to crop the video from a:b and c:d
+
+ clim (list): Colormap limits. Default is None, in which case the min and max values
+ of the array are used.
+
+ cmap (str): Matplotlib colormap.
+
+ figsize (int or float): Figure size in inches.
+
+ dpi (int): Figure DPI.
+ '''
+
+ if clim is not None:
+ min = clim[0]
+ max = clim[1]
+ elif clim is None:
+ max = da.max().values
+ min = da.min().values
+
+ if data_type == 'anom':
+ if np.abs(max) > np.abs(min):
+ min = -max
+ elif np.abs(min) > np.abs(max):
+ max = -min
+
+ def make_frame(date):
+ fig, ax = plt.subplots(figsize=(figsize, figsize))
+ fig.set_dpi(dpi)
+ im = ax.imshow(da.sel(time=date), cmap=cmap, clim=(min, max))
+ if mask is not None:
+ if mask_type == 'contour':
+ ax.contour(mask, levels=[.5, 1], colors='k')
+ elif mask_type == 'contourf':
+ ax.contourf(mask, levels=[.5, 1], colors='k')
+ ax.axes.xaxis.set_visible(False)
+ ax.axes.yaxis.set_visible(False)
+
+ ax.set_title('{:04d}/{:02d}/{:02d}'.format(date.year, date.month, date.day), fontsize=figsize * 4)
+
+ divider = make_axes_locatable(ax)
+ cax = divider.append_axes('right', size='5%', pad=0.05)
+ plt.colorbar(im, cax)
+
+ # TEMP crop to image
+ # fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
+
+ fig.canvas.draw()
+ image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
+ image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+
+ plt.close()
+ return image
+
+ if video_dates is None:
+ video_dates = [pd.Timestamp(date).to_pydatetime() for date in da.time.values]
+
+ if crop is not None:
+ a = crop[0][0]
+ b = crop[0][1]
+ c = crop[1][0]
+ d = crop[1][1]
+ da = da.isel(xc=np.arange(a, b), yc=np.arange(c, d))
+ if mask is not None:
+ mask = mask[a:b, c:d]
+
+ imageio.mimsave(video_path,
+ [make_frame(date) for date in tqdm(video_dates)],
+ fps=fps)