Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize sst_trends #99

Merged
merged 8 commits into from
Oct 28, 2024
9 changes: 9 additions & 0 deletions diagnostics/physics/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,13 @@ levels_step: 2
# Colorbar for sst difference plots
bias_min: -2
bias_max: 2.1
bias_min_trends: -1.5
bias_max_trends: 1.51
bias_step: 0.25

ticks: [-2, -1, 0, 1, 2]


# SST Trends Settings
start_year: "2005"
end_year: "2019"
48 changes: 37 additions & 11 deletions diagnostics/physics/plot_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,15 @@ def load_config(config_path: str):
logger.error(f"Error loading configuration from {config_path}: {e}")
raise

def process_oisst(config, target_grid, model_ave):
def process_oisst(config, target_grid, model_ave, start=1993, end = 2020, resamp_freq = None, do_regrid=True
) -> ( (xesmf.Regridder | xarray.DataArray), xarray.DataArray, xarray.DataArray, xarray.DataArray ):
"""Open and regrid OISST dataset, return relevant vars from dataset."""
try:
oisst = (
xarray.open_mfdataset([config['oisst'] + f'sst.month.mean.{y}.nc' for y in range(1993, 2020)])
xarray.open_mfdataset([config['oisst'] + f'sst.month.mean.{y}.nc' for y in range(start, end)])
.sst
.sel(lat=slice(config['lat']['south'], config['lat']['north']), lon=slice(config['lon']['west'], config['lon']['east']))
.load()
)
except Exception as e:
logger.error(f"Error processing OISST data: {e}")
Expand All @@ -193,19 +195,31 @@ def process_oisst(config, target_grid, model_ave):

oisst_lonc, oisst_latc = corners(oisst.lon, oisst.lat)
oisst_lonc -= 360

mom_to_oisst = xesmf.Regridder(
target_grid,
{'lat': oisst.lat, 'lon': oisst.lon, 'lat_b': oisst_latc, 'lon_b': oisst_lonc},
method='conservative_normed',
unmapped_to_nan=True
)

oisst_ave = oisst.mean('time').load()
mom_rg = mom_to_oisst(model_ave)
logger.info("OISST data processed successfully.")
return mom_rg, oisst_ave, oisst_lonc, oisst_latc
# If a resample frequency is provided, use it to resample the oisst data over time before taking the average
if resamp_freq:
oisst = oisst.resample( time = resamp_freq )

oisst_ave = oisst.mean('time')

# Either apply the regridder to the average, or return the regrid object itself
if do_regrid:
mom_rg = mom_to_oisst(model_ave)
logger.info("OISST data processed successfully.")
return mom_rg, oisst_ave, oisst_lonc, oisst_latc

return mom_to_oisst, oisst_ave, oisst_lonc, oisst_latc

def process_glorys(config, target_grid, var):
def process_glorys(
config, target_grid, var, sel_time = None, resamp_freq = None, do_regrid=True
) -> ( (xesmf.Regridder | xarray.DataArray), xarray.DataArray, xarray.DataArray, xarray.DataArray ):
""" Open and regrid glorys data, return regridded glorys data """
glorys = xarray.open_dataset( config['glorys'] ).squeeze(drop=True) #.rename({'longitude': 'lon', 'latitude': 'lat'})
if var in glorys:
Expand All @@ -225,14 +239,26 @@ def process_glorys(config, target_grid, var):
logger.info("Glorys data is using longitude/latitude")
except:
logger.error("Name of longitude and latitude variables is unknown")
raise Exception("Error: Lat/Latitude, Lon/Longitdue not found in glorys data")
raise Exception("Error: Lat/Latitude, Lon/Longitude not found in glorys data")

# If a time slice is provided use it to select a portion of the glorys data
if sel_time:
glorys = glorys.sel( time = sel_time )

# If a resample frequency is provided, use it to resample the glorys data over time before taking the average
if resamp_freq:
glorys = glorys.resample(time = resamp_freq)

glorys_ave = glorys.mean('time').load()
glorys_to_mom = xesmf.Regridder(glorys_ave, target_grid, method='bilinear', unmapped_to_nan=True)
glorys_rg = glorys_to_mom(glorys_ave)

