From de9bf41b00125aeb0944dc1cc18f7e13746e0df1 Mon Sep 17 00:00:00 2001 From: Patricia Arakawa Date: Sat, 11 Jan 2025 16:55:04 +0100 Subject: [PATCH] add pipeline --- .gitignore | 7 ++ README.md | 0 main.py | 35 +++++++ requirements.txt | 15 +++ src/__init__.py | 0 src/data_processing.py | 186 +++++++++++++++++++++++++++++++++++++ src/encoding.py | 51 ++++++++++ src/kmer_signatures.py | 34 +++++++ src/model_training.py | 59 ++++++++++++ src/sequence_operations.py | 25 +++++ src/utils.py | 10 ++ 11 files changed, 422 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 main.py create mode 100644 requirements.txt create mode 100644 src/__init__.py create mode 100644 src/data_processing.py create mode 100644 src/encoding.py create mode 100644 src/kmer_signatures.py create mode 100644 src/model_training.py create mode 100644 src/sequence_operations.py create mode 100644 src/utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6fa2352 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +**/.DS_Store +**/__pycache__/ +all_recommended_organisms.pkl +df_organisms_selection.pkl +models +test_genes +grid_search_results.csv \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py new file mode 100644 index 0000000..9c8632b --- /dev/null +++ b/main.py @@ -0,0 +1,35 @@ +import os +import pandas as pd +from src.model_training import evaluate_and_save_model +from src.utils import load_dataframe + +print("loading data...") +# load data +df = load_dataframe("df_organisms_selection.pkl") + +chunk_sizes = [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000] +ks = [2, 3, 4, 5, 6, 7, 8] +rf_params = {"n_estimators": 100, "max_depth": None, "random_state": 42} + +# create directories for saving outputs +os.makedirs("models", exist_ok=True) +os.makedirs("test_genes", exist_ok=True) + +# train models for all parameter combinations +results = [] +for chunk_size in chunk_sizes: + for k in ks: + accuracy = evaluate_and_save_model( + df, + chunk_size, + k, + rf_params, + overlap=0, + model_dir="models", + test_gene_dir="test_genes" + ) + results.append({"chunk_size": chunk_size, "k": k, "accuracy": accuracy}) + +# save grid search results +results_df = pd.DataFrame(results) +results_df.to_csv("grid_search_results.csv", index=False) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2936b26 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +biopython @ file:///Users/runner/miniforge3/conda-bld/biopython_1720014912120/work +colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work +joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1733736026804/work +numpy @ file:///Users/runner/miniforge3/conda-bld/numpy_1734904295467/work/dist/numpy-2.2.1-cp312-cp312-macosx_11_0_arm64.whl#sha256=db59a85bf3c4ae6cff04ba4cf7a70f066547c91c14f592eb291ef71ba81f5da8 +pandas @ file:///Users/runner/miniforge3/conda-bld/pandas_1726878422361/work +python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work +pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1706886791323/work +scikit-learn @ file:///Users/runner/miniforge3/conda-bld/scikit-learn_1736496824048/work/dist/scikit_learn-1.6.1-cp312-cp312-macosx_11_0_arm64.whl#sha256=b73b4426f1e03a197903eddede31a62b6859e1eb5d298b145d003d742720c9f7 +scipy @ file:///Users/runner/miniforge3/conda-bld/scipy-split_1736351819767/work/dist/scipy-1.15.0-cp312-cp312-macosx_11_0_arm64.whl#sha256=073562fec8c5aba27fbbb060efef7cbd039c0c5ea33e65214d07f80e8571e016 +setuptools==75.8.0 +six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work +threadpoolctl @ file:///home/conda/feedstock_root/build_artifacts/threadpoolctl_1714400101435/work +tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1735661334605/work +tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1733235305708/work +wheel==0.45.1 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data_processing.py b/src/data_processing.py new file mode 100644 index 0000000..3bc6432 --- /dev/null +++ b/src/data_processing.py @@ -0,0 +1,186 @@ +from Bio import Entrez +import pandas as pd +from utils import save_dataframe, load_dataframe +import random + +# list of organisms +organisms = [ + "Aeropyrum pernix K1", + "Archaeoglobus fulgidus DSM 4304", + "Archaeoglobus profundus DSM 5631", + "Caldivirga maquilingensis IC-167", + "Candidatus Korarchaeum cryptofilum OPF8", + "Candidatus Methanoregula boonei 6A8", + "Candidatus Methanosphaerula palustris E1-9c", + "Cenarchaeum symbiosum A", + "Desulfurococcus kamchatkensis 1221n", + "Haloarcula marismortui ATCC 43049", + "Haloarcula marismortui ATCC 43049", + "Halobacterium sp. NRC-1", + "Halomicrobium mukohataei DSM 12286", + "Haloquadratum walsbyi DSM 16790", + "Halorhabdus utahensis DSM 12940", + "Halorubrum lacusprofundi ATCC 49239", + "Halorubrum lacusprofundi ATCC 49239", + "Haloterrigena turkmenica DSM 5511", + "Hyperthermus butylicus DSM 5456", + "Ignicoccus hospitalis KIN4/I", + "Metallosphaera sedula DSM 5348", + "Methanobrevibacter ruminantium M1", + "Methanocaldococcus fervens AG86", + "Methanocella paludicola SANAE", + "Methanococcoides burtonii DSM 6242", + "Methanococcus aeolicus Nankai-3", + "Methanococcus maripaludis C6", + "Methanococcus vannielii SB", + "Methanocorpusculum labreanum Z", + "Methanoculleus marisnigri JR1", + "Methanopyrus kandleri AV19", + "Methanosaeta thermophila PT", + "Methanosarcina acetivorans C2A", + "Methanosarcina barkeri str. Fusaro", + "Methanosarcina mazei Go1", + "Methanosphaera stadtmanae DSM 3091", + "Methanospirillum hungatei JF-1", + "Nanoarchaeum equitans Kin4-M", + "Natronomonas pharaonis DSM 2160", + "Nitrosopumilus maritimus SCM1", + "Picrophilus torridus DSM 9790", + "Pyrobaculum aerophilum str. IM2", + "Pyrobaculum arsenaticum DSM 13514", + "Pyrococcus abyssi GE5", + "Pyrococcus furiosus DSM 3638", + "Pyrococcus horikoshii OT3", + "Staphylothermus marinus F1", + "Sulfolobus acidocaldarius DSM 639", + "Sulfolobus solfataricus P2", + "Thermococcus gammatolerans EJ3", + "Thermofilum pendens Hrk 5", + "Thermoplasma acidophilum DSM 1728", + "Thermoplasma volcanium GSS1", + "Thermoproteus neutrophilus V24Sta", + "Acholeplasma laidlawii PG-8A", + "Acidobacterium capsulatum ATCC 51196", + "Akkermansia muciniphila ATCC BAA-835", + "Alicyclobacillus acidocaldarius subsp. acidocaldarius DSM 446", + "Aquifex aeolicus VF5", + "Bacillus cereus Q1", + "Bacillus pseudofirmus OF4", + "Bacteroides fragilis YCH46", + "Bdellovibrio bacteriovorus HD100", + "Bordetella pertussis Tohama I", + "Borrelia burgdorferi B31", + "Campylobacter jejuni subsp. jejuni 81-176", + "Candidatus Amoebophilus asiaticus 5a2", + "Candidatus Cloacamonas acidaminovorans", + "Candidatus Endomicrobium sp. Rs-D17", + "Carboxydothermus hydrogenoformans Z-2901", + "Chlamydia trachomatis 434/Bu", + "Chlorobium chlorochromatii CaD3", + "Chloroflexus aurantiacus J-10-fl", + "Clostridium acetobutylicum ATCC 824", + "Corynebacterium glutamicum ATCC 13032", + "Coxiella burnetii RSA 493", + "Cupriavidus taiwanensis", + "Cupriavidus taiwanensis", + "Cyanothece sp. ATCC 51142", + "Cyanothece sp. ATCC 51142", + "Dehalococcoides ethenogenes 195", + "Deinococcus radiodurans R1", + "Deinococcus radiodurans R1", + "Dictyoglomus thermophilum H-6-12", + "Elusimicrobium minutum Pei191", + "Fibrobacter succinogenes subsp. succinogenes S85", + "Flavobacterium psychrophilum JIP02/86", + "Fusobacterium nucleatum subsp. nucleatum ATCC 25586", + "Gemmata obscuriglobus UQM 2246", + "Gemmatimonas aurantiaca T-27", + "Gloeobacter violaceus PCC 7421", + "Leptospira interrogans serovar Lai str. 56601", + "Leptospira interrogans serovar Lai str. 56601", + "Magnetococcus sp. MC-1", + "Methylacidiphilum infernorum V4", + "Mycoplasma genitalium G37", + "Nostoc punctiforme PCC 73102", + "Opitutus terrae PB90-1", + "Pedobacter heparinus DSM 2366", + "Pirellula staleyi DSM 6068", + "Prochlorococcus marinus str. AS9601", + "Psychrobacter arcticus 273-4", + "Rhizobium leguminosarum bv. trifolii WSM1325", + "Rhodopirellula baltica SH 1", + "Rhodospirillum rubrum ATCC 11170", + "Rickettsia rickettsii str. Iowa", + "Shewanella putrefaciens CN-32", + "Solibacter usitatus Ellin6076", + "Synechococcus elongatus PCC 6301", + "Thermanaerovibrio acidaminovorans DSM 6589", + "Thermoanaerobacter tengcongensis MB4", + "Thermobaculum terrenum ATCC BAA-798", + "Thermobaculum terrenum ATCC BAA-798", + "Thermodesulfovibrio yellowstonii DSM 11347", + "Thermomicrobium roseum DSM 5159", + "Thermotoga maritima MSB8" +] + +# list to store results +genomes_data = [] + +# function to download genomes +def download_genome(organism_name): + try: + # search for genome using the organism name + search_handle = Entrez.esearch(db="nucleotide", term=organism_name, retmax=1) + search_results = Entrez.read(search_handle) + search_handle.close() + + # get the id of the first result + id_list = search_results["IdList"] + if not id_list: + print(f"no genome found for {organism_name}") + return None + + # fetch the genome in fasta format + genome_id = id_list[0] + fetch_handle = Entrez.efetch(db="nucleotide", id=genome_id, rettype="fasta", retmode="text") + genome_data = fetch_handle.read() + fetch_handle.close() + + # clean fasta header + genome_sequence = genome_data.split("\n", 1)[1].replace("\n", "") # remove header and line breaks + + # return organism name and genome sequence + return organism_name, genome_sequence + except Exception as e: + print(f"error fetching genome for {organism_name}: {e}") + return None + +# download genomes for each organism +for organism in organisms: + print(f"download organism {organism}") + result = download_genome(organism) + if result: + genomes_data.append(result) + +# create a dataframe with the results +all_recommended_organisms = pd.DataFrame(genomes_data, columns=["organism name", "genome sequence"]) + +# save the original dataframe +save_dataframe(all_recommended_organisms, "all_recommended_organisms.pkl") + +# function to filter organisms with invalid sequences +def delete_bad_organisms(df): + mask = df["genome sequence"].str.fullmatch(r"[ACGT]*") # validate only a, c, g, t + filtered_df = df[mask].reset_index(drop=True) + return filtered_df + +# filter the organisms +all_organisms_filtered = delete_bad_organisms(all_recommended_organisms) + +# select 10 organisms randomly +df_organisms_selection = all_organisms_filtered.sample(n=10, random_state=42) + +# save the selected organisms +save_dataframe(df_organisms_selection, "df_organisms_selection2.pkl") + + diff --git a/src/encoding.py b/src/encoding.py new file mode 100644 index 0000000..821de25 --- /dev/null +++ b/src/encoding.py @@ -0,0 +1,51 @@ +from itertools import product + +def encode_kmer(kmer): + """convert a k-mer to its integer representation""" + encoding = {'A': 0b00, 'C': 0b01, 'G': 0b10, 'T': 0b11} + result = 0 + for base in kmer: + result = (result << 2) | encoding[base] + return result + +def decode_kmer(encoded_kmer, k): + """convert an encoded k-mer back to its string representation""" + decoding = ['A', 'C', 'G', 'T'] + kmer = [] + for _ in range(k): + base = encoded_kmer & 0b11 # take the last 2 bits + kmer.append(decoding[base]) + encoded_kmer >>= 2 + return ''.join(reversed(kmer)) + +def reverse_complement_bits(kmer, k): + """compute the reverse complement of a k-mer encoded as an integer""" + mask = (1 << (2 * k)) - 1 # mask to keep only 2k bits + complement = 0 + for _ in range(k): + complement = (complement << 2) | ((kmer & 0b11) ^ 0b11) # complement the base + kmer >>= 2 + return complement & mask # ensure there are no extra bits + +def canonical_kmer_bits(kmer, k): + """return the canonical k-mer in its integer form""" + rev_comp = reverse_complement_bits(kmer, k) + return min(kmer, rev_comp) + +def generate_all_canonical_kmers_bits(k): + """generate all canonical k-mers of length k as integers""" + canonical_kmers = set() + for kmer_tuple in product([0b00, 0b01, 0b10, 0b11], repeat=k): + kmer = 0 + for base in kmer_tuple: + kmer = (kmer << 2) | base + canonical_kmers.add(canonical_kmer_bits(kmer, k)) + return sorted(canonical_kmers) + + +if False: + sequence = "ACGTTCGACG" + k = 3 + signature = compute_kmer_signature(sequence, k) + print(f"k-mer signature vector (k={k}): {signature}") + canonical_kmers = generate_all_canonical_kmers(k) diff --git a/src/kmer_signatures.py b/src/kmer_signatures.py new file mode 100644 index 0000000..0622164 --- /dev/null +++ b/src/kmer_signatures.py @@ -0,0 +1,34 @@ +from collections import defaultdict +from tqdm import tqdm +from .encoding import encode_kmer, canonical_kmer_bits, generate_all_canonical_kmers_bits + +def compute_kmer_signature_bits(sequence, k=2): + """ + compute the k-mer signature vector for a sequence + uses bitwise encoding to handle k-mers efficiently + """ + if len(sequence) < k: + raise ValueError("Sequence length must be at least k.") + + # dictionary to store frequencies (default value is 0) + kmer_counts = defaultdict(int) + + # encode the entire sequence into bits + encoded_sequence = encode_kmer(sequence) + + # mask to extract k-mers + mask = (1 << (2 * k)) - 1 # mask for the last 2k bits + for i in range(len(sequence) - k + 1): + kmer = (encoded_sequence >> (2 * (len(sequence) - k - i))) & mask + canonical = canonical_kmer_bits(kmer, k) + kmer_counts[canonical] += 1 + + # generate the signature vector + possible_kmers = generate_all_canonical_kmers_bits(k) + signature_vector = [kmer_counts[kmer] for kmer in possible_kmers] + + return signature_vector + +def signature_for_all_genes(X, k): + """compute k-mer signatures for all sequences""" + return [compute_kmer_signature_bits(seq, k) for seq in tqdm(X)] diff --git a/src/model_training.py b/src/model_training.py new file mode 100644 index 0000000..4d7ab8d --- /dev/null +++ b/src/model_training.py @@ -0,0 +1,59 @@ +import os +import joblib +import pandas as pd +import numpy as np +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score + +def evaluate_and_save_model(df, chunk_size, k, rf_params, overlap, model_dir, test_gene_dir): + """ + train and save a random forest model for a specific parameter combination + also saves test genes in their raw form (before signatures) + """ + from .sequence_operations import split_gens_over_all_genomes + from .kmer_signatures import signature_for_all_genes + + print(f"Training model for chunk_size={chunk_size}, k={k}...") + + # extract genes from genomes + X_raw, y = split_gens_over_all_genomes(df, chunk_size=chunk_size, overlap=overlap) + print(f"Generated genes: {len(X_raw)}") + + # sign the genes with k-mers + X_signature = signature_for_all_genes(X_raw, k) + print("Generated signatures.") + + # split into training and testing sets + X_train_sig, X_test_sig, y_train, y_test, X_train_raw, X_test_raw = train_test_split( + X_signature, y, X_raw, test_size=0.2, random_state=42 + ) + + # train random forest model + rf_model = RandomForestClassifier(**rf_params) + rf_model.fit(X_train_sig, y_train) + + # evaluate the model + y_pred = rf_model.predict(X_test_sig) + accuracy = accuracy_score(y_test, y_pred) + print(f"Accuracy for chunk_size={chunk_size}, k={k}: {accuracy}") + + # save the model + model_path = os.path.join(model_dir, f"model_chunk{chunk_size}_k{k}.joblib") + joblib.dump(rf_model, model_path) + print(f"Model saved to {model_path}") + + # select random test genes and save them + selected_indices = np.random.choice(len(X_test_raw), size=10, replace=False) + selected_genes = [X_test_raw[i] for i in selected_indices] + selected_labels = [y_test[i] for i in selected_indices] + + test_genes_df = pd.DataFrame({ + "gene": selected_genes, + "label": selected_labels + }) + test_gene_path = os.path.join(test_gene_dir, f"test_genes_chunk{chunk_size}_k{k}.csv") + test_genes_df.to_csv(test_gene_path, index=False) + print(f"Test genes saved to {test_gene_path}") + + return accuracy diff --git a/src/sequence_operations.py b/src/sequence_operations.py new file mode 100644 index 0000000..6f9e411 --- /dev/null +++ b/src/sequence_operations.py @@ -0,0 +1,25 @@ +def split_sequence_alternate_overlap(sequence, chunk_size=500, overlap=100): + fragments = [] + i = 0 + toggle_overlap = True + while i <= len(sequence) - chunk_size: + fragments.append(sequence[i:i + chunk_size]) + if toggle_overlap: + i += chunk_size - overlap # overlapping + else: + i += chunk_size # no overlapping + toggle_overlap = not toggle_overlap + return fragments + +def split_gens_over_all_genomes(df, chunk_size=500, overlap=100): + """split all genomes in the dataframe into smaller fragments""" + X = [] + y = [] + for _, row in df.iterrows(): + sequence = row["genome sequence"] + organism_name = row["organism name"] + + gens_for_sequence = split_sequence_alternate_overlap(sequence, chunk_size, overlap=overlap) + X.extend(gens_for_sequence) + y.extend([organism_name] * len(gens_for_sequence)) + return X, y diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..1669583 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,10 @@ +import pickle + +def save_dataframe(df, filename): + with open(filename, 'wb') as file: + pickle.dump(df, file) + + +def load_dataframe(filename): + with open(filename, 'rb') as file: + return pickle.load(file)