From 5f30ff1f43ce9edc388678d155e71892131464b7 Mon Sep 17 00:00:00 2001 From: GarethCabournDavies Date: Fri, 9 Aug 2024 02:30:10 -0700 Subject: [PATCH] allow log parameters in pycbc_plot_bank_corner --- bin/plotting/pycbc_plot_bank_corner | 46 +++++++++++++++++++++++++---- pycbc/results/scatter_histograms.py | 13 ++++++-- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/bin/plotting/pycbc_plot_bank_corner b/bin/plotting/pycbc_plot_bank_corner index 01fdf9100cf..627de91cccd 100644 --- a/bin/plotting/pycbc_plot_bank_corner +++ b/bin/plotting/pycbc_plot_bank_corner @@ -72,6 +72,12 @@ parser.add_argument("--parameters", "property of that parameter will be used. If not " "provided, will plot all of the parameters in the " "bank.") +parser.add_argument( + '--log-parameters', + nargs='+', + help="Which parameters are to be plotted on a log scale? " + "Must be also given in parameters" +) parser.add_argument('--plot-histogram', action='store_true', help="Plot 1D histograms of parameters on the " @@ -103,6 +109,11 @@ parser.add_argument("--color-parameter", help="Color scatter points according to the parameter given. " "May optionally provide a label in the same way as for " "--parameters. Default=No scatter point coloring.") +parser.add_argument( + '--log-colormap', + action='store_true', + help="Should the colorbar be plotted on a log scale?" +) parser.add_argument('--dpi', type=int, default=200, @@ -117,6 +128,13 @@ parser.add_argument('--title', add_style_opt_to_parser(parser) args = parser.parse_args() +for lp in args.log_parameters: + if not lp in args.parameters: + parser.error( + "--log-parameters should be in --parameters. " + f"{lp} not in [{', '.join(args.parameters)}]" + ) + pycbc.init_logging(args.verbose) set_style_from_cli(args) @@ -146,7 +164,7 @@ if args.fits_file is not None: param = fits_f[p][:].astype(float) # We need to check for the cardinal '-1' value which means # that the fit is invalid - param[param <= 0] = 0 if 'count' in p and 'log' not in p else np.nan + param[param <= 0] = np.nan bank[p] = param logging.info("Got %d templates from the bank", banklen) @@ -227,12 +245,21 @@ if cpar: for p in required_minmax: minval = np.nanmin(bank_fa[p][bank_fa[p] != -np.inf]) maxval = np.nanmax(bank_fa[p][bank_fa[p] != np.inf]) - valrange = maxval - minval + if (p in args.log_parameters) or (p == cpar and args.log_colormap): + # Extend the range by 10% in log-space + logvalrange = np.log(maxval) - np.log(minval) + if p not in mins: + mins[p] = np.exp(np.log(minval) - 0.05 * logvalrange) + if p not in maxs: + maxs[p] = np.exp(np.log(maxval) + 0.05 * logvalrange) + else: + # Extend the range by 10% + valrange = maxval - minval + if p not in mins: + mins[p] = minval - 0.05 * valrange + if p not in maxs: + maxs[p] = maxval + 0.05 * valrange - if p not in mins: - mins[p] = minval - 0.05 * valrange - if p not in maxs: - maxs[p] = maxval + 0.05 * valrange # Deal with non-coloring case: zvals = bank_fa[cpar] if cpar else None @@ -247,6 +274,7 @@ fig, axis_dict = create_multidim_plot( plot_scatter=True, plot_contours=False, scatter_cmap="viridis", + scatter_log_cmap=args.log_colormap, marginal_title=False, marginal_percentiles=[], fill_color='g', @@ -293,6 +321,12 @@ for i in range(len(args.parameters)): for s0, s1 in zip(sharex_axes[:-1], sharex_axes[1:]): s0.sharex(s1) +for (p1, p2), ax in axis_dict.items(): + if p1 in args.log_parameters: + ax[0].semilogx() + if p2 in args.log_parameters: + ax[0].semilogy() + logging.info("Plot generated") fig.set_dpi(args.dpi) diff --git a/pycbc/results/scatter_histograms.py b/pycbc/results/scatter_histograms.py index f89cc5a563f..e4f8e927135 100644 --- a/pycbc/results/scatter_histograms.py +++ b/pycbc/results/scatter_histograms.py @@ -43,7 +43,7 @@ if 'matplotlib.backends' not in sys.modules: # nopep8 matplotlib.use('agg') -from matplotlib import (offsetbox, pyplot, gridspec) +from matplotlib import (offsetbox, pyplot, gridspec, colors) from pycbc.results import str_utils from pycbc.io import FieldArray @@ -545,6 +545,7 @@ def create_multidim_plot(parameters, samples, labels=None, marginal_title=True, marginal_linestyle='-', zvals=None, show_colorbar=True, cbar_label=None, vmin=None, vmax=None, scatter_cmap='plasma', + scatter_log_cmap=False, plot_density=False, plot_contours=True, density_cmap='viridis', contour_color=None, label_contours=True, @@ -614,6 +615,8 @@ def create_multidim_plot(parameters, samples, labels=None, zvals. scatter_cmap : {'plasma', string} The color map to use for the scatter points. Default is 'plasma'. + scatter_log_cmap : boolean + Should the scatter point coloring be on a log scale? Default False plot_density : {False, bool} Plot the density of points as a color map. plot_contours : {True, bool} @@ -749,8 +752,14 @@ def create_multidim_plot(parameters, samples, labels=None, alpha = 0.3 else: alpha = 1. + if scatter_log_cmap: + cmap_norm = colors.LogNorm(vmin=vmin, vmax=vmax) + else: + cmap_norm = colors.Normalize(vmin=vmin, vmax=vmax) + + plt = ax.scatter(x=samples[px], y=samples[py], c=zvals, s=5, - edgecolors='none', vmin=vmin, vmax=vmax, + edgecolors='none', norm=cmap_norm, cmap=scatter_cmap, alpha=alpha, zorder=2) if plot_contours or plot_density: