diff --git a/dress/configs/generate_binfiller.yaml b/dress/configs/generate_binfiller.yaml index 32a98fa..18c344c 100644 --- a/dress/configs/generate_binfiller.yaml +++ b/dress/configs/generate_binfiller.yaml @@ -15,7 +15,8 @@ generate: preprocessing: cache_dir: data/cache/ genome: data/cache/Homo_sapiens.GRCh38.dna.primary_assembly.fa - use_full_sequence: false + use_model_resolution: false + use_full_triplet: false fitness: minimize_fitness: false fitness_function: bin_filler diff --git a/dress/configs/generate_binfiller_pwm_grammar.yaml b/dress/configs/generate_binfiller_pwm_grammar.yaml index a183281..8cb667c 100644 --- a/dress/configs/generate_binfiller_pwm_grammar.yaml +++ b/dress/configs/generate_binfiller_pwm_grammar.yaml @@ -15,7 +15,8 @@ generate: preprocessing: cache_dir: data/cache/ genome: data/cache/Homo_sapiens.GRCh38.dna.primary_assembly.fa - use_full_sequence: false + use_model_resolution: false + use_full_triplet: false fitness: minimize_fitness: false fitness_function: bin_filler diff --git a/dress/configs/generate_iad.yaml b/dress/configs/generate_iad.yaml index 279e126..8506c55 100644 --- a/dress/configs/generate_iad.yaml +++ b/dress/configs/generate_iad.yaml @@ -15,7 +15,8 @@ generate: preprocessing: cache_dir: data/cache/ genome: data/cache/Homo_sapiens.GRCh38.dna.primary_assembly.fa - use_full_sequence: false + use_model_resolution: false + use_full_triplet: false fitness: minimize_fitness: false fitness_function: increase_archive_diversity diff --git a/dress/configs/generate_iad_pwm_grammar.yaml b/dress/configs/generate_iad_pwm_grammar.yaml index 3545246..e771a3c 100644 --- a/dress/configs/generate_iad_pwm_grammar.yaml +++ b/dress/configs/generate_iad_pwm_grammar.yaml @@ -15,7 +15,8 @@ generate: preprocessing: cache_dir: data/cache/ genome: data/cache/Homo_sapiens.GRCh38.dna.primary_assembly.fa - use_full_sequence: false + use_model_resolution: false + use_full_triplet: false fitness: minimize_fitness: false fitness_function: increase_archive_diversity diff --git a/dress/datasetevaluation/representation/motifs/utils.py b/dress/datasetevaluation/representation/motifs/utils.py index 484a21f..9d495db 100644 --- a/dress/datasetevaluation/representation/motifs/utils.py +++ b/dress/datasetevaluation/representation/motifs/utils.py @@ -4,6 +4,8 @@ import re import pandas as pd from typing import Union +import pandas as pd +pd.set_option('future.no_silent_downcasting', True) from dress.datasetevaluation.representation.motifs.rbp_lists import RBP_SUBSETS import numpy as np @@ -406,7 +408,7 @@ def _remove_self_contained(gr: pr.PyRanges, scan_method: str) -> pr.PyRanges: df = pd.merge(df, contained_same_rbp, how="left", on=to_drop_cols).drop( columns=to_clean_cols ) - df.Has_self_submotif.fillna(False, inplace=True) + df.fillna({'Has_self_submotif': False}, inplace=True) ####################### # Other RBP contained # @@ -428,7 +430,7 @@ def _remove_self_contained(gr: pr.PyRanges, scan_method: str) -> pr.PyRanges: df = pd.merge(df, contained_other_rbp, how="left", on=to_drop_cols).drop( columns=to_clean_cols[:-1] ) - df.Has_other_submotif.fillna(False, inplace=True) + df.fillna({'Has_other_submotif': False}, inplace=True) # logger.debug(".. {} hits flagged ..".format(contained_other_rbp.shape[0])) return pr.PyRanges(df) diff --git a/dress/datasetgeneration/json_schema.py b/dress/datasetgeneration/json_schema.py index 05c87e0..57dc399 100644 --- a/dress/datasetgeneration/json_schema.py +++ b/dress/datasetgeneration/json_schema.py @@ -7,6 +7,8 @@ "properties": { "dry_run": {"type": "boolean"}, "disable_gpu": {"type": "boolean"}, + "use_full_triplet": {"type": "boolean"}, + "use_model_resolution": {"type": "boolean"}, "verbosity": {"type": "integer"}, "shuffle_input": {"type": ["null", "string"]}, "outdir": {"type": "string"}, diff --git a/dress/datasetgeneration/preprocessing/gtf_cache.py b/dress/datasetgeneration/preprocessing/gtf_cache.py index 711e159..c466a48 100644 --- a/dress/datasetgeneration/preprocessing/gtf_cache.py +++ b/dress/datasetgeneration/preprocessing/gtf_cache.py @@ -121,7 +121,9 @@ def preprocessing(data: pr.PyRanges, **kwargs): df=extracted, fasta=genome, extend_borders=100, - use_full_seqs=kwargs["use_full_sequence"], + use_full_triplet=kwargs["use_full_triplet"], + use_model_resolution=kwargs["use_model_resolution"], + model = kwargs["model"] ) if os.path.isdir(kwargs["outdir"]): @@ -156,7 +158,10 @@ def write_output( Additional arguments in **kwargs: outdir (str): Output directory. outbasename (str): Output basename. - use_full_sequence (bool): Whether to use the full sequence when running the black box model. + use_full_triplet (bool): Whether to use the full exon triplet as input sequence + when making model inferences. + use_model_resolution (bool): Whether to use the model resolution to determine input + sequence size when making model inferences. """ to_write = { @@ -193,8 +198,13 @@ def write_output( ] ], ) + + out_flag = '' + if kwargs['use_full_triplet']: + out_flag = "_full_triplet" + elif kwargs['use_model_resolution']: + out_flag = "_model_res" - out_flag = "" if kwargs["use_full_sequence"] else "_trimmed_at_5000bp" if len(extracted_with_seqs) > 0: extracted_with_seqs[ ["header", "acceptor_idx", "donor_idx", "tx_id", "exon"] diff --git a/dress/datasetgeneration/preprocessing/utils.py b/dress/datasetgeneration/preprocessing/utils.py index 5231986..a972aa6 100644 --- a/dress/datasetgeneration/preprocessing/utils.py +++ b/dress/datasetgeneration/preprocessing/utils.py @@ -79,7 +79,7 @@ def process_ss_idx( def _get_flat_ss( - info: pd.Series, _level: str, start: int, end: int, use_full_seqs: bool + info: pd.Series, _level: str, start: int, end: int, use_full_tiplet: bool ): """ Extracts flat splice site indexes for sequences @@ -96,11 +96,11 @@ def _get_flat_ss( extensions. end (int): End coordinate of the flat sequence, after accounting for the extensions. - use_full_seqs (bool, optional): Whether `start` and `end` + use_full_tiplet (bool, optional): Whether `start` and `end` coordinates represent the true start and end of the sequence at a given surrounding level. If `False`, they represent start and end positions up to the - limit of deep learning model resolution. + limit of the model resolution. Returns: str: Flat acceptor indexes @@ -109,7 +109,7 @@ def _get_flat_ss( _info = info.copy() # Check out of scope indexes - if use_full_seqs is False: + if use_full_tiplet is False: if _level != 0: cols = [ "Start_upstream" + _level, @@ -205,7 +205,6 @@ def get_fasta_sequences( end_col (str, optional): Colname in `x` where end coordinate is. Defaults to "End". Returns: - pd.Series: Additional column with the fasta sequence for the requested interval """ @@ -444,38 +443,47 @@ def generate_pipeline_input( df: pd.DataFrame, fasta: str, extend_borders: int = 0, - use_full_seqs: bool = True, + use_full_triplet: bool = False, + use_model_resolution: bool = True, + model: str = "spliceai", ): """ - Generates SpliceAI input sequences from a dataframe - of target features with upstream and downstream + Generates model input sequences from a dataframe + of target exons with upstream and downstream intervals. It will generate the splice site indexes of the exons and introns surrounding the target exon of interest, for the associated transcript ID. - Additionally, two sets of fasta sequences can be written: - If `use_full_seqs` is `True`: - - Complete sequences ranging from the start of the highest - upstream level available (by default 2) up to the end of the - feature represented by the highest downstream level available. - Depending on intron sizes, these sequences may be very large. - If `extend_borders` is > 0, this number of nucleotides will - be additionally extracted on each side. - - If `use_full_seqs` is `False`: - - Sequences surrounding the target exon up to a maximum size - of 5000 nucleotides on each side will be extracted. This is - the maximum resolution sequence-based models can get to - see long-range effects on splice site definition. - When this is set, sequences may be much shorter than the expected - size, specially if dealing with introns that are >50kb long. If `False`, - inference time for these long sequences may be quite high. - - On the other side, if the length of the final sequence is small - (let's say less than 10000 + average exon size), the input will - be padded so that models can accept the input. + Sequence extraction can be done in three ways: + - `use_full_triplet` is `True`: The complete sequence + from the start of the exon upstream until the end of the + exon downstream is extracted, regardless of the size of + the resulting sequence. This can result in very large + sequences if introns are very long. + - `use_model_resolution` is `True`: The sequence is + extracted up to the maximum resolution of the model. For example, + for SpliceAI, the extracted sequence will be the cassette exon plus + 5000bp upstream of the acceptor and 5000bp downstream of the donor. + - `use_model_resolution` is `False` and `use_full_triplet` is `False`: + If neither of the above is set (default), the sequence will be extracted + such that: if the full exon triplet is smaller than the model resolution, + the full exon triplet is extracted, and then at inference time the sequence + is padded to the model resolution. If the exon triplet is larger than the + model resolution, the sequence is trimmed at the model resolution, which may not + include upstream and/or downstream exons. + + If `extend_borders` > 0, the sequence will be extended on both sides. + This is useful for cases where the acceptor of the uptream exon + or the donor of the downstream exon represent the start or end of the + sequence, respectively. Because the splice site score of such positions + may not be properly captured (because the full context is not present), + extending the sequence can help to provide a more realistic prediction + as if those positions were in the middle of the sequence. This extension + is applied when `use_full_triplet` is `True` or in the default setting + (`use_model_resolution=False` and `use_full_triplet=False`) when the + sequence is shorter than the model resolution. Returns: Tuples with the following information: @@ -484,6 +492,10 @@ def generate_pipeline_input( - A dataframe with the exons excluded due to having NAs, if level == 2 """ + assert any( + x is False for x in [use_full_triplet, use_model_resolution] + ), "Can't set both `use_full_triplet` and `use_model_resolution` to `True`." + if list(df.filter(regex="stream")): # Extract level so that we know the # borders of the genomic sequence to extract @@ -512,7 +524,7 @@ def generate_pipeline_input( for _, seq_record in df.iterrows(): - if use_full_seqs: + if use_full_triplet: if seq_record.Strand == "+": start = "Start_upstream" + _level if level != 0 else "Start" end = "End_downstream" + _level if level != 0 else "End" @@ -549,7 +561,7 @@ def generate_pipeline_input( _level, start=seq_record[start] - extend_borders, end=seq_record[end] + extend_borders, - use_full_seqs=use_full_seqs, + use_full_tiplet=use_full_triplet, ) out.append( @@ -565,101 +577,14 @@ def generate_pipeline_input( ) else: - # Get sequences trimmed at max SpliceAI resolution - def _get_slack( - seq_record: pd.Series, - region: Literal["upstream", "downstream"], - _level: str, - extend_borders: int, - ): - """ - Returns the number of base pairs to extend - coordinates surrounding the central - exon, for cases where the upstream or downstream (`region`) - exon lie within the boundaries of the model resolution - (e.g, SpliceAI = 5000bp), so that the sequence to be run contains - only the triplet exon + `extend_borders` region. - It is the same procedure as when `use_full_sequence` - for cases where the upstream or downstream exon is less than - 5000bp away from the central exon. - - Args: - seq_record (pd.Series): A row from the input dataframe - region (Literal): Either 'upstream' or 'downstream' - _level (str): The level to be considered - extend_borders (int): The number of base pairs to extend the coordinates, - regardless of the region - - Returns: - int: The number of base pairs to extend the coordinates - """ - if region == "upstream": - - if seq_record.Strand == "+": - if ( - seq_record.Start - - seq_record["End_upstream{}".format(_level)] - >= 5000 - ): - - return 5000 - - else: - return ( - seq_record.Start - - seq_record["Start_upstream{}".format(_level)] - + extend_borders - ) - - else: - if ( - seq_record["Start_upstream{}".format(_level)] - - seq_record.End - >= 5000 - ): - - return 5000 - else: - return ( - seq_record["End_upstream{}".format(_level)] - - seq_record.End - + extend_borders - ) - - else: - if seq_record.Strand == "+": - if ( - seq_record["Start_downstream{}".format(_level)] - - seq_record.End - >= 5000 - ): - return 5000 - else: - return ( - seq_record["End_downstream{}".format(_level)] - - seq_record.End - + extend_borders - ) - else: - - if ( - seq_record.Start - - seq_record["End_downstream{}".format(_level)] - >= 5000 - ): - return 5000 - else: - return ( - seq_record.Start - - seq_record["Start_downstream{}".format(_level)] - + extend_borders - ) slack_upst = _get_slack( seq_record, region="upstream", _level=_level, extend_borders=extend_borders, + use_model_resolution=use_model_resolution, + model=model, ) slack_downst = _get_slack( @@ -667,6 +592,8 @@ def _get_slack( region="downstream", _level=_level, extend_borders=extend_borders, + use_model_resolution=use_model_resolution, + model=model, ) seq = get_fasta_sequences( @@ -702,7 +629,7 @@ def _get_slack( _level, start=min(left, right), end=max(left, right), - use_full_seqs=use_full_seqs, + use_full_tiplet=use_full_triplet, ) out.append( @@ -743,3 +670,85 @@ def _get_slack( ).drop_duplicates() return out, out_dpsi, _with_NAs + +def _get_slack( + seq_record: pd.Series, + region: Literal["upstream", "downstream"], + _level: str, + extend_borders: int, + use_model_resolution: bool, + model: str +): + """ + Returns the number of base pairs to extend + coordinates surrounding the central exon. + + If `use_model_resolution` is `True` returns the number of + base pairs corresponding to the model resolution + (e.g., 5000bp on each side for SpliceAI). If `False`, + returns the number of base pairs up to location of the + upstream or downstream exons, if their distance is lower + than the model resolution. If the distance is higher, + returns the distance to the model resolution. + + Args: + seq_record (pd.Series): A row from the input dataframe + region (Literal): Either 'upstream' or 'downstream' + _level (str): The level to be considered + extend_borders (int): The number of base pairs to extend the coordinates, + regardless of the region + use_model_resolution (bool): Whether to use the model resolution + model (str): The model to be used + + Returns: + int: The number of base pairs to extend the coordinates + """ + models = {"spliceai": 5000, "pangolin": 5000} + model_res = models[model] + if use_model_resolution: + return model_res + + if region == "upstream": + + if seq_record.Strand == "+": + if seq_record.Start - seq_record["End_upstream{}".format(_level)] >= model_res: + return model_res + + else: + return ( + seq_record.Start + - seq_record["Start_upstream{}".format(_level)] + + extend_borders + ) + + else: + if seq_record["Start_upstream{}".format(_level)] - seq_record.End >= model_res: + + return model_res + else: + return ( + seq_record["End_upstream{}".format(_level)] + - seq_record.End + + extend_borders + ) + + else: + if seq_record.Strand == "+": + if seq_record["Start_downstream{}".format(_level)] - seq_record.End >= model_res: + return model_res + else: + return ( + seq_record["End_downstream{}".format(_level)] + - seq_record.End + + extend_borders + ) + else: + + if seq_record.Start - seq_record["End_downstream{}".format(_level)] >= model_res: + return model_res + else: + return ( + seq_record.Start + - seq_record["Start_downstream{}".format(_level)] + + extend_borders + ) diff --git a/dress/datasetgeneration/run.py b/dress/datasetgeneration/run.py index 99ac47d..1590a8d 100755 --- a/dress/datasetgeneration/run.py +++ b/dress/datasetgeneration/run.py @@ -595,16 +595,30 @@ def grammar_options(fun): "it overrides all other non-mandatory arguments. Default: None. A working " "config file is presented in 'dress/configs/generate.yaml' file.", ) +@click.option( + "-umr", + "--use_model_resolution", + is_flag=True, + help="Whether to use the model resolution for extracting and predicting the full sequence. " + "If set, for example with '--model spliceai', the input sequence will include the " + "target exon plus 5000bp on either side. This ensures that both the acceptor and " + "donor positions are evaluated within the real genomic context, regardless of " + "whether surrounding exons fall within or outside the model resolution. " + "Default: 'False'. When not set (default), the surrounding context of the target exon is " + "dynamically extracted. If upstream/downstream exons are within the model resolution, " + "sequences are shorter and padded during inference. If they lie beyond the model " + "resolution, the sequence is trimmed to the model resolution. Cannot be used with '--use_full_triplet'.", +) @click.option( "-ufs", - "--use_full_sequence", + "--use_full_triplet", is_flag=True, - help="Whether to extract and predict the full sequence " - "from the start coordinate of the upstream exon to the end coordinate of the " - "downstream exon. Default: 'False': Use restricted sequence regions up to the resolution limit " - "of the selected model (e.g, for SpliceaAI, 5000bp on each side of the target exon). " - "Only used when input is 'bed' or 'tabular'. Use this option with caution, it can easily lead to " - "memory exhaustion if the sequence triplet size is too large.", + help=( + "Whether to extract and predict the full sequence from the start coordinate of the upstream " + "exon to the end coordinate of the downstream exon, regardless of the model resolution. " + "Default: 'False'. Use this option with caution, as it can easily lead to memory exhaustion " + "if the sequence triplet size is too large (e.g., large introns). Cannot be used with '--use_model_resolution'." + ), ) @click.option( "-dr", diff --git a/dress/datasetgeneration/validate_args.py b/dress/datasetgeneration/validate_args.py index 0887261..b181b50 100644 --- a/dress/datasetgeneration/validate_args.py +++ b/dress/datasetgeneration/validate_args.py @@ -15,6 +15,8 @@ "--cache_dir", "--genome", "--config", + "--use_model_resolution", + "--use_full_triplet" ], }, { @@ -89,7 +91,6 @@ { "name": "Other options", "options": [ - "--use_full_sequence", "--dry_run", "--disable_gpu", "--verbosity", diff --git a/tests/preprocessing_raw_test.py b/tests/preprocessing_raw_test.py index fe3f876..4e6ea24 100644 --- a/tests/preprocessing_raw_test.py +++ b/tests/preprocessing_raw_test.py @@ -11,41 +11,92 @@ ) -abs_path = os.path.dirname(os.path.abspath(__file__)) -raw_data = os.path.join(abs_path,"data/raw_data.tsv") -cache_dir = os.path.join(abs_path,"data") -genome_c = os.path.join(abs_path, "data/chr22.fa.gz") -level = 2 - -with gzip.open(genome_c, "rb") as f_in: - tmp_f = tempfile.NamedTemporaryFile() - genome = tmp_f.name - with open(genome, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) - +# abs_path = os.path.dirname(os.path.abspath(__file__)) +# raw_data = os.path.join(abs_path,"data/raw_data.tsv") +# cache_dir = os.path.join(abs_path,"data") +# genome_c = os.path.join(abs_path, "data/chr22.fa.gz") +# level = 2 + +# with gzip.open(genome_c, "rb") as f_in: +# tmp_f = tempfile.NamedTemporaryFile() +# genome = tmp_f.name +# with open(genome, "wb") as f_out: +# shutil.copyfileobj(f_in, f_out) + +@pytest.fixture(scope="module") +def setup_paths(): + abs_path = os.path.dirname(os.path.abspath(__file__)) + raw_data = os.path.join(abs_path, "data/raw_data.tsv") + cache_dir = os.path.join(abs_path, "data") + genome_c = os.path.join(abs_path, "data/chr22.fa.gz") + + with gzip.open(genome_c, "rb") as f_in: + tmp_f = tempfile.NamedTemporaryFile(delete=False) + genome = tmp_f.name + with open(genome, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + + yield raw_data, cache_dir, genome + + os.remove(genome) + +@pytest.fixture +def initial_setup(setup_paths): + raw_data, cache_dir, genome = setup_paths + df = tabular_file_to_genomics_df( + raw_data, col_index=0, is_0_based=False, header=0 + ) + extracted, absent_in_gtf = extractGeneStructure( + df.as_df(), cache_dir=cache_dir, genome=genome, level=2 + ) + + assert extracted.shape[0] == 8 + assert absent_in_gtf.shape[0] == 0 + return extracted, absent_in_gtf + +@pytest.fixture +def common_tests(): + def assert_common_tests(seq_info, len_test, header_test, acceptors_test, donors_test): + exon = seq_info.exon + header = seq_info.header + acceptor_idx = seq_info.acceptor_idx + donor_idx = seq_info.donor_idx + seq = seq_info.seq + + # Assert that the header (after expanding exon coordinates) is correct + assert header_test[exon] == header + + # Assert that sequences have the correct length + assert len_test[exon] == len(seq) + + # Assert that sequences have the correct indexes for donors and acceptors + assert acceptors_test[exon] == acceptor_idx + assert donors_test[exon] == donor_idx + + # Assert that the acceptor and donor sites at the retrieved indexes are correct + for i, (acceptor, donor) in enumerate( + zip(acceptor_idx.split(";"), donor_idx.split(";")) + ): + if i == 1: + acceptor = int(acceptor) + donor = int(donor) + assert seq[acceptor - 2 : acceptor] == "AG" + assert seq[donor + 1 : donor + 3] == "GT" + + return assert_common_tests class TestRawPreprocessing: - @pytest.fixture - def initial_setup(self): - df = tabular_file_to_genomics_df( - raw_data, col_index=0, is_0_based=False, header=0 - ) - extracted, absent_in_gtf = extractGeneStructure( - df.as_df(), cache_dir=cache_dir, genome=genome, level=level - ) - - assert extracted.shape[0] == 8 - assert absent_in_gtf.shape[0] == 0 - return extracted, absent_in_gtf - - def test_full_sequence(self, initial_setup): - extracted = initial_setup[0] + + def test_full_sequence(self, setup_paths, initial_setup, common_tests): + extracted, _ = initial_setup + _, cache_dir, genome = setup_paths data, _, na_exons = generate_pipeline_input( df=extracted, fasta=open_fasta(genome, cache_dir), extend_borders=100, - use_full_seqs=True, + use_full_triplet=True, + use_model_resolution=False, ) assert na_exons.shape[0] == 0 @@ -96,41 +147,18 @@ def test_full_sequence(self, initial_setup): } for _, seq_info in data.iterrows(): - exon = seq_info.exon - header = seq_info.header - acceptor_idx = seq_info.acceptor_idx - donor_idx = seq_info.donor_idx - seq = seq_info.seq - - # Assert that the header (after expanding exon coordinates) is correct - assert header_test[exon] == header - - # Assert that sequences have the correct length - assert len_test[exon] == len(seq) + common_tests(seq_info, len_test, header_test, acceptors_test, donors_test) - # Asser that sequences have the correct indexes for donors and acceptors - assert acceptors_test[exon] == acceptor_idx - assert donors_test[exon] == donor_idx - - # Assert that the acceptor and donor sites at the retrieved indexes are correct - for i, (acceptor, donor) in enumerate( - zip(acceptor_idx.split(";"), donor_idx.split(";")) - ): - if i == 1: - acceptor = int(acceptor) - donor = int(donor) - - assert seq[acceptor - 2 : acceptor] == "AG" - assert seq[donor + 1 : donor + 3] == "GT" - - def test_trimmed_sequence(self, initial_setup): - extracted = initial_setup[0] + def test_trimmed_sequence(self, setup_paths, initial_setup, common_tests): + extracted, _ = initial_setup + _, cache_dir, genome = setup_paths data, _, _ = generate_pipeline_input( df=extracted, fasta=open_fasta(genome, cache_dir), extend_borders=100, - use_full_seqs=False, + use_full_triplet=False, + use_model_resolution=False, ) len_test = { @@ -178,29 +206,17 @@ def test_trimmed_sequence(self, initial_setup): } for _, seq_info in data.iterrows(): - exon = seq_info.exon + common_tests(seq_info, len_test, header_test, acceptors_test, donors_test) - header = seq_info.header + # Specific tests for trimmed (default) resolution + exon = seq_info.exon acceptor_idx = seq_info.acceptor_idx donor_idx = seq_info.donor_idx seq = seq_info.seq - # Assert that the header (after expanding exon coordinates) is correct - assert header_test[exon] == header - - # Assert that sequences have the correct length - assert len_test[exon] == len(seq) - - # Assert that sequences have the correct indexes for donors and acceptors - assert acceptors_test[exon] == acceptor_idx - assert donors_test[exon] == donor_idx - - # Assert that the acceptor and donor sites at the retrieved indexes are correct - for i, (acceptor, donor) in enumerate( - zip(acceptor_idx.split(";"), donor_idx.split(";")) - ): - # Case where just the upstream is trimmed - if exon == "chr22:31688156-31688227": + # Case where just the upstream is trimmed + if exon == "chr22:31688156-31688227": + for i, (acceptor, donor) in enumerate(zip(acceptor_idx.split(";"), donor_idx.split(";"))): if i == 0: assert acceptor == "" assert donor == "" @@ -214,9 +230,83 @@ def test_trimmed_sequence(self, initial_setup): assert seq[donor - 1 : donor + 5] == "AATAAC" assert len(seq[acceptor : donor + 1]) == 4457 - if i == 1: - acceptor = int(acceptor) - donor = int(donor) + def test_model_resolution(self, setup_paths, initial_setup, common_tests): + extracted, _ = initial_setup + _, cache_dir, genome = setup_paths - assert seq[acceptor - 2 : acceptor] == "AG" - assert seq[donor + 1 : donor + 3] == "GT" \ No newline at end of file + data, _, _ = generate_pipeline_input( + df=extracted, + fasta=open_fasta(genome, cache_dir), + extend_borders=100, + use_full_triplet=False, + use_model_resolution=True, + model="spliceai" + ) + + len_test = { + "chr22:29957036-29957088": 53 + 10000, + "chr22:29970976-29971062": 87 + 10000, + "chr22:36253838-36253991": 154 + 10000, + "chr22:17812507-17812557": 51 + 10000, + "chr22:31688156-31688227": 72 + 10000, + "chr22:42601970-42602107": 138 + 10000, + "chr22:43137081-43137298": 218 + 10000, + "chr22:50457034-50457111": 78 + 10000, + } + + header_test = { + "chr22:29957036-29957088": "chr22:29952036-29962088(+)", + "chr22:29970976-29971062": "chr22:29965976-29976062(+)", + "chr22:36253838-36253991": "chr22:36248838-36258991(+)", + "chr22:17812507-17812557": "chr22:17807507-17817557(-)", + "chr22:31688156-31688227": "chr22:31683156-31693227(-)", + "chr22:42601970-42602107": "chr22:42596970-42607107(-)", + "chr22:43137081-43137298": "chr22:43132081-43142298(-)", + "chr22:50457034-50457111": "chr22:50452034-50462111(-)", + } + + acceptors_test = { + "chr22:29957036-29957088": ";5000;", + "chr22:29970976-29971062": ";5000;", + "chr22:36253838-36253991": "4240;5000;6099", + "chr22:17812507-17812557": "773;5000;6744", + "chr22:31688156-31688227": ";5000;7424", + "chr22:42601970-42602107": "3947;5000;7314", + "chr22:43137081-43137298": "1049;5000;8812", + "chr22:50457034-50457111": "2719;5000;5438", + } + + donors_test = { + "chr22:29957036-29957088": ";5052;", + "chr22:29970976-29971062": ";5086;", + "chr22:36253838-36253991": "4381;5153;6161", + "chr22:17812507-17812557": "867;5050;6854", + "chr22:31688156-31688227": ";5071;", + "chr22:42601970-42602107": "4337;5137;7409", + "chr22:43137081-43137298": "1136;5217;10089", + "chr22:50457034-50457111": "2856;5077;5619", + } + + for _, seq_info in data.iterrows(): + common_tests(seq_info, len_test, header_test, acceptors_test, donors_test) + + exon = seq_info.exon + acceptor_idx = seq_info.acceptor_idx + donor_idx = seq_info.donor_idx + seq = seq_info.seq + + if exon == "chr22:31688156-31688227": + for i, (acceptor, donor) in enumerate( + zip(acceptor_idx.split(";"), donor_idx.split(";")) + ): + if i == 0: + assert acceptor == "" + assert donor == "" + + if i == 2: + acceptor = int(acceptor) + assert seq[acceptor - 2 : acceptor] == "AG" + assert seq[acceptor - 8 : acceptor + 4] == "ATCCTTAGGTGT" + + assert donor == "" + assert len(seq[acceptor:]) == 2648 \ No newline at end of file