Skip to content

Commit

Permalink
Add curl and metadata filter option
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-parker committed Dec 9, 2024
1 parent f1f9245 commit f77cd7a
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 0 deletions.
46 changes: 46 additions & 0 deletions ingest/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ NCBI_API_KEY = os.getenv("NCBI_API_KEY")
APPROVE_TIMEOUT_MIN = config.get("approve_timeout_min") # time in minutes
CHECK_ENA_DEPOSITION = config.get("check_ena_deposition", False)
ALIGN = True
FILTER = config.get("filter", None)

dataset_server_map = {}
dataset_name_map = {}
Expand Down Expand Up @@ -319,6 +320,38 @@ rule group_segments:
--log-level {params.log_level} \
"""

if FILTER:
rule filter_fasta_headers:
input:
metadata=(
"results/metadata_post_group.json"
if SEGMENTED
else "results/metadata_post_prepare.json"
),
sequences=(
"results/sequences_post_group.ndjson"
if SEGMENTED
else "results/sequences.ndjson"
),
script="scripts/filter.py",
config="results/config.yaml",
output:
sequences="results/sequences_filtered.fasta",
metadata="results/metadata_filtered.json",
params:
filter_fasta_headers=FILTER_FASTA_HEADERS,
log_level=LOG_LEVEL,
shell:
"""
python {input.script} \
--input-seq {input.sequences} \
--output-seq {output.results} \
--input-metadata {input.sequences} \
--output-metadata {output.results} \
--log-level {params.log_level} \
--config-file {input.config} \
"""


rule get_previous_submissions:
"""Download metadata and sequence hashes of all previously submitted sequences
Expand All @@ -340,6 +373,9 @@ rule get_previous_submissions:
# By delaying the start of the script
script="scripts/call_loculus.py",
prepped_metadata=(
"results/metadata_filtered.json"
if FILTER
else
"results/metadata_post_group.json"
if SEGMENTED
else "results/metadata_post_prepare.json"
Expand Down Expand Up @@ -367,6 +403,9 @@ rule compare_hashes:
config="results/config.yaml",
old_hashes="results/previous_submissions.json",
metadata=(
"results/metadata_filtered.json"
if FILTER
else
"results/metadata_post_group.json"
if SEGMENTED
else "results/metadata_post_prepare.json"
Expand Down Expand Up @@ -403,11 +442,18 @@ rule prepare_files:
script="scripts/prepare_files.py",
config="results/config.yaml",
metadata=(
"results/metadata_filtered.json"
if FILTER
else
"results/metadata_post_group.json"
if SEGMENTED
else "results/metadata_post_prepare.json"
),
metadata
sequences=(
"results/sequences_filtered.fasta"
if FILTER
else
"results/sequences_post_group.ndjson"
if SEGMENTED
else "results/sequences.ndjson"
Expand Down
1 change: 1 addition & 0 deletions ingest/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ dependencies:
- snakemake
- tsv-utils
- unzip
- curl
81 changes: 81 additions & 0 deletions ingest/scripts/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""filters sequences by metadata fields"""

import logging
from dataclasses import dataclass

import click
import pandas as pd
import yaml
from Bio import SeqIO


@dataclass
class FilterObjects:
name: str
value: str | int


@dataclass
class Config:
filter: list[FilterObjects]
nucleotide_sequences: list[str]
segmented: bool


logger = logging.getLogger(__name__)
logging.basicConfig(
encoding="utf-8",
level=logging.DEBUG,
format="%(asctime)s %(levelname)8s (%(filename)20s:%(lineno)4d) - %(message)s ",
datefmt="%H:%M:%S",
)


@click.command(help="Parse fasta header, only keep if fits regex filter_fasta_headers")
@click.option("--config-file", required=True, type=click.Path(exists=True))
@click.option("--input-seq", required=True, type=click.Path(exists=True))
@click.option("--output-seq", required=True, type=click.Path())
@click.option("--input-metadata", required=True, type=click.Path(exists=True))
@click.option("--output-metadata", required=True, type=click.Path())
@click.option(
"--log-level",
default="INFO",
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
)
def main(
config_file: str,
input_seq: str,
output_seq: str,
input_metadata: str,
output_metadata: str,
log_level: str,
) -> None:
logger.setLevel(log_level)
with open(config_file, encoding="utf-8") as file:
full_config = yaml.safe_load(file)
relevant_config = {key: full_config.get(key, []) for key in Config.__annotations__}
config = Config(**relevant_config)
config.filter = [FilterObjects(**filter) for filter in config.filter]
df = pd.read_csv(input_metadata, sep="\t", dtype=str, keep_default_na=False)
for filter in config.filter:
df = df[df[filter.name].str.contains(filter.value)]
submission_ids = df["submissionId"].tolist()
df.to_csv(output_metadata, sep="\t", index=False)
if not config.segmented:
with (
open(input, encoding="utf-8") as f_in,
open(output_seq, "a", encoding="utf-8") as f_out,
):
records = SeqIO.parse(f_in, "fasta")
for record in records:
if record.id in submission_ids:
SeqIO.write(record, f_out, "fasta")
return
with (
open(input, encoding="utf-8") as f_in,
open(output_seq, "a", encoding="utf-8") as f_out,
):
records = SeqIO.parse(f_in, "fasta")
for record in records:
if record.id.split("_")[:-1] in submission_ids:
SeqIO.write(record, f_out, "fasta")

0 comments on commit f77cd7a

Please sign in to comment.