Skip to content

Commit

Permalink
add pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
parakawa committed Jan 11, 2025
0 parents commit de9bf41
Show file tree
Hide file tree
Showing 11 changed files with 422 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
**/.DS_Store
**/__pycache__/
all_recommended_organisms.pkl
df_organisms_selection.pkl
models
test_genes
grid_search_results.csv
Empty file added README.md
Empty file.
35 changes: 35 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
import pandas as pd
from src.model_training import evaluate_and_save_model
from src.utils import load_dataframe

print("loading data...")
# load data
df = load_dataframe("df_organisms_selection.pkl")

chunk_sizes = [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000]
ks = [2, 3, 4, 5, 6, 7, 8]
rf_params = {"n_estimators": 100, "max_depth": None, "random_state": 42}

# create directories for saving outputs
os.makedirs("models", exist_ok=True)
os.makedirs("test_genes", exist_ok=True)

# train models for all parameter combinations
results = []
for chunk_size in chunk_sizes:
for k in ks:
accuracy = evaluate_and_save_model(
df,
chunk_size,
k,
rf_params,
overlap=0,
model_dir="models",
test_gene_dir="test_genes"
)
results.append({"chunk_size": chunk_size, "k": k, "accuracy": accuracy})

# save grid search results
results_df = pd.DataFrame(results)
results_df.to_csv("grid_search_results.csv", index=False)
15 changes: 15 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
biopython @ file:///Users/runner/miniforge3/conda-bld/biopython_1720014912120/work
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work
joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1733736026804/work
numpy @ file:///Users/runner/miniforge3/conda-bld/numpy_1734904295467/work/dist/numpy-2.2.1-cp312-cp312-macosx_11_0_arm64.whl#sha256=db59a85bf3c4ae6cff04ba4cf7a70f066547c91c14f592eb291ef71ba81f5da8
pandas @ file:///Users/runner/miniforge3/conda-bld/pandas_1726878422361/work
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1706886791323/work
scikit-learn @ file:///Users/runner/miniforge3/conda-bld/scikit-learn_1736496824048/work/dist/scikit_learn-1.6.1-cp312-cp312-macosx_11_0_arm64.whl#sha256=b73b4426f1e03a197903eddede31a62b6859e1eb5d298b145d003d742720c9f7
scipy @ file:///Users/runner/miniforge3/conda-bld/scipy-split_1736351819767/work/dist/scipy-1.15.0-cp312-cp312-macosx_11_0_arm64.whl#sha256=073562fec8c5aba27fbbb060efef7cbd039c0c5ea33e65214d07f80e8571e016
setuptools==75.8.0
six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work
threadpoolctl @ file:///home/conda/feedstock_root/build_artifacts/threadpoolctl_1714400101435/work
tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1735661334605/work
tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1733235305708/work
wheel==0.45.1
Empty file added src/__init__.py
Empty file.
186 changes: 186 additions & 0 deletions src/data_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from Bio import Entrez
import pandas as pd
from utils import save_dataframe, load_dataframe
import random