logger.info("Glorys data processed successfully.")
return glorys_rg, glorys_ave, glorys_lonc, glorys_latc
# Either apply the regridder to the average, or return the regrid object itself
if do_regrid:
glorys_rg = glorys_to_mom(glorys_ave)
logger.info("Glorys data processed successfully.")
return glorys_rg, glorys_ave, glorys_lonc, glorys_latc

return glorys_to_mom, glorys_ave, glorys_lonc, glorys_latc

def get_end_of_climatology_period(clima_file):
"""
Expand Down
144 changes: 91 additions & 53 deletions diagnostics/physics/sst_trends.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Compare the model 2005-019 sea surface temperature trends from OISST and GLORYS.
How to use:
python sst_trends.py /archive/acr/fre/NWA/2023_04/NWA12_COBALT_2023_04_kpo4-coastatten-physics/gfdl.ncrc5-intel22-prod
python sst_trends.py -p /archive/acr/fre/NWA/2023_04/NWA12_COBALT_2023_04_kpo4-coastatten-physics/gfdl.ncrc5-intel22-prod -c config.yaml
"""
import cartopy.crs as ccrs
from cartopy.mpl.geoaxes import GeoAxes
Expand All @@ -10,12 +10,15 @@
from mpl_toolkits.axes_grid1 import AxesGrid
import numpy as np
import xarray
import xesmf
import logging

from plot_common import autoextend_colorbar, corners, get_map_norm, annotate_skill, open_var, save_figure

PC = ccrs.PlateCarree()
from plot_common import( autoextend_colorbar, corners, get_map_norm,
annotate_skill, open_var, save_figure, load_config,
process_glorys, process_oisst)

# Configure logging for sst_eval
logger = logging.getLogger(__name__)
logging.basicConfig(filename="sst_trends.log", format='%(asctime)s %(levelname)s:%(name)s: %(message)s',level=logging.INFO)

def get_3d_trends(x, y):
x = np.array(x)
Expand All @@ -25,101 +28,132 @@ def get_3d_trends(x, y):
return trends


def plot_sst_trends(pp_root, label):
def plot_sst_trends(pp_root, label, config):
model = (
open_var(pp_root, 'ocean_monthly', 'tos')
.sel(time=slice('2005', '2019'))
.sel(time=slice(config['start_year'], config['end_year']))
.resample(time='1AS')
.mean('time')
.load()
)
model_grid = xarray.open_dataset('../data/geography/ocean_static.nc')
logger.info("MODEL: %s",model)
model_grid = xarray.open_dataset( config['model_grid'])
logger.info("MODEL_GRID: %s",model_grid)

# Verify that xh/yh are set as coordinates, then make sure model coordinates match grid data
model_grid = model_grid.assign_coords( {'xh':model_grid.xh, 'yh':model_grid.yh } )
model = xarray.align(model_grid, model, join='override', exclude='time')[1]
logger.info("Successfully modified coordinates of model grid, and aligned model coordinates to grid coordinates")

model_trend = get_3d_trends(model['time.year'], model) * 10 # -> C / decade
# Convert to Data Array, since xskillscore expects dataarrays to calculate skill metrics
model_trend = xarray.DataArray(model_trend, dims=['yh', 'xh'], coords={'yh': model.yh, 'xh': model.xh})
logger.info("MODEL_TREND: %s", model_trend)

oisst = (
xarray.open_mfdataset([f'/work/acr/oisstv2/sst.month.mean.{y}.nc' for y in range(2005, 2020)])
.sst
.sel(lat=slice(0, 60), lon=slice(360-100, 360-30))
.resample(time='1AS')
.mean('time')
.load()
)
oisst_trend = get_3d_trends(oisst['time.year'], oisst) * 10 # -> C / decade
target_grid = model_grid[ config['rename_map'].keys() ].rename( config['rename_map'] )

glorys = (
xarray.open_dataset('/work/acr/mom6/diagnostics/glorys/glorys_sfc.nc')
['thetao']
.sel(time=slice('2005', '2019'))
.resample(time='1AS')
.mean('time')
)
# Process OISST and get trend
mom_to_oisst, oisst, oisst_lonc, oisst_latc = process_oisst(config, target_grid, model, start = int(config['start_year']),
end = int(config['end_year'])+1, resamp_freq = '1AS',
do_regrid = False) # (note that model data is not used if do_regrid = False)
uwagura marked this conversation as resolved.
Show resolved Hide resolved
logger.info("OISST: %s", oisst )
oisst_trend = get_3d_trends(oisst['time.year'], oisst) * 10 # -> C / decade
oisst_trend = xarray.DataArray(oisst_trend, dims=['lat','lon'], coords={'lat':oisst.lat,'lon':oisst.lon} )
logger.info("OISST_TREND: %s",oisst_trend)

mom_rg = mom_to_oisst(model_trend)
mom_rg = xarray.DataArray(mom_rg, dims = ['lat','lon'], coords = {'lat':oisst.lat, 'lon':oisst.lon} )
oisst_delta = mom_rg - oisst_trend
logger.info("MOM_RG: %s",mom_rg)
logger.info("OISST_DELTA: %s",oisst_delta)

# Process Glorys and get trend
glorys_to_mom , glorys, glorys_lonc, glorys_latc = process_glorys(config, target_grid, 'thetao',
sel_time = slice(config['start_year'], config['end_year']),
resamp_freq = '1AS', do_regrid=False)
glorys_trend = get_3d_trends(glorys['time.year'], glorys) * 10 # -> C / decade

oisst_lonc, oisst_latc = corners(oisst.lon, oisst.lat)
oisst_lonc -= 360
oisst_to_mom = xesmf.Regridder({'lat': oisst.lat, 'lon': oisst.lon}, model_grid[['geolon', 'geolat']].rename({'geolon': 'lon', 'geolat': 'lat'}), method='bilinear')

glorys_lonc, glorys_latc = corners(glorys.lon, glorys.lat)
glorys_to_mom = xesmf.Regridder(glorys, model_grid[['geolon', 'geolat']].rename({'geolon': 'lon', 'geolat': 'lat'}), method='bilinear')
logger.info("GLORYS_TREND: %s",glorys_trend)

glorys_rg = glorys_to_mom(glorys_trend)
glorys_rg = xarray.DataArray(glorys_rg, dims=['yh', 'xh'], coords={'yh': model.yh, 'xh': model.xh})
glorys_delta = model_trend - glorys_rg
logger.info("GLORYS_RG: %s",glorys_rg)
logger.info("GLORYS_DELTA: %s",glorys_delta)

oisst_rg = oisst_to_mom(oisst_trend)
oisst_rg = xarray.DataArray(oisst_rg, dims=['yh', 'xh'], coords={'yh': model.yh, 'xh': model.xh})
oisst_delta = model_trend - oisst_rg
# Set projection of each grid in the plot
# For now, sst_eval.py will only support a projection for the arctic and a projection for all other domains
if config['projection_grid'] == 'NorthPolarStereo':
p = ccrs.NorthPolarStereo()
else:
p = ccrs.PlateCarree()

fig = plt.figure(figsize=(10, 14))
grid = AxesGrid(fig, 111,
nrows_ncols=(2, 3),
axes_class = (GeoAxes, dict(projection=PC)),
axes_class = (GeoAxes, dict(projection=p)),
axes_pad=0.3,
cbar_location='bottom',
cbar_mode='edge',
cbar_pad=0.2,
cbar_size='15%',
label_mode=''
label_mode='keep'
)
logger.info("Successfully created grid")

cmap, norm = get_map_norm('cet_CET_D1', np.arange(-2, 2.1, .25), no_offset=True)
cmap, norm = get_map_norm('cet_CET_D1', np.arange(config['bias_min'], config['bias_max'], config['bias_step']), no_offset=True)
common = dict(cmap=cmap, norm=norm)

bias_cmap, bias_norm = get_map_norm('RdBu_r', np.arange(-1.5, 1.51, .25), no_offset=True)
bias_cmap, bias_norm = get_map_norm('RdBu_r', np.arange(config['bias_min_trends'], config['bias_max_trends'], config['bias_step']), no_offset=True)
bias_common = dict(cmap=bias_cmap, norm=bias_norm)

p0 = grid[0].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, model_trend, **common)
# Set projection of input data files so that data is correctly tranformed when plotting
# For now, sst_eval.py will only support a projection for the arctic and a projection for all other domains
if config['projection_data'] == 'NorthPolarStereo':
proj = ccrs.NorthPolarStereo()
else:
proj = ccrs.PlateCarree()

