Skip to content

`pyani` graphics code

Bailey Harrington edited this page Jul 26, 2022 · 1 revision

Pyani plotting code overview

The layout of the code used by pyani plot can be difficult to follow because it does not follow the convention used by other features that have a subcommand file, scripts/subcommands/subcmd_<feature>.py and an API file, <feature>.py.

Instead, the relevant code is contained in the subtree:

pyani_graphics
|
├── __pycache__
│   └── ...
├── mpl
│   ├── __pycache__
│   |   └── ...
│   └── __init__.py
├── sns
│   ├── __pycache__
│   |   └── ...
│   └── __init__.py
└── __init__.py

The entry point to this code is the subcmd_plot.py file. At the top of this file, pyani_graphics is imported, and then some global variables are created that point to custom plotting methods defined in the mpl/__init__.py and sns/__init__.py files:

from pyani import pyani_graphics

# Dictionary of matrix graphics methods
GMETHODS = {"mpl": pyani_graphics.mpl.heatmap, "seaborn": pyani_graphics.sns.heatmap}

# Dictionary of scatter plot graphics methods
SMETHODS = {"mpl": pyani_graphics.mpl.scatter, "seaborn": pyani_graphics.sns.scatter}

# Distribution dictionary of distribution graphics methods
DISTMETHODS = {
    "mpl": pyani_graphics.mpl.distribution, "seaborn": pyani_graphics.sns.distribution,
}

An important point to note here is that the mpl and sns mentioned in these dictionaries are not the plotting libraries, but the directories within pyani. This layer of abstraction allows the subsequent code to be more compact. An example of how can be seen in the write_distribution() function:

def write_distribution(
    run_id: int, matdata: MatrixData, outfmts: List[str], args: Namespace
) -> None:
    """Write distribution plots for each matrix type.

    :param run_id:  int, run_id for this run
    :param matdata:  MatrixData object for this distribution plot
    :param args:  Namespace for command-line arguments
    :param outfmts:  list of output formats for files
    """
    logger = logging.getLogger(__name__)

    logger.info("Writing distribution plot for %s matrix", matdata.name)
    for fmt in outfmts:
        outfname = Path(args.outdir) / f"distribution_{matdata.name}_run{run_id}.{fmt}"
        logger.debug("\tWriting graphics to %s", outfname)
        DISTMETHODS[args.method](
            matdata.data,
            outfname,
            matdata.name,
            title=f"matrix_{matdata.name}_run{run_id}",
        )

    # Be tidy with matplotlib caches
    plt.close("all")

The custom distribution plotting functions defined in mpl/__init__.py and sns/__init__.py are accessed by this portion of the function:

DISTMETHODS[args.method](
            matdata.data,
            outfname,
            matdata.name,
            title=f"matrix_{matdata.name}_run{run_id}",
        )

DISTMETHODS[args.method] will retrieve the correct function name for whichever key it is passed ("mpl" or "seaborn"). That function will be called, with the given parameters. In the case where the key is "seaborn", the corresponding pyani plotting function is:

def distribution(dfr, outfilename, matname, title=None):
    """Return seaborn distribution plot for matrix.

    :param drf:  DataFrame with results matrix
    :param outfilename:  Path to output file for writing
    :param matname:  str, type of matrix being plotted
    :param title:  str, optional title
    """
    fill = "#A6C8E0"
    rug = "#2678B2"
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    fig.suptitle(title)
    sns.histplot(
        dfr.values.flatten(),
        ax=axes[0],
        stat="count",
        element="step",
        color=fill,
        edgecolor=fill,
    )
    axes[0].set_ylim(ymin=0)
    sns.kdeplot(dfr.values.flatten(), ax=axes[1])
    sns.rugplot(dfr.values.flatten(), ax=axes[1], color=rug)

    # Modify axes after data is plotted
    for _ in axes:
        if matname == "sim_errors":
            _.set_xlim(0, _.get_xlim()[1])
        elif matname in ["hadamard", "coverage"]:
            _.set_xlim(0, 1.01)
        elif matname == "identity":
            _.set_xlim(0.75, 1.01)

    # Tidy figure
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])

    if outfilename:
        # For some reason seaborn gives us an AxesSubPlot with
        # sns.distplot, rather than a Figure, so we need this hack
        fig.savefig(outfilename)

    return fig

This function serves as a wrapper around sns.histplot(), which is called within it.

Multiprocessing

In an effort to speed up the plotting (each run plotted results in 15 plots being generated), multiprocessing has been used to split the creation of individual plots out into individual threads. This significantly reduces the time needed to generate all of the plots, but at the cost of straightforward logging and debugging of the code run when creating those plots.

First, in the function write_run_plots(), in subcmd_plot.py, a pool of workers (threads) is created:

pool = multiprocessing.Pool(processes=args.workers)

The a list, plotting_commands, is defined to which tuples of the form:

(<function_name>, <list of arguments>)

are appended, e.g.:

(write_distribution, [run_id, matdata, outfmts, args])

This tuple form is used to make it easy to unpack the values for each plotting command in a for loop, and pass them to the pool.apply_async() method:

    for func, options in plotting_commands:
        logger.debug("Running %s with options %s", func, options)
        pool.apply_async(func, args=options)

Because each call to pool.apply_async() creates a new subprocess, logging and writing to stdout and stderr become difficult. In order to have each plotting command write to the log file specified in the pyani command line, concerns about thread safety would need to be considered—to prevent multiple threads from trying to write to the file at once.

It is not impossible to get 'logging' statements out of this black box, but an elegant way to do so in pyani has not yet been found. Possibilities include using the callback parameter of pool.apply_async():

apply_async(func[, args[, kwds[, callback[, error_callback]]]])

and specifying a method like logging.info(), or sys.stdout.write(), but then, rather than using normal logging/other commands, the desired information must be appended to a list that can be returned and passed as the argument to the specified callback method.

tl;dr

While doing active development on the plotting code, it may be worthwhile to disable the asynchronous bits so that logging works as expected.