# list of organisms
organisms = [
"Aeropyrum pernix K1",
"Archaeoglobus fulgidus DSM 4304",
"Archaeoglobus profundus DSM 5631",
"Caldivirga maquilingensis IC-167",
"Candidatus Korarchaeum cryptofilum OPF8",
"Candidatus Methanoregula boonei 6A8",
"Candidatus Methanosphaerula palustris E1-9c",
"Cenarchaeum symbiosum A",
"Desulfurococcus kamchatkensis 1221n",
"Haloarcula marismortui ATCC 43049",
"Haloarcula marismortui ATCC 43049",
"Halobacterium sp. NRC-1",
"Halomicrobium mukohataei DSM 12286",
"Haloquadratum walsbyi DSM 16790",
"Halorhabdus utahensis DSM 12940",
"Halorubrum lacusprofundi ATCC 49239",
"Halorubrum lacusprofundi ATCC 49239",
"Haloterrigena turkmenica DSM 5511",
"Hyperthermus butylicus DSM 5456",
"Ignicoccus hospitalis KIN4/I",
"Metallosphaera sedula DSM 5348",
"Methanobrevibacter ruminantium M1",
"Methanocaldococcus fervens AG86",
"Methanocella paludicola SANAE",
"Methanococcoides burtonii DSM 6242",
"Methanococcus aeolicus Nankai-3",
"Methanococcus maripaludis C6",
"Methanococcus vannielii SB",
"Methanocorpusculum labreanum Z",
"Methanoculleus marisnigri JR1",
"Methanopyrus kandleri AV19",
"Methanosaeta thermophila PT",
"Methanosarcina acetivorans C2A",
"Methanosarcina barkeri str. Fusaro",
"Methanosarcina mazei Go1",
"Methanosphaera stadtmanae DSM 3091",
"Methanospirillum hungatei JF-1",
"Nanoarchaeum equitans Kin4-M",
"Natronomonas pharaonis DSM 2160",
"Nitrosopumilus maritimus SCM1",
"Picrophilus torridus DSM 9790",
"Pyrobaculum aerophilum str. IM2",
"Pyrobaculum arsenaticum DSM 13514",
"Pyrococcus abyssi GE5",
"Pyrococcus furiosus DSM 3638",
"Pyrococcus horikoshii OT3",
"Staphylothermus marinus F1",
"Sulfolobus acidocaldarius DSM 639",
"Sulfolobus solfataricus P2",
"Thermococcus gammatolerans EJ3",
"Thermofilum pendens Hrk 5",
"Thermoplasma acidophilum DSM 1728",
"Thermoplasma volcanium GSS1",
"Thermoproteus neutrophilus V24Sta",
"Acholeplasma laidlawii PG-8A",
"Acidobacterium capsulatum ATCC 51196",
"Akkermansia muciniphila ATCC BAA-835",
"Alicyclobacillus acidocaldarius subsp. acidocaldarius DSM 446",
"Aquifex aeolicus VF5",
"Bacillus cereus Q1",
"Bacillus pseudofirmus OF4",
"Bacteroides fragilis YCH46",
"Bdellovibrio bacteriovorus HD100",
"Bordetella pertussis Tohama I",
"Borrelia burgdorferi B31",
"Campylobacter jejuni subsp. jejuni 81-176",
"Candidatus Amoebophilus asiaticus 5a2",
"Candidatus Cloacamonas acidaminovorans",
"Candidatus Endomicrobium sp. Rs-D17",
"Carboxydothermus hydrogenoformans Z-2901",
"Chlamydia trachomatis 434/Bu",
"Chlorobium chlorochromatii CaD3",
"Chloroflexus aurantiacus J-10-fl",
"Clostridium acetobutylicum ATCC 824",
"Corynebacterium glutamicum ATCC 13032",
"Coxiella burnetii RSA 493",
"Cupriavidus taiwanensis",
"Cupriavidus taiwanensis",
"Cyanothece sp. ATCC 51142",
"Cyanothece sp. ATCC 51142",
"Dehalococcoides ethenogenes 195",
"Deinococcus radiodurans R1",
"Deinococcus radiodurans R1",
"Dictyoglomus thermophilum H-6-12",
"Elusimicrobium minutum Pei191",
"Fibrobacter succinogenes subsp. succinogenes S85",
"Flavobacterium psychrophilum JIP02/86",
"Fusobacterium nucleatum subsp. nucleatum ATCC 25586",
"Gemmata obscuriglobus UQM 2246",
"Gemmatimonas aurantiaca T-27",
"Gloeobacter violaceus PCC 7421",
"Leptospira interrogans serovar Lai str. 56601",
"Leptospira interrogans serovar Lai str. 56601",
"Magnetococcus sp. MC-1",
"Methylacidiphilum infernorum V4",
"Mycoplasma genitalium G37",
"Nostoc punctiforme PCC 73102",
"Opitutus terrae PB90-1",
"Pedobacter heparinus DSM 2366",
"Pirellula staleyi DSM 6068",
"Prochlorococcus marinus str. AS9601",
"Psychrobacter arcticus 273-4",
"Rhizobium leguminosarum bv. trifolii WSM1325",
"Rhodopirellula baltica SH 1",
"Rhodospirillum rubrum ATCC 11170",
"Rickettsia rickettsii str. Iowa",
"Shewanella putrefaciens CN-32",
"Solibacter usitatus Ellin6076",
"Synechococcus elongatus PCC 6301",
"Thermanaerovibrio acidaminovorans DSM 6589",
"Thermoanaerobacter tengcongensis MB4",
"Thermobaculum terrenum ATCC BAA-798",
"Thermobaculum terrenum ATCC BAA-798",
"Thermodesulfovibrio yellowstonii DSM 11347",
"Thermomicrobium roseum DSM 5159",
"Thermotoga maritima MSB8"
]

