Skip to content

Commit

Permalink
Cleaning up, changing names.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Myrberg committed Jul 27, 2017
1 parent f206e33 commit 393358e
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 88 deletions.
35 changes: 32 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,35 @@

## Training steps
### 1. Fit a dictionary
python -m dict_train.py \
--dict-save-path './data/dicts/lemmatizer.dict' \
--dict-train-path './data/train
python -m dict_train ^
--dict-save-path ./data/dictionaries/lemmatizer.dict ^
--dict-train-path ./data/dictionaries/lemmatizer.vocab ^
--vocab-size 50 ^
--min-freq 0.0 ^
--max-freq 1.0 ^
--file-batch-size 8192 ^
--prune-every-n 200

### 2. Create and train a new model
python -m model_train ^
--model-dir ./data/models/lemmatizer2 ^
--dict-path ./data/dictionaries/lemmatizer.dict ^
--train-data-path ./data/datasets/lemmatizer_train.csv ^
--optimizer 'adam' ^
--learning-rate 0.0001 ^
--dropout-rate 0.2 ^
--batch-size 128 ^
--file-batch-size 8192 ^
--max-file-pool-size 50 ^
--shuffle-files True ^
--shuffle-file-batches True ^
--save-every-n-batch 500 ^
--validate-every-n-batch 100 ^
--validation-data-path ./data/datasets/lemmatizer_validation.csv ^
--validate-n-rows 5000

### 3. Make predictions on test set
python -m model_decode ^
--model-dir ./data/models/lemmatizer ^
--source-data-path ./data/datasets/lemmatizer_test.csv ^
--decoded-data-path ./data/decoded/lemmatizer_decoded_1.csv
5 changes: 5 additions & 0 deletions src/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def rebatch(batches,
np.random.shuffle(in_batches)
yield out_batch

def read_file(filename, nrows=None):
"""Read one file entirely."""
ar = pd.read_csv(filename, nrows=nrows).values
return ar

def read_file_batched(filename,
file_batch_size=8192,
file_batch_shuffle=False,
Expand Down
12 changes: 6 additions & 6 deletions src/dict_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import argparse

from dictionary import Dictionary
from seq2seq import doc_to_tokens
from model_wrappers import doc_to_tokens
from utils import get_path_files
from data_utils import read_files_batched

Expand All @@ -25,10 +25,10 @@
type=int, action='store',
help='Size of vocabulary')
parser.add_argument("--min-freq", default=0.0,
type=int, action='store',
type=float, action='store',
help='Minimum word frequency')
parser.add_argument("--max-freq", default=1.0,
type=int, action='store',
type=float, action='store',
help='Maximum word frequency')

# Training params (optional)
Expand All @@ -54,12 +54,12 @@ def train_dict(args):
# Batch generator
train_gen = read_files_batched(files,
file_batch_size=args.file_batch_size,
file_batch_shuffle=False)
file_batch_shuffle=False,
return_mode='array')

# Fit dictionary in batches
for docs in train_gen:
long_doc = " ".join(docs.flatten())
tokens = [[token] for token in doc_to_tokens(long_doc)]
tokens = [doc_to_tokens(doc) for doc in docs.flatten()]
model_dict.fit_batch(tokens, prune_every_n=args.prune_every_n)

# Save dict
Expand Down
39 changes: 31 additions & 8 deletions src/decode.py → src/model_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@


import argparse
import os

from seq2seq import Seq2Seq
from utils import get_path_files
from datetime import datetime

from model_wrappers import Seq2Seq
from utils import get_path_files, create_folder
from data_utils import read_files_cycled, rebatch


Expand Down Expand Up @@ -59,53 +62,73 @@ def decode_model(args):
filenames=files,
max_file_pool_size=args.max_file_pool_size,
file_batch_size=args.file_batch_size,
file_batch_shuffle=args.shuffle_file_batches)
file_batch_shuffle=False)

# Decode batches
decode_gen = rebatch(
file_gen,
in_batch_size_limit=args.file_batch_size*args.max_file_pool_size,
out_batch_size=args.batch_size,
shuffle=args.shuffle_file_batches,
shuffle=False,
flatten=True)

# Decode
write_target = False
for batch_nb,batch in enumerate(decode_gen):

print('Batch number {}'.format(batch_nb))

if batch_nb == 0:

# Number of columns in batch
n_cols = len(batch[0])
print(n_cols)
if n_cols == 1:
source_docs = batch
elif n_cols == 2:
source_docs,target_docs = batch
source_docs,target_docs = zip(*batch)
write_target = True
else:
raise ValueError("Number of columns found %d not in [1,2]" \
% n_cols)

# Output file handle and headers
create_folder(os.path.dirname(args.decoded_data_path))
fout = open(args.decoded_data_path, 'w', encoding='utf8')
fout.write('source\t')
if write_target:
fout.write('target\t')
fout.write("\t".join([str(k) for k in args.beam_width])+'\n')
fout.write("\t".join([str(k) for k in range(args.beam_width)]))
fout.write('\n')

