From 16e145ad3b32610ed88f9682eb260eaf48645b6b Mon Sep 17 00:00:00 2001 From: Luis Pedro Coelho Date: Sun, 15 Sep 2024 22:05:48 +1000 Subject: [PATCH] RFCT Simplify code by normalizing args --- SemiBin/main.py | 26 +++++++++----------------- SemiBin/utils.py | 16 +++++++++++----- test/test_train.py | 4 ++-- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/SemiBin/main.py b/SemiBin/main.py index f99656f..43db35e 100644 --- a/SemiBin/main.py +++ b/SemiBin/main.py @@ -1078,7 +1078,7 @@ def fasta_sample_iter(fn): def training(logger, contig_fasta, data, data_split, cannot_link, *, output, device, mode, - args, training_type='semi'): + args): """ Training the model @@ -1098,7 +1098,7 @@ def training(logger, contig_fasta, logger.info('Start training from multiple samples.') is_combined = False - if training_type == 'semi': + if args.training_type == 'semi': binned_lengths = [] for fafile in contig_fasta: binned_lengths.append( @@ -1271,20 +1271,17 @@ def single_easy_binning(logger, args, binned_length, fasta = [args.contig_fasta] cannot_link = [os.path.join(args.output, 'cannot', 'cannot.txt')] - training_type = 'semi' else: fasta = None cannot_link = None - training_type = 'self' training(logger, fasta, [data_path], [data_split_path], cannot_link=cannot_link, output=args.output, device=device, mode='single', - args=args, - training_type=training_type) + args=args) binning_kwargs = { 'logger': logger, @@ -1347,18 +1344,15 @@ def multi_easy_binning(logger, args, device): sample_fasta = [sample_fasta] sample_cannot = [os.path.join( args.output, 'samples', sample, 'cannot', f'{sample}.txt')] - training_type = 'semi' else: sample_fasta = None sample_cannot = None - training_type = 'self' training(logger, sample_fasta, [sample_data], [sample_data_split], sample_cannot, output=os.path.join(args.output, 'samples', sample), device=device, mode='single', - args=args, - training_type=training_type) + args=args) binning_kwargs = { 'logger': logger, @@ -1476,7 +1470,7 @@ def main2(raw_args=None, is_semibin2=True): print(citation.CHICAGO) print(f'\nUse `SemiBin2 citation --help` to see all available citation formats') sys.exit(0) - if args.cmd in ['single_easy_bin', 'multi_easy_bin', 'train', 'train_semi', 'bin', 'train_self', 'bin_long']: + if args.cmd in ['single_easy_bin', 'multi_easy_bin', 'train_semi', 'bin', 'train_self', 'bin_long']: import torch if args.engine == 'cpu': device = torch.device("cpu") @@ -1488,7 +1482,7 @@ def main2(raw_args=None, is_semibin2=True): logger.info('Running with GPU.') else: device = torch.device("cpu") - logger.info('Did not detect GPU, using CPU.') + logger.warning('Did not detect GPU or CUDA was not installed/supported, using CPU.') if getattr(args, 'tmpdir', None) is not None: os.environ['TMPDIR'] = args.tmpdir @@ -1555,7 +1549,7 @@ def main2(raw_args=None, is_semibin2=True): elif args.cmd == 'generate_sequence_features_multi': generate_sequence_features_multi(logger, args) - elif args.cmd in ['train', 'train_semi']: + elif args.cmd == 'train_semi': training(logger, contig_fasta=args.contig_fasta, data=args.data, @@ -1564,8 +1558,7 @@ def main2(raw_args=None, is_semibin2=True): output=args.output, device=device, mode=args.mode, - args=args, - training_type='semi') + args=args) elif args.cmd == 'train_self': training(logger, @@ -1576,8 +1569,7 @@ def main2(raw_args=None, is_semibin2=True): output=args.output, device=device, mode=args.mode, - args=args, - training_type='self') + args=args) elif args.cmd == 'bin': diff --git a/SemiBin/utils.py b/SemiBin/utils.py index 0088df1..84bef04 100644 --- a/SemiBin/utils.py +++ b/SemiBin/utils.py @@ -99,8 +99,14 @@ def expect_file_list(fs): os.environ['NUMEXPR_MAX_THREADS'] = str(args.num_process) os.environ['OMP_NUM_THREADS'] = str(args.num_process) + if args.cmd in ['train', 'train_semi']: + args.cmd = 'train_semi' + args.training_type = 'semi' + + if args.cmd == 'train_self': + args.training_type = 'self' - if args.cmd in ['single_easy_bin', 'multi_easy_bin', 'train', 'bin']: + if args.cmd in ['single_easy_bin', 'multi_easy_bin', 'train_semi', 'bin']: if args.orf_finder not in ['prodigal', 'fraggenescan', 'fast-naive']: sys.stderr.write( f"Error: SemiBin only supports 'prodigal'/'fraggenescan'/'fast-naive' as the ORF finder (--orf-finder option).\n") @@ -144,7 +150,7 @@ def expect_file_list(fs): if args.abundances: expect_file_list(args.abundances) - if args.cmd in ['train', 'train_semi', 'train_self']: + if args.cmd in ['train_semi', 'train_self']: if not args.train_from_many: if len(args.data) > 1: sys.stderr.write( @@ -155,7 +161,7 @@ def expect_file_list(fs): sys.stderr.write( f"Error: Expected one data_split.csv file with single mode.\n") exit_with_error = True - if args.cmd in ['train_semi', 'train']: + if args.cmd == 'train_semi': if len(args.contig_fasta) > 1: sys.stderr.write( f"Error: Expected one fasta file with single mode.\n") @@ -172,7 +178,7 @@ def expect_file_list(fs): expect_file(args.data_split[0]) else: - if args.cmd in ['train_semi', 'train']: + if args.cmd == 'train_semi': assert len(args.contig_fasta) == len(args.data) == len(args.data_split) == len(args.cannot_link), 'Must input same number of fasta, data, data_split, cannot files!' expect_file_list(args.cannot_link) expect_file_list(args.contig_fasta) @@ -252,7 +258,7 @@ def expect_file_list(fs): if getattr(args, 'train_from_many', False): args.mode = 'several' - elif args.cmd in ['train', 'train_semi', 'train_self'] and not hasattr(args, 'mode'): + elif args.cmd in ['train_semi', 'train_self'] and not hasattr(args, 'mode'): args.mode = 'single' if getattr(args, 'write_pre_reclustering_bins', False) and \ diff --git a/test/test_train.py b/test/test_train.py index 9e6e468..0961c5b 100644 --- a/test/test_train.py +++ b/test/test_train.py @@ -22,6 +22,7 @@ def test_train(tmpdir): os.makedirs(odir) ofile = f'{odir}/model.h5' + args.training_type = 'semi' training(contig_fasta = ['test/train_data/input.fasta'], data = ['test/train_data/data.csv'], data_split = ['test/train_data/data_split.csv'], @@ -31,7 +32,6 @@ def test_train(tmpdir): device = 'cpu', mode = 'single', args = args, - training_type='semi' ) assert os.path.exists(ofile) @@ -40,6 +40,7 @@ def test_train_self(tmpdir): contig_dict = {h:seq for h,seq in fasta_iter('test/train_data/input.fasta')} odir = f'{tmpdir}/output_train_self' os.makedirs(odir) + args.training_type = 'self' training(contig_fasta = ['test/train_data/input.fasta'], data = ['test/train_data/data.csv'], data_split = ['test/train_data/data_split.csv'], @@ -49,7 +50,6 @@ def test_train_self(tmpdir): device = 'cpu', mode = 'single', args = args, - training_type='self' ) assert os.path.exists(f'{odir}/model.h5')