Skip to content

Commit

Permalink
RFCT Simplify code by normalizing args
Browse files Browse the repository at this point in the history
  • Loading branch information
luispedro committed Sep 15, 2024
1 parent 4ae1b09 commit 16e145a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 24 deletions.
26 changes: 9 additions & 17 deletions SemiBin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ def fasta_sample_iter(fn):
def training(logger, contig_fasta,
data, data_split, cannot_link,
*, output, device, mode,
args, training_type='semi'):
args):
"""
Training the model
Expand All @@ -1098,7 +1098,7 @@ def training(logger, contig_fasta,
logger.info('Start training from multiple samples.')
is_combined = False

if training_type == 'semi':
if args.training_type == 'semi':
binned_lengths = []
for fafile in contig_fasta:
binned_lengths.append(
Expand Down Expand Up @@ -1271,20 +1271,17 @@ def single_easy_binning(logger, args, binned_length,
fasta = [args.contig_fasta]
cannot_link = [os.path.join(args.output, 'cannot', 'cannot.txt')]

training_type = 'semi'

else:
fasta = None
cannot_link = None
training_type = 'self'
training(logger, fasta,
[data_path], [data_split_path],
cannot_link=cannot_link,
output=args.output,
device=device,
mode='single',
args=args,
training_type=training_type)
args=args)

binning_kwargs = {
'logger': logger,
Expand Down Expand Up @@ -1347,18 +1344,15 @@ def multi_easy_binning(logger, args, device):
sample_fasta = [sample_fasta]
sample_cannot = [os.path.join(
args.output, 'samples', sample, 'cannot', f'{sample}.txt')]
training_type = 'semi'
else:
sample_fasta = None
sample_cannot = None
training_type = 'self'
training(logger, sample_fasta,
[sample_data], [sample_data_split], sample_cannot,
output=os.path.join(args.output, 'samples', sample),
device=device,
mode='single',
args=args,
training_type=training_type)
args=args)

binning_kwargs = {
'logger': logger,
Expand Down Expand Up @@ -1476,7 +1470,7 @@ def main2(raw_args=None, is_semibin2=True):
print(citation.CHICAGO)
print(f'\nUse `SemiBin2 citation --help` to see all available citation formats')
sys.exit(0)
if args.cmd in ['single_easy_bin', 'multi_easy_bin', 'train', 'train_semi', 'bin', 'train_self', 'bin_long']:
if args.cmd in ['single_easy_bin', 'multi_easy_bin', 'train_semi', 'bin', 'train_self', 'bin_long']:
import torch
if args.engine == 'cpu':
device = torch.device("cpu")
Expand All @@ -1488,7 +1482,7 @@ def main2(raw_args=None, is_semibin2=True):
logger.info('Running with GPU.')
else:
device = torch.device("cpu")
logger.info('Did not detect GPU, using CPU.')
logger.warning('Did not detect GPU or CUDA was not installed/supported, using CPU.')

if getattr(args, 'tmpdir', None) is not None:
os.environ['TMPDIR'] = args.tmpdir
Expand Down Expand Up @@ -1555,7 +1549,7 @@ def main2(raw_args=None, is_semibin2=True):
elif args.cmd == 'generate_sequence_features_multi':
generate_sequence_features_multi(logger, args)

elif args.cmd in ['train', 'train_semi']:
elif args.cmd == 'train_semi':
training(logger,
contig_fasta=args.contig_fasta,
data=args.data,
Expand All @@ -1564,8 +1558,7 @@ def main2(raw_args=None, is_semibin2=True):
output=args.output,
device=device,
mode=args.mode,
args=args,
training_type='semi')
args=args)

elif args.cmd == 'train_self':
training(logger,
Expand All @@ -1576,8 +1569,7 @@ def main2(raw_args=None, is_semibin2=True):
output=args.output,
device=device,
mode=args.mode,
args=args,
training_type='self')
args=args)


elif args.cmd == 'bin':
Expand Down
16 changes: 11 additions & 5 deletions SemiBin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,14 @@ def expect_file_list(fs):
os.environ['NUMEXPR_MAX_THREADS'] = str(args.num_process)
os.environ['OMP_NUM_THREADS'] = str(args.num_process)

if args.cmd in ['train', 'train_semi']:
args.cmd = 'train_semi'
args.training_type = 'semi'

if args.cmd == 'train_self':
args.training_type = 'self'

if args.cmd in ['single_easy_bin', 'multi_easy_bin', 'train', 'bin']:
if args.cmd in ['single_easy_bin', 'multi_easy_bin', 'train_semi', 'bin']:
if args.orf_finder not in ['prodigal', 'fraggenescan', 'fast-naive']:
sys.stderr.write(
f"Error: SemiBin only supports 'prodigal'/'fraggenescan'/'fast-naive' as the ORF finder (--orf-finder option).\n")
Expand Down Expand Up @@ -144,7 +150,7 @@ def expect_file_list(fs):
if args.abundances:
expect_file_list(args.abundances)

if args.cmd in ['train', 'train_semi', 'train_self']:
if args.cmd in ['train_semi', 'train_self']:
if not args.train_from_many:
if len(args.data) > 1:
sys.stderr.write(
Expand All @@ -155,7 +161,7 @@ def expect_file_list(fs):
sys.stderr.write(
f"Error: Expected one data_split.csv file with single mode.\n")
exit_with_error = True
if args.cmd in ['train_semi', 'train']:
if args.cmd == 'train_semi':
if len(args.contig_fasta) > 1:
sys.stderr.write(
f"Error: Expected one fasta file with single mode.\n")
Expand All @@ -172,7 +178,7 @@ def expect_file_list(fs):
expect_file(args.data_split[0])

else:
if args.cmd in ['train_semi', 'train']:
if args.cmd == 'train_semi':
assert len(args.contig_fasta) == len(args.data) == len(args.data_split) == len(args.cannot_link), 'Must input same number of fasta, data, data_split, cannot files!'
expect_file_list(args.cannot_link)
expect_file_list(args.contig_fasta)
Expand Down Expand Up @@ -252,7 +258,7 @@ def expect_file_list(fs):

if getattr(args, 'train_from_many', False):
args.mode = 'several'
elif args.cmd in ['train', 'train_semi', 'train_self'] and not hasattr(args, 'mode'):
elif args.cmd in ['train_semi', 'train_self'] and not hasattr(args, 'mode'):
args.mode = 'single'

if getattr(args, 'write_pre_reclustering_bins', False) and \
Expand Down
4 changes: 2 additions & 2 deletions test/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_train(tmpdir):
os.makedirs(odir)
ofile = f'{odir}/model.h5'

args.training_type = 'semi'
training(contig_fasta = ['test/train_data/input.fasta'],
data = ['test/train_data/data.csv'],
data_split = ['test/train_data/data_split.csv'],
Expand All @@ -31,7 +32,6 @@ def test_train(tmpdir):
device = 'cpu',
mode = 'single',
args = args,
training_type='semi'
)

assert os.path.exists(ofile)
Expand All @@ -40,6 +40,7 @@ def test_train_self(tmpdir):
contig_dict = {h:seq for h,seq in fasta_iter('test/train_data/input.fasta')}
odir = f'{tmpdir}/output_train_self'
os.makedirs(odir)
args.training_type = 'self'
training(contig_fasta = ['test/train_data/input.fasta'],
data = ['test/train_data/data.csv'],
data_split = ['test/train_data/data_split.csv'],
Expand All @@ -49,7 +50,6 @@ def test_train_self(tmpdir):
device = 'cpu',
mode = 'single',
args = args,
training_type='self'
)

assert os.path.exists(f'{odir}/model.h5')
Expand Down

0 comments on commit 16e145a

Please sign in to comment.