diff --git a/scaden/__main__.py b/scaden/__main__.py index e3deeba..580d27b 100644 --- a/scaden/__main__.py +++ b/scaden/__main__.py @@ -56,9 +56,11 @@ def cli(): type=click.Path(exists=True), required=True, metavar='') -@click.option('--train_datasets', - default='', - help='Datasets used for training. Uses all by default.') +@click.option( + '--train_datasets', + default='', + help= + 'Comma-separated list of datasets used for training. Uses all by default.') @click.option('--model_dir', default="./", help='Path to store the model in') @click.option('--batch_size', default=128, diff --git a/scaden/train.py b/scaden/train.py index 1d32d24..1694b54 100644 --- a/scaden/train.py +++ b/scaden/train.py @@ -48,7 +48,7 @@ def training(data_path, if train_datasets == '': train_datasets = [] else: - train_datasets = train_datasets.split() + train_datasets = train_datasets.split(',') print(f"Training on: {train_datasets}") # Training of M256 model