From 245ca437429a2890feac416f012cd2ec4612b30c Mon Sep 17 00:00:00 2001
From: jonas-fuchs <jonas.michael.fuchs@googlemail.com>
Date: Thu, 16 Nov 2023 18:16:09 +0100
Subject: [PATCH 1/5] first zoom function heatmap restricted and genome
 restricted, genes missing

---
 virheat/__init__.py          |  2 +-
 virheat/command.py           | 26 +++++++++++++++++++++++---
 virheat/scripts/data_prep.py | 12 ++++++++++++
 virheat/scripts/plotting.py  | 20 +++++++++++---------
 4 files changed, 47 insertions(+), 13 deletions(-)

diff --git a/virheat/__init__.py b/virheat/__init__.py
index d80b915..0be322f 100644
--- a/virheat/__init__.py
+++ b/virheat/__init__.py
@@ -1,3 +1,3 @@
 """plot vcf data as a heatmap mapped to a virus genome"""
 _program = "virheat"
-__version__ = "0.5.4"
+__version__ = "0.6"
diff --git a/virheat/command.py b/virheat/command.py
index d6b1fab..8d286f8 100644
--- a/virheat/command.py
+++ b/virheat/command.py
@@ -80,6 +80,14 @@ def get_args(sysargs):
         default=None,
         help="do not show mutations that occur n times or less (default: Do not delete)"
     )
+    parser.add_argument(
+        "-z",
+        "--zoom",
+        type=int,
+        metavar=("start", "stop"),
+        nargs=2,
+        help="restrict the plot to a specific genomic region."
+    )
     parser.add_argument(
         "--sort",
         action=argparse.BooleanOptionalAction,
@@ -118,9 +126,10 @@ def main(sysargs=sys.argv[1:]):
     vcf_files = data_prep.get_files(args.input[0], "vcf")
     if args.sort:
         vcf_files = sorted(vcf_files, key=lambda x: data_prep.get_digit_and_alpha(os.path.basename(x)))
-
     # extract vcf info
     reference_name, frequency_lists, unique_mutations, file_names = data_prep.extract_vcf_data(vcf_files, threshold=args.threshold)
+    if args.zoom:
+        unique_mutations = data_prep.zoom_to_genomic_regions(unique_mutations, args.zoom)
     frequency_array = data_prep.create_freq_array(unique_mutations, frequency_lists)
     # user specified delete options (removes mutations based on various rationales)
     if args.delete:
@@ -166,12 +175,23 @@ def main(sysargs=sys.argv[1:]):
     # ini the fig
     fig, ax = plt.subplots(figsize=[y_size, x_size])
 
+    # define boundaries for the plot
+    if args.zoom:
+        start, stop = args.zoom[0], args.zoom[1]
+        # rescue plot if invalid zoom values are given
+        if args.zoom[0] < 0:
+            start = 0
+        if args.zoom[1] > genome_end:
+            stop = genome_end
+    else:
+        start, stop = 0, genome_end
+
     # plot all elements
     cmap = cm.gist_heat_r
     cmap.set_bad('silver', 1.)
     cmap_cells = cm.ScalarMappable(norm=colors.Normalize(0, 1), cmap=cmap)
     plotting.create_heatmap(ax, frequency_array, cmap_cells)
-    mutation_set = plotting.create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, genome_end)
+    mutation_set = plotting.create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, start, stop)
     if args.gff3_path is not None:
         if genes_with_mutations:
             # distinct colors for the genes
@@ -179,7 +199,7 @@ def main(sysargs=sys.argv[1:]):
             colors_genes = [cmap_genes(i) for i in range(len(genes_with_mutations))]
             # plot gene track
             plotting.create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, genome_end, min_y_location, genome_y_location, colors_genes)
