Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite training and inference code to more modern pytorch, add some functionalities and models #20

Merged
merged 42 commits into from
Jul 18, 2021
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
8d52b51
Add simple dataset object for word-level files
annaproxy Jul 12, 2021
6116957
Add dutch dictionary for pre-training
annaproxy Jul 12, 2021
1c3900f
Add non-manual RNN using pytorch nn.RNN
annaproxy Jul 12, 2021
8120faf
Add vocabulary file for dutch dictionary (handy as a standard for fut…
annaproxy Jul 12, 2021
811b0c9
Rewrite generate_word to deal with BOS items, non one-hot and other s…
annaproxy Jul 12, 2021
91a38a3
A simple generating notebook to deal with the previous pretrained mod…
annaproxy Jul 12, 2021
fca5d65
Minor improvements in word generator. (But there may still be a bug)
annaproxy Jul 14, 2021
e6cac51
Add new training loop using pytorch dataloader and CE
annaproxy Jul 14, 2021
50296c8
Add pretrained NL model for who wants it
annaproxy Jul 14, 2021
6db972f
Lint data tools
annaproxy Jul 14, 2021
0bf66b4
Replace trainable embedding with one-hot for now. Add dropout.
annaproxy Jul 14, 2021
603623d
Add two notebooks, one for training, one for generating.
annaproxy Jul 14, 2021
8cdec33
Merge branch 'master' into anna-clean-notebook
annaproxy Jul 14, 2021
7823e86
Add functionality to load Anna's Model
annaproxy Jul 14, 2021
90d97eb
Add Anna's failed model
annaproxy Jul 14, 2021
9c69d7a
Only choose random letters. Return neatly formatted word no EOS tag
annaproxy Jul 14, 2021
e94f43b
Remove debug print
annaproxy Jul 14, 2021
0586cc7
app can now use Anna's model
annaproxy Jul 14, 2021
42aaa99
Put back supervision at every step
annaproxy Jul 14, 2021
812b789
Add anna pretrained models
annaproxy Jul 14, 2021
73e7574
Delete old models
annaproxy Jul 14, 2021
d22c0fc
Add generation notebook with less converged model
annaproxy Jul 14, 2021
f221ef5
Add training notebook with currently Dutch model generations (fun words)
annaproxy Jul 14, 2021
75e7ebf
Bookkeeping / cleaning in several files
annaproxy Jul 14, 2021
d5b12b8
Update app/ml_models/rnn/data_tools.py
annaproxy Jul 16, 2021
bb1829b
Change to absolute import statements.
Sasafrass Jul 16, 2021
ceaa907
Merge branch 'anna-clean-notebook' of github.com:Sasafrass/straattaal…
Sasafrass Jul 16, 2021
63a9769
Update app/ml_models/rnn/data_tools.py with docstring
annaproxy Jul 16, 2021
e21f95b
Add docstring for convert_to_string
annaproxy Jul 16, 2021
379ede3
Improve docstring of load_model
annaproxy Jul 16, 2021
33e4120
Albert docstring for RNNANNA
annaproxy Jul 16, 2021
968dffe
Add docstring for next_char
annaproxy Jul 16, 2021
c70bd99
Lint data_tools, uncomment <BOS> feeding
annaproxy Jul 16, 2021
8fb5ddf
Fix Albert docstring. Remove "hi" example
annaproxy Jul 16, 2021
524e716
Improve names and docstrings in rnn and train loop
annaproxy Jul 16, 2021
e09ed39
Docstring for train loop
annaproxy Jul 17, 2021
d1aec68
formatted with black
rgrouls Jul 17, 2021
d5ffcdb
black formatting
rgrouls Jul 17, 2021
e7af65a
added r prefix before escaped string
rgrouls Jul 17, 2021
40f2d61
Merge pull request #33 from Sasafrass/anna-clean-notebook-review
annaproxy Jul 17, 2021
a4710b0
Improve docstring for convert_to_string, but this won't matter once v…
annaproxy Jul 18, 2021
3c0b9cb
Temp hotfix, rename "rnn" back to "lstm" for legacy model loading
annaproxy Jul 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions app/api/slang.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from flask import jsonify
from app.api import bp

from app.ml_models.rnn.loaded_rnn_model import return_loaded_model
from app.ml_models.rnn.loaded_rnn_model import return_loaded_model, load_model
from app.ml_models.rnn.helpers import random_choice
from app.ml_models.rnn.generate import generate_word


@bp.route("/generate_slang", methods=["GET"])
def generate_slang():
"""Generate and return a new slang word."""
model, ALL_LETTERS = return_loaded_model()
N_LETTERS = len(ALL_LETTERS) + 1

# TODO Should not load model every time a word is queried
# I know nothing of flask, can we save the model upon starting the app?
Sasafrass marked this conversation as resolved.
Show resolved Hide resolved

model, dataset = load_model()

# TODO: Should check if the word is just memorized
new_word = generate_word(
model=model,
N_LETTERS=N_LETTERS,
ALL_LETTERS=ALL_LETTERS,
start_letter=random_choice(ALL_LETTERS),
maxn=20, # TODO: Fix this.
temp=0.3,
dataset=dataset,
start_letter='random',
max_len=20, # TODO: Fix this. (?)
temperature=0.3,
)

# TODO: Return a json containing the word.
Expand Down
Binary file added app/ml_models/rnn/2021_straattaal_epoch100.pt
Binary file not shown.
Binary file added app/ml_models/rnn/2021_straattaal_epoch200.pt
Binary file not shown.
50 changes: 50 additions & 0 deletions app/ml_models/rnn/data_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import os
from torch.utils.data import Dataset, DataLoader
from collections import Counter


class WordLevelDataset(Dataset):
def __init__(self,
prefix: str = '../../../data/',
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
filename_dataset: str = 'straattaal.txt',
filename_vocab: str = 'vocabulary.txt'):
filename_dataset = os.path.join(prefix, filename_dataset)
filename_vocab = os.path.join(prefix, filename_vocab)

with open(filename_dataset, 'r', encoding='utf-8') as f:
lines = f.read().strip().lower()
self.words = [s.strip().replace('\t', '')
for s in lines.split("\n")]
with open(filename_vocab, 'r', encoding='utf-8') as f:
self.vocabulary = list(f.read())
self.vocabulary += ['<BOS>', '<EOS>']
self.vocabulary_size = len(self.vocabulary)
self.char_to_idx_dict = {ch: i for i, ch in enumerate(self.vocabulary)}
self.idx_to_char_dict = {i: ch for i, ch in enumerate(self.vocabulary)}

def __len__(self):
return len(self.words)

def __getitem__(self, i):
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
s1 = [
self.char_to_idx_dict[z]
for z in ["<BOS>"] + list(self.words[i])
]
s2 = [
self.char_to_idx_dict[z]
for z in list(self.words[i]) + ["<EOS>"]
]
return torch.LongTensor(s1), torch.LongTensor(s2)

def convert_to_string(self, char_ix):
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
result = "".join(self.idx_to_char_dict[ix] for ix in char_ix)
return result


if __name__ == "__main__":
hi = WordLevelDataset('../../../data/', 'dutch.txt')
hi_loader = DataLoader(hi, 1)
for z in hi_loader:
print(z)
break
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
98 changes: 59 additions & 39 deletions app/ml_models/rnn/generate.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,74 @@
import torch
from app.ml_models.rnn.helpers import get_input_tensor
from app.ml_models.rnn.rnn_model import RNN
from random import choice as choose


def generate_word(
model: RNN,
N_LETTERS: int,
ALL_LETTERS: set,
start_letter: str = "a",
maxn: int = 20,
temp: float = 0,
) -> str:
"""Generate a new Slang word.
def next_char(out, temperature):
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
# Softmax of the last dimension
if torch.distributions.Uniform(0, 1).sample() < temperature:
probs = torch.softmax((out), -1)
#probs = torch.softmax(temperature*(out), -1) # This is good for randomness (temperature < 1)
choice = torch.multinomial(probs.squeeze(0), 1)
else:
choice = torch.argmax(out, dim=2)
return choice


def generate_word(model, dataset, start_letter=None, max_len=20, temperature=0.25, device='cpu'):
"""Generate a new word.

Args:
model: Pre-trained Recurrent Neural Network model.
N_LETTERS: Number of unique letters found in the training corpus.
ALL_LETTERS: All letters found in the training data.
dataset: WordLevelDataset object
start_letter: Letter to start the word with.
maxn: Maximum number of letters to be used.
max_len: Maximum number of letters to be used.
temp: Temperature used for sampling.
device: torch device string
"""
# Evaluation mode
model.eval()
# no gradient
with torch.no_grad():
input = get_input_tensor(
start_letter,
N_LETTERS=N_LETTERS,
ALL_LETTERS=ALL_LETTERS,
)
hidden = model.initHidden()
output_name = start_letter
for i in range(maxn):
output, hidden = model(input[0], hidden)
if torch.distributions.Uniform(0, 1).sample() < temp:
probs = torch.softmax(output, 1)
dist = torch.distributions.Categorical(probs)
pick = dist.sample()
else:
topv, topi = output.topk(1)
pick = topi[0][0]
if pick == N_LETTERS - 1:
break
else:
letter = ALL_LETTERS[pick]
output_name += letter
input = get_input_tensor(
letter,
N_LETTERS=N_LETTERS,
ALL_LETTERS=ALL_LETTERS)

return output_name
# Hidden stuff initialized to None (pytorch makes this zeros automatically)
h = None
it = 0

# Always generate the Beginning of Word token first and feed it to the RNN
# TODO: Maybe don't ? It results in a lot of copying behaviour for small datasets
# idxs = torch.Tensor([dataset.char_to_idx_dict["<BOS>"]]
# ).long().unsqueeze(0).to(device)
# out, h = model(idxs, h)

choice = torch.Tensor([-99])
annaproxy marked this conversation as resolved.
Show resolved Hide resolved

# Generate a random choice from the vocabulary and put it in the to-be-fed IDXs
if start_letter == 'random':
letters_idx = torch.Tensor(
[dataset.char_to_idx_dict[choose("abcdefghijklmnopqrstuvwxyz")]]
).long().unsqueeze(0).to(device)

# Generate a random choice from the input
elif start_letter is not None:
letters_idx = torch.Tensor(
[dataset.char_to_idx_dict[choose(start_letter)]]
).long().unsqueeze(0).to(device)

# Let the RNN decide for this first round.
else:
choice = next_char(out, temperature)
letters_idx = choice.to(device)

# Check if the token is an EOS token.
while choice.item() != dataset.char_to_idx_dict["<EOS>"] and it < max_len:
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
# Pass the latest character to the model, store new hidden stuff.
out, h = model(letters_idx[it:], h)
choice = next_char(out, temperature)
letters_idx = torch.cat((letters_idx, choice), 0)
it += 1

output_string = letters_idx.squeeze(1).tolist()
return dataset.convert_to_string(output_string).split('<EOS>')[0]


# # TODO: Move this piece of code to generate.py?
Expand Down
21 changes: 20 additions & 1 deletion app/ml_models/rnn/loaded_rnn_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
import os
import torch
from app.ml_models.rnn.rnn_model import RNN
from app.ml_models.rnn.rnn_model import RNN, RNNAnna
from app.ml_models.rnn.data_tools import WordLevelDataset


def load_model(model_name: str = '2021_straattaal_epoch100.pt', device: str = 'cpu'):
"""
Args
model_name: Filename of the model
device: CUDA device name to map to, probably cpu
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
"""
path = os.path.join(os.path.abspath(os.getcwd()),
"app", "ml_models", "rnn")
path = os.path.join(path, model_name)
dataset = WordLevelDataset('data/', 'straattaal.txt')
# TODO: Fix hardcoded hidden size
m = RNNAnna(dataset.vocabulary_size, 128)
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
checkpoint = torch.load(path, map_location=torch.device(device))
m.load_state_dict(checkpoint['model_state_dict'])
m.eval()
return m, dataset
annaproxy marked this conversation as resolved.
Show resolved Hide resolved


def return_loaded_model():
Expand Down
Binary file added app/ml_models/rnn/pretrained_dutch_epoch3.pt
Binary file not shown.
23 changes: 23 additions & 0 deletions app/ml_models/rnn/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,26 @@ def forward(self, input, hidden):

def initHidden(self):
return torch.zeros(1, self.hidden_size)


class RNNAnna(nn.Module):
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
vocab_size,
hidden_size,
train_embeddings=False,
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
):
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
super(RNNAnna, self).__init__()
self._embedding = nn.Embedding(vocab_size, vocab_size)
self._embedding.weight.data = torch.eye(vocab_size)
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
self._embedding.weight.requires_grad = train_embeddings

self.lstm = nn.RNN(vocab_size, hidden_size, 1, batch_first=False)
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
self.dropout = nn.Dropout(0.1)
self.final = nn.Linear(hidden_size, vocab_size)
self.hidden_size = hidden_size

def forward(self, x, hidden=None):
x = self._embedding(x)
out, hidden = self.lstm(x, hidden)
return self.final(self.dropout(out)), hidden
82 changes: 82 additions & 0 deletions app/ml_models/rnn/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from app.ml_models.rnn.data_tools import WordLevelDataset
from app.ml_models.rnn.rnn_model import RNNAnna
from app.ml_models.rnn.generate import generate_word


def train(rnn,
dataloader,
dataset,
learning_rate=0.0005,
epochs=500,
device='cpu',
name='straattaal',
save_every=50,
print_every=10000):

annaproxy marked this conversation as resolved.
Show resolved Hide resolved
# With CrossEntropyLoss we don't need (manual) one-hot
criterion = nn.CrossEntropyLoss()

# Use SGD optimizer so we don't need manual param updates.
optimizer = torch.optim.SGD(
rnn.parameters(), lr=learning_rate, momentum=0.9)
for epoch in range(epochs):
total_loss = 0
rnn.train()
for i, (input_line_tensor, target_line_tensor) in tqdm(enumerate(dataloader), total=len(dataloader)):
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
optimizer.zero_grad()
input_line_tensor = input_line_tensor.to(device)
target_line_tensor = target_line_tensor.to(device)

# Run model ye olde way
#output, _ = rnn(input_line_tensor)
#loss = criterion(output.permute(1, 2, 0), target_line_tensor.permute(1,0))

# Run model ye new way
loss = 0
hidden = None
for Z in range(input_line_tensor.size(1)):
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
# TODO unsqueeze is necessary for batch size 1
# Make this generic for larger batch size (it will also be faster on bigger dataset)
output, hidden = rnn(
input_line_tensor[:, Z].unsqueeze(1), hidden)
l = criterion(output.permute(
1, 2, 0), target_line_tensor[:, Z].unsqueeze(1).permute(1, 0))
loss += l

loss.backward()
optimizer.step()
total_loss += loss.item()
annaproxy marked this conversation as resolved.
Show resolved Hide resolved
if (i+1) % print_every == 0:
print('Loss', total_loss / i)
for _ in range(10):
print('\t', generate_word(
rnn, dataset, start_letter="afhklmnopqrstu", temperature=0.3, device=device))
rnn.train()

# TODO plot loss... maybe.... store it somewhere.... im too lazy
if epoch % save_every == 0:
print('Loss', total_loss / i)
for _ in range(10):
print('\t', generate_word(
rnn, dataset, start_letter="abcdefghijklmnoprstuvwz", temperature=0.3, device=device))

# TODO Save this to some generic spot, not just aat cwd...
torch.save({
'epoch': epoch,
'model_state_dict': rnn.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, f"{name}_statedict_{epoch}.pt")

annaproxy marked this conversation as resolved.
Show resolved Hide resolved

if __name__ == "__main__":
hi = WordLevelDataset('../../../data/', 'straattaal.txt')

# Currently only batch size 1 works
hi_loader = DataLoader(hi, 1, shuffle=True)
rnn = RNNAnna(hi.vocabulary_size, 64, 128)
train(rnn, hi_loader, hi)
Loading