diff --git a/virheat/command.py b/virheat/command.py index 96abda1..dd1fcf4 100644 --- a/virheat/command.py +++ b/virheat/command.py @@ -144,7 +144,7 @@ def main(sysargs=sys.argv[1:]): # extract vcf info n_scores = 0 if args.scores: - reference_name, frequency_lists, unique_mutations, file_names = data_prep.extract_vcf_data(vcf_files, threshold=args.threshold, scores=1) + reference_name, frequency_lists, unique_mutations, file_names = data_prep.extract_vcf_data(vcf_files, threshold=args.threshold, scores=True) n_scores = len(args.scores) else: reference_name, frequency_lists, unique_mutations, file_names = data_prep.extract_vcf_data(vcf_files, threshold=args.threshold) @@ -188,14 +188,14 @@ def main(sysargs=sys.argv[1:]): # define min y coordinate if n_tracks != 0 or n_scores != 0: - min_y_location = genome_y_location + genome_y_location / 2 * (n_tracks + n_scores + 1) + min_y_location = genome_y_location + genome_y_location / 2 * (n_tracks + n_scores + 2) else: min_y_location = genome_y_location # define size of the plot y_size = n_mutations*0.4 if args.scores: - x_size = y_size * (n_samples + min_y_location + len(args.scores)) / n_mutations + x_size = y_size * (n_samples + min_y_location + n_scores) / n_mutations else: x_size = y_size*(n_samples + min_y_location)/n_mutations x_size = x_size-x_size*0.15 # compensate of heatmap annotation @@ -223,37 +223,16 @@ def main(sysargs=sys.argv[1:]): plotting.create_axis(ax, n_mutations, min_y_location, n_samples, file_names, start, stop, genome_y_location, unique_mutations, reference_name, n_scores) plotting.create_mutation_legend(mutation_set, min_y_location, n_samples, n_scores) - plotting.create_colorbar(args.threshold, cmap_cells, min_y_location, n_samples, ax, n_scores) + plotting.create_colorbar(args.threshold, cmap_cells, min_y_location, n_samples, ax) # plot scores as track below the genome track if args.scores: score_count = 1 for score_params in args.scores: - score_count += 1 scores_file, pos_col, score_col, score_name = score_params unique_scores = data_prep.extract_scores(unique_mutations, scores_file, pos_col, score_col) - plotting.create_scores_vis(ax, genome_y_location, n_mutations, n_tracks, unique_scores, start, stop, score_name=score_name, score_count=score_count) - - - - #cmap_scores = plt.cm.get_cmap('coolwarm') # blue to red colormap - #if n_scoresets < 3: - # plotting.create_scores_cbar(cmap_scores, min_y_location, n_scoresets, score_set, score_name, score_count, ax) - - # plotting colorbars for scorsets if more than 3 scoresets - move colorbars under - # if n_scoresets > 2: - # ax2 = fig.add_axes(ax.get_position(), frameon=False) - # ax2.set_position([ax.get_position().x0, ax.get_position().y0 - 1.1 * ax.get_position().height*0.5, - # ax.get_position().width, 0.4]) - # ax2.set_xticks([]) - # ax2.set_yticks([]) - # score_count = 0 - # for score_params in args.scores: - # score_count += 1 - # scores_file, pos_col, score_col, score_name = score_params - # unique_scores = data_prep.extract_scores(unique_mutations, scores_file, pos_col, score_col) - # score_set = plotting.create_scores_vis(ax, genome_y_location, n_mutations, unique_scores, start, stop, score_name=score_name, score_count=score_count, no_plot=1) - # cmap_scores = plt.cm.get_cmap('coolwarm') - # #plotting.create_scores_cbar(cmap_scores, min_y_location, n_scoresets, score_set, score_name, score_count, ax2) + track_created = plotting.create_scores_vis(ax, genome_y_location, n_mutations, n_tracks, unique_scores, start, stop, score_count, score_name) + if track_created: + score_count += 1 if args.gff3_path is not None: if genes_with_mutations: @@ -261,11 +240,7 @@ def main(sysargs=sys.argv[1:]): cmap_genes = plt.get_cmap('tab20', len(genes_with_mutations)) colors_genes = [cmap_genes(i) for i in range(len(genes_with_mutations))] # plot gene track - if args.scores and n_samples >= 4: - min_y_location_sc = min_y_location - 1.7 * len(args.scores) - plotting.create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, start, stop, min_y_location_sc, genome_y_location, colors_genes) - else: - plotting.create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, start, stop, min_y_location, genome_y_location, colors_genes) + plotting.create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, start, stop, min_y_location, genome_y_location, colors_genes, n_scores) # create output folder if not os.path.exists(args.input[1]): diff --git a/virheat/scripts/data_prep.py b/virheat/scripts/data_prep.py index d174e42..af1c957 100644 --- a/virheat/scripts/data_prep.py +++ b/virheat/scripts/data_prep.py @@ -100,7 +100,7 @@ def read_vcf(vcf_file): return vcf_dict -def extract_vcf_data(vcf_files, threshold=0, scores=None): +def extract_vcf_data(vcf_files, threshold=0, scores=False): """ extract relevant vcf data """ @@ -118,7 +118,7 @@ def extract_vcf_data(vcf_files, threshold=0, scores=None): continue if scores: if vcf_dict['EFF'][idx] is not None: - aa_change = vcf_dict['EFF'][idx].split('|')[3] # extract amino acid changes if provided + aa_change = vcf_dict['EFF'][idx].split('|')[3] # extract amino acid changes if provided else: aa_change = '-' frequency_list.append( @@ -228,6 +228,7 @@ def delete_common_mutations(frequency_array, unique_mutations): return np.delete(frequency_array, mut_to_del, axis=1) + def delete_n_mutations(frequency_array, unique_mutations, min_mut): """ delete mutations that are not present in more than n samples @@ -261,6 +262,7 @@ def zoom_to_genomic_regions(unique_mutations, start_stop): return zoomed_unique + def parse_gff3(file): """ parse gff3 to dictionary @@ -367,4 +369,4 @@ def create_track_dict(unique_mutations, gff3_info, annotation_type): # and indicate the track in the dict gene_dict[gene].append(track) - return gene_dict, len(track_stops) \ No newline at end of file + return gene_dict, len(track_stops) diff --git a/virheat/scripts/plotting.py b/virheat/scripts/plotting.py index 0d832c1..9cf4e65 100644 --- a/virheat/scripts/plotting.py +++ b/virheat/scripts/plotting.py @@ -5,14 +5,11 @@ # BUILT-INS import math import numpy as np -import textwrap # LIBS import matplotlib.pyplot as plt import matplotlib.patches as patches import matplotlib.colors as mcolors -import matplotlib.cm as cm - def create_heatmap(ax, frequency_array, cmap): """ @@ -71,74 +68,44 @@ def create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, star return mutation_set -def create_scores_vis(ax, genome_y_location, n_mutations, n_tracks, unique_mutations, start, stop, score_name=None, score_count=None): +def create_scores_vis(ax, genome_y_location, n_mutations, n_tracks, unique_mutations, start, stop, score_count, score_name): """ create the scores rectangles, mappings to the reference """ score_set = [] - if n_tracks == 0: - y_min = -genome_y_location - 1 - score_count * (genome_y_location / 2) - else: - y_min = -genome_y_location - (n_tracks + 1) * (genome_y_location / 2) - score_count * (genome_y_location / 2) - y_max = y_min + genome_y_location / 2 - - if score_name: - for mutation in unique_mutations: - mutation_attributes = mutation.split("_") - score_set.append(float(mutation_attributes[5])) - score_set = [value for value in score_set if not np.isnan(value)] - if score_set: - norm = mcolors.Normalize(vmin=min(score_set), vmax=max(score_set)) # Normalization for the score range - else: - print("\033[31m\033[1mERROR:\033[0m Seems like there are no scores in the score set '{}' corresponding to the plotted mutation positions.".format(score_name)) - cmap = plt.cm.get_cmap('coolwarm') # blue to red colormap - ax.text(-0.5, y_min+genome_y_location/4, score_name, ha='right', va='center') - - # create a rectangle for the scores - rect = patches.FancyBboxPatch( - (0, y_min), n_mutations, y_max - y_min, - boxstyle="round,pad=-0.0040,rounding_size=0.03", - ec="lightgray", fc=(0, 0, 0, 0.05) - ) - ax.add_patch(rect) - - # create score lines on the score rectangle and the mapping to the respective mutation lines - x_start = 0 + y_zero = -genome_y_location - (n_tracks + score_count + 1) * (genome_y_location / 2) length = stop - start + + # create list of tuples [(nt pos, score)] for mutation in unique_mutations: mutation_attributes = mutation.split("_") - mutation_x_location = n_mutations/length*(int(mutation_attributes[0])-start) - x_start += 1 - # create lines for score_set - if score_name and score_set: - score_value = float(mutation_attributes[5]) - if score_value in score_set: - color = cmap(norm(score_value)) # map score to colormap - plt.vlines(x=mutation_x_location, ymin=y_min, ymax=y_max, color=color, linestyle='-') - - return score_set - + if not np.isnan(float(mutation_attributes[5])): + score_set.append((int(mutation_attributes[0]), float(mutation_attributes[5]))) -def create_scores_cbar(cmap, min_y_location, n_scoresets, score_set, score_name, score_count, ax): - """ - create a colorbar for the scoreset - """ if score_set: - if n_scoresets<3: - cbar = plt.colorbar(cm.ScalarMappable(cmap=cmap, norm=mcolors.Normalize(vmin=min(score_set), vmax=max(score_set))), - pad=0, shrink=1.5 / (min_y_location + 1), anchor=(-0.8, 0), - aspect=15, orientation='vertical', ax=ax, label=score_name) - else: - cbar = plt.colorbar(cm.ScalarMappable(cmap=cmap, norm=mcolors.Normalize(vmin=min(score_set), vmax=max(score_set))), label=score_name, ax=ax) - ax_pos = ax.get_position() - cbar.ax.set_position([ax_pos.x0 + score_count * 0.05, ax_pos.y0, ax_pos.width * 0.7, ax_pos.height*0.9]) - ticks = [min(score_set), min(score_set)+abs((max(score_set)-min(score_set))/2), max(score_set), 0] - cbar.set_ticks(ticks) - cbar.ax.tick_params(direction='in', labelsize='x-small') - cbar.set_label('\n'.join(textwrap.wrap(score_name, 20)), size='x-small') + # create zero line and score name on the left + plt.axhline(y=y_zero, color='black', linestyle='-', linewidth=0.5) + ax.text(-0.5, y_zero, score_name, ha='right', va='center') + # define normalization multiplier + multiplier = max([abs(score[1]) for score in score_set]) / abs((y_zero + genome_y_location / 4) - y_zero) + + # create score lines on the score rectangle and the mapping to the respective mutation lines + for score in score_set: + # define x value + mutation_x_location = n_mutations / length * (score[0] - start) + # define y value + # create lines for score_set + if score[1] < 0: + plt.vlines(x=mutation_x_location, ymin=y_zero + score[1]/multiplier, ymax=y_zero, color="red", linestyle='-') + else: + plt.vlines(x=mutation_x_location, ymin=y_zero, ymax=y_zero + score[1] / multiplier, color="blue", linestyle='-') + return True + else: + print("\033[31m\033[1mERROR:\033[0m Seems like there are no scores in the score set '{}' corresponding to the plotted mutation positions.".format(score_name)) + return False -def create_colorbar(threshold, cmap, min_y_location, n_samples, ax, n_scores=0): +def create_colorbar(threshold, cmap, min_y_location, n_samples, ax): """ creates a custom colorbar and annotates the threshold """ @@ -162,12 +129,12 @@ def create_colorbar(threshold, cmap, min_y_location, n_samples, ax, n_scores=0): labels.remove(rounded_threshold) ticks.append(threshold) labels.append(f"threshold\n={threshold}") - cbar = plt.colorbar(cmap, label="variant frequency", pad=0, shrink=n_samples/(min_y_location+n_samples+n_scores), anchor=(0.1,1), aspect=15, ax=ax) + cbar = plt.colorbar(cmap, label="variant frequency", pad=0, shrink=n_samples/(min_y_location+n_samples), anchor=(0.1,1), aspect=15, ax=ax) cbar.set_ticks(ticks) cbar.set_ticklabels(labels) -def create_mutation_legend(mutation_set, min_y_location, n_samples, n_scoresets=0): +def create_mutation_legend(mutation_set, min_y_location, n_samples, n_scores): """ create a legend for the mutation type """ @@ -179,11 +146,7 @@ def create_mutation_legend(mutation_set, min_y_location, n_samples, n_scoresets= legend_patches.append(patches.Patch(color="blue", label="INS")) if "SNV" in mutation_set: legend_patches.append(patches.Patch(color="dimgrey", label="SNV")) - - if n_scoresets != 0: - plt.legend(handles=legend_patches, bbox_to_anchor=(1.02, 0.95 - (n_samples / (min_y_location + n_samples))), loc='upper left', ncol=len(legend_patches)) - else: - plt.legend(bbox_to_anchor=(1.01, 0.95-(n_samples/(min_y_location+n_samples))), handles=legend_patches) + plt.legend(bbox_to_anchor=(1.01, 0.95-(n_samples/(min_y_location+n_samples+n_scores))), handles=legend_patches) def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, start, stop, genome_y_location, unique_mutations, reference_name, n_scoresets=0): @@ -193,7 +156,7 @@ def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, start, s # define plot limits ax.set_xlim(0, n_mutations) - ax.set_ylim(-min_y_location - n_scoresets, n_samples) + ax.set_ylim(-min_y_location, n_samples) # define new ticks depending on the genome size axis_length = stop - start if n_mutations >= 20: @@ -230,7 +193,7 @@ def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, start, s ax.spines["left"].set_visible(False) -def create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, start, stop, min_y_location, genome_y_location, colors_genes): +def create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, start, stop, min_y_location, genome_y_location, colors_genes, n_scores): """ create the vis for the gene """ @@ -244,7 +207,7 @@ def create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, sta x_value = 0 else: x_value = mult_factor*(genes_with_mutations[gene][0][0]-start) - y_value = -min_y_location+(n_tracks-genes_with_mutations[gene][1])*genome_y_location/2-genome_y_location/2 + y_value = -min_y_location+(n_tracks+n_scores-genes_with_mutations[gene][1])*genome_y_location/2 if genes_with_mutations[gene][0][1] > stop: width = n_mutations - x_value else: