Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize sst_trends #99

Merged
merged 8 commits into from
Oct 28, 2024
9 changes: 9 additions & 0 deletions diagnostics/physics/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,13 @@ levels_step: 2
# Colorbar for sst difference plots
bias_min: -2
bias_max: 2.1
bias_min_trends: -1.5
bias_max_trends: 1.51
bias_step: 0.25

ticks: [-2, -1, 0, 1, 2]


# SST Trends Settings
start_year: "2005"
end_year: "2019"
43 changes: 36 additions & 7 deletions diagnostics/physics/plot_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,14 @@ def load_config(config_path: str):
logger.error(f"Error loading configuration from {config_path}: {e}")
raise

def process_oisst(config, target_grid, model_ave):
def process_oisst(config, target_grid, model_ave, start=1993, end = 2020, resamp_freq = None):
"""Open and regrid OISST dataset, return relevant vars from dataset."""
try:
oisst = (
xarray.open_mfdataset([config['oisst'] + f'sst.month.mean.{y}.nc' for y in range(1993, 2020)])
xarray.open_mfdataset([config['oisst'] + f'sst.month.mean.{y}.nc' for y in range(start, end)])
.sst
.sel(lat=slice(config['lat']['south'], config['lat']['north']), lon=slice(config['lon']['west'], config['lon']['east']))
.load()
)
except Exception as e:
logger.error(f"Error processing OISST data: {e}")
Expand All @@ -193,20 +194,33 @@ def process_oisst(config, target_grid, model_ave):

oisst_lonc, oisst_latc = corners(oisst.lon, oisst.lat)
oisst_lonc -= 360

mom_to_oisst = xesmf.Regridder(
target_grid,
{'lat': oisst.lat, 'lon': oisst.lon, 'lat_b': oisst_latc, 'lon_b': oisst_lonc},
method='conservative_normed',
unmapped_to_nan=True
)

oisst_ave = oisst.mean('time').load()
# If a resample frequency is provided, use it to resample the oisst data over time before taking the average
if resamp_freq:
oisst = oisst.resample( time = resamp_freq )

oisst_ave = oisst.mean('time')

mom_rg = mom_to_oisst(model_ave)
logger.info("OISST data processed successfully.")
return mom_rg, oisst_ave, oisst_lonc, oisst_latc

def process_glorys(config, target_grid, var):
""" Open and regrid glorys data, return regridded glorys data """
def process_glorys(config, target_grid, var, sel_time = None, resamp_freq = None, preprocess_regrid = None):
"""
Open and regrid glorys data, return regridded glorys data
If a function is passed to the preprocess_regrid option, it will be called on the
data before it is passed to the regridder but after the regridder
is created and the average is calculated
NOTE: if preprocess_regrid returns numpy array, the return value of glorys_ave will
be a numpy array, not an xarray dataarray as is the default
"""
glorys = xarray.open_dataset( config['glorys'] ).squeeze(drop=True) #.rename({'longitude': 'lon', 'latitude': 'lat'})
if var in glorys:
glorys = glorys[var]
Expand All @@ -225,15 +239,30 @@ def process_glorys(config, target_grid, var):
logger.info("Glorys data is using longitude/latitude")
except:
logger.error("Name of longitude and latitude variables is unknown")
raise Exception("Error: Lat/Latitude, Lon/Longitdue not found in glorys data")
raise Exception("Error: Lat/Latitude, Lon/Longitude not found in glorys data")

# If a time slice is provided use it to select a portion of the glorys data
if sel_time:
glorys = glorys.sel( time = sel_time )

# If a resample frequency is provided, use it to resample the glorys data over time before taking the average
if resamp_freq:
glorys = glorys.resample(time = resamp_freq)

glorys_ave = glorys.mean('time').load()

glorys_to_mom = xesmf.Regridder(glorys_ave, target_grid, method='bilinear', unmapped_to_nan=True)
glorys_rg = glorys_to_mom(glorys_ave)

# If a preprocessing function is provided, call it before doing any regridding
# glorys_ave may not remain a xarray dataset after this step
if preprocess_regrid:
glorys_ave = preprocess_regrid(glorys_ave)

