Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiments with new models #30

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/config/lstm_baseline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ model_module_name: 'models.lstm_baseline'
model_class_name: 'LSTMBaseline'

n_train: !!int 60000
n_decay: !!int 10000
print_every_n: !!int 1000
val_every_n: !!float 5000
print_every_n: !!int 100
val_every_n: !!float 1000
n_val: !!int 600
n_test: !!int 600
n_samples: !!int 20

lr: !!float 5e-3
lr: !!float 1e-3
patience_iters: !!int 7
max_grad_norm: !!int 5
batch_size: !!int 5
embedding_size: !!int 250
embedding_size: !!int 200
n_layers: !!int 1
hidden_size: !!int 200
use_sentinel: !!bool false
23 changes: 23 additions & 0 deletions src/config/lstm_cont_cache.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: 'lstm_baseline'
model_module_name: 'models.lstm_cont_cache'
model_class_name: 'LSTMContCache'

n_train: !!int 60000
n_decay: !!int 10000
print_every_n: !!int 1000
val_every_n: !!float 1000
n_val: !!int 600
n_test: !!int 600
n_samples: !!int 20

lr: !!float 1e-3
patience_iters: !!int 5
max_grad_norm: !!int 5
batch_size: !!int 5
embedding_size: !!int 250
n_layers: !!int 1
hidden_size: !!int 200
subseq_len: !!int 10

theta: !!float 0.2
lambda: !!float 0.05
20 changes: 20 additions & 0 deletions src/config/lstm_dynamic_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: 'lstm_baseline'
model_module_name: 'models.lstm_dynamic_eval'
model_class_name: 'LSTMDynamicEval'

n_train: !!int 60000
n_decay: !!int 10000
print_every_n: !!int 1000
val_every_n: !!float 1000
n_val: !!int 600
n_test: !!int 600
n_samples: !!int 20

lr: !!float 5e-1
patience_iters: !!int 5
max_grad_norm: !!int 5
batch_size: !!int 5
embedding_size: !!int 250
n_layers: !!int 1
hidden_size: !!int 200
subseq_len: !!int 10
25 changes: 25 additions & 0 deletions src/config/lstm_enc_dec.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: 'lstm_seq2seq'
model_module_name: 'models.lstm_enc_dec'
model_class_name: 'LSTMEncDec'

n_train: !!int 60000
print_every_n: !!int 100
val_every_n: !!float 1000
n_val: !!int 600
n_test: !!int 600
n_samples: !!int 20

lr: !!float 1e-3
patience_iters: !!int 7
max_grad_norm: !!int 5
batch_size: !!int 5
embedding_size: !!int 200
n_layers: !!int 1
hidden_size: !!int 200

stop_grad: !!bool true

enc_size: !!int 32
use_sentinel: !!bool false
use_film: !!bool false
decode_support: !!bool false
22 changes: 22 additions & 0 deletions src/config/lstm_maml.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: 'lstm_fomaml'
model_module_name: 'models.lstm_maml'
model_class_name: 'LSTMMAML'

n_train: !!int 60000
n_decay: !!int 10000
print_every_n: !!int 1000
val_every_n: !!float 1000
n_val: !!int 600
n_test: !!int 600
n_samples: !!int 20

lr: !!float 1e-1
meta_lr: !!float 1e-3
patience_iters: !!int 5
n_update: !!int 3
max_grad_norm: !!int 5
batch_size: !!int 5
embedding_size: !!int 250
n_layers: !!int 1
hidden_size: !!int 200
stop_grad: !!bool true
4 changes: 4 additions & 0 deletions src/config/lyrics.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
dataset: 'lyrics'
dataset_path: '../raw-data/lyrics/lyrics_data/'
splits: ['train', 'val', 'test']
word_min_times: 5
max_unk_percent: 0.05
min_len: 10
max_len: 50
eval_len: 50
2 changes: 1 addition & 1 deletion src/config/midi.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
dataset: 'midi'
dataset_path: '../raw-data/freemidi/freemidi_data/'
splits: ['train', 'val', 'test']
max_len: 50
max_len: 50 # [50, 250, 500]
51 changes: 40 additions & 11 deletions src/data/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
class Loader(object):
"""A class for turning data into a sequence of tokens.
"""
def __init__(self, max_len, dtype=np.int32, persist=True):
def __init__(self, min_len, max_len, max_unk_pecent,
dtype=np.int32, persist=True):
self.min_len = min_len
self.max_len = max_len
self.max_unk_percent = max_unk_pecent
self.dtype = dtype
self.persist = persist

Expand All @@ -34,8 +37,12 @@ def get_num_tokens(self):

