From 6925c4162b794520def6da2d387e61241046c810 Mon Sep 17 00:00:00 2001 From: Jonas Fuchs <78491186+jonas-fuchs@users.noreply.github.com> Date: Fri, 17 Nov 2023 11:38:37 +0100 Subject: [PATCH] Zoom (#11) * first zoom function heatmap restricted and genome restricted, genes missing * updated the genome plotting parameters * updated readme * added --name option for plot name * added error message if 'region' is missing in gff3 --- README.md | 6 ++++- virheat/__init__.py | 2 +- virheat/command.py | 41 +++++++++++++++++++++++++++----- virheat/scripts/data_prep.py | 14 +++++++++++ virheat/scripts/plotting.py | 45 ++++++++++++++++++++++-------------- 5 files changed, 83 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index e642424..305008c 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,8 @@ positional arguments: options: -h, --help show this help message and exit + --name virHEAT_plot.pdf + plot name and file type (pdf, png, svg, jpg). Default: virHEAT_plot.pdf -l None, --genome-length None length of the genome (needed if gff3 is not provided) -g None, --gff3-path None @@ -70,7 +72,9 @@ options: --delete, --no-delete delete mutations that are present in all samples and their maximum frequency divergence is smaller than 0.5 (default: True) -n None, --delete-n None - do not show mutations that occur n times or less (default: Do not delete) + do not show mutations that occur n times or less (default: Do not delete) + -z start stop, --zoom start stop + restrict the plot to a specific genomic region. --sort, --no-sort sort sample names alphanumerically (default: False) --min-cov 20 display mutations covered at least x time (only if per base cov tsv files are provided) -v, --version show program's version number and exit diff --git a/virheat/__init__.py b/virheat/__init__.py index d80b915..0be322f 100644 --- a/virheat/__init__.py +++ b/virheat/__init__.py @@ -1,3 +1,3 @@ """plot vcf data as a heatmap mapped to a virus genome""" _program = "virheat" -__version__ = "0.5.4" +__version__ = "0.6" diff --git a/virheat/command.py b/virheat/command.py index d6b1fab..6d2fde8 100644 --- a/virheat/command.py +++ b/virheat/command.py @@ -32,6 +32,13 @@ def get_args(sysargs): nargs=2, help="folder containing input files and output folder" ) + parser.add_argument( + "--name", + type=str, + metavar="virHEAT_plot.pdf", + default="virHEAT_plot.pdf", + help="plot name and file type (pdf, png, svg, jpg). Default: virHEAT_plot.pdf" + ) parser.add_argument( "-l", "--genome-length", @@ -80,6 +87,14 @@ def get_args(sysargs): default=None, help="do not show mutations that occur n times or less (default: Do not delete)" ) + parser.add_argument( + "-z", + "--zoom", + type=int, + metavar=("start", "stop"), + nargs=2, + help="restrict the plot to a specific genomic region." + ) parser.add_argument( "--sort", action=argparse.BooleanOptionalAction, @@ -121,18 +136,21 @@ def main(sysargs=sys.argv[1:]): # extract vcf info reference_name, frequency_lists, unique_mutations, file_names = data_prep.extract_vcf_data(vcf_files, threshold=args.threshold) + if args.zoom: + unique_mutations = data_prep.zoom_to_genomic_regions(unique_mutations, args.zoom) frequency_array = data_prep.create_freq_array(unique_mutations, frequency_lists) + # user specified delete options (removes mutations based on various rationales) if args.delete: frequency_array = data_prep.delete_common_mutations(frequency_array, unique_mutations) if args.delete_n is not None: frequency_array = data_prep.delete_n_mutations(frequency_array, unique_mutations, args.delete_n) + # annotate low coverage if per base coveage from qualimap was provided data_prep.annotate_non_covered_regions(args.input[0], args.min_cov, frequency_array, file_names, unique_mutations) # define relative locations of all items in the plot - n_samples = len(frequency_array) - n_mutations = len(frequency_array[0]) + n_samples, n_mutations = len(frequency_array), len(frequency_array[0]) if n_mutations == 0: sys.exit("\033[31m\033[1mERROR:\033[0m Frequency array seems to be empty. There is nothing to plot.") if n_samples < 4: @@ -166,20 +184,31 @@ def main(sysargs=sys.argv[1:]): # ini the fig fig, ax = plt.subplots(figsize=[y_size, x_size]) + # define boundaries for the plot + if args.zoom: + start, stop = args.zoom[0], args.zoom[1] + # rescue plot if invalid zoom values are given + if args.zoom[0] < 0: + start = 0 + if args.zoom[1] > genome_end: + stop = genome_end + else: + start, stop = 0, genome_end + # plot all elements cmap = cm.gist_heat_r cmap.set_bad('silver', 1.) cmap_cells = cm.ScalarMappable(norm=colors.Normalize(0, 1), cmap=cmap) plotting.create_heatmap(ax, frequency_array, cmap_cells) - mutation_set = plotting.create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, genome_end) + mutation_set = plotting.create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, start, stop) if args.gff3_path is not None: if genes_with_mutations: # distinct colors for the genes cmap_genes = plt.get_cmap('tab20', len(genes_with_mutations)) colors_genes = [cmap_genes(i) for i in range(len(genes_with_mutations))] # plot gene track - plotting.create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, genome_end, min_y_location, genome_y_location, colors_genes) - plotting.create_axis(ax, n_mutations, min_y_location, n_samples, file_names, genome_end, genome_y_location, unique_mutations, reference_name) + plotting.create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, start, stop, min_y_location, genome_y_location, colors_genes) + plotting.create_axis(ax, n_mutations, min_y_location, n_samples, file_names, start, stop, genome_y_location, unique_mutations, reference_name) plotting.create_colorbar(args.threshold, cmap_cells, min_y_location, n_samples, ax) plotting.create_mutation_legend(mutation_set, min_y_location, n_samples) @@ -188,5 +217,5 @@ def main(sysargs=sys.argv[1:]): os.makedirs(args.input[1]) # save fig - fig.savefig(os.path.join(args.input[1], "virHEAT_plot.pdf"), bbox_inches="tight") + fig.savefig(os.path.join(args.input[1], args.name), bbox_inches="tight") diff --git a/virheat/scripts/data_prep.py b/virheat/scripts/data_prep.py index e8e338e..45ca30e 100644 --- a/virheat/scripts/data_prep.py +++ b/virheat/scripts/data_prep.py @@ -216,6 +216,18 @@ def delete_n_mutations(frequency_array, unique_mutations, min_mut): return np.delete(frequency_array, mut_to_del, axis=1) +def zoom_to_genomic_regions(unique_mutations, start_stop): + """ + restrict the displayed mutations to a user defined genomic range + """ + zoomed_unique = [] + + for mutation in unique_mutations: + if start_stop[0] <= int(mutation.split("_")[0]) <= start_stop[1]: + zoomed_unique.append(mutation) + + return zoomed_unique + def parse_gff3(file): """ parse gff3 to dictionary @@ -259,6 +271,8 @@ def get_genome_end(gff3_dict): genome_end = 0 + if "region" not in gff3_dict: + sys.exit("\033[31m\033[1mERROR:\033[0m Region annotation is missing in the gff3!") for attribute in gff3_dict["region"].keys(): stop = gff3_dict["region"][attribute]["stop"] if stop > genome_end: diff --git a/virheat/scripts/plotting.py b/virheat/scripts/plotting.py index 7819eda..415a018 100644 --- a/virheat/scripts/plotting.py +++ b/virheat/scripts/plotting.py @@ -25,7 +25,7 @@ def create_heatmap(ax, frequency_array, cmap): y_start += 1 -def create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, genome_end): +def create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, start, stop): """ create the genome rectangle, mutations and mappings to the heatmap """ @@ -51,11 +51,12 @@ def create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, geno # create mutation lines on the genome rectangle and the mapping to the respective cells x_start = 0 + length = stop - start for mutation in unique_mutations: mutation_attributes = mutation.split("_") mutation_color = mutation_type_colors[mutation_attributes[3]] mutation_set.add(mutation_attributes[3]) - mutation_x_location = n_mutations/genome_end*int(mutation_attributes[0]) + mutation_x_location = n_mutations/length*(int(mutation_attributes[0])-start) # create mutation lines plt.vlines(x=mutation_x_location, ymin=y_min, ymax=y_max, color=mutation_color) # create polygon @@ -111,7 +112,7 @@ def create_mutation_legend(mutation_set, min_y_location, n_samples): plt.legend(bbox_to_anchor=(1.01, 0.95-(n_samples/(min_y_location+n_samples))), handles=legend_patches) -def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, genome_end, genome_y_location, unique_mutations, reference_name): +def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, start, stop, genome_y_location, unique_mutations, reference_name): """ create the axis of the plot """ @@ -120,17 +121,18 @@ def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, genome_e ax.set_xlim(0, n_mutations) ax.set_ylim(-min_y_location, n_samples) # define new ticks depending on the genome size + axis_length = stop - start if n_mutations >= 20: - xtick_dis = round(genome_end/6, -int(math.log10(genome_end / 6)) + 1) - xtick_labels = [0, xtick_dis, xtick_dis*2, xtick_dis*3, xtick_dis*4, xtick_dis*5, genome_end] + xtick_dis = round(axis_length/6, -int(math.log10(axis_length / 6)) + 1) + xtick_labels = [start, start + xtick_dis, start + xtick_dis*2, start + xtick_dis*3, start + xtick_dis*4, start + xtick_dis*5, stop] elif n_mutations >= 10: - xtick_dis = round(genome_end / 3, -int(math.log10(genome_end / 3)) + 1) - xtick_labels = [0, xtick_dis, xtick_dis * 2, genome_end] + xtick_dis = round(axis_length / 3, -int(math.log10(axis_length / 3)) + 1) + xtick_labels = [start, start + xtick_dis, start + xtick_dis * 2, stop] else: - xtick_labels = [0, genome_end] + xtick_labels = [start, stop] xtick_labels = [int(tick) for tick in xtick_labels] # get the correct location of the genome pos on the axis - xticks = [n_mutations/genome_end*tick for tick in xtick_labels] + xticks = [n_mutations/axis_length*(tick - start) for tick in xtick_labels] # set new ticks and change spines/yaxis ax.set_xticks(xticks, xtick_labels) # set y axis labels @@ -154,27 +156,36 @@ def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, genome_e ax.spines["left"].set_visible(False) -def create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, genome_end, min_y_location, genome_y_location, colors_genes): +def create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, start, stop, min_y_location, genome_y_location, colors_genes): """ create the vis for the gene """ gene_annotations = [] - mult_factor = n_mutations/genome_end + mult_factor = n_mutations/(stop-start) for idx, gene in enumerate(genes_with_mutations): - start = (mult_factor*genes_with_mutations[gene][0][0], -min_y_location+(n_tracks-genes_with_mutations[gene][1])*genome_y_location/2-genome_y_location/2) - stop = mult_factor*genes_with_mutations[gene][0][1] + # define the plotting values for the patch + if genes_with_mutations[gene][0][0] < start: + x_value = 0 + else: + x_value = mult_factor*(genes_with_mutations[gene][0][0]-start) + y_value = -min_y_location+(n_tracks-genes_with_mutations[gene][1])*genome_y_location/2-genome_y_location/2 + if genes_with_mutations[gene][0][1] > stop: + width = n_mutations - x_value + else: + width = mult_factor*(genes_with_mutations[gene][0][1]-start) - x_value height = genome_y_location/2 + # plot the patch ax.add_patch( patches.FancyBboxPatch( - start, stop-start[0], height, + (x_value, y_value), width, height, boxstyle="round,pad=-0.0040,rounding_size=0.03", ec="black", fc=colors_genes[idx] ) ) # define text pos for gene description inside or below the gene box, depending if it fits within - if stop-start[0] > n_mutations/(y_size*8)*len(gene): - gene_annotations.append(ax.text(start[0]+(stop-start[0])/2, start[1]+height/2, gene, ha="center", va="center")) + if width > n_mutations/(y_size*8)*len(gene): + gene_annotations.append(ax.text(x_value+width/2, y_value+height/2, gene, ha="center", va="center")) else: - gene_annotations.append(ax.text(start[0]+(stop-start[0])/2, start[1]-height/4, gene, rotation=40, rotation_mode="anchor", ha="right", va="bottom")) + gene_annotations.append(ax.text(x_value+width/2, y_value-height/4, gene, rotation=40, rotation_mode="anchor", ha="right", va="bottom"))