glorys_rg = glorys_to_mom(glorys_ave)
logger.info("Glorys data processed successfully.")
return glorys_rg, glorys_ave, glorys_lonc, glorys_latc


def get_end_of_climatology_period(clima_file):
"""
Determine the time period covered by the last climatology file. This function is needed
Expand Down
149 changes: 92 additions & 57 deletions diagnostics/physics/sst_trends.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Compare the model 2005-019 sea surface temperature trends from OISST and GLORYS.
How to use:
python sst_trends.py /archive/acr/fre/NWA/2023_04/NWA12_COBALT_2023_04_kpo4-coastatten-physics/gfdl.ncrc5-intel22-prod
python sst_trends.py -p /archive/acr/fre/NWA/2023_04/NWA12_COBALT_2023_04_kpo4-coastatten-physics/gfdl.ncrc5-intel22-prod -c config.yaml
"""
import cartopy.crs as ccrs
from cartopy.mpl.geoaxes import GeoAxes
Expand All @@ -10,116 +10,147 @@
from mpl_toolkits.axes_grid1 import AxesGrid
import numpy as np
import xarray
import xesmf
import logging

from plot_common import autoextend_colorbar, corners, get_map_norm, annotate_skill, open_var, save_figure
from plot_common import( autoextend_colorbar, corners, get_map_norm,
annotate_skill, open_var, save_figure, load_config,
process_glorys, process_oisst)

PC = ccrs.PlateCarree()
# Configure logging for sst_eval
logger = logging.getLogger(__name__)
logging.basicConfig(filename="sst_trends.log", format='%(asctime)s %(levelname)s:%(name)s: %(message)s',level=logging.INFO)


def get_3d_trends(x, y):
x = np.array(x)
def get_3d_trends(y):
x = np.array( y['time.year'] )
y2 = np.array(y).reshape((len(x), -1))
coefs = np.polyfit(x, y2, 1)
trends = coefs[0, :].reshape(y.shape[1:])
trends = coefs[0, :].reshape(y.shape[1:]) * 10 # -> C / decade

return trends


def plot_sst_trends(pp_root, label):
def plot_sst_trends(pp_root, label, config):
model = (
open_var(pp_root, 'ocean_monthly', 'tos')
.sel(time=slice('2005', '2019'))
.sel(time=slice(config['start_year'], config['end_year']))
.resample(time='1AS')
.mean('time')
.load()
)
model_grid = xarray.open_dataset('../data/geography/ocean_static.nc')
logger.info("MODEL: %s",model)
model_grid = xarray.open_dataset( config['model_grid'])
logger.info("MODEL_GRID: %s",model_grid)

# Verify that xh/yh are set as coordinates, then make sure model coordinates match grid data
model_grid = model_grid.assign_coords( {'xh':model_grid.xh, 'yh':model_grid.yh } )
model = xarray.align(model_grid, model, join='override', exclude='time')[1]
logger.info("Successfully modified coordinates of model grid, and aligned model coordinates to grid coordinates")

model_trend = get_3d_trends(model['time.year'], model) * 10 # -> C / decade
model_trend = get_3d_trends(model)
# Convert to Data Array, since xskillscore expects dataarrays to calculate skill metrics
model_trend = xarray.DataArray(model_trend, dims=['yh', 'xh'], coords={'yh': model.yh, 'xh': model.xh})
logger.info("MODEL_TREND: %s", model_trend)

oisst = (
xarray.open_mfdataset([f'/work/acr/oisstv2/sst.month.mean.{y}.nc' for y in range(2005, 2020)])
.sst
.sel(lat=slice(0, 60), lon=slice(360-100, 360-30))
.resample(time='1AS')
.mean('time')
.load()
)
oisst_trend = get_3d_trends(oisst['time.year'], oisst) * 10 # -> C / decade
target_grid = model_grid[ config['rename_map'].keys() ].rename( config['rename_map'] )