# MODEL
p0 = grid[0].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, model_trend, transform = proj, **common)
grid[0].set_title('(a) Model')
cbar0 = autoextend_colorbar(grid.cbar_axes[0], p0)
cbar0.ax.set_xlabel('SST trend (°C / decade)')
cbar0.set_ticks([-2, -1, 0, 1, 2])
cbar0.set_ticklabels([-2, -1, 0, 1, 2])
cbar0.set_ticks( config['ticks'] )
cbar0.set_ticklabels( config['ticks'] )
logger.info("Successfully plotted model data")

p1 = grid[1].pcolormesh(oisst_lonc, oisst_latc, oisst_trend, **common)
# OISST
p1 = grid[1].pcolormesh(oisst_lonc, oisst_latc, oisst_trend, transform = proj, **common)
grid[1].set_title('(b) OISST')
logger.info("Successfully plotted oisst")

grid[2].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, oisst_delta, **bias_common)
# MODEL - OISST
grid[2].pcolormesh(oisst_lonc, oisst_latc, oisst_delta, transform = proj, **bias_common)
grid[2].set_title('(c) Model - OISST')
annotate_skill(model_trend, oisst_rg, grid[2], weights=model_grid.areacello)
# NOTE: Oisst dims are [lat,lon], so dim argument is needed. Must use mom_rg though, since oisst also contains
# an extra time dimension that changes output of xskillscore functions and leads to error when annotating plot
annotate_skill(mom_rg, oisst_trend, grid[2], dim= list(mom_rg.dims), x0=config['text_x'], y0=config['text_y'], xint=config['text_xint'], plot_lat=config['plot_lat'])
logger.info("Successfully plotted difference between model and oisst")

grid[4].pcolormesh(glorys_lonc, glorys_latc, glorys_trend, **common)
# GLORYS
grid[4].pcolormesh(glorys_lonc, glorys_latc, glorys_trend, transform = proj, **common)
grid[4].set_title('(d) GLORYS12')
cbar1 = autoextend_colorbar(grid.cbar_axes[1], p1)
cbar1.ax.set_xlabel('SST trend (°C / decade)')
cbar1.set_ticks([-2, -1, 0, 1, 2])
cbar1.set_ticklabels([-2, -1, 0, 1, 2])
cbar1.set_ticks( config['ticks'] )
cbar1.set_ticklabels( config['ticks'] )
logger.info("Successfully plotted glorys")

p2 = grid[5].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, glorys_delta, **bias_common)
# MODEL - GLORYS
p2 = grid[5].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, glorys_delta, transform = proj, **bias_common)
grid[5].set_title('(e) Model - GLORYS12')
cbar2 = autoextend_colorbar(grid.cbar_axes[2], p2)
cbar2.ax.set_xlabel('SST trend difference (°C / decade)')
annotate_skill(model_trend, glorys_rg, grid[5], weights=model_grid.areacello)
annotate_skill(model_trend, glorys_rg, grid[5], weights=model_grid.areacello, x0=config['text_x'], y0=config['text_y'], xint=config['text_xint'], plot_lat=config['plot_lat'])
logger.info("Successfully plotted difference between glorys and model")

for i, ax in enumerate(grid):
ax.set_xlim(-99, -35)
ax.set_ylim(4, 59)
ax.set_extent([ config['x']['min'], config['x']['max'], config['y']['min'], config['y']['max'] ], crs=proj)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
Expand All @@ -128,14 +162,18 @@ def plot_sst_trends(pp_root, label):
ax.set_facecolor('#bbbbbb')
for s in ax.spines.values():
s.set_visible(False)
logger.info("Successfully set extent of each axis")

save_figure('sst_trends', label=label)
logger.info("Successfully saved figure")


if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('pp_root', help='Path to postprocessed data (up to but not including /pp/)')
parser.add_argument('-p','--pp_root', help='Path to postprocessed data (up to but not including /pp/)', required = True)
parser.add_argument('-c','--config', help='Path to yaml config file', required = True)
parser.add_argument('-l', '--label', help='Label to add to figure file names', type=str, default='')
args = parser.parse_args()
plot_sst_trends(args.pp_root, args.label)
config = load_config(args.config)
plot_sst_trends(args.pp_root, args.label, config)