# Extra stream
# @TODO: Remove
fout2 = open('.'+args.decoded_data_path.split('.')[-2]+'_CLEAR.csv',
'w',
encoding='utf8')

# Get decoded documents: list of lists, with beams as elements
decoded_docs = model.decode(source_docs)

# Write beams to file
for i in len(decoded_docs):
for i in range(len(decoded_docs)):
fout.write(source_docs[i]+'\t')
if write_target:
fout.write(target_docs[i]+'\t')
for k in args.beam_width:
for k in range(args.beam_width):

if k == 0:
out_fmt = '[{}] Decode example {} {:>50s} --> {:<50s}'\
.format(str(datetime.now()), i, source_docs[i], decoded_docs[i][k])
fout2.write(out_fmt+'\n')

decoded_doc = decoded_docs[i][k]
fout.write(decoded_doc+'\t')
fout.write('\n')

fout.close()
fout2.close()


def main():
args = parser.parse_args()
Expand Down
61 changes: 43 additions & 18 deletions src/train.py → src/model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

import numpy as np

from seq2seq import Seq2Seq
from utils.utils import get_path_files
from utils.data_utils import read_files_cycled, rebatch
from model_wrappers import Seq2Seq
from utils import get_path_files
from data_utils import read_file, read_files_cycled, rebatch


parser = argparse.ArgumentParser('Model training')
Expand All @@ -30,25 +30,40 @@
parser.add_argument("--train-data-path", default=None,
type=str, action='store',
help='Training data folder or file')
parser.add_argument("--validation-data-path", default=None,
type=str, action='store',
help='Validation data folder or file')

# Model params (required for the first time)
# keep_every_n_hours = 1

# Model training params (optional)
parser.add_argument("--max-seq-len", default=100,
type=int, action='store',
help='Optimizer to use')
parser.add_argument("--optimizer", default='adam',
type=str, action='store',
help='Optimizer to use')
parser.add_argument("--learning-rate", default=0.0001,
type=int, action='store',
type=float, action='store',
help='Learning rate of the optimizer')
parser.add_argument("--dropout-rate", default=0.2,
type=str, action='store',
type=float, action='store',
help='Hidden layer dropout')

# Training params (optional)
parser.add_argument("--batch-size", default=32,
type=int, action='store',
help='Batch size to feed into model')
parser.add_argument("--validate-every-n-batch", default=100,
type=int, action='store',
help='Save model checkpoint every n batch')
parser.add_argument("--validate-n-rows", default=None,
type=int, action='store',
help='Number of rows to read from validation file')
parser.add_argument("--validation-batch-size", default=32,
type=bool, action='store',
help='Validation batch size')
parser.add_argument("--file-batch-size", default=8192,
type=int, action='store',
help='Number of rows to read in-memory from each file')
Expand All @@ -61,9 +76,9 @@
parser.add_argument("--shuffle-file-batches", default=True,
type=bool, action='store',
help='Shuffle file batches before training')
parser.add_argument("--save-every-n", default=500,
parser.add_argument("--save-every-n-batch", default=1000,
type=int, action='store',
help='Maximum number of files to cycle at a time')
help='Save model checkpoint every n batch')


def train_model(args):
Expand Down Expand Up @@ -97,24 +112,34 @@ def train_model(args):
shuffle=args.shuffle_file_batches,
flatten=True)

if args.validation_data_path is not None:
valid_data = read_file(args.validation_data_path,
nrows=args.validate_n_rows)
valid_source_docs,valid_target_docs = zip(*valid_data)

# Train
start = time.clock()
for batch_nb,batch in enumerate(train_gen):
source_docs,target_docs = zip(*batch)
loss,global_step = model.train(source_docs, target_docs,
max_seq_len=args.max_seq_len)
loss,global_step = model.train(
source_docs, target_docs,
max_seq_len=args.max_seq_len,
save_every_n_batch=args.save_every_n_batch)

# Print progress
end = time.clock()
print('[{}] Step: {} - Samples: {} - Loss: {:<.3f} - Time {:<.3f}'.format(
str(datetime.now()), global_step, global_step*args.batch_size,
loss,round(end-start,3)))
samples = global_step*args.batch_size
print('[{}] Training step: {} - Samples: {} - Loss: {:<.3f} - Time {:<.3f}'\
.format(str(datetime.now()), global_step, samples, loss, round(end-start,3)))
start = end

if batch_nb % 512 == 0 and batch_nb > 0:
print('Evaluating...')
print('Source:',source_docs[0])
print('Target:',target_docs[0])
print('Prediction:',model.decode(source_docs[0:1]))

# Validation
if batch_nb % args.validate_every_n_batch == 0 and batch_nb > 0:
loss,global_step = model.eval(valid_source_docs, valid_target_docs)
end = time.clock()
print('[{}] Validation step: {} - Samples: {} - Loss: {:<.3f} - Time {:<.3f}'\
.format(str(datetime.now()), global_step, samples, loss, round(end-start,3)))
start = end

else:
print('No training files provided')
Expand Down
Loading

0 comments on commit 393358e

Please sign in to comment.