From 245ca437429a2890feac416f012cd2ec4612b30c Mon Sep 17 00:00:00 2001 From: jonas-fuchs Date: Thu, 16 Nov 2023 18:16:09 +0100 Subject: [PATCH] first zoom function heatmap restricted and genome restricted, genes missing --- virheat/__init__.py | 2 +- virheat/command.py | 26 +++++++++++++++++++++++--- virheat/scripts/data_prep.py | 12 ++++++++++++ virheat/scripts/plotting.py | 20 +++++++++++--------- 4 files changed, 47 insertions(+), 13 deletions(-) diff --git a/virheat/__init__.py b/virheat/__init__.py index d80b915..0be322f 100644 --- a/virheat/__init__.py +++ b/virheat/__init__.py @@ -1,3 +1,3 @@ """plot vcf data as a heatmap mapped to a virus genome""" _program = "virheat" -__version__ = "0.5.4" +__version__ = "0.6" diff --git a/virheat/command.py b/virheat/command.py index d6b1fab..8d286f8 100644 --- a/virheat/command.py +++ b/virheat/command.py @@ -80,6 +80,14 @@ def get_args(sysargs): default=None, help="do not show mutations that occur n times or less (default: Do not delete)" ) + parser.add_argument( + "-z", + "--zoom", + type=int, + metavar=("start", "stop"), + nargs=2, + help="restrict the plot to a specific genomic region." + ) parser.add_argument( "--sort", action=argparse.BooleanOptionalAction, @@ -118,9 +126,10 @@ def main(sysargs=sys.argv[1:]): vcf_files = data_prep.get_files(args.input[0], "vcf") if args.sort: vcf_files = sorted(vcf_files, key=lambda x: data_prep.get_digit_and_alpha(os.path.basename(x))) - # extract vcf info reference_name, frequency_lists, unique_mutations, file_names = data_prep.extract_vcf_data(vcf_files, threshold=args.threshold) + if args.zoom: + unique_mutations = data_prep.zoom_to_genomic_regions(unique_mutations, args.zoom) frequency_array = data_prep.create_freq_array(unique_mutations, frequency_lists) # user specified delete options (removes mutations based on various rationales) if args.delete: @@ -166,12 +175,23 @@ def main(sysargs=sys.argv[1:]): # ini the fig fig, ax = plt.subplots(figsize=[y_size, x_size]) + # define boundaries for the plot + if args.zoom: + start, stop = args.zoom[0], args.zoom[1] + # rescue plot if invalid zoom values are given + if args.zoom[0] < 0: + start = 0 + if args.zoom[1] > genome_end: + stop = genome_end + else: + start, stop = 0, genome_end + # plot all elements cmap = cm.gist_heat_r cmap.set_bad('silver', 1.) cmap_cells = cm.ScalarMappable(norm=colors.Normalize(0, 1), cmap=cmap) plotting.create_heatmap(ax, frequency_array, cmap_cells) - mutation_set = plotting.create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, genome_end) + mutation_set = plotting.create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, start, stop) if args.gff3_path is not None: if genes_with_mutations: # distinct colors for the genes @@ -179,7 +199,7 @@ def main(sysargs=sys.argv[1:]): colors_genes = [cmap_genes(i) for i in range(len(genes_with_mutations))] # plot gene track plotting.create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, genome_end, min_y_location, genome_y_location, colors_genes) - plotting.create_axis(ax, n_mutations, min_y_location, n_samples, file_names, genome_end, genome_y_location, unique_mutations, reference_name) + plotting.create_axis(ax, n_mutations, min_y_location, n_samples, file_names, start, stop, genome_y_location, unique_mutations, reference_name) plotting.create_colorbar(args.threshold, cmap_cells, min_y_location, n_samples, ax) plotting.create_mutation_legend(mutation_set, min_y_location, n_samples) diff --git a/virheat/scripts/data_prep.py b/virheat/scripts/data_prep.py index e8e338e..edea269 100644 --- a/virheat/scripts/data_prep.py +++ b/virheat/scripts/data_prep.py @@ -216,6 +216,18 @@ def delete_n_mutations(frequency_array, unique_mutations, min_mut): return np.delete(frequency_array, mut_to_del, axis=1) +def zoom_to_genomic_regions(unique_mutations, start_stop): + """ + restrict the displayed mutations to a user defined genomic range + """ + zoomed_unique = [] + + for mutation in unique_mutations: + if start_stop[0] <= int(mutation.split("_")[0]) <= start_stop[1]: + zoomed_unique.append(mutation) + + return zoomed_unique + def parse_gff3(file): """ parse gff3 to dictionary diff --git a/virheat/scripts/plotting.py b/virheat/scripts/plotting.py index 7819eda..56d1967 100644 --- a/virheat/scripts/plotting.py +++ b/virheat/scripts/plotting.py @@ -25,7 +25,7 @@ def create_heatmap(ax, frequency_array, cmap): y_start += 1 -def create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, genome_end): +def create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, genome_start, genome_stop): """ create the genome rectangle, mutations and mappings to the heatmap """ @@ -51,11 +51,12 @@ def create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, geno # create mutation lines on the genome rectangle and the mapping to the respective cells x_start = 0 + length = genome_stop - genome_start for mutation in unique_mutations: mutation_attributes = mutation.split("_") mutation_color = mutation_type_colors[mutation_attributes[3]] mutation_set.add(mutation_attributes[3]) - mutation_x_location = n_mutations/genome_end*int(mutation_attributes[0]) + mutation_x_location = n_mutations/length*(int(mutation_attributes[0])-genome_start) # create mutation lines plt.vlines(x=mutation_x_location, ymin=y_min, ymax=y_max, color=mutation_color) # create polygon @@ -111,7 +112,7 @@ def create_mutation_legend(mutation_set, min_y_location, n_samples): plt.legend(bbox_to_anchor=(1.01, 0.95-(n_samples/(min_y_location+n_samples))), handles=legend_patches) -def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, genome_end, genome_y_location, unique_mutations, reference_name): +def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, start, stop, genome_y_location, unique_mutations, reference_name): """ create the axis of the plot """ @@ -120,17 +121,18 @@ def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, genome_e ax.set_xlim(0, n_mutations) ax.set_ylim(-min_y_location, n_samples) # define new ticks depending on the genome size + axis_length = stop - start if n_mutations >= 20: - xtick_dis = round(genome_end/6, -int(math.log10(genome_end / 6)) + 1) - xtick_labels = [0, xtick_dis, xtick_dis*2, xtick_dis*3, xtick_dis*4, xtick_dis*5, genome_end] + xtick_dis = round(axis_length/6, -int(math.log10(axis_length / 6)) + 1) + xtick_labels = [start, start + xtick_dis, start + xtick_dis*2, start + xtick_dis*3, start + xtick_dis*4, start + xtick_dis*5, stop] elif n_mutations >= 10: - xtick_dis = round(genome_end / 3, -int(math.log10(genome_end / 3)) + 1) - xtick_labels = [0, xtick_dis, xtick_dis * 2, genome_end] + xtick_dis = round(axis_length / 3, -int(math.log10(axis_length / 3)) + 1) + xtick_labels = [start, start + xtick_dis, start + xtick_dis * 2, stop] else: - xtick_labels = [0, genome_end] + xtick_labels = [start, stop] xtick_labels = [int(tick) for tick in xtick_labels] # get the correct location of the genome pos on the axis - xticks = [n_mutations/genome_end*tick for tick in xtick_labels] + xticks = [n_mutations/axis_length*(tick - start) for tick in xtick_labels] # set new ticks and change spines/yaxis ax.set_xticks(xticks, xtick_labels) # set y axis labels