forked from parakawa/genom_classify
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit de9bf41
Showing
11 changed files
with
422 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
**/.DS_Store | ||
**/__pycache__/ | ||
all_recommended_organisms.pkl | ||
df_organisms_selection.pkl | ||
models | ||
test_genes | ||
grid_search_results.csv |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import os | ||
import pandas as pd | ||
from src.model_training import evaluate_and_save_model | ||
from src.utils import load_dataframe | ||
|
||
print("loading data...") | ||
# load data | ||
df = load_dataframe("df_organisms_selection.pkl") | ||
|
||
chunk_sizes = [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000] | ||
ks = [2, 3, 4, 5, 6, 7, 8] | ||
rf_params = {"n_estimators": 100, "max_depth": None, "random_state": 42} | ||
|
||
# create directories for saving outputs | ||
os.makedirs("models", exist_ok=True) | ||
os.makedirs("test_genes", exist_ok=True) | ||
|
||
# train models for all parameter combinations | ||
results = [] | ||
for chunk_size in chunk_sizes: | ||
for k in ks: | ||
accuracy = evaluate_and_save_model( | ||
df, | ||
chunk_size, | ||
k, | ||
rf_params, | ||
overlap=0, | ||
model_dir="models", | ||
test_gene_dir="test_genes" | ||
) | ||
results.append({"chunk_size": chunk_size, "k": k, "accuracy": accuracy}) | ||
|
||
# save grid search results | ||
results_df = pd.DataFrame(results) | ||
results_df.to_csv("grid_search_results.csv", index=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
biopython @ file:///Users/runner/miniforge3/conda-bld/biopython_1720014912120/work | ||
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work | ||
joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1733736026804/work | ||
numpy @ file:///Users/runner/miniforge3/conda-bld/numpy_1734904295467/work/dist/numpy-2.2.1-cp312-cp312-macosx_11_0_arm64.whl#sha256=db59a85bf3c4ae6cff04ba4cf7a70f066547c91c14f592eb291ef71ba81f5da8 | ||
pandas @ file:///Users/runner/miniforge3/conda-bld/pandas_1726878422361/work | ||
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work | ||
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1706886791323/work | ||
scikit-learn @ file:///Users/runner/miniforge3/conda-bld/scikit-learn_1736496824048/work/dist/scikit_learn-1.6.1-cp312-cp312-macosx_11_0_arm64.whl#sha256=b73b4426f1e03a197903eddede31a62b6859e1eb5d298b145d003d742720c9f7 | ||
scipy @ file:///Users/runner/miniforge3/conda-bld/scipy-split_1736351819767/work/dist/scipy-1.15.0-cp312-cp312-macosx_11_0_arm64.whl#sha256=073562fec8c5aba27fbbb060efef7cbd039c0c5ea33e65214d07f80e8571e016 | ||
setuptools==75.8.0 | ||
six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work | ||
threadpoolctl @ file:///home/conda/feedstock_root/build_artifacts/threadpoolctl_1714400101435/work | ||
tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1735661334605/work | ||
tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1733235305708/work | ||
wheel==0.45.1 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
from Bio import Entrez | ||
import pandas as pd | ||
from utils import save_dataframe, load_dataframe | ||
import random | ||
|
||
# list of organisms | ||
organisms = [ | ||
"Aeropyrum pernix K1", | ||
"Archaeoglobus fulgidus DSM 4304", | ||
"Archaeoglobus profundus DSM 5631", | ||
"Caldivirga maquilingensis IC-167", | ||
"Candidatus Korarchaeum cryptofilum OPF8", | ||
"Candidatus Methanoregula boonei 6A8", | ||
"Candidatus Methanosphaerula palustris E1-9c", | ||
"Cenarchaeum symbiosum A", | ||
"Desulfurococcus kamchatkensis 1221n", | ||
"Haloarcula marismortui ATCC 43049", | ||
"Haloarcula marismortui ATCC 43049", | ||
"Halobacterium sp. NRC-1", | ||
"Halomicrobium mukohataei DSM 12286", | ||
"Haloquadratum walsbyi DSM 16790", | ||
"Halorhabdus utahensis DSM 12940", | ||
"Halorubrum lacusprofundi ATCC 49239", | ||
"Halorubrum lacusprofundi ATCC 49239", | ||
"Haloterrigena turkmenica DSM 5511", | ||
"Hyperthermus butylicus DSM 5456", | ||
"Ignicoccus hospitalis KIN4/I", | ||
"Metallosphaera sedula DSM 5348", | ||
"Methanobrevibacter ruminantium M1", | ||
"Methanocaldococcus fervens AG86", | ||
"Methanocella paludicola SANAE", | ||
"Methanococcoides burtonii DSM 6242", | ||
"Methanococcus aeolicus Nankai-3", | ||
"Methanococcus maripaludis C6", | ||
"Methanococcus vannielii SB", | ||
"Methanocorpusculum labreanum Z", | ||
"Methanoculleus marisnigri JR1", | ||
"Methanopyrus kandleri AV19", | ||
"Methanosaeta thermophila PT", | ||
"Methanosarcina acetivorans C2A", | ||
"Methanosarcina barkeri str. Fusaro", | ||
"Methanosarcina mazei Go1", | ||
"Methanosphaera stadtmanae DSM 3091", | ||
"Methanospirillum hungatei JF-1", | ||
"Nanoarchaeum equitans Kin4-M", | ||
"Natronomonas pharaonis DSM 2160", | ||
"Nitrosopumilus maritimus SCM1", | ||
"Picrophilus torridus DSM 9790", | ||
"Pyrobaculum aerophilum str. IM2", | ||
"Pyrobaculum arsenaticum DSM 13514", | ||
"Pyrococcus abyssi GE5", | ||
"Pyrococcus furiosus DSM 3638", | ||
"Pyrococcus horikoshii OT3", | ||
"Staphylothermus marinus F1", | ||
"Sulfolobus acidocaldarius DSM 639", | ||
"Sulfolobus solfataricus P2", | ||
"Thermococcus gammatolerans EJ3", | ||
"Thermofilum pendens Hrk 5", | ||
"Thermoplasma acidophilum DSM 1728", | ||
"Thermoplasma volcanium GSS1", | ||
"Thermoproteus neutrophilus V24Sta", | ||
"Acholeplasma laidlawii PG-8A", | ||
"Acidobacterium capsulatum ATCC 51196", | ||
"Akkermansia muciniphila ATCC BAA-835", | ||
"Alicyclobacillus acidocaldarius subsp. acidocaldarius DSM 446", | ||
"Aquifex aeolicus VF5", | ||
"Bacillus cereus Q1", | ||
"Bacillus pseudofirmus OF4", | ||
"Bacteroides fragilis YCH46", | ||
"Bdellovibrio bacteriovorus HD100", | ||
"Bordetella pertussis Tohama I", | ||
"Borrelia burgdorferi B31", | ||
"Campylobacter jejuni subsp. jejuni 81-176", | ||
"Candidatus Amoebophilus asiaticus 5a2", | ||
"Candidatus Cloacamonas acidaminovorans", | ||
"Candidatus Endomicrobium sp. Rs-D17", | ||
"Carboxydothermus hydrogenoformans Z-2901", | ||
"Chlamydia trachomatis 434/Bu", | ||
"Chlorobium chlorochromatii CaD3", | ||
"Chloroflexus aurantiacus J-10-fl", | ||
"Clostridium acetobutylicum ATCC 824", | ||
"Corynebacterium glutamicum ATCC 13032", | ||
"Coxiella burnetii RSA 493", | ||
"Cupriavidus taiwanensis", | ||
"Cupriavidus taiwanensis", | ||
"Cyanothece sp. ATCC 51142", | ||
"Cyanothece sp. ATCC 51142", | ||
"Dehalococcoides ethenogenes 195", | ||
"Deinococcus radiodurans R1", | ||
"Deinococcus radiodurans R1", | ||
"Dictyoglomus thermophilum H-6-12", | ||
"Elusimicrobium minutum Pei191", | ||
"Fibrobacter succinogenes subsp. succinogenes S85", | ||
"Flavobacterium psychrophilum JIP02/86", | ||
"Fusobacterium nucleatum subsp. nucleatum ATCC 25586", | ||
"Gemmata obscuriglobus UQM 2246", | ||
"Gemmatimonas aurantiaca T-27", | ||
"Gloeobacter violaceus PCC 7421", | ||
"Leptospira interrogans serovar Lai str. 56601", | ||
"Leptospira interrogans serovar Lai str. 56601", | ||
"Magnetococcus sp. MC-1", | ||
"Methylacidiphilum infernorum V4", | ||
"Mycoplasma genitalium G37", | ||
"Nostoc punctiforme PCC 73102", | ||
"Opitutus terrae PB90-1", | ||
"Pedobacter heparinus DSM 2366", | ||
"Pirellula staleyi DSM 6068", | ||
"Prochlorococcus marinus str. AS9601", | ||
"Psychrobacter arcticus 273-4", | ||
"Rhizobium leguminosarum bv. trifolii WSM1325", | ||
"Rhodopirellula baltica SH 1", | ||
"Rhodospirillum rubrum ATCC 11170", | ||
"Rickettsia rickettsii str. Iowa", | ||
"Shewanella putrefaciens CN-32", | ||
"Solibacter usitatus Ellin6076", | ||
"Synechococcus elongatus PCC 6301", | ||
"Thermanaerovibrio acidaminovorans DSM 6589", | ||
"Thermoanaerobacter tengcongensis MB4", | ||
"Thermobaculum terrenum ATCC BAA-798", | ||
"Thermobaculum terrenum ATCC BAA-798", | ||
"Thermodesulfovibrio yellowstonii DSM 11347", | ||
"Thermomicrobium roseum DSM 5159", | ||
"Thermotoga maritima MSB8" | ||
] | ||
|
||
# list to store results | ||
genomes_data = [] | ||
|
||
# function to download genomes | ||
def download_genome(organism_name): | ||
try: | ||
# search for genome using the organism name | ||
search_handle = Entrez.esearch(db="nucleotide", term=organism_name, retmax=1) | ||
search_results = Entrez.read(search_handle) | ||
search_handle.close() | ||
|
||
# get the id of the first result | ||
id_list = search_results["IdList"] | ||
if not id_list: | ||
print(f"no genome found for {organism_name}") | ||
return None | ||
|
||
# fetch the genome in fasta format | ||
genome_id = id_list[0] | ||
fetch_handle = Entrez.efetch(db="nucleotide", id=genome_id, rettype="fasta", retmode="text") | ||
genome_data = fetch_handle.read() | ||
fetch_handle.close() | ||
|
||
# clean fasta header | ||
genome_sequence = genome_data.split("\n", 1)[1].replace("\n", "") # remove header and line breaks | ||
|
||
# return organism name and genome sequence | ||
return organism_name, genome_sequence | ||
except Exception as e: | ||
print(f"error fetching genome for {organism_name}: {e}") | ||
return None | ||
|
||
# download genomes for each organism | ||
for organism in organisms: | ||
print(f"download organism {organism}") | ||
result = download_genome(organism) | ||
if result: | ||
genomes_data.append(result) | ||
|
||
# create a dataframe with the results | ||
all_recommended_organisms = pd.DataFrame(genomes_data, columns=["organism name", "genome sequence"]) | ||
|
||
# save the original dataframe | ||
save_dataframe(all_recommended_organisms, "all_recommended_organisms.pkl") | ||
|
||
# function to filter organisms with invalid sequences | ||
def delete_bad_organisms(df): | ||
mask = df["genome sequence"].str.fullmatch(r"[ACGT]*") # validate only a, c, g, t | ||
filtered_df = df[mask].reset_index(drop=True) | ||
return filtered_df | ||
|
||
# filter the organisms | ||
all_organisms_filtered = delete_bad_organisms(all_recommended_organisms) | ||
|
||
# select 10 organisms randomly | ||
df_organisms_selection = all_organisms_filtered.sample(n=10, random_state=42) | ||
|
||
# save the selected organisms | ||
save_dataframe(df_organisms_selection, "df_organisms_selection2.pkl") | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from itertools import product | ||
|
||
def encode_kmer(kmer): | ||
"""convert a k-mer to its integer representation""" | ||
encoding = {'A': 0b00, 'C': 0b01, 'G': 0b10, 'T': 0b11} | ||
result = 0 | ||
for base in kmer: | ||
result = (result << 2) | encoding[base] | ||
return result | ||
|
||
def decode_kmer(encoded_kmer, k): | ||
"""convert an encoded k-mer back to its string representation""" | ||
decoding = ['A', 'C', 'G', 'T'] | ||
kmer = [] | ||
for _ in range(k): | ||
base = encoded_kmer & 0b11 # take the last 2 bits | ||
kmer.append(decoding[base]) | ||
encoded_kmer >>= 2 | ||
return ''.join(reversed(kmer)) | ||
|
||
def reverse_complement_bits(kmer, k): | ||
"""compute the reverse complement of a k-mer encoded as an integer""" | ||
mask = (1 << (2 * k)) - 1 # mask to keep only 2k bits | ||
complement = 0 | ||
for _ in range(k): | ||
complement = (complement << 2) | ((kmer & 0b11) ^ 0b11) # complement the base | ||
kmer >>= 2 | ||
return complement & mask # ensure there are no extra bits | ||
|
||
def canonical_kmer_bits(kmer, k): | ||
"""return the canonical k-mer in its integer form""" | ||
rev_comp = reverse_complement_bits(kmer, k) | ||
return min(kmer, rev_comp) | ||
|
||
def generate_all_canonical_kmers_bits(k): | ||
"""generate all canonical k-mers of length k as integers""" | ||
canonical_kmers = set() | ||
for kmer_tuple in product([0b00, 0b01, 0b10, 0b11], repeat=k): | ||
kmer = 0 | ||
for base in kmer_tuple: | ||
kmer = (kmer << 2) | base | ||
canonical_kmers.add(canonical_kmer_bits(kmer, k)) | ||
return sorted(canonical_kmers) | ||
|
||
|
||
if False: | ||
sequence = "ACGTTCGACG" | ||
k = 3 | ||
signature = compute_kmer_signature(sequence, k) | ||
print(f"k-mer signature vector (k={k}): {signature}") | ||
canonical_kmers = generate_all_canonical_kmers(k) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from collections import defaultdict | ||
from tqdm import tqdm | ||
from .encoding import encode_kmer, canonical_kmer_bits, generate_all_canonical_kmers_bits | ||
|
||
def compute_kmer_signature_bits(sequence, k=2): | ||
""" | ||
compute the k-mer signature vector for a sequence | ||
uses bitwise encoding to handle k-mers efficiently | ||
""" | ||
if len(sequence) < k: | ||
raise ValueError("Sequence length must be at least k.") | ||
|
||
# dictionary to store frequencies (default value is 0) | ||
kmer_counts = defaultdict(int) | ||
|
||
# encode the entire sequence into bits | ||
encoded_sequence = encode_kmer(sequence) | ||
|
||
# mask to extract k-mers | ||
mask = (1 << (2 * k)) - 1 # mask for the last 2k bits | ||
for i in range(len(sequence) - k + 1): | ||
kmer = (encoded_sequence >> (2 * (len(sequence) - k - i))) & mask | ||
canonical = canonical_kmer_bits(kmer, k) | ||
kmer_counts[canonical] += 1 | ||
|
||
# generate the signature vector | ||
possible_kmers = generate_all_canonical_kmers_bits(k) | ||
signature_vector = [kmer_counts[kmer] for kmer in possible_kmers] | ||
|
||
return signature_vector | ||
|
||
def signature_for_all_genes(X, k): | ||
"""compute k-mer signatures for all sequences""" | ||
return [compute_kmer_signature_bits(seq, k) for seq in tqdm(X)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
import joblib | ||
import pandas as pd | ||
import numpy as np | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.metrics import accuracy_score | ||
|
||
def evaluate_and_save_model(df, chunk_size, k, rf_params, overlap, model_dir, test_gene_dir): | ||
""" | ||
train and save a random forest model for a specific parameter combination | ||
also saves test genes in their raw form (before signatures) | ||
""" | ||
from .sequence_operations import split_gens_over_all_genomes | ||
from .kmer_signatures import signature_for_all_genes | ||
|
||
print(f"Training model for chunk_size={chunk_size}, k={k}...") | ||
|
||
# extract genes from genomes | ||
X_raw, y = split_gens_over_all_genomes(df, chunk_size=chunk_size, overlap=overlap) | ||
print(f"Generated genes: {len(X_raw)}") | ||
|
||
# sign the genes with k-mers | ||
X_signature = signature_for_all_genes(X_raw, k) | ||
print("Generated signatures.") | ||
|
||
# split into training and testing sets | ||
X_train_sig, X_test_sig, y_train, y_test, X_train_raw, X_test_raw = train_test_split( | ||
X_signature, y, X_raw, test_size=0.2, random_state=42 | ||
) | ||
|
||
# train random forest model | ||
rf_model = RandomForestClassifier(**rf_params) | ||
rf_model.fit(X_train_sig, y_train) | ||
|
||
# evaluate the model | ||
y_pred = rf_model.predict(X_test_sig) | ||
accuracy = accuracy_score(y_test, y_pred) | ||
print(f"Accuracy for chunk_size={chunk_size}, k={k}: {accuracy}") | ||
|
||
# save the model | ||
model_path = os.path.join(model_dir, f"model_chunk{chunk_size}_k{k}.joblib") | ||
joblib.dump(rf_model, model_path) | ||
print(f"Model saved to {model_path}") | ||
|
||
# select random test genes and save them | ||
selected_indices = np.random.choice(len(X_test_raw), size=10, replace=False) | ||
selected_genes = [X_test_raw[i] for i in selected_indices] | ||
selected_labels = [y_test[i] for i in selected_indices] | ||
|
||
test_genes_df = pd.DataFrame({ | ||
"gene": selected_genes, | ||
"label": selected_labels | ||
}) | ||
test_gene_path = os.path.join(test_gene_dir, f"test_genes_chunk{chunk_size}_k{k}.csv") | ||
test_genes_df.to_csv(test_gene_path, index=False) | ||
print(f"Test genes saved to {test_gene_path}") | ||
|
||
return accuracy |
Oops, something went wrong.