diff --git a/lstm_train.py b/lstm_train.py index 7c6a0a6..393f29d 100644 --- a/lstm_train.py +++ b/lstm_train.py @@ -27,7 +27,7 @@ SEQUENCE_LEN = 10 MIN_WORD_FREQUENCY = 10 STEP = 1 -BATCH_SIZE = 128 +BATCH_SIZE = 32 SIMPLE_MODEL = True @@ -161,8 +161,8 @@ def on_epoch_end(epoch, logs): corpus = sys.argv[1] examples = sys.argv[2] - if not os.path.isdir('./checkpoints/s15/'): - os.makedirs('./checkpoints/s15/') + if not os.path.isdir('./checkpoints/'): + os.makedirs('./checkpoints/') with io.open(corpus, encoding='utf-8') as f: text = f.read().lower().replace('\n', ' \n ') @@ -210,7 +210,7 @@ def on_epoch_end(epoch, logs): model = get_model(SIMPLE_MODEL) model.compile(loss='categorical_crossentropy', optimizer="adam", metrics=['accuracy']) - file_path = "./checkpoints/s15/LSTM_LYRICS_words%d_sequence%d_simple%r_minfreq%d_epoch{epoch:02d}_loss{loss:.4f}" % ( + file_path = "./checkpoints/LSTM_LYRICS_words%d_sequence%d_simple%r_minfreq%d_epoch{epoch:02d}_loss{loss:.4f}" % ( len(words), SEQUENCE_LEN, SIMPLE_MODEL,