diff --git a/dress/configs/generate_binfiller.yaml b/dress/configs/generate_binfiller.yaml index 18c344c..bb7e9a7 100644 --- a/dress/configs/generate_binfiller.yaml +++ b/dress/configs/generate_binfiller.yaml @@ -15,8 +15,8 @@ generate: preprocessing: cache_dir: data/cache/ genome: data/cache/Homo_sapiens.GRCh38.dna.primary_assembly.fa - use_model_resolution: false - use_full_triplet: false + extract_dynamically: false + extract_full_triplet: false fitness: minimize_fitness: false fitness_function: bin_filler diff --git a/dress/configs/generate_binfiller_pwm_grammar.yaml b/dress/configs/generate_binfiller_pwm_grammar.yaml index 8cb667c..fbc7b5a 100644 --- a/dress/configs/generate_binfiller_pwm_grammar.yaml +++ b/dress/configs/generate_binfiller_pwm_grammar.yaml @@ -15,8 +15,8 @@ generate: preprocessing: cache_dir: data/cache/ genome: data/cache/Homo_sapiens.GRCh38.dna.primary_assembly.fa - use_model_resolution: false - use_full_triplet: false + extract_dynamically: false + extract_full_triplet: false fitness: minimize_fitness: false fitness_function: bin_filler diff --git a/dress/configs/generate_iad.yaml b/dress/configs/generate_iad.yaml index 8506c55..183a23a 100644 --- a/dress/configs/generate_iad.yaml +++ b/dress/configs/generate_iad.yaml @@ -15,8 +15,8 @@ generate: preprocessing: cache_dir: data/cache/ genome: data/cache/Homo_sapiens.GRCh38.dna.primary_assembly.fa - use_model_resolution: false - use_full_triplet: false + extract_dynamically: false + extract_full_triplet: false fitness: minimize_fitness: false fitness_function: increase_archive_diversity diff --git a/dress/configs/generate_iad_pwm_grammar.yaml b/dress/configs/generate_iad_pwm_grammar.yaml index e771a3c..b3752a0 100644 --- a/dress/configs/generate_iad_pwm_grammar.yaml +++ b/dress/configs/generate_iad_pwm_grammar.yaml @@ -15,8 +15,8 @@ generate: preprocessing: cache_dir: data/cache/ genome: data/cache/Homo_sapiens.GRCh38.dna.primary_assembly.fa - use_model_resolution: false - use_full_triplet: false + extract_dynamically: false + extract_full_triplet: false fitness: minimize_fitness: false fitness_function: increase_archive_diversity diff --git a/dress/datasetgeneration/json_schema.py b/dress/datasetgeneration/json_schema.py index 87f85ab..3280900 100644 --- a/dress/datasetgeneration/json_schema.py +++ b/dress/datasetgeneration/json_schema.py @@ -7,8 +7,8 @@ "properties": { "dry_run": {"type": "boolean"}, "disable_gpu": {"type": "boolean"}, - "use_full_triplet": {"type": "boolean"}, - "use_model_resolution": {"type": "boolean"}, + "extract_full_triplet": {"type": "boolean"}, + "extract_dynamically": {"type": "boolean"}, "verbosity": {"type": "integer"}, "shuffle_input": {"type": ["null", "string"]}, "outdir": {"type": "string"}, diff --git a/dress/datasetgeneration/preprocessing/gtf_cache.py b/dress/datasetgeneration/preprocessing/gtf_cache.py index c466a48..83e35f9 100644 --- a/dress/datasetgeneration/preprocessing/gtf_cache.py +++ b/dress/datasetgeneration/preprocessing/gtf_cache.py @@ -121,8 +121,8 @@ def preprocessing(data: pr.PyRanges, **kwargs): df=extracted, fasta=genome, extend_borders=100, - use_full_triplet=kwargs["use_full_triplet"], - use_model_resolution=kwargs["use_model_resolution"], + extract_full_triplet=kwargs["extract_full_triplet"], + extract_dynamically=kwargs["extract_dynamically"], model = kwargs["model"] ) @@ -158,10 +158,10 @@ def write_output( Additional arguments in **kwargs: outdir (str): Output directory. outbasename (str): Output basename. - use_full_triplet (bool): Whether to use the full exon triplet as input sequence + extract_full_triplet (bool): Whether to use the full exon triplet as input sequence when making model inferences. - use_model_resolution (bool): Whether to use the model resolution to determine input - sequence size when making model inferences. + extract_dynamically (bool): Whether to extract input sequence dynamically based on + triplet length and model resolution. """ to_write = { @@ -200,10 +200,10 @@ def write_output( ) out_flag = '' - if kwargs['use_full_triplet']: + if kwargs['extract_full_triplet']: out_flag = "_full_triplet" - elif kwargs['use_model_resolution']: - out_flag = "_model_res" + elif kwargs['extract_dynamically']: + out_flag = "_dynamically" if len(extracted_with_seqs) > 0: extracted_with_seqs[ diff --git a/dress/datasetgeneration/preprocessing/utils.py b/dress/datasetgeneration/preprocessing/utils.py index a972aa6..93589a5 100644 --- a/dress/datasetgeneration/preprocessing/utils.py +++ b/dress/datasetgeneration/preprocessing/utils.py @@ -79,7 +79,7 @@ def process_ss_idx( def _get_flat_ss( - info: pd.Series, _level: str, start: int, end: int, use_full_tiplet: bool + info: pd.Series, _level: str, start: int, end: int, extract_full_triplet: bool ): """ Extracts flat splice site indexes for sequences @@ -96,7 +96,7 @@ def _get_flat_ss( extensions. end (int): End coordinate of the flat sequence, after accounting for the extensions. - use_full_tiplet (bool, optional): Whether `start` and `end` + extract_full_triplet (bool, optional): Whether `start` and `end` coordinates represent the true start and end of the sequence at a given surrounding level. If `False`, they represent start and end positions up to the @@ -109,7 +109,7 @@ def _get_flat_ss( _info = info.copy() # Check out of scope indexes - if use_full_tiplet is False: + if extract_full_triplet is False: if _level != 0: cols = [ "Start_upstream" + _level, @@ -443,8 +443,8 @@ def generate_pipeline_input( df: pd.DataFrame, fasta: str, extend_borders: int = 0, - use_full_triplet: bool = False, - use_model_resolution: bool = True, + extract_full_triplet: bool = False, + extract_dynamically: bool = False, model: str = "spliceai", ): """ @@ -457,22 +457,23 @@ def generate_pipeline_input( associated transcript ID. Sequence extraction can be done in three ways: - - `use_full_triplet` is `True`: The complete sequence + - `extract_full_triplet` is `True`: The complete sequence from the start of the exon upstream until the end of the exon downstream is extracted, regardless of the size of the resulting sequence. This can result in very large sequences if introns are very long. - - `use_model_resolution` is `True`: The sequence is - extracted up to the maximum resolution of the model. For example, - for SpliceAI, the extracted sequence will be the cassette exon plus - 5000bp upstream of the acceptor and 5000bp downstream of the donor. - - `use_model_resolution` is `False` and `use_full_triplet` is `False`: - If neither of the above is set (default), the sequence will be extracted + - `extract_dynamically` is `True`: The sequence will be extracted such that: if the full exon triplet is smaller than the model resolution, the full exon triplet is extracted, and then at inference time the sequence is padded to the model resolution. If the exon triplet is larger than the model resolution, the sequence is trimmed at the model resolution, which may not include upstream and/or downstream exons. + + - `extract_dynamically` is `False` and `extract_full_triplet` is `False`: + If neither of the above is set (default), the sequence is + extracted up to the maximum resolution of the model. For example, + for SpliceAI, the extracted sequence will be the cassette exon plus + 5000bp upstream of the acceptor and 5000bp downstream of the donor. If `extend_borders` > 0, the sequence will be extended on both sides. This is useful for cases where the acceptor of the uptream exon @@ -481,9 +482,8 @@ def generate_pipeline_input( may not be properly captured (because the full context is not present), extending the sequence can help to provide a more realistic prediction as if those positions were in the middle of the sequence. This extension - is applied when `use_full_triplet` is `True` or in the default setting - (`use_model_resolution=False` and `use_full_triplet=False`) when the - sequence is shorter than the model resolution. + is applied when `extract_full_triplet` or `extract_dynamically`is `True`, + the latter only if the upstream/downstream exon is within the model resolution. Returns: Tuples with the following information: @@ -493,8 +493,8 @@ def generate_pipeline_input( """ assert any( - x is False for x in [use_full_triplet, use_model_resolution] - ), "Can't set both `use_full_triplet` and `use_model_resolution` to `True`." + x is False for x in [extract_full_triplet, extract_dynamically] + ), "Can't set both `extract_full_triplet` and `extract_dynamically` to `True`." if list(df.filter(regex="stream")): # Extract level so that we know the @@ -524,7 +524,7 @@ def generate_pipeline_input( for _, seq_record in df.iterrows(): - if use_full_triplet: + if extract_full_triplet: if seq_record.Strand == "+": start = "Start_upstream" + _level if level != 0 else "Start" end = "End_downstream" + _level if level != 0 else "End" @@ -561,7 +561,7 @@ def generate_pipeline_input( _level, start=seq_record[start] - extend_borders, end=seq_record[end] + extend_borders, - use_full_tiplet=use_full_triplet, + extract_full_triplet=extract_full_triplet, ) out.append( @@ -583,7 +583,7 @@ def generate_pipeline_input( region="upstream", _level=_level, extend_borders=extend_borders, - use_model_resolution=use_model_resolution, + no_model_resolution=extract_dynamically, model=model, ) @@ -592,7 +592,7 @@ def generate_pipeline_input( region="downstream", _level=_level, extend_borders=extend_borders, - use_model_resolution=use_model_resolution, + no_model_resolution=extract_dynamically, model=model, ) @@ -629,7 +629,7 @@ def generate_pipeline_input( _level, start=min(left, right), end=max(left, right), - use_full_tiplet=use_full_triplet, + extract_full_triplet=extract_full_triplet, ) out.append( @@ -676,20 +676,20 @@ def _get_slack( region: Literal["upstream", "downstream"], _level: str, extend_borders: int, - use_model_resolution: bool, + no_model_resolution: bool, model: str ): """ Returns the number of base pairs to extend coordinates surrounding the central exon. - If `use_model_resolution` is `True` returns the number of - base pairs corresponding to the model resolution - (e.g., 5000bp on each side for SpliceAI). If `False`, - returns the number of base pairs up to location of the - upstream or downstream exons, if their distance is lower + If `no_model_resolution` is `True` returns the number + of base pairs up to location of the upstream or + downstream exons, if their distance is lower than the model resolution. If the distance is higher, - returns the distance to the model resolution. + returns the distance to the model resolution. If `False`, + returns the number of base pairs corresponding to the model + resolution (e.g., 5000bp on each side for SpliceAI). Args: seq_record (pd.Series): A row from the input dataframe @@ -697,7 +697,8 @@ def _get_slack( _level (str): The level to be considered extend_borders (int): The number of base pairs to extend the coordinates, regardless of the region - use_model_resolution (bool): Whether to use the model resolution + no_model_resolution (bool): Whether to avoid using the model resolution + to extract distances model (str): The model to be used Returns: @@ -705,7 +706,7 @@ def _get_slack( """ models = {"spliceai": 5000, "pangolin": 5000} model_res = models[model] - if use_model_resolution: + if no_model_resolution is False: return model_res if region == "upstream": diff --git a/dress/datasetgeneration/run.py b/dress/datasetgeneration/run.py index 3168b9d..75373a6 100755 --- a/dress/datasetgeneration/run.py +++ b/dress/datasetgeneration/run.py @@ -599,28 +599,26 @@ def grammar_options(fun): "config file is presented in 'dress/configs/generate.yaml' file.", ) @click.option( - "-umr", - "--use_model_resolution", + "-esd", + "--extract_dynamically", is_flag=True, - help="Whether to use the model resolution for extracting and predicting the full sequence. " - "If set, for example with '--model spliceai', the input sequence will include the " - "target exon plus 5000bp on either side. This ensures that both the acceptor and " - "donor positions are evaluated within the real genomic context, regardless of " - "whether surrounding exons fall within or outside the model resolution. " - "Default: 'False'. When not set (default), the surrounding context of the target exon is " - "dynamically extracted. If upstream/downstream exons are within the model resolution, " - "sequences are shorter and padded during inference. If they lie beyond the model " - "resolution, the sequence is trimmed to the model resolution. Cannot be used with '--use_full_triplet'.", + help="Extract sequences dynamically based on the exon triplet size and the model resolution. " + "If upstream/downstream exons are within the model resolution, sequences will be shorter and padded " + "during inference. If they lie beyond the model resolution, sequences are trimmed to the model resolution. " + "Default: use the model resolution for extracting surrounding context and making inferences. For example, " + "with '--model spliceai', the default setting will include the target exon plus 5000bp on either side. This ensures that both the acceptor and donor positions are evaluated within the real genomic context, regardless " + "of whether surrounding exons fall within or outside the 5000bp." ) @click.option( - "-ufs", - "--use_full_triplet", + "-efs", + "--extract_full_triplet", is_flag=True, help=( - "Whether to extract and predict the full sequence from the start coordinate of the upstream " + "Extract and predict the full sequences from the start coordinate of the upstream " "exon to the end coordinate of the downstream exon, regardless of the model resolution. " "Default: 'False'. Use this option with caution, as it can easily lead to memory exhaustion " - "if the sequence triplet size is too large (e.g., large introns). Cannot be used with '--use_model_resolution'." + "if the sequence triplet size is too large (e.g., large introns). Default: use the model resolution, " + "as described in '--extract_dynamically'." ), ) @click.option( diff --git a/dress/datasetgeneration/validate_args.py b/dress/datasetgeneration/validate_args.py index b181b50..6df7b80 100644 --- a/dress/datasetgeneration/validate_args.py +++ b/dress/datasetgeneration/validate_args.py @@ -15,8 +15,8 @@ "--cache_dir", "--genome", "--config", - "--use_model_resolution", - "--use_full_triplet" + "--extract_dynamically", + "--extract_full_triplet" ], }, { diff --git a/tests/preprocessing_raw_test.py b/tests/preprocessing_raw_test.py index 2812bf8..93c8130 100644 --- a/tests/preprocessing_raw_test.py +++ b/tests/preprocessing_raw_test.py @@ -86,8 +86,8 @@ def test_full_sequence(self, setup_paths, initial_setup, common_tests): df=extracted, fasta=open_fasta(genome, cache_dir), extend_borders=100, - use_full_triplet=True, - use_model_resolution=False, + extract_full_triplet=True, + extract_dynamically=False, ) assert na_exons.shape[0] == 0 @@ -148,8 +148,8 @@ def test_trimmed_sequence(self, setup_paths, initial_setup, common_tests): df=extracted, fasta=open_fasta(genome, cache_dir), extend_borders=100, - use_full_triplet=False, - use_model_resolution=False, + extract_full_triplet=False, + extract_dynamically=True, ) len_test = { @@ -231,8 +231,8 @@ def test_model_resolution(self, setup_paths, initial_setup, common_tests): df=extracted, fasta=open_fasta(genome, cache_dir), extend_borders=100, - use_full_triplet=False, - use_model_resolution=True, + extract_full_triplet=False, + extract_dynamically=False, model="spliceai", )