diff --git a/mc3/plots/posterior.py b/mc3/plots/posterior.py index ff940a7..1e3580a 100644 --- a/mc3/plots/posterior.py +++ b/mc3/plots/posterior.py @@ -193,6 +193,7 @@ def _pairwise( hist, hist_xran, axes, ranges, estimates, palette, nlevels, absolute_dens, lmax, linewidth, theme, alpha=0.8, clear=True, + swap_axes=False, ): """ Lowest-lever routine to plot pair-wise posterior distributions. @@ -206,13 +207,19 @@ def _pairwise( for icol in range(npars-1): for irow in range(icol, npars-1): ax = axes[irow,icol] + if swap_axes: + row, col = icol, irow+1 + histo = hist[col-1,row].T + else: + row, col = irow+1, icol + histo = hist[row-1,col] if clear: ax.clear() extent = ( - hist_xran[icol,0], - hist_xran[icol,-1], - hist_xran[irow+1,0], - hist_xran[irow+1,-1], + hist_xran[col, 0], + hist_xran[col, -1], + hist_xran[row, 0], + hist_xran[row, -1], ) levels = np.zeros(nlevels+1) levels[1:] = np.linspace(1.0, lmax[irow,icol], nlevels) @@ -220,29 +227,30 @@ def _pairwise( colors[0,3] = 0.0 colors[1,3] = 0.75*alpha cont = ax.contourf( - hist[irow,icol], + histo, colors=colors, levels=levels, origin='lower', extent=extent, ) - edge_color = to_rgba(theme.color, alpha=0.65) - for c in cont.collections: - c.set_edgecolor(edge_color) - c.set_linewidth(0.1) - cont.collections[0].set_edgecolor((1,1,1,0)) - if estimates[icol] is not None: + edge_color = to_rgba(theme.color, alpha=0.25) + edge_colors = [edge_color for _ in cont.get_paths()] + edge_colors[0] = (1,1,1,0) + cont.set_edgecolor(edge_colors) + cont.set_linewidth(0.1) + + if estimates[col] is not None: ax.axvline( - estimates[icol], + estimates[col], dashes=(9,2), lw=linewidth, color=theme.dark_color, ) - if estimates[irow+1] is not None: + if estimates[row] is not None: ax.axhline( - estimates[irow+1], + estimates[row], dashes=(9,2), lw=linewidth, color=theme.dark_color, ) - if ranges[icol] is not None: - ax.set_xlim(ranges[icol]) - if ranges[irow] is not None: - ax.set_ylim(ranges[irow+1]) + if ranges[col] is not None: + ax.set_xlim(ranges[col]) + if ranges[row] is not None: + ax.set_ylim(ranges[row]) def _plot_marginal(obj): @@ -302,7 +310,11 @@ def _plot_pairwise(obj): for icol in range(npars-1): for irow in range(icol, npars-1): ax = obj.pair_axes[irow,icol] - h = nx*irow + icol + 1 + npars*int(obj.plot_marginal) + if obj.orientation == 'vertical': + h = nx*irow + icol + 1 + npars*int(obj.plot_marginal) + else: + h = 2 + (nx+int(obj.plot_marginal))*icol + irow - icol + ax_position = subplot( obj.rect, obj.margin, h, nx, nx, obj.ymargin, dry=True, ) @@ -495,7 +507,7 @@ def __set__(self, obj, value): setattr(obj.source, var_name, value) -class Marginal(object): +class Marginal(): """A mostly-interactive marginal posterior plotting object.""" # Soft-update properties: pnames = SoftUpdate() @@ -625,6 +637,48 @@ def plot( if savefile is not None: self.fig.savefig(savefile, dpi=300) + def overplot(self, post, labels=None, nlevels=4, alpha=0.4): + """ + Overplot additional posteriors in the same figure. + + This method is still work in progress! + Note that a call to self.update() or even soft updates + will remove all/some of the overplot data. In such case + the user would need to make a new call to self.overplot(). + It is also recommended to set show_estimates=False to + prevent over-crowding the figures. + + Parameters + ---------- + posts: 1D iterable of Posterior objects + Currently there are no checks that these new posteriors + have the same parameters (nor same statistics) as self. + The user needs to make sure they are all compartible. + labels: 1D iterable of strings + Labels for each posterior. Note that if provided, the + length of labels has to be one more than posts, because + it also contains the label for self. + """ + if '_like' in post.statistics: + hpd_min = post.hpd_min + else: + hpd_min = None + + estimates = post.estimates + if not self.show_estimates: + estimates = [None for _ in estimates] + + _histogram( + post.posterior, estimates, post.ranges, + self.hist_axes, self.bins, + post.pdf, post.xpdf, + hpd_min, post.low_bounds, post.high_bounds, + self.linewidth, post.theme, + post.orientation, + top_pad=1.25, + clear=False, + ) + class Figure(Marginal): """A mostly-interactive pair-wise posterior plotting object.""" @@ -756,6 +810,7 @@ def plot( self.lmax, self.linewidth, self.theme, + swap_axes=self.orientation=='upper right', ) # The colorbar: