Skip to content

Commit

Permalink
RFCT Code cleanups
Browse files Browse the repository at this point in the history
- Remove unused imports
- Simplify code with f-strings instead of format()
- Better variable names
- Update multiprocessing usage to always use the 'spawn' method (for
  future proofing)
  • Loading branch information
luispedro committed Feb 16, 2025
1 parent 708deb7 commit 4247565
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 36 deletions.
2 changes: 0 additions & 2 deletions SemiBin/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import tempfile

from .utils import write_bins, cal_num_bins
from .fasta import fasta_iter

# This is the default in the igraph package
NR_INFOMAP_TRIALS = 10
Expand Down Expand Up @@ -276,7 +275,6 @@ def cluster(logger, model, data, device, is_combined,
max_node: max percentage of contigs considered in binning
"""
import pandas as pd
import numpy as np
embedding, contig_labels = run_embed_infomap(logger, model, data,
device=device, max_edges=args.max_edges, max_node=args.max_node,
is_combined=is_combined, n_sample=n_sample,
Expand Down
9 changes: 4 additions & 5 deletions SemiBin/generate_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def calculate_coverage(depth_stream, bam_file, must_link_threshold, edge=75, is_
}, index=contigs), None


def generate_cov(bam_file, bam_index, out, threshold,
def generate_cov(bam_file: str, bam_index, out, threshold,
is_combined, contig_threshold, logger, sep = None):
"""
Call bedtools and generate coverage file
bam_file: bam files used
bam_file: bam file used
out: output
threshold: threshold of contigs that will be binned
is_combined: if using abundance feature in deep learning. True: use
Expand Down Expand Up @@ -126,16 +126,15 @@ def generate_cov(bam_file, bam_index, out, threshold,
if sep is None:
abun_scale = (contig_cov.mean() / 100).apply(np.ceil) * 100
contig_cov = contig_cov.div(abun_scale)
with atomic_write(os.path.join(out, '{}_data_cov.csv'.format(bam_name)), overwrite=True) as ofile:
with atomic_write(os.path.join(out, f'{bam_name}_data_cov.csv'), overwrite=True) as ofile:
contig_cov.to_csv(ofile)

if is_combined:
must_link_contig_cov = must_link_contig_cov.apply(lambda x: x + 1e-5)
if sep is None:
abun_split_scale = (must_link_contig_cov.mean() / 100).apply(np.ceil) * 100
must_link_contig_cov = must_link_contig_cov.div(abun_split_scale)

with atomic_write(os.path.join(out, '{}_data_split_cov.csv'.format(bam_name)), overwrite=True) as ofile:
with atomic_write(os.path.join(out, f'{bam_name}_data_split_cov.csv'), overwrite=True) as ofile:
must_link_contig_cov.to_csv(ofile)

return bam_file
Expand Down
41 changes: 22 additions & 19 deletions SemiBin/main.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import argparse
import logging
import os
from multiprocessing.pool import Pool
from os import path
import multiprocessing as mp
import subprocess
from .atomicwrite import atomic_write
import shutil
import sys
from itertools import groupby
from . import utils
from .utils import validate_normalize_args, get_must_link_threshold, generate_cannot_link, \
set_random_seed, process_fasta, split_data, get_model_path, extract_bams
set_random_seed, process_fasta, split_data, get_model_path, maybe_crams2bams
from .generate_coverage import generate_cov, combine_cov, generate_cov_from_abundances
from .generate_kmer import generate_kmer_features_from_fasta
from .fasta import fasta_iter
from .semibin_version import __version__

Pool = mp.get_context('spawn').Pool


def parse_args(args, is_semibin2):
# BooleanOptionalAction is available in Python 3.9; before that, we fall back on the default
Expand Down Expand Up @@ -725,9 +728,9 @@ def predict_taxonomy(logger, contig_fasta, cannot_name,
os.path.join(tdir, 'contig_DB')],
stdout=None,
)
except:
except Exception as e:
sys.stderr.write(
f"Error: Running mmseqs createdb fail\n")
f"Error: Running mmseqs createdb failed (error: {e})\n")
sys.exit(1)
if os.path.exists(os.path.join(output, 'mmseqs_contig_annotation')):
shutil.rmtree(os.path.join(output, 'mmseqs_contig_annotation'))
Expand All @@ -746,9 +749,9 @@ def predict_taxonomy(logger, contig_fasta, cannot_name,
check=True,
stdout=None,
)
except:
except Exception as e:
sys.stderr.write(
f"Error: Running mmseqs taxonomy fail\n")
f"Error: Running mmseqs taxonomy failed (error: {e})\n")
sys.exit(1)
taxonomy_results_fname = os.path.join(output,
'mmseqs_contig_annotation',
Expand Down Expand Up @@ -785,7 +788,7 @@ def generate_sequence_features_single(logger, contig_fasta,

if bams is None and abundances is None and not only_kmer:
sys.stderr.write(
f"Error: You need to specify input BAM files or abundance files if you want to calculate coverage features.\n")
f"Error: You need to specify input BAM files or abundance files to calculate coverage features.\n")
sys.exit(1)

if (bams is not None or abundances is not None) and only_kmer:
Expand Down Expand Up @@ -891,7 +894,7 @@ def generate_sequence_features_multi(logger, args):

# Gererate contig file for every sample
sample_list = []
contig_length_list = []
contig_lengths = []

os.makedirs(os.path.join(args.output, 'samples'), exist_ok=True)

Expand All @@ -904,16 +907,16 @@ def fasta_sample_iter(fn):
yield sample_name, contig_name, seq

for sample_name, contigs in groupby(fasta_sample_iter(args.contig_fasta), lambda sn_cn_seq : sn_cn_seq[0]):
with open(os.path.join(args.output, 'samples', '{}.fa'.format(sample_name)), 'wt') as out:
with utils.possibly_compressed_write(os.path.join(args.output, 'samples', f'{sample_name}.fa')) as out:
for _, contig_name, seq in contigs:
out.write(f'>{contig_name}\n{seq}\n')
contig_length_list.append(len(seq))
contig_lengths.append(len(seq))
sample_list.append(sample_name)
if len(sample_list) != len(set(sample_list)):
logger.error(f'Concatenated FASTA file {args.contig_fasta} not in expected format. Samples should follow each other.')
sys.exit(1)

must_link_threshold = get_must_link_threshold(contig_length_list) if args.ml_threshold is None else args.ml_threshold
must_link_threshold = get_must_link_threshold(contig_lengths) if args.ml_threshold is None else args.ml_threshold
binning_threshold = {}
for sample in sample_list:
binning_threshold[sample] = utils.compute_min_length(
Expand Down Expand Up @@ -943,16 +946,16 @@ def fasta_sample_iter(fn):
logger.info(f'Processed: {s}')

for bam_index, bam_file in enumerate(args.bams):
if not os.path.exists(os.path.join(os.path.join(args.output, 'samples'), '{}_data_cov.csv'.format(
os.path.split(bam_file)[-1] + '_{}'.format(bam_index)))):
if not path.exists(path.join(args.output, 'samples',
f'{path.split(bam_file)[-1]}_{bam_index}_data_cov.csv')):
sys.stderr.write(
f"Error: Generating coverage file fail\n")
f"Error: Generating coverage file failed (for BAM file {bam_file})\n")
sys.exit(1)
if is_combined:
if not os.path.exists(os.path.join(os.path.join(args.output, 'samples'), '{}_data_split_cov.csv'.format(
os.path.split(bam_file)[-1] + '_{}'.format(bam_index)))):
if not path.exists(path.join(args.output, 'samples',
f'{path.split(bam_file)[-1]}_{bam_index}_data_split_cov.csv')):
sys.stderr.write(
f"Error: Generating coverage file fail\n")
f"Error: Generating split coverage file failed (for BAM file {bam_file})\n")
sys.exit(1)

# Generate cov features for every sample
Expand Down Expand Up @@ -1471,8 +1474,8 @@ def main2(raw_args=None, is_semibin2=True):
set_random_seed(args.random_seed)

with tempfile.TemporaryDirectory() as tdir:
if hasattr(args, 'bams'):
args.bams = extract_bams(args.bams, args.contig_fasta, args.num_process, tdir)
if hasattr(args, 'bams') and args.bams is not None:
args.bams = maybe_crams2bams(args.bams, args.contig_fasta, args.num_process, tdir)

if args.cmd in ['generate_cannot_links', 'generate_sequence_features_single', 'bin','single_easy_bin', 'bin_long']:
binned_short, must_link_threshold, contig_dict = process_fasta(args.contig_fasta, args.ratio)
Expand Down
24 changes: 21 additions & 3 deletions SemiBin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,8 @@ def process_fasta(fasta_path, ratio):
binned_short = contig_bp_2500 / whole_contig_bp < ratio
must_link_threshold = get_must_link_threshold(contig_length_list)
if not contig_dict:
import logging
logger = logging.getLogger('SemiBin2')
logger.warning(f'No contigs in {fasta_path}')
return binned_short, must_link_threshold, contig_dict

Expand Down Expand Up @@ -627,11 +629,26 @@ def n50_l50(sizes):
return n50, l50+1


def extract_bams(bams, contig_fasta : str, num_process : int, odir : str): # bams : list[str] is not available on Python 3.7
def maybe_crams2bams(bams, contig_fasta : str, num_process : int, odir : str): # bams : list[str] is not available on Python 3.7
'''
extract_bams converts CRAM to BAM
maybe_crams2bams converts CRAM to BAM
Parameters
----------
bams : list of str
List of BAM/CRAM files
contig_fasta : str
Contig FASTA file
num_process : int
Number of processes to use
odir : str
Output directory for extracted BAM files
Returns
-------
obams : list of str
List of BAM files (extracted if CRAM or original)
'''
if bams is None: return None
rs = []
for bam in bams:
if bam.endswith('.cram'):
Expand All @@ -649,6 +666,7 @@ def extract_bams(bams, contig_fasta : str, num_process : int, odir : str): # bam
rs.append(bam)
return rs


def compute_min_length(min_length, fafile, ratio):
if min_length is not None: return min_length
binned_short ,_ ,_ = process_fasta(fafile, ratio)
Expand Down
10 changes: 3 additions & 7 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from SemiBin.utils import get_must_link_threshold, get_marker, split_data, n50_l50, extract_bams, norm_abundance
from SemiBin.utils import get_must_link_threshold, get_marker, split_data, n50_l50, maybe_crams2bams, norm_abundance
from hypothesis import given, strategies as st
from io import StringIO
import numpy as np
Expand Down Expand Up @@ -122,18 +122,14 @@ def test_n50_l50(sizes):

def test_extract_bams(tmpdir):
single_sample_input = 'test/single_sample_data'
rs = extract_bams([f'{single_sample_input}/input.cram'],
rs = maybe_crams2bams([f'{single_sample_input}/input.cram'],
f'{single_sample_input}/input.fasta',
2,
tmpdir)
assert len(rs) == 1
assert rs[0].endswith('.bam')
assert rs[0].startswith(str(tmpdir))
rs = extract_bams(None,
f'{single_sample_input}/input.fasta',
2,
tmpdir)
assert rs is None


def test_norm_abundance():
assert not norm_abundance(np.random.randn(10, 136))
Expand Down

0 comments on commit 4247565

Please sign in to comment.