glorys = (
xarray.open_dataset('/work/acr/mom6/diagnostics/glorys/glorys_sfc.nc')
['thetao']
.sel(time=slice('2005', '2019'))
.resample(time='1AS')
.mean('time')
)
glorys_trend = get_3d_trends(glorys['time.year'], glorys) * 10 # -> C / decade
# Process OISST and get trend
mom_rg, oisst, oisst_lonc, oisst_latc = process_oisst(config, target_grid, model_trend, start = int(config['start_year']),
end = int(config['end_year'])+1, resamp_freq = '1AS')
logger.info("OISST: %s", oisst )
oisst_trend = get_3d_trends(oisst)
oisst_trend = xarray.DataArray(oisst_trend, dims=['lat','lon'], coords={'lat':oisst.lat,'lon':oisst.lon} )
logger.info("OISST_TREND: %s",oisst_trend)

oisst_lonc, oisst_latc = corners(oisst.lon, oisst.lat)
oisst_lonc -= 360
oisst_to_mom = xesmf.Regridder({'lat': oisst.lat, 'lon': oisst.lon}, model_grid[['geolon', 'geolat']].rename({'geolon': 'lon', 'geolat': 'lat'}), method='bilinear')
oisst_delta = mom_rg - oisst_trend
logger.info("MOM_RG: %s",mom_rg)
logger.info("OISST_DELTA: %s",oisst_delta)

glorys_lonc, glorys_latc = corners(glorys.lon, glorys.lat)
glorys_to_mom = xesmf.Regridder(glorys, model_grid[['geolon', 'geolat']].rename({'geolon': 'lon', 'geolat': 'lat'}), method='bilinear')
# Process Glorys and get trend
# NOTE: Glorys_ave is glorys_trends because we call get_3d_trends on it.
glorys_rg, glorys_trend, glorys_lonc, glorys_latc = process_glorys(config, target_grid, 'thetao',
sel_time = slice(config['start_year'], config['end_year']),
resamp_freq = '1AS', preprocess_regrid = get_3d_trends)
logger.info("GLORYS_TREND: %s",glorys_trend)

glorys_rg = glorys_to_mom(glorys_trend)
glorys_rg = xarray.DataArray(glorys_rg, dims=['yh', 'xh'], coords={'yh': model.yh, 'xh': model.xh})
glorys_delta = model_trend - glorys_rg
logger.info("GLORYS_RG: %s",glorys_rg)
logger.info("GLORYS_DELTA: %s",glorys_delta)

oisst_rg = oisst_to_mom(oisst_trend)
oisst_rg = xarray.DataArray(oisst_rg, dims=['yh', 'xh'], coords={'yh': model.yh, 'xh': model.xh})
oisst_delta = model_trend - oisst_rg
# Set projection of each grid in the plot
# For now, sst_eval.py will only support a projection for the arctic and a projection for all other domains
if config['projection_grid'] == 'NorthPolarStereo':
p = ccrs.NorthPolarStereo()
else:
p = ccrs.PlateCarree()

fig = plt.figure(figsize=(10, 14))
grid = AxesGrid(fig, 111,
nrows_ncols=(2, 3),
axes_class = (GeoAxes, dict(projection=PC)),
axes_class = (GeoAxes, dict(projection=p)),
axes_pad=0.3,
cbar_location='bottom',
cbar_mode='edge',
cbar_pad=0.2,
cbar_size='15%',
label_mode=''
label_mode='keep'
)
logger.info("Successfully created grid")

cmap, norm = get_map_norm('cet_CET_D1', np.arange(-2, 2.1, .25), no_offset=True)
cmap, norm = get_map_norm('cet_CET_D1', np.arange(config['bias_min'], config['bias_max'], config['bias_step']), no_offset=True)
common = dict(cmap=cmap, norm=norm)

bias_cmap, bias_norm = get_map_norm('RdBu_r', np.arange(-1.5, 1.51, .25), no_offset=True)
bias_cmap, bias_norm = get_map_norm('RdBu_r', np.arange(config['bias_min_trends'], config['bias_max_trends'], config['bias_step']), no_offset=True)
bias_common = dict(cmap=bias_cmap, norm=bias_norm)

p0 = grid[0].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, model_trend, **common)
# Set projection of input data files so that data is correctly tranformed when plotting
# For now, sst_eval.py will only support a projection for the arctic and a projection for all other domains
if config['projection_data'] == 'NorthPolarStereo':
proj = ccrs.NorthPolarStereo()
else:
proj = ccrs.PlateCarree()

# MODEL
p0 = grid[0].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, model_trend, transform = proj, **common)
grid[0].set_title('(a) Model')
cbar0 = autoextend_colorbar(grid.cbar_axes[0], p0)
cbar0.ax.set_xlabel('SST trend (°C / decade)')
cbar0.set_ticks([-2, -1, 0, 1, 2])
cbar0.set_ticklabels([-2, -1, 0, 1, 2])
cbar0.set_ticks( config['ticks'] )
cbar0.set_ticklabels( config['ticks'] )
logger.info("Successfully plotted model data")

p1 = grid[1].pcolormesh(oisst_lonc, oisst_latc, oisst_trend, **common)
# OISST
p1 = grid[1].pcolormesh(oisst_lonc, oisst_latc, oisst_trend, transform = proj, **common)
grid[1].set_title('(b) OISST')
logger.info("Successfully plotted oisst")

grid[2].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, oisst_delta, **bias_common)
# MODEL - OISST
grid[2].pcolormesh(oisst_lonc, oisst_latc, oisst_delta, transform = proj, **bias_common)
grid[2].set_title('(c) Model - OISST')
annotate_skill(model_trend, oisst_rg, grid[2], weights=model_grid.areacello)
# NOTE: Oisst dims are [lat,lon], so dim argument is needed. Must use mom_rg though, since oisst also contains
# an extra time dimension that changes output of xskillscore functions and leads to error when annotating plot
annotate_skill(mom_rg, oisst_trend, grid[2], dim= list(mom_rg.dims), x0=config['text_x'], y0=config['text_y'], xint=config['text_xint'], plot_lat=config['plot_lat'])
logger.info("Successfully plotted difference between model and oisst")

grid[4].pcolormesh(glorys_lonc, glorys_latc, glorys_trend, **common)
# GLORYS
grid[4].pcolormesh(glorys_lonc, glorys_latc, glorys_trend, transform = proj, **common)
grid[4].set_title('(d) GLORYS12')
cbar1 = autoextend_colorbar(grid.cbar_axes[1], p1)
cbar1.ax.set_xlabel('SST trend (°C / decade)')
cbar1.set_ticks([-2, -1, 0, 1, 2])
cbar1.set_ticklabels([-2, -1, 0, 1, 2])
cbar1.set_ticks( config['ticks'] )
cbar1.set_ticklabels( config['ticks'] )
logger.info("Successfully plotted glorys")

p2 = grid[5].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, glorys_delta, **bias_common)
# MODEL - GLORYS
p2 = grid[5].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, glorys_delta, transform = proj, **bias_common)
grid[5].set_title('(e) Model - GLORYS12')
cbar2 = autoextend_colorbar(grid.cbar_axes[2], p2)
cbar2.ax.set_xlabel('SST trend difference (°C / decade)')
annotate_skill(model_trend, glorys_rg, grid[5], weights=model_grid.areacello)
annotate_skill(model_trend, glorys_rg, grid[5], weights=model_grid.areacello, x0=config['text_x'], y0=config['text_y'], xint=config['text_xint'], plot_lat=config['plot_lat'])
logger.info("Successfully plotted difference between glorys and model")

for i, ax in enumerate(grid):
ax.set_xlim(-99, -35)
ax.set_ylim(4, 59)
ax.set_extent([ config['x']['min'], config['x']['max'], config['y']['min'], config['y']['max'] ], crs=proj)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
Expand All @@ -128,14 +159,18 @@ def plot_sst_trends(pp_root, label):
ax.set_facecolor('#bbbbbb')
for s in ax.spines.values():
s.set_visible(False)
logger.info("Successfully set extent of each axis")

save_figure('sst_trends', label=label)
logger.info("Successfully saved figure")


if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('pp_root', help='Path to postprocessed data (up to but not including /pp/)')
parser.add_argument('-p','--pp_root', help='Path to postprocessed data (up to but not including /pp/)', required = True)
parser.add_argument('-c','--config', help='Path to yaml config file', required = True)
parser.add_argument('-l', '--label', help='Label to add to figure file names', type=str, default='')
args = parser.parse_args()
plot_sst_trends(args.pp_root, args.label)
config = load_config(args.config)
plot_sst_trends(args.pp_root, args.label, config)