Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split Gulf Steam plot from ssh_eval and rewrite ssh_eval to use config file #101

Merged
merged 9 commits into from
Oct 28, 2024
6 changes: 6 additions & 0 deletions diagnostics/physics/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
figures_dir: 'figures/'
glorys: '/work/acr/mom6/diagnostics/glorys/glorys_sfc.nc'
glorys_zos: '/work/acr/glorys/GLOBAL_MULTIYEAR_PHY_001_030/monthly/glorys_monthly_z_fine_*.nc'
model_grid: '../data/geography/ocean_static.nc'

# Variables to rename
Expand Down Expand Up @@ -66,3 +67,8 @@ levels_step: 2
bias_min: -2
bias_max: 2.1
bias_step: 0.25

# Colorbar for ssh plots
ssh_levels_min: -1.1
ssh_levels_max: .8
ssh_levels_step: .1
9 changes: 5 additions & 4 deletions diagnostics/physics/plot_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_map_norm(cmap, levels, no_offset=True):
norm = BoundaryNorm(levels, ncolors=nlev, clip=False)
return cmap, norm

def annotate_skill(model, obs, ax, dim=['yh', 'xh'], x0=-98.5, y0=54, yint=4, xint=4, weights=None, cols=1, proj = ccrs.PlateCarree(), plot_lat=False,**kwargs):
def annotate_skill(model, obs, ax, dim=['yh', 'xh'], x0=-98.5, y0=54, yint=4, xint=4, weights=None, cols=1, proj = ccrs.PlateCarree(), plot_lat=False, **kwargs):
"""
Annotate an axis with model vs obs skill metrics
"""
Expand All @@ -65,6 +65,7 @@ def annotate_skill(model, obs, ax, dim=['yh', 'xh'], x0=-98.5, y0=54, yint=4, xi
medae = xskillscore.median_absolute_error(model, obs, dim=dim, skipna=True)

ax.text(x0, y0, f'Bias: {float(bias):2.2f}', transform=proj, **kwargs)

# Set plot_lat=True in order to plot skill along a line of latitude. Otherwise, plot along longitude
if plot_lat:
ax.text(x0-xint, y0, f'RMSE: {float(rmse):2.2f}', transform=proj, **kwargs)
Expand Down Expand Up @@ -113,20 +114,20 @@ def autoextend_colorbar(ax, plot, plot_array=None, **kwargs):
extend = 'neither'
return ax.colorbar(plot, extend=extend, **kwargs)

def add_ticks(ax, xticks=np.arange(-100, -31, 1), yticks=np.arange(2, 61, 1), xlabelinterval=2, ylabelinterval=2, fontsize=10, **kwargs):
def add_ticks(ax, xticks=np.arange(-100, -31, 1), yticks=np.arange(2, 61, 1), xlabelinterval=2, ylabelinterval=2, fontsize=10, projection = ccrs.PlateCarree(), **kwargs):
"""
Add lat and lon ticks and labels to a plot axis.
By default, tick at 1 degree intervals for x and y, and label every other tick.
Additional kwargs are passed to LongitudeFormatter and LatitudeFormatter.
"""
ax.yaxis.tick_right()
ax.set_xticks(xticks, crs=ccrs.PlateCarree())
ax.set_xticks(xticks, crs = projection)
if xlabelinterval == 0:
plt.setp(ax.get_xticklabels(), visible=False)
else:
plt.setp([l for i, l in enumerate(ax.get_xticklabels()) if i % xlabelinterval != 0], visible=False, fontsize=fontsize)
plt.setp([l for i, l in enumerate(ax.get_xticklabels()) if i % xlabelinterval == 0], fontsize=fontsize)
ax.set_yticks(yticks, crs=ccrs.PlateCarree())
ax.set_yticks(yticks, crs = projection)
if ylabelinterval == 0:
plt.setp(ax.get_yticklabels(), visible=False)
else:
Expand Down
153 changes: 153 additions & 0 deletions diagnostics/physics/plot_gulf_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
Plot of the Gulf Stream position and index,
Uses whatever model data can be found within the directory pp_root,
and does not try to match the model and observed time periods.
How to use:
python plot_gulf_stream.py -p /archive/acr/fre/NWA/2023_04/NWA12_COBALT_2023_04_kpo4-coastatten-physics/gfdl.ncrc5-intel22-prod
"""
import xarray
import xesmf
import pandas as pd
import numpy as np
import cartopy.feature as feature
import cartopy.crs as ccrs
from cartopy.mpl.geoaxes import GeoAxes
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import AxesGrid

from plot_common import open_var, add_ticks, save_figure

def compute_gs(ssh, data_grid=None):
lons = np.arange(360-72, 360-51.9, 1)
lats = np.arange(36, 42, 0.1)
target_grid = {'lat': lats, 'lon': lons}

if data_grid is None:
data_grid = {'lat': ssh.lat, 'lon': ssh.lon}

ssh_to_grid = xesmf.Regridder(
data_grid,
target_grid,
method='bilinear'
)

# Interpolate the SSH data onto the index grid.
regridded = ssh_to_grid(ssh)

# Find anomalies relative to the calendar month mean SSH over the full model run.
anom = regridded.groupby('time.month') - regridded.groupby('time.month').mean('time')

# For each longitude point, the Gulf Stream is located at the latitude with the maximum SSH anomaly variance.
stdev = anom.std('time')
amax = stdev.argmax('lat').compute()
gs_points = stdev.lat.isel(lat=amax).compute()

# The index is the mean latitude of the Gulf Stream, divided by the standard deviation of the mean latitude of the Gulf Stream.
index = ((anom.isel(lat=amax).mean('lon')) / anom.isel(lat=amax).mean('lon').std('time')).compute()

# Move times to the beginning of the month to match observations.
monthly_index = index.to_pandas().resample('1MS').first()
return monthly_index, gs_points

def plot_gulf_stream(pp_root, label):

# Load Natural Earth Shapefiles
_LAND_50M = feature.NaturalEarthFeature(
'physical', 'land', '50m',
edgecolor='face',
facecolor='#999999'
)

# Get model grid
model_grid = xarray.open_dataset( '../data/geography/ocean_static.nc' )

# Get model thetao data TODO: maki this comment better
model_thetao = open_var(pp_root, 'ocean_monthly_z', 'thetao')

if '01_l' in model_thetao.coords:
model_thetao = model_thetao.rename({'01_l': 'z_l'})

model_t200 = model_thetao.interp(z_l=200).mean('time')

# Ideally would use SSH, but some diag_tables only saved zos
try:
model_ssh = open_var(pp_root, 'ocean_monthly', 'ssh')
except:
print('Using zos')
model_ssh = open_var(pp_root, 'ocean_monthly', 'zos')

model_ssh_index, model_ssh_points = compute_gs(
model_ssh,
data_grid=model_grid[['geolon', 'geolat']].rename({'geolon': 'lon', 'geolat': 'lat'})
)

# Get Glorys data
glorys_t200 = xarray.open_dataarray('../data/diagnostics/glorys_T200.nc')

# Get satellite points
#satellite_ssh_index, satellite_ssh_points = compute_gs(satellite['adt'])
#satellite_ssh_points.to_netcdf('../data/obs/satellite_ssh_points.nc')
#satellite_ssh_index.to_pickle('../data/obs/satellite_ssh_index.pkl')
#read pre-calculate satellite_ssh_index and points
satellite_ssh_points = xarray.open_dataset('../data/obs/satellite_ssh_points.nc')
satellite_ssh_index = pd.read_pickle('../data/obs/satellite_ssh_index.pkl')
satellite_rolled = satellite_ssh_index.rolling(25, center=True, min_periods=25).mean().dropna()

#satellite = xarray.open_mfdataset([f'/net2/acr/altimetry/SEALEVEL_GLO_PHY_L4_MY_008_047/adt_{y}_{m:02d}.nc' for y in range(1993, 2020) for m in range(1, 13)])
#satellite = satellite.rename({'longitude': 'lon', 'latitude': 'lat'})
#satellite = satellite.resample(time='1MS').mean('time')

# Get rolling averages and correlations
model_rolled = model_ssh_index.rolling(25, center=True, min_periods=25).mean().dropna()
corr = pd.concat((model_ssh_index, satellite_ssh_index), axis=1).corr().iloc[0, 1]
corr_rolled = pd.concat((model_rolled, satellite_rolled), axis=1).corr().iloc[0, 1]

# Plot of Gulf Stream position and index based on SSH,
# plus position based on T200.
fig = plt.figure(figsize=(10, 6), tight_layout=True)
gs = gridspec.GridSpec(2, 2, hspace=.25)

# Set projection of input data files so that data is correctly tranformed when plotting
proj = ccrs.PlateCarree()

ax = fig.add_subplot(gs[0, 0], projection = proj)
ax.add_feature(_LAND_50M)
ax.contour(model_grid.geolon, model_grid.geolat, model_t200, levels=[15], colors='r')
ax.contour(glorys_t200.longitude, glorys_t200.latitude, glorys_t200, levels=[15], colors='k')
add_ticks(ax, xlabelinterval=5)
ax.set_extent([-82, -50, 25, 41])
ax.set_title('(a) Gulf Stream position based on T200')
custom_lines = [Line2D([0], [0], color=c, lw=2) for c in ['r', 'k']]
ax.legend(custom_lines, ['Model', 'GLORYS12'], loc='lower right', frameon=False)

ax = fig.add_subplot(gs[0, 1], projection = proj)
ax.add_feature(_LAND_50M)
ax.plot(model_ssh_points.lon-360, model_ssh_points, c='r')
ax.plot(satellite_ssh_points.lon-360, satellite_ssh_points['__xarray_dataarray_variable__'], c='k')
add_ticks(ax, xlabelinterval=5)
ax.set_extent([-82, -50, 25, 41])
ax.set_title('(b) Gulf Stream position based on SSH variance')
ax.legend(custom_lines, ['Model', 'Altimetry'], loc='lower right', frameon=False)

ax = fig.add_subplot(gs[1, :])
model_ssh_index.plot(ax=ax, c='#ffbbbb', label='Model')
satellite_ssh_index.plot(ax=ax, c='#bbbbbb', label=f'Altimetry (r={corr:2.2f})')
model_rolled.plot(ax=ax, c='r', label='Model rolling mean')
satellite_rolled.plot(ax=ax, c='k', label=f'Altimetry rolling mean (r={corr_rolled:2.2f})')
ax.set_title('(c) Gulf Stream index based on SSH variance')
ax.set_xlabel('')
ax.set_ylim(-3, 3)
ax.set_ylabel('Index (positive north)')
ax.legend(ncol=4, loc='lower right', frameon=False, fontsize=8)

save_figure('gulfstream_eval', label=label, pdf=True)

if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('-p','--pp_root', help='Path to postprocessed data (up to but not including /pp/)', required = True)
parser.add_argument('-l', '--label', help='Label to add to figure file names', type=str, default='')
args = parser.parse_args()
plot_gulf_stream(args.pp_root, args.label)
Loading