diff --git a/.gitignore b/.gitignore index f9bf2cc..c6e9902 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ virheat/__pycache__/** virheat.egg-info/** build/** venv/** +test.py diff --git a/README.md b/README.md index 934e018..aba7667 100644 --- a/README.md +++ b/README.md @@ -28,13 +28,8 @@ pip install virheat ```shell git clone https://github.com/jonas-fuchs/virHEAT cd virHEAT -``` -and then install virHEAT with: -```shell pip install -r requirements.txt -``` -or: -```shell +# or pip install . ``` That was already it. To check if it worked: @@ -55,21 +50,20 @@ usage: virheat -l or -g [additional argum ``` positional arguments: - input folder containing vcf (and tsv) files and output folder + input folder containing input files and output folder -optional arguments: +options: -h, --help show this help message and exit -l None, --genome-length None length of the genome (needed if gff3 is not provided) -g None, --gff3-path None path to gff3 (needed if length is not provided) - -a gene, --gff3-annotations gene - annotations to display from gff3 file (standard: gene) - -t 0, --threshold 0 display frequencies above this threshold + -a [gene ...], --gff3-annotations [gene ...] + annotations to display from gff3 file (standard: gene). Multiple possible. + -t 0, --threshold 0 display frequencies above this threshold (0-1) --delete, --no-delete - delete mutations with frequencies present in all - samples (default: True) - --sort, --no-sort sort alphanumerically (default: False) + delete mutations that are present in all samples and their maximum frequency divergence is smaller than 0.5 (default: True) + --sort, --no-sort sort sample names alphanumerically (default: False) --min-cov 20 display mutations covered at least x time (only if per base cov tsv files are provided) -v, --version show program's version number and exit ``` diff --git a/virheat/__init__.py b/virheat/__init__.py index e841665..bf5eead 100644 --- a/virheat/__init__.py +++ b/virheat/__init__.py @@ -1,3 +1,3 @@ """plot vcf data as a heatmap mapped to a virus genome""" _program = "virheat" -__version__ = "0.5.2" +__version__ = "0.5.3" diff --git a/virheat/command.py b/virheat/command.py index 963f4cc..74481ff 100644 --- a/virheat/command.py +++ b/virheat/command.py @@ -52,9 +52,11 @@ def get_args(sysargs): "-a", "--gff3-annotations", type=str, + action="store", metavar="gene", - default="gene", - help="annotations to display from gff3 file (standard: gene). Multiple possible (comma seperated)" + nargs="*", + default=["gene"], + help="annotations to display from gff3 file (standard: gene). Multiple possible." ) parser.add_argument( "-t", @@ -62,19 +64,19 @@ def get_args(sysargs): type=float, metavar="0", default=0, - help="display frequencies above this threshold" + help="display frequencies above this threshold (0-1)" ) parser.add_argument( "--delete", action=argparse.BooleanOptionalAction, default=True, - help="delete mutations with frequencies present in all samples" + help="delete mutations that are present in all samples and their maximum frequency divergence is smaller than 0.5" ) parser.add_argument( "--sort", action=argparse.BooleanOptionalAction, default=False, - help="sort alphanumerically" + help="sort sample names alphanumerically" ) parser.add_argument( "--min-cov", @@ -120,6 +122,8 @@ def main(sysargs=sys.argv[1:]): # define relative locations of all items in the plot n_samples = len(frequency_array) n_mutations = len(frequency_array[0]) + if n_mutations == 0: + sys.exit("\033[31m\033[1mERROR:\033[0m Frequency array seems to be empty. There is nothing to plot.") if n_samples < 4: genome_y_location = 2 else: @@ -134,8 +138,7 @@ def main(sysargs=sys.argv[1:]): if gff3_ref_name not in reference_name and reference_name not in gff3_ref_name: print("\033[31m\033[1mWARNING:\033[0m gff3 reference does not match the vcf reference!") genome_end = data_prep.get_genome_end(gff3_info) - annotation_list = args.gff3_annotations.split(",") - genes_with_mutations, n_tracks = data_prep.create_track_dict(unique_mutations, gff3_info, annotation_list) + genes_with_mutations, n_tracks = data_prep.create_track_dict(unique_mutations, gff3_info, args.gff3_annotations) # define space for the genome vis tracks min_y_location = genome_y_location + genome_y_location/2 * (n_tracks+1) elif args.genome_length is not None: @@ -159,11 +162,12 @@ def main(sysargs=sys.argv[1:]): plotting.create_heatmap(ax, frequency_array, cmap_cells) mutation_set = plotting.create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, genome_end) if args.gff3_path is not None: - # distinct colors for the genes - cmap_genes = plt.get_cmap('tab20', len(genes_with_mutations)) - colors_genes = [cmap_genes(i) for i in range(len(genes_with_mutations))] - # plot gene track - plotting.create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, genome_end, min_y_location, genome_y_location, colors_genes) + if genes_with_mutations: + # distinct colors for the genes + cmap_genes = plt.get_cmap('tab20', len(genes_with_mutations)) + colors_genes = [cmap_genes(i) for i in range(len(genes_with_mutations))] + # plot gene track + plotting.create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, genome_end, min_y_location, genome_y_location, colors_genes) plotting.create_axis(ax, n_mutations, min_y_location, n_samples, file_names, genome_end, genome_y_location, unique_mutations, reference_name) plotting.create_colorbar(args.threshold, cmap_cells, min_y_location, n_samples, ax) plotting.create_mutation_legend(mutation_set, min_y_location, n_samples) diff --git a/virheat/scripts/data_prep.py b/virheat/scripts/data_prep.py index ab2bd20..813776e 100644 --- a/virheat/scripts/data_prep.py +++ b/virheat/scripts/data_prep.py @@ -66,7 +66,7 @@ def read_vcf(vcf_file): for key in header[0:6]: vcf_dict[key] = [] # functional effect - vcf_dict["TYPE"] = [] + vcf_dict["MUT_TYPE_"] = [] # info field for line in lines: for info in line[7].split(";"): @@ -80,13 +80,13 @@ def read_vcf(vcf_file): vcf_dict[key].append(convert_string(line[idx])) # get mutation type if len(line[3]) == len(line[4]): - vcf_dict["TYPE"].append("SNV") + vcf_dict["MUT_TYPE_"].append("SNV") elif len(line[3]) < len(line[4]): - vcf_dict["TYPE"].append("INS") + vcf_dict["MUT_TYPE_"].append("INS") elif len(line[3]) > len(line[4]): - vcf_dict["TYPE"].append("DEL") + vcf_dict["MUT_TYPE_"].append("DEL") visited_keys.extend(header[0:6]) - visited_keys.append("TYPE") + visited_keys.append("MUT_TYPE_") # get data from info field for info in line[7].split(";"): if "=" in info: @@ -117,7 +117,7 @@ def extract_vcf_data(vcf_files, threshold=0): if not vcf_dict["AF"][idx] >= threshold: continue frequency_list.append( - (f"{vcf_dict['POS'][idx]}_{vcf_dict['REF'][idx]}_{vcf_dict['ALT'][idx]}_{vcf_dict['TYPE'][idx]}", vcf_dict['AF'][idx]) + (f"{vcf_dict['POS'][idx]}_{vcf_dict['REF'][idx]}_{vcf_dict['ALT'][idx]}_{vcf_dict['MUT_TYPE_'][idx]}", vcf_dict['AF'][idx]) ) frequency_lists.append(frequency_list) # sort by mutation index @@ -179,13 +179,15 @@ def delete_common_mutations(frequency_array, unique_mutations): mut_to_del = [] for idx in range(0, len(frequency_array[0])): + check_all = [] for frequency_list in frequency_array: - if frequency_list[idx] != 0: - common_mut = True - else: - common_mut = False - break - if common_mut: + check_all.append(frequency_list[idx]) + # check if all mutation in a column are zero (happens with some weird callers) + if all(x == 0 for x in check_all): + mut_to_del.append(idx) + # check if frequencies are present in all columns and the maximal diff is greater than 0.5 + # example [0.8, 0.7, 0.3] is not deleted whereas [0.8, 0.7, 0.7] is deleted + elif all(x > 0 for x in check_all) and max(check_all)-min(check_all) < 0.5: mut_to_del.append(idx) for idx in sorted(mut_to_del, reverse=True): @@ -272,7 +274,8 @@ def create_track_dict(unique_mutations, gff3_info, annotation_type): gff3_info[type][annotation]["stop"]) ) if not genes_with_mutations: - sys.exit("none of the given annotation types were found in gff3.") + print("\033[31m\033[1mWARNING:\033[0m either the annotation types were not found in gff3 or the mutations are not within genes.") + return {}, 0 # create a dict and sort gene_dict = {element[0]: [element[1:4]] for element in genes_with_mutations}