def validate(self, filepath):
try:
self.load(filepath)
return True
# Must have at least one valid stanza for whole song to be valid
np_tokens, _ = self.load(filepath)
if np.shape(np_tokens)[0] > 0:
return True
else:
return False
except OSError:
return False
except KeyError:
Expand All @@ -50,15 +57,37 @@ def validate(self, filepath):
return False

def load(self, filepath):
npfile = '%s.%s.npy' % (filepath, self.max_len)
npfile = '%s.%s.npz' % (filepath, self.max_len)
if self.persist and os.path.isfile(npfile):
return np.load(npfile).astype(self.dtype)
data = np.load(npfile)
numpy_tokens = data['tokens'].astype(self.dtype)
numpy_seq_lens = data['seq_lens'].astype(self.dtype)
else:
data = self.read(filepath)
tokens = self.tokenize(data)
numpy_tokens = np.zeros(self.max_len, dtype=self.dtype)
for token_index in range(min(self.max_len, len(tokens))):
numpy_tokens[token_index] = tokens[token_index]
all_tokens = self.tokenize(data)
# Filter stanzas
# Keep all stanzas that are >= min length in length
all_tokens = list(
filter(lambda x: len(x) >= self.min_len, all_tokens))
# Keep all stanzas that are <= max length in length
all_tokens = list(
filter(lambda x: len(x) <= self.max_len, all_tokens))
# Keep all stanzas that have < max unk% of tokens
all_tokens = list(filter(
lambda x: self.get_unk_percent(x) < self.max_unk_percent,
all_tokens))

n_stanzas = len(all_tokens)
numpy_tokens = np.zeros(
(n_stanzas, self.max_len), dtype=self.dtype)
numpy_seq_lens = np.array(list(
map(lambda x: len(x), all_tokens)
))
for i, stanza_tokens in enumerate(all_tokens):
for j in range(min(self.max_len, len(stanza_tokens))):
numpy_tokens[i][j] = all_tokens[i][j]

if self.persist:
np.save(npfile, numpy_tokens)
return numpy_tokens
np.savez(npfile, tokens=numpy_tokens, seq_lens=numpy_seq_lens)

return numpy_tokens, numpy_seq_lens
97 changes: 46 additions & 51 deletions src/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@
import multiprocessing
import itertools
try:
from urllib import quote, unquote # python 2
from urllib import quote, unquote # python 2
except ImportError:
from urllib.parse import quote, unquote # python 3

from urllib.parse import quote, unquote # python 3
import numpy as np

from data.lib import get_random

log = logging.getLogger('few-shot')
logging.basicConfig(level=logging.INFO)


class Metadata(object):
"""An object for tracking the metadata associated with a configuration of
the sampler.
"""
"""Track the metadata associated with a configuration of the sampler."""

def __init__(self, root, name):
self.dir = os.path.join(root, name)
self.open_files = {}
Expand All @@ -38,7 +37,8 @@ def lines(self, filename):

def write(self, filename, line):
if filename not in self.open_files:
self.open_files[filename] = open(os.path.join(self.dir, filename), 'a')
self.open_files[filename] = open(
os.path.join(self.dir, filename), 'a')
self.open_files[filename].write(line)