-    plotting.create_axis(ax, n_mutations, min_y_location, n_samples, file_names, genome_end, genome_y_location, unique_mutations, reference_name)
+    plotting.create_axis(ax, n_mutations, min_y_location, n_samples, file_names, start, stop, genome_y_location, unique_mutations, reference_name)
     plotting.create_colorbar(args.threshold, cmap_cells, min_y_location, n_samples, ax)
     plotting.create_mutation_legend(mutation_set, min_y_location, n_samples)
 
diff --git a/virheat/scripts/data_prep.py b/virheat/scripts/data_prep.py
index e8e338e..edea269 100644
--- a/virheat/scripts/data_prep.py
+++ b/virheat/scripts/data_prep.py
@@ -216,6 +216,18 @@ def delete_n_mutations(frequency_array, unique_mutations, min_mut):
     return np.delete(frequency_array, mut_to_del, axis=1)
 
 
+def zoom_to_genomic_regions(unique_mutations, start_stop):
+    """
+    restrict the displayed mutations to a user defined genomic range
+    """
+    zoomed_unique = []
+
+    for mutation in unique_mutations:
+        if start_stop[0] <= int(mutation.split("_")[0]) <= start_stop[1]:
+            zoomed_unique.append(mutation)
+
+    return zoomed_unique
+
 def parse_gff3(file):
     """
     parse gff3 to dictionary
diff --git a/virheat/scripts/plotting.py b/virheat/scripts/plotting.py
index 7819eda..56d1967 100644
--- a/virheat/scripts/plotting.py
+++ b/virheat/scripts/plotting.py
@@ -25,7 +25,7 @@ def create_heatmap(ax, frequency_array, cmap):
         y_start += 1
 
 
-def create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, genome_end):
+def create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, genome_start, genome_stop):
     """
     create the genome rectangle, mutations and mappings to the heatmap
     """
@@ -51,11 +51,12 @@ def create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, geno
 
     # create mutation lines on the genome rectangle and the mapping to the respective cells
     x_start = 0
+    length = genome_stop - genome_start
     for mutation in unique_mutations:
         mutation_attributes = mutation.split("_")
         mutation_color = mutation_type_colors[mutation_attributes[3]]
         mutation_set.add(mutation_attributes[3])
-        mutation_x_location = n_mutations/genome_end*int(mutation_attributes[0])
+        mutation_x_location = n_mutations/length*(int(mutation_attributes[0])-genome_start)
         # create mutation lines
         plt.vlines(x=mutation_x_location, ymin=y_min, ymax=y_max, color=mutation_color)
         # create polygon
@@ -111,7 +112,7 @@ def create_mutation_legend(mutation_set, min_y_location, n_samples):
     plt.legend(bbox_to_anchor=(1.01, 0.95-(n_samples/(min_y_location+n_samples))), handles=legend_patches)
 
 
-def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, genome_end, genome_y_location, unique_mutations, reference_name):
+def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, start, stop, genome_y_location, unique_mutations, reference_name):
     """
     create the axis of the plot
     """
@@ -120,17 +121,18 @@ def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, genome_e
     ax.set_xlim(0, n_mutations)
     ax.set_ylim(-min_y_location, n_samples)
     # define new ticks depending on the genome size
+    axis_length = stop - start
     if n_mutations >= 20:
-        xtick_dis = round(genome_end/6, -int(math.log10(genome_end / 6)) + 1)
-        xtick_labels = [0, xtick_dis, xtick_dis*2, xtick_dis*3, xtick_dis*4, xtick_dis*5, genome_end]
+        xtick_dis = round(axis_length/6, -int(math.log10(axis_length / 6)) + 1)
+        xtick_labels = [start, start + xtick_dis, start + xtick_dis*2, start + xtick_dis*3, start + xtick_dis*4, start + xtick_dis*5, stop]
     elif n_mutations >= 10:
-        xtick_dis = round(genome_end / 3, -int(math.log10(genome_end / 3)) + 1)
-        xtick_labels = [0, xtick_dis, xtick_dis * 2, genome_end]
+        xtick_dis = round(axis_length / 3, -int(math.log10(axis_length / 3)) + 1)
+        xtick_labels = [start, start + xtick_dis, start + xtick_dis * 2, stop]
     else:
-        xtick_labels = [0, genome_end]
+        xtick_labels = [start, stop]
     xtick_labels = [int(tick) for tick in xtick_labels]
     # get the correct location of the genome pos on the axis
-    xticks = [n_mutations/genome_end*tick for tick in xtick_labels]
+    xticks = [n_mutations/axis_length*(tick - start) for tick in xtick_labels]
     # set new ticks and change spines/yaxis
     ax.set_xticks(xticks, xtick_labels)
     # set y axis labels

From ffad4a6ce4a959fc3ab810a4b0ad2b1a8869c871 Mon Sep 17 00:00:00 2001
From: jonas-fuchs <jonas.michael.fuchs@googlemail.com>
Date: Fri, 17 Nov 2023 10:27:28 +0100
Subject: [PATCH 2/5] updated the genome plotting parameters

---
 virheat/command.py          |  2 +-
 virheat/scripts/plotting.py | 31 ++++++++++++++++++++-----------
 2 files changed, 21 insertions(+), 12 deletions(-)

diff --git a/virheat/command.py b/virheat/command.py
index 8d286f8..851ebed 100644
--- a/virheat/command.py
+++ b/virheat/command.py
@@ -198,7 +198,7 @@ def main(sysargs=sys.argv[1:]):
             cmap_genes = plt.get_cmap('tab20', len(genes_with_mutations))
             colors_genes = [cmap_genes(i) for i in range(len(genes_with_mutations))]
             # plot gene track
-            plotting.create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, genome_end, min_y_location, genome_y_location, colors_genes)
+            plotting.create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, start, stop, min_y_location, genome_y_location, colors_genes)
     plotting.create_axis(ax, n_mutations, min_y_location, n_samples, file_names, start, stop, genome_y_location, unique_mutations, reference_name)
     plotting.create_colorbar(args.threshold, cmap_cells, min_y_location, n_samples, ax)
     plotting.create_mutation_legend(mutation_set, min_y_location, n_samples)
diff --git a/virheat/scripts/plotting.py b/virheat/scripts/plotting.py
index 56d1967..415a018 100644
--- a/virheat/scripts/plotting.py
+++ b/virheat/scripts/plotting.py
@@ -25,7 +25,7 @@ def create_heatmap(ax, frequency_array, cmap):
         y_start += 1
 
 
-def create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, genome_start, genome_stop):
+def create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, start, stop):
     """
     create the genome rectangle, mutations and mappings to the heatmap
     """
@@ -51,12 +51,12 @@ def create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, geno
 
     # create mutation lines on the genome rectangle and the mapping to the respective cells
     x_start = 0
-    length = genome_stop - genome_start
+    length = stop - start
     for mutation in unique_mutations:
         mutation_attributes = mutation.split("_")
         mutation_color = mutation_type_colors[mutation_attributes[3]]
         mutation_set.add(mutation_attributes[3])
-        mutation_x_location = n_mutations/length*(int(mutation_attributes[0])-genome_start)
+        mutation_x_location = n_mutations/length*(int(mutation_attributes[0])-start)
         # create mutation lines
         plt.vlines(x=mutation_x_location, ymin=y_min, ymax=y_max, color=mutation_color)
         # create polygon
@@ -156,27 +156,36 @@ def create_axis(ax, n_mutations, min_y_location, n_samples, file_names, start, s
     ax.spines["left"].set_visible(False)
 
 
-def create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, genome_end, min_y_location, genome_y_location, colors_genes):
+def create_gene_vis(ax, genes_with_mutations, n_mutations, y_size, n_tracks, start, stop, min_y_location, genome_y_location, colors_genes):
     """
     create the vis for the gene
     """
 
     gene_annotations = []
-    mult_factor = n_mutations/genome_end
+    mult_factor = n_mutations/(stop-start)
 
     for idx, gene in enumerate(genes_with_mutations):
-        start = (mult_factor*genes_with_mutations[gene][0][0], -min_y_location+(n_tracks-genes_with_mutations[gene][1])*genome_y_location/2-genome_y_location/2)
-        stop = mult_factor*genes_with_mutations[gene][0][1]
+        # define the plotting values for the patch
+        if genes_with_mutations[gene][0][0] < start:
+            x_value = 0
+        else:
+            x_value = mult_factor*(genes_with_mutations[gene][0][0]-start)
+        y_value = -min_y_location+(n_tracks-genes_with_mutations[gene][1])*genome_y_location/2-genome_y_location/2
+        if genes_with_mutations[gene][0][1] > stop:
+            width = n_mutations - x_value
+        else:
+            width = mult_factor*(genes_with_mutations[gene][0][1]-start) - x_value
         height = genome_y_location/2
+        # plot the patch
         ax.add_patch(
             patches.FancyBboxPatch(
-                                start, stop-start[0], height,
+                                (x_value, y_value), width, height,
                                 boxstyle="round,pad=-0.0040,rounding_size=0.03",
                                 ec="black", fc=colors_genes[idx]
                             )
         )
         # define text pos for gene description inside or below the gene box, depending if it fits within
-        if stop-start[0] > n_mutations/(y_size*8)*len(gene):
-            gene_annotations.append(ax.text(start[0]+(stop-start[0])/2, start[1]+height/2, gene, ha="center", va="center"))
+        if width > n_mutations/(y_size*8)*len(gene):
+            gene_annotations.append(ax.text(x_value+width/2, y_value+height/2, gene, ha="center", va="center"))
         else:
-            gene_annotations.append(ax.text(start[0]+(stop-start[0])/2, start[1]-height/4, gene, rotation=40, rotation_mode="anchor", ha="right", va="bottom"))
+            gene_annotations.append(ax.text(x_value+width/2, y_value-height/4, gene, rotation=40, rotation_mode="anchor", ha="right", va="bottom"))

From 4df87fc6904c5af714081e09af63544119a614d6 Mon Sep 17 00:00:00 2001
From: jonas-fuchs <jonas.michael.fuchs@googlemail.com>
Date: Fri, 17 Nov 2023 10:40:34 +0100
Subject: [PATCH 3/5] updated readme

---
 README.md | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/README.md b/README.md
index e642424..8f0220a 100644
--- a/README.md
+++ b/README.md
@@ -70,7 +70,9 @@ options:
   --delete, --no-delete
                         delete mutations that are present in all samples and their maximum frequency divergence is smaller than 0.5 (default: True)
   -n None, --delete-n None
-                        do not show mutations that occur n times or less (default: Do not delete)                      
+                        do not show mutations that occur n times or less (default: Do not delete)
+  -z start stop, --zoom start stop
+                        restrict the plot to a specific genomic region.                      
   --sort, --no-sort     sort sample names alphanumerically (default: False)
   --min-cov 20          display mutations covered at least x time (only if per base cov tsv files are provided)
   -v, --version         show program's version number and exit

From af905a536ca0c9375ab5eaa4bcd793bd913a2008 Mon Sep 17 00:00:00 2001
From: jonas-fuchs <jonas.michael.fuchs@googlemail.com>
Date: Fri, 17 Nov 2023 11:12:11 +0100
Subject: [PATCH 4/5] added --name option for plot name

---
 README.md          |  2 ++
 virheat/command.py | 15 ++++++++++++---
 2 files changed, 14 insertions(+), 3 deletions(-)

diff --git a/README.md b/README.md
index 8f0220a..305008c 100644
--- a/README.md
+++ b/README.md
@@ -60,6 +60,8 @@ positional arguments:
 
 options:
   -h, --help            show this help message and exit
+  --name virHEAT_plot.pdf
+                        plot name and file type (pdf, png, svg, jpg). Default: virHEAT_plot.pdf
   -l None, --genome-length None
                         length of the genome (needed if gff3 is not provided)
   -g None, --gff3-path None
diff --git a/virheat/command.py b/virheat/command.py
index 851ebed..6d2fde8 100644
--- a/virheat/command.py
+++ b/virheat/command.py
@@ -32,6 +32,13 @@ def get_args(sysargs):
         nargs=2,
         help="folder containing input files and output folder"
     )
+    parser.add_argument(
+        "--name",
+        type=str,
+        metavar="virHEAT_plot.pdf",
+        default="virHEAT_plot.pdf",
+        help="plot name and file type (pdf, png, svg, jpg). Default: virHEAT_plot.pdf"
+    )
     parser.add_argument(
         "-l",
         "--genome-length",
@@ -126,22 +133,24 @@ def main(sysargs=sys.argv[1:]):
     vcf_files = data_prep.get_files(args.input[0], "vcf")
     if args.sort:
         vcf_files = sorted(vcf_files, key=lambda x: data_prep.get_digit_and_alpha(os.path.basename(x)))
+
     # extract vcf info
     reference_name, frequency_lists, unique_mutations, file_names = data_prep.extract_vcf_data(vcf_files, threshold=args.threshold)
     if args.zoom:
         unique_mutations = data_prep.zoom_to_genomic_regions(unique_mutations, args.zoom)
     frequency_array = data_prep.create_freq_array(unique_mutations, frequency_lists)
+
     # user specified delete options (removes mutations based on various rationales)
     if args.delete:
         frequency_array = data_prep.delete_common_mutations(frequency_array, unique_mutations)
     if args.delete_n is not None:
         frequency_array = data_prep.delete_n_mutations(frequency_array, unique_mutations, args.delete_n)
+
     # annotate low coverage if per base coveage from qualimap was provided
     data_prep.annotate_non_covered_regions(args.input[0], args.min_cov, frequency_array, file_names, unique_mutations)
 
     # define relative locations of all items in the plot
-    n_samples = len(frequency_array)
-    n_mutations = len(frequency_array[0])
+    n_samples, n_mutations = len(frequency_array), len(frequency_array[0])
     if n_mutations == 0:
         sys.exit("\033[31m\033[1mERROR:\033[0m Frequency array seems to be empty. There is nothing to plot.")
     if n_samples < 4:
@@ -208,5 +217,5 @@ def main(sysargs=sys.argv[1:]):
         os.makedirs(args.input[1])
 
     # save fig
-    fig.savefig(os.path.join(args.input[1], "virHEAT_plot.pdf"), bbox_inches="tight")
+    fig.savefig(os.path.join(args.input[1], args.name), bbox_inches="tight")
 

From f397a53e3d820423f3082ce741db6c8f92bbfda2 Mon Sep 17 00:00:00 2001
From: jonas-fuchs <jonas.michael.fuchs@googlemail.com>
Date: Fri, 17 Nov 2023 11:21:23 +0100
Subject: [PATCH 5/5] added error message if 'region' is missing in gff3

---
 virheat/scripts/data_prep.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/virheat/scripts/data_prep.py b/virheat/scripts/data_prep.py
index edea269..45ca30e 100644
--- a/virheat/scripts/data_prep.py
+++ b/virheat/scripts/data_prep.py
@@ -271,6 +271,8 @@ def get_genome_end(gff3_dict):
 
     genome_end = 0
 
+    if "region" not in gff3_dict:
+        sys.exit("\033[31m\033[1mERROR:\033[0m Region annotation is missing in the gff3!")
     for attribute in gff3_dict["region"].keys():
         stop = gff3_dict["region"][attribute]["stop"]
         if stop > genome_end: