Skip to content

Commit

Permalink
allow log parameters in pycbc_plot_bank_corner
Browse files Browse the repository at this point in the history
  • Loading branch information
GarethCabournDavies committed Aug 9, 2024
1 parent 84e8c6d commit 5f30ff1
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 8 deletions.
46 changes: 40 additions & 6 deletions bin/plotting/pycbc_plot_bank_corner
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ parser.add_argument("--parameters",
"property of that parameter will be used. If not "
"provided, will plot all of the parameters in the "
"bank.")
parser.add_argument(
'--log-parameters',
nargs='+',
help="Which parameters are to be plotted on a log scale? "
"Must be also given in parameters"
)
parser.add_argument('--plot-histogram',
action='store_true',
help="Plot 1D histograms of parameters on the "
Expand Down Expand Up @@ -103,6 +109,11 @@ parser.add_argument("--color-parameter",
help="Color scatter points according to the parameter given. "
"May optionally provide a label in the same way as for "
"--parameters. Default=No scatter point coloring.")
parser.add_argument(
'--log-colormap',
action='store_true',
help="Should the colorbar be plotted on a log scale?"
)
parser.add_argument('--dpi',
type=int,
default=200,
Expand All @@ -117,6 +128,13 @@ parser.add_argument('--title',
add_style_opt_to_parser(parser)
args = parser.parse_args()

for lp in args.log_parameters:
if not lp in args.parameters:
parser.error(
"--log-parameters should be in --parameters. "
f"{lp} not in [{', '.join(args.parameters)}]"
)

pycbc.init_logging(args.verbose)
set_style_from_cli(args)

Expand Down Expand Up @@ -146,7 +164,7 @@ if args.fits_file is not None:
param = fits_f[p][:].astype(float)
# We need to check for the cardinal '-1' value which means
# that the fit is invalid
param[param <= 0] = 0 if 'count' in p and 'log' not in p else np.nan
param[param <= 0] = np.nan
bank[p] = param

logging.info("Got %d templates from the bank", banklen)
Expand Down Expand Up @@ -227,12 +245,21 @@ if cpar:
for p in required_minmax:
minval = np.nanmin(bank_fa[p][bank_fa[p] != -np.inf])
maxval = np.nanmax(bank_fa[p][bank_fa[p] != np.inf])
valrange = maxval - minval
if (p in args.log_parameters) or (p == cpar and args.log_colormap):
# Extend the range by 10% in log-space
logvalrange = np.log(maxval) - np.log(minval)
if p not in mins:
mins[p] = np.exp(np.log(minval) - 0.05 * logvalrange)
if p not in maxs:
maxs[p] = np.exp(np.log(maxval) + 0.05 * logvalrange)
else:
# Extend the range by 10%
valrange = maxval - minval
if p not in mins:
mins[p] = minval - 0.05 * valrange
if p not in maxs:
maxs[p] = maxval + 0.05 * valrange

if p not in mins:
mins[p] = minval - 0.05 * valrange
if p not in maxs:
maxs[p] = maxval + 0.05 * valrange

# Deal with non-coloring case:
zvals = bank_fa[cpar] if cpar else None
Expand All @@ -247,6 +274,7 @@ fig, axis_dict = create_multidim_plot(
plot_scatter=True,
plot_contours=False,
scatter_cmap="viridis",
scatter_log_cmap=args.log_colormap,
marginal_title=False,
marginal_percentiles=[],
fill_color='g',
Expand Down Expand Up @@ -293,6 +321,12 @@ for i in range(len(args.parameters)):
for s0, s1 in zip(sharex_axes[:-1], sharex_axes[1:]):
s0.sharex(s1)

for (p1, p2), ax in axis_dict.items():
if p1 in args.log_parameters:
ax[0].semilogx()
if p2 in args.log_parameters:
ax[0].semilogy()

logging.info("Plot generated")
fig.set_dpi(args.dpi)

Expand Down
13 changes: 11 additions & 2 deletions pycbc/results/scatter_histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
if 'matplotlib.backends' not in sys.modules: # nopep8
matplotlib.use('agg')

from matplotlib import (offsetbox, pyplot, gridspec)
from matplotlib import (offsetbox, pyplot, gridspec, colors)

from pycbc.results import str_utils
from pycbc.io import FieldArray
Expand Down Expand Up @@ -545,6 +545,7 @@ def create_multidim_plot(parameters, samples, labels=None,
marginal_title=True, marginal_linestyle='-',
zvals=None, show_colorbar=True, cbar_label=None,
vmin=None, vmax=None, scatter_cmap='plasma',
scatter_log_cmap=False,
plot_density=False, plot_contours=True,
density_cmap='viridis',
contour_color=None, label_contours=True,
Expand Down Expand Up @@ -614,6 +615,8 @@ def create_multidim_plot(parameters, samples, labels=None,
zvals.
scatter_cmap : {'plasma', string}
The color map to use for the scatter points. Default is 'plasma'.
scatter_log_cmap : boolean
Should the scatter point coloring be on a log scale? Default False
plot_density : {False, bool}
Plot the density of points as a color map.
plot_contours : {True, bool}
Expand Down Expand Up @@ -749,8 +752,14 @@ def create_multidim_plot(parameters, samples, labels=None,
alpha = 0.3
else:
alpha = 1.
if scatter_log_cmap:
cmap_norm = colors.LogNorm(vmin=vmin, vmax=vmax)
else:
cmap_norm = colors.Normalize(vmin=vmin, vmax=vmax)


plt = ax.scatter(x=samples[px], y=samples[py], c=zvals, s=5,
edgecolors='none', vmin=vmin, vmax=vmax,
edgecolors='none', norm=cmap_norm,
cmap=scatter_cmap, alpha=alpha, zorder=2)

if plot_contours or plot_density:
Expand Down

0 comments on commit 5f30ff1

Please sign in to comment.