# list to store results
genomes_data = []

# function to download genomes
def download_genome(organism_name):
try:
# search for genome using the organism name
search_handle = Entrez.esearch(db="nucleotide", term=organism_name, retmax=1)
search_results = Entrez.read(search_handle)
search_handle.close()

# get the id of the first result
id_list = search_results["IdList"]
if not id_list:
print(f"no genome found for {organism_name}")
return None

# fetch the genome in fasta format
genome_id = id_list[0]
fetch_handle = Entrez.efetch(db="nucleotide", id=genome_id, rettype="fasta", retmode="text")
genome_data = fetch_handle.read()
fetch_handle.close()

# clean fasta header
genome_sequence = genome_data.split("\n", 1)[1].replace("\n", "") # remove header and line breaks

# return organism name and genome sequence
return organism_name, genome_sequence
except Exception as e:
print(f"error fetching genome for {organism_name}: {e}")
return None

# download genomes for each organism
for organism in organisms:
print(f"download organism {organism}")
result = download_genome(organism)
if result:
genomes_data.append(result)

# create a dataframe with the results
all_recommended_organisms = pd.DataFrame(genomes_data, columns=["organism name", "genome sequence"])

# save the original dataframe
save_dataframe(all_recommended_organisms, "all_recommended_organisms.pkl")

# function to filter organisms with invalid sequences
def delete_bad_organisms(df):
mask = df["genome sequence"].str.fullmatch(r"[ACGT]*") # validate only a, c, g, t
filtered_df = df[mask].reset_index(drop=True)
return filtered_df

# filter the organisms
all_organisms_filtered = delete_bad_organisms(all_recommended_organisms)

# select 10 organisms randomly
df_organisms_selection = all_organisms_filtered.sample(n=10, random_state=42)

# save the selected organisms
save_dataframe(df_organisms_selection, "df_organisms_selection2.pkl")


51 changes: 51 additions & 0 deletions src/encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from itertools import product

def encode_kmer(kmer):
"""convert a k-mer to its integer representation"""
encoding = {'A': 0b00, 'C': 0b01, 'G': 0b10, 'T': 0b11}
result = 0
for base in kmer:
result = (result << 2) | encoding[base]
return result

def decode_kmer(encoded_kmer, k):
"""convert an encoded k-mer back to its string representation"""
decoding = ['A', 'C', 'G', 'T']
kmer = []
for _ in range(k):
base = encoded_kmer & 0b11 # take the last 2 bits
kmer.append(decoding[base])
encoded_kmer >>= 2
return ''.join(reversed(kmer))

def reverse_complement_bits(kmer, k):
"""compute the reverse complement of a k-mer encoded as an integer"""
mask = (1 << (2 * k)) - 1 # mask to keep only 2k bits
complement = 0
for _ in range(k):
complement = (complement << 2) | ((kmer & 0b11) ^ 0b11) # complement the base
kmer >>= 2
return complement & mask # ensure there are no extra bits

def canonical_kmer_bits(kmer, k):
"""return the canonical k-mer in its integer form"""
rev_comp = reverse_complement_bits(kmer, k)
return min(kmer, rev_comp)