def close(self):
Expand Down Expand Up @@ -82,25 +82,28 @@ class Dataset(object):
seed (int or None): the random seed which is used for shuffling the
artists.
"""
def __init__(self, root, split, loader, metadata, split_proportions=(8,1,1),
persist=True, cache=True, validate=True, min_songs=0, parallel=False,
valid_songs_file='valid_songs.csv', seed=None):

def __init__(self, root, loader, metadata, artists_in_split, seed=None,
persist=True, cache=True, validate=True, min_songs=0,
artists_file='artists.csv',
valid_songs_file='valid_songs.csv'):
self.root = root
self.cache = cache
self.cache_data = {}
self.loader = loader
self.metadata = metadata
self.random = get_random(seed)
self.artists = []
self.valid_songs_file = valid_songs_file
valid_songs = {}
artist_in_split = []
valid_artists_in_split = []

# If we're both validating and using persistence, load any validation
# data from disk. The format of the validation file is just a CSV
# with two entries: artist and song. The artist is the name of the
# artist (i.e. the directory (e.g. 'K_s Choice')) and the song is
# the song file name (e.g. 'ironflowers.mid').
if validate and persist:
if validate and persist and self.metadata.exists(valid_songs_file):
for line in self.metadata.lines(valid_songs_file):
artist, song = line.rstrip('\n').split(',', 1)
artist = unquote(artist)
Expand All @@ -109,79 +112,61 @@ def __init__(self, root, split, loader, metadata, split_proportions=(8,1,1),
valid_songs[artist] = set()
valid_songs[artist].add(song)

if persist and self.metadata.exists('%s.csv' % split):
artists_in_split = []
for line in self.metadata.lines('%s.csv' % split):
artists_in_split.append(line.rstrip('\n'))
valid_artists_in_split = artists_in_split
else:
dirs = []
all_artists = []
skipped_count = 0
pool = multiprocessing.Pool(multiprocessing.cpu_count())

for artist in os.listdir(root):
for artist in artists_in_split:
if os.path.isdir(os.path.join(root, artist)):
songs = os.listdir(os.path.join(root, artist))
songs = [song for song in songs if loader.is_song(song)]
songs = [s for s in songs if loader.is_song(s)]
if len(songs) > 0:
dirs.append(artist)

num_dirs = len(dirs)
progress_logger = ProgressLogger(num_dirs)

for artist_index, artist in enumerate(dirs):
# log.info("Processing artist %s" % artist)
songs = os.listdir(os.path.join(root, artist))
# We only want .txt and .mid files. Filter all others.
songs = [song for song in songs if loader.is_song(song)]
songs = [s for s in songs if loader.is_song(s)]
# populate `valid_songs[artist]`
if validate:
progress_logger.maybe_log(artist_index)
if artist not in valid_songs:
valid_songs[artist] = set()
songs_to_validate = [song for song in songs if song not in valid_songs[artist]]
song_files = [os.path.join(root, artist, song) for song in songs_to_validate]
if parallel:
mapped = pool.map(loader.validate, song_files)
else:
mapped = map(loader.validate, song_files)
songs_to_validate = [s for s in songs if s not in valid_songs[artist]]
song_files = [os.path.join(root, artist, s) for s in songs_to_validate]
mapped = map(loader.validate, song_files)
validated = itertools.compress(songs_to_validate, mapped)
for song in validated:
song_file = os.path.join(root, artist, song)
# song_file = os.path.join(root, artist, song)
if persist:
line = '%s,%s\n' % (quote(artist), quote(song))
self.metadata.write(self.valid_songs_file, line)
# log.info("Validated song %s" % song)
valid_songs[artist].add(song)
else:
valid_songs[artist] = set(songs)

if len(valid_songs[artist]) >= min_songs:
all_artists.append(artist)
valid_artists_in_split.append(artist)
else:
skipped_count += 1
pool.close()
pool.join()

if skipped_count > 0:
log.info("%s artists don't have K+K'=%s songs. Using %s artists" % (
skipped_count, min_songs, len(all_artists)))
train_count = int(float(split_proportions[0]) / sum(split_proportions) * len(all_artists))
val_count = int(float(split_proportions[1]) / sum(split_proportions) * len(all_artists))
# Use RandomState(seed) so that shuffles with the same set of
# artists will result in the same shuffle on different computers.
np.random.RandomState(seed).shuffle(all_artists)
skipped_count, min_songs, len(valid_artists_in_split)))

if persist:
self.metadata.write('train.csv', '\n'.join(all_artists[:train_count]))
self.metadata.write('val.csv', '\n'.join(all_artists[train_count:train_count+val_count]))
self.metadata.write('test.csv', '\n'.join(all_artists[train_count+val_count:]))
if split == 'train':
artists_in_split = all_artists[:train_count]
elif split == 'val':
artists_in_split = all_artists[train_count:train_count+val_count]
else:
artists_in_split = all_artists[train_count+val_count:]
metadata.write(
artists_file, '\n'.join(valid_artists_in_split))

self.metadata.close()

for artist in artists_in_split:
for artist in valid_artists_in_split:
self.artists.append(ArtistDataset(artist, list(valid_songs[artist])))

def load(self, artist, song):
Expand All @@ -192,11 +177,21 @@ def load(self, artist, song):
artist (str): the name of the artist directory. e.g. `"tool"`
"""
if self.cache and (artist, song) in self.cache_data:
return self.cache_data[(artist, song)]
tokens, seq_lens = self.cache_data[(artist, song)]
else:
tokens, seq_lens = self.loader.load(
os.path.join(self.root, artist, song))
self.cache_data[(artist, song)] = (tokens, seq_lens)

if tokens.ndim == 1:
data = (tokens, seq_lens)
elif tokens.ndim == 2:
idx = self.random.choice(range(np.shape(tokens)[0]), size=1)[0]
data = (tokens[idx], seq_lens[idx])
else:
data = self.loader.load(os.path.join(self.root, artist, song))
self.cache_data[(artist, song)] = data
return data
raise RuntimeError(
"Token matrix has dimension > 2: %d", tokens.ndim)
return data

def __len__(self):
return len(self.artists)
Expand Down
Loading