def generate_all_canonical_kmers_bits(k):
"""generate all canonical k-mers of length k as integers"""
canonical_kmers = set()
for kmer_tuple in product([0b00, 0b01, 0b10, 0b11], repeat=k):
kmer = 0
for base in kmer_tuple:
kmer = (kmer << 2) | base
canonical_kmers.add(canonical_kmer_bits(kmer, k))
return sorted(canonical_kmers)


if False:
sequence = "ACGTTCGACG"
k = 3
signature = compute_kmer_signature(sequence, k)
print(f"k-mer signature vector (k={k}): {signature}")
canonical_kmers = generate_all_canonical_kmers(k)
34 changes: 34 additions & 0 deletions src/kmer_signatures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from collections import defaultdict
from tqdm import tqdm
from .encoding import encode_kmer, canonical_kmer_bits, generate_all_canonical_kmers_bits

def compute_kmer_signature_bits(sequence, k=2):
"""
compute the k-mer signature vector for a sequence
uses bitwise encoding to handle k-mers efficiently
"""
if len(sequence) < k:
raise ValueError("Sequence length must be at least k.")

# dictionary to store frequencies (default value is 0)
kmer_counts = defaultdict(int)

# encode the entire sequence into bits
encoded_sequence = encode_kmer(sequence)

# mask to extract k-mers
mask = (1 << (2 * k)) - 1 # mask for the last 2k bits
for i in range(len(sequence) - k + 1):
kmer = (encoded_sequence >> (2 * (len(sequence) - k - i))) & mask
canonical = canonical_kmer_bits(kmer, k)
kmer_counts[canonical] += 1

# generate the signature vector
possible_kmers = generate_all_canonical_kmers_bits(k)
signature_vector = [kmer_counts[kmer] for kmer in possible_kmers]

return signature_vector

def signature_for_all_genes(X, k):
"""compute k-mer signatures for all sequences"""
return [compute_kmer_signature_bits(seq, k) for seq in tqdm(X)]
59 changes: 59 additions & 0 deletions src/model_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import joblib
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

def evaluate_and_save_model(df, chunk_size, k, rf_params, overlap, model_dir, test_gene_dir):
"""
train and save a random forest model for a specific parameter combination
also saves test genes in their raw form (before signatures)
"""
from .sequence_operations import split_gens_over_all_genomes
from .kmer_signatures import signature_for_all_genes

print(f"Training model for chunk_size={chunk_size}, k={k}...")

# extract genes from genomes
X_raw, y = split_gens_over_all_genomes(df, chunk_size=chunk_size, overlap=overlap)
print(f"Generated genes: {len(X_raw)}")

# sign the genes with k-mers
X_signature = signature_for_all_genes(X_raw, k)
print("Generated signatures.")

# split into training and testing sets
X_train_sig, X_test_sig, y_train, y_test, X_train_raw, X_test_raw = train_test_split(
X_signature, y, X_raw, test_size=0.2, random_state=42
)

# train random forest model
rf_model = RandomForestClassifier(**rf_params)
rf_model.fit(X_train_sig, y_train)

# evaluate the model
y_pred = rf_model.predict(X_test_sig)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy for chunk_size={chunk_size}, k={k}: {accuracy}")

# save the model
model_path = os.path.join(model_dir, f"model_chunk{chunk_size}_k{k}.joblib")
joblib.dump(rf_model, model_path)
print(f"Model saved to {model_path}")

# select random test genes and save them
selected_indices = np.random.choice(len(X_test_raw), size=10, replace=False)
selected_genes = [X_test_raw[i] for i in selected_indices]
selected_labels = [y_test[i] for i in selected_indices]

test_genes_df = pd.DataFrame({
"gene": selected_genes,
"label": selected_labels
})
test_gene_path = os.path.join(test_gene_dir, f"test_genes_chunk{chunk_size}_k{k}.csv")
test_genes_df.to_csv(test_gene_path, index=False)
print(f"Test genes saved to {test_gene_path}")

return accuracy
Loading

0 comments on commit de9bf41

Please sign in to comment.