diff --git a/src/config/lstm_baseline.yaml b/src/config/lstm_baseline.yaml index e6dfee8..354636a 100644 --- a/src/config/lstm_baseline.yaml +++ b/src/config/lstm_baseline.yaml @@ -3,16 +3,17 @@ model_module_name: 'models.lstm_baseline' model_class_name: 'LSTMBaseline' n_train: !!int 60000 -n_decay: !!int 10000 -print_every_n: !!int 1000 -val_every_n: !!float 5000 +print_every_n: !!int 100 +val_every_n: !!float 1000 n_val: !!int 600 n_test: !!int 600 n_samples: !!int 20 -lr: !!float 5e-3 +lr: !!float 1e-3 +patience_iters: !!int 7 max_grad_norm: !!int 5 batch_size: !!int 5 -embedding_size: !!int 250 +embedding_size: !!int 200 n_layers: !!int 1 hidden_size: !!int 200 +use_sentinel: !!bool false diff --git a/src/config/lstm_cont_cache.yaml b/src/config/lstm_cont_cache.yaml new file mode 100644 index 0000000..7d7f817 --- /dev/null +++ b/src/config/lstm_cont_cache.yaml @@ -0,0 +1,23 @@ +name: 'lstm_baseline' +model_module_name: 'models.lstm_cont_cache' +model_class_name: 'LSTMContCache' + +n_train: !!int 60000 +n_decay: !!int 10000 +print_every_n: !!int 1000 +val_every_n: !!float 1000 +n_val: !!int 600 +n_test: !!int 600 +n_samples: !!int 20 + +lr: !!float 1e-3 +patience_iters: !!int 5 +max_grad_norm: !!int 5 +batch_size: !!int 5 +embedding_size: !!int 250 +n_layers: !!int 1 +hidden_size: !!int 200 +subseq_len: !!int 10 + +theta: !!float 0.2 +lambda: !!float 0.05 diff --git a/src/config/lstm_dynamic_eval.yaml b/src/config/lstm_dynamic_eval.yaml new file mode 100644 index 0000000..a01fc67 --- /dev/null +++ b/src/config/lstm_dynamic_eval.yaml @@ -0,0 +1,20 @@ +name: 'lstm_baseline' +model_module_name: 'models.lstm_dynamic_eval' +model_class_name: 'LSTMDynamicEval' + +n_train: !!int 60000 +n_decay: !!int 10000 +print_every_n: !!int 1000 +val_every_n: !!float 1000 +n_val: !!int 600 +n_test: !!int 600 +n_samples: !!int 20 + +lr: !!float 5e-1 +patience_iters: !!int 5 +max_grad_norm: !!int 5 +batch_size: !!int 5 +embedding_size: !!int 250 +n_layers: !!int 1 +hidden_size: !!int 200 +subseq_len: !!int 10 diff --git a/src/config/lstm_enc_dec.yaml b/src/config/lstm_enc_dec.yaml new file mode 100644 index 0000000..524130c --- /dev/null +++ b/src/config/lstm_enc_dec.yaml @@ -0,0 +1,25 @@ +name: 'lstm_seq2seq' +model_module_name: 'models.lstm_enc_dec' +model_class_name: 'LSTMEncDec' + +n_train: !!int 60000 +print_every_n: !!int 100 +val_every_n: !!float 1000 +n_val: !!int 600 +n_test: !!int 600 +n_samples: !!int 20 + +lr: !!float 1e-3 +patience_iters: !!int 7 +max_grad_norm: !!int 5 +batch_size: !!int 5 +embedding_size: !!int 200 +n_layers: !!int 1 +hidden_size: !!int 200 + +stop_grad: !!bool true + +enc_size: !!int 32 +use_sentinel: !!bool false +use_film: !!bool false +decode_support: !!bool false diff --git a/src/config/lstm_maml.yaml b/src/config/lstm_maml.yaml new file mode 100644 index 0000000..d4dcf67 --- /dev/null +++ b/src/config/lstm_maml.yaml @@ -0,0 +1,22 @@ +name: 'lstm_fomaml' +model_module_name: 'models.lstm_maml' +model_class_name: 'LSTMMAML' + +n_train: !!int 60000 +n_decay: !!int 10000 +print_every_n: !!int 1000 +val_every_n: !!float 1000 +n_val: !!int 600 +n_test: !!int 600 +n_samples: !!int 20 + +lr: !!float 1e-1 +meta_lr: !!float 1e-3 +patience_iters: !!int 5 +n_update: !!int 3 +max_grad_norm: !!int 5 +batch_size: !!int 5 +embedding_size: !!int 250 +n_layers: !!int 1 +hidden_size: !!int 200 +stop_grad: !!bool true diff --git a/src/config/lyrics.yaml b/src/config/lyrics.yaml index 9e07f2a..fd9cdc1 100644 --- a/src/config/lyrics.yaml +++ b/src/config/lyrics.yaml @@ -1,4 +1,8 @@ dataset: 'lyrics' dataset_path: '../raw-data/lyrics/lyrics_data/' splits: ['train', 'val', 'test'] +word_min_times: 5 +max_unk_percent: 0.05 +min_len: 10 max_len: 50 +eval_len: 50 diff --git a/src/config/midi.yaml b/src/config/midi.yaml index 8eb33aa..1162e25 100644 --- a/src/config/midi.yaml +++ b/src/config/midi.yaml @@ -1,4 +1,4 @@ dataset: 'midi' dataset_path: '../raw-data/freemidi/freemidi_data/' splits: ['train', 'val', 'test'] -max_len: 50 +max_len: 50 # [50, 250, 500] diff --git a/src/data/base_loader.py b/src/data/base_loader.py index 4b27cc6..33846d0 100644 --- a/src/data/base_loader.py +++ b/src/data/base_loader.py @@ -12,8 +12,11 @@ class Loader(object): """A class for turning data into a sequence of tokens. """ - def __init__(self, max_len, dtype=np.int32, persist=True): + def __init__(self, min_len, max_len, max_unk_pecent, + dtype=np.int32, persist=True): + self.min_len = min_len self.max_len = max_len + self.max_unk_percent = max_unk_pecent self.dtype = dtype self.persist = persist @@ -34,8 +37,12 @@ def get_num_tokens(self): def validate(self, filepath): try: - self.load(filepath) - return True + # Must have at least one valid stanza for whole song to be valid + np_tokens, _ = self.load(filepath) + if np.shape(np_tokens)[0] > 0: + return True + else: + return False except OSError: return False except KeyError: @@ -50,15 +57,37 @@ def validate(self, filepath): return False def load(self, filepath): - npfile = '%s.%s.npy' % (filepath, self.max_len) + npfile = '%s.%s.npz' % (filepath, self.max_len) if self.persist and os.path.isfile(npfile): - return np.load(npfile).astype(self.dtype) + data = np.load(npfile) + numpy_tokens = data['tokens'].astype(self.dtype) + numpy_seq_lens = data['seq_lens'].astype(self.dtype) else: data = self.read(filepath) - tokens = self.tokenize(data) - numpy_tokens = np.zeros(self.max_len, dtype=self.dtype) - for token_index in range(min(self.max_len, len(tokens))): - numpy_tokens[token_index] = tokens[token_index] + all_tokens = self.tokenize(data) + # Filter stanzas + # Keep all stanzas that are >= min length in length + all_tokens = list( + filter(lambda x: len(x) >= self.min_len, all_tokens)) + # Keep all stanzas that are <= max length in length + all_tokens = list( + filter(lambda x: len(x) <= self.max_len, all_tokens)) + # Keep all stanzas that have < max unk% of tokens + all_tokens = list(filter( + lambda x: self.get_unk_percent(x) < self.max_unk_percent, + all_tokens)) + + n_stanzas = len(all_tokens) + numpy_tokens = np.zeros( + (n_stanzas, self.max_len), dtype=self.dtype) + numpy_seq_lens = np.array(list( + map(lambda x: len(x), all_tokens) + )) + for i, stanza_tokens in enumerate(all_tokens): + for j in range(min(self.max_len, len(stanza_tokens))): + numpy_tokens[i][j] = all_tokens[i][j] + if self.persist: - np.save(npfile, numpy_tokens) - return numpy_tokens + np.savez(npfile, tokens=numpy_tokens, seq_lens=numpy_seq_lens) + + return numpy_tokens, numpy_seq_lens diff --git a/src/data/dataset.py b/src/data/dataset.py index 5bd97d1..6f42637 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -7,21 +7,20 @@ import multiprocessing import itertools try: - from urllib import quote, unquote # python 2 + from urllib import quote, unquote # python 2 except ImportError: - from urllib.parse import quote, unquote # python 3 - + from urllib.parse import quote, unquote # python 3 import numpy as np +from data.lib import get_random log = logging.getLogger('few-shot') logging.basicConfig(level=logging.INFO) class Metadata(object): - """An object for tracking the metadata associated with a configuration of - the sampler. - """ + """Track the metadata associated with a configuration of the sampler.""" + def __init__(self, root, name): self.dir = os.path.join(root, name) self.open_files = {} @@ -38,7 +37,8 @@ def lines(self, filename): def write(self, filename, line): if filename not in self.open_files: - self.open_files[filename] = open(os.path.join(self.dir, filename), 'a') + self.open_files[filename] = open( + os.path.join(self.dir, filename), 'a') self.open_files[filename].write(line) def close(self): @@ -82,25 +82,28 @@ class Dataset(object): seed (int or None): the random seed which is used for shuffling the artists. """ - def __init__(self, root, split, loader, metadata, split_proportions=(8,1,1), - persist=True, cache=True, validate=True, min_songs=0, parallel=False, - valid_songs_file='valid_songs.csv', seed=None): + + def __init__(self, root, loader, metadata, artists_in_split, seed=None, + persist=True, cache=True, validate=True, min_songs=0, + artists_file='artists.csv', + valid_songs_file='valid_songs.csv'): self.root = root self.cache = cache self.cache_data = {} self.loader = loader self.metadata = metadata + self.random = get_random(seed) self.artists = [] self.valid_songs_file = valid_songs_file valid_songs = {} - artist_in_split = [] + valid_artists_in_split = [] # If we're both validating and using persistence, load any validation # data from disk. The format of the validation file is just a CSV # with two entries: artist and song. The artist is the name of the # artist (i.e. the directory (e.g. 'K_s Choice')) and the song is # the song file name (e.g. 'ironflowers.mid'). - if validate and persist: + if validate and persist and self.metadata.exists(valid_songs_file): for line in self.metadata.lines(valid_songs_file): artist, song = line.rstrip('\n').split(',', 1) artist = unquote(artist) @@ -109,20 +112,15 @@ def __init__(self, root, split, loader, metadata, split_proportions=(8,1,1), valid_songs[artist] = set() valid_songs[artist].add(song) - if persist and self.metadata.exists('%s.csv' % split): - artists_in_split = [] - for line in self.metadata.lines('%s.csv' % split): - artists_in_split.append(line.rstrip('\n')) + valid_artists_in_split = artists_in_split else: dirs = [] - all_artists = [] skipped_count = 0 - pool = multiprocessing.Pool(multiprocessing.cpu_count()) - for artist in os.listdir(root): + for artist in artists_in_split: if os.path.isdir(os.path.join(root, artist)): songs = os.listdir(os.path.join(root, artist)) - songs = [song for song in songs if loader.is_song(song)] + songs = [s for s in songs if loader.is_song(s)] if len(songs) > 0: dirs.append(artist) @@ -130,58 +128,45 @@ def __init__(self, root, split, loader, metadata, split_proportions=(8,1,1), progress_logger = ProgressLogger(num_dirs) for artist_index, artist in enumerate(dirs): + # log.info("Processing artist %s" % artist) songs = os.listdir(os.path.join(root, artist)) # We only want .txt and .mid files. Filter all others. - songs = [song for song in songs if loader.is_song(song)] + songs = [s for s in songs if loader.is_song(s)] # populate `valid_songs[artist]` if validate: progress_logger.maybe_log(artist_index) if artist not in valid_songs: valid_songs[artist] = set() - songs_to_validate = [song for song in songs if song not in valid_songs[artist]] - song_files = [os.path.join(root, artist, song) for song in songs_to_validate] - if parallel: - mapped = pool.map(loader.validate, song_files) - else: - mapped = map(loader.validate, song_files) + songs_to_validate = [s for s in songs if s not in valid_songs[artist]] + song_files = [os.path.join(root, artist, s) for s in songs_to_validate] + mapped = map(loader.validate, song_files) validated = itertools.compress(songs_to_validate, mapped) for song in validated: - song_file = os.path.join(root, artist, song) + # song_file = os.path.join(root, artist, song) if persist: line = '%s,%s\n' % (quote(artist), quote(song)) self.metadata.write(self.valid_songs_file, line) + # log.info("Validated song %s" % song) valid_songs[artist].add(song) else: valid_songs[artist] = set(songs) if len(valid_songs[artist]) >= min_songs: - all_artists.append(artist) + valid_artists_in_split.append(artist) else: skipped_count += 1 - pool.close() - pool.join() + if skipped_count > 0: log.info("%s artists don't have K+K'=%s songs. Using %s artists" % ( - skipped_count, min_songs, len(all_artists))) - train_count = int(float(split_proportions[0]) / sum(split_proportions) * len(all_artists)) - val_count = int(float(split_proportions[1]) / sum(split_proportions) * len(all_artists)) - # Use RandomState(seed) so that shuffles with the same set of - # artists will result in the same shuffle on different computers. - np.random.RandomState(seed).shuffle(all_artists) + skipped_count, min_songs, len(valid_artists_in_split))) + if persist: - self.metadata.write('train.csv', '\n'.join(all_artists[:train_count])) - self.metadata.write('val.csv', '\n'.join(all_artists[train_count:train_count+val_count])) - self.metadata.write('test.csv', '\n'.join(all_artists[train_count+val_count:])) - if split == 'train': - artists_in_split = all_artists[:train_count] - elif split == 'val': - artists_in_split = all_artists[train_count:train_count+val_count] - else: - artists_in_split = all_artists[train_count+val_count:] + metadata.write( + artists_file, '\n'.join(valid_artists_in_split)) self.metadata.close() - for artist in artists_in_split: + for artist in valid_artists_in_split: self.artists.append(ArtistDataset(artist, list(valid_songs[artist]))) def load(self, artist, song): @@ -192,11 +177,21 @@ def load(self, artist, song): artist (str): the name of the artist directory. e.g. `"tool"` """ if self.cache and (artist, song) in self.cache_data: - return self.cache_data[(artist, song)] + tokens, seq_lens = self.cache_data[(artist, song)] + else: + tokens, seq_lens = self.loader.load( + os.path.join(self.root, artist, song)) + self.cache_data[(artist, song)] = (tokens, seq_lens) + + if tokens.ndim == 1: + data = (tokens, seq_lens) + elif tokens.ndim == 2: + idx = self.random.choice(range(np.shape(tokens)[0]), size=1)[0] + data = (tokens[idx], seq_lens[idx]) else: - data = self.loader.load(os.path.join(self.root, artist, song)) - self.cache_data[(artist, song)] = data - return data + raise RuntimeError( + "Token matrix has dimension > 2: %d", tokens.ndim) + return data def __len__(self): return len(self.artists) diff --git a/src/data/episode.py b/src/data/episode.py index 8cf053b..d507eef 100644 --- a/src/data/episode.py +++ b/src/data/episode.py @@ -1,21 +1,30 @@ #!/usr/bin/python3 import os -import time import logging import yaml - import numpy as np -from numpy.random import RandomState from data.midi_loader import MIDILoader from data.lyrics_loader import LyricsLoader from data.dataset import Dataset, Metadata +from data.lib import get_random + +log = logging.getLogger('few-shot') +logging.basicConfig(level=logging.INFO) class Episode(object): - def __init__(self, support, query): + def __init__(self, support, support_seq_len, query, query_seq_len, + other_query=None, other_query_seq_len=None, + metadata_support=None, metadata_query=None): self.support = support + self.support_seq_len = support_seq_len self.query = query + self.query_seq_len = query_seq_len + self.other_query = other_query + self.other_query_seq_len = other_query_seq_len + self.metadata_support = metadata_support + self.metadata_query = metadata_query class SQSampler(object): @@ -34,12 +43,23 @@ def __init__(self, support_size, query_size, random): def sample(self, artist): sample = self.random.choice( artist, - size=self.support_size+self.query_size, + size=self.support_size + self.query_size, replace=False) query = sample[:self.query_size] support = sample[self.query_size:] return query, support + def sample_from_artists(self, artists, n): + ret = [] + for artist in artists: + sample = self.random.choice( + artist, + size=n, + replace=False) + ret += sample.tolist() + + return ret + class EpisodeSampler(object): def __init__(self, dataset, batch_size, support_size, query_size, max_len, @@ -50,8 +70,10 @@ def __init__(self, dataset, batch_size, support_size, query_size, max_len, self.query_size = query_size self.max_len = max_len self.dtype = dtype - self.random = get_random(seed) + self.seed = seed + self.random = get_random(self.seed) self.sq_sampler = SQSampler(support_size, query_size, self.random) + self.dataset.random = self.random def __len__(self): return len(self.data) @@ -59,28 +81,187 @@ def __len__(self): def __repr__(self): return 'EpisodeSampler("%s", "%s")' % (self.root, self.split) + def reset_seed(self): + self.random = get_random(self.seed) + self.sq_sampler.random = self.random + self.dataset.random = self.random + + def get_artists_episode(self, artists): + batch_size = len(artists) + support = np.zeros( + (batch_size, self.support_size, self.max_len), dtype=self.dtype) + support_seq_len = np.zeros( + (batch_size, self.support_size), dtype=self.dtype) + query = np.zeros( + (batch_size, self.query_size, self.max_len), dtype=self.dtype) + query_seq_len = np.zeros( + (batch_size, self.query_size), dtype=self.dtype) + + metadata_support = {} + metadata_query = {} + for batch_index, artist in enumerate(artists): + query_songs, support_songs = self.sq_sampler.sample(artist) + metadata_support[artist.name] = support_songs.tolist() + metadata_query[artist.name] = query_songs.tolist() + + for support_index, song in enumerate(support_songs): + parsed_song, parsed_len = self.dataset.load(artist.name, song) + support[batch_index, support_index, :] = parsed_song + support_seq_len[batch_index, support_index] = parsed_len + for query_index, song in enumerate(query_songs): + parsed_song, parsed_len = self.dataset.load(artist.name, song) + query[batch_index, query_index, :] = parsed_song + query_seq_len[batch_index, query_index] = parsed_len + + return Episode(support, support_seq_len, query, query_seq_len, + metadata_support=metadata_support, + metadata_query=metadata_query) + def get_episode(self): - support = np.zeros((self.batch_size, self.support_size, self.max_len), dtype=self.dtype) - query = np.zeros((self.batch_size, self.query_size, self.max_len), dtype=self.dtype) - artists = self.random.choice(self.dataset, size=self.batch_size, replace=False) + support = np.zeros( + (self.batch_size, self.support_size, self.max_len), + dtype=self.dtype) + support_seq_len = np.zeros( + (self.batch_size, self.support_size), dtype=self.dtype) + query = np.zeros( + (self.batch_size, self.query_size, self.max_len), dtype=self.dtype) + query_seq_len = np.zeros( + (self.batch_size, self.query_size), dtype=self.dtype) + artists = self.random.choice( + self.dataset, size=self.batch_size, replace=False) + + metadata_support = {} + metadata_query = {} + for batch_index, artist in enumerate(artists): + query_songs, support_songs = self.sq_sampler.sample(artist) + metadata_support[artist.name] = support_songs.tolist() + metadata_query[artist.name] = query_songs.tolist() + + for support_index, song in enumerate(support_songs): + parsed_song, parsed_len = self.dataset.load(artist.name, song) + support[batch_index, support_index, :] = parsed_song + support_seq_len[batch_index, support_index] = parsed_len + for query_index, song in enumerate(query_songs): + parsed_song, parsed_len = self.dataset.load(artist.name, song) + query[batch_index, query_index, :] = parsed_song + query_seq_len[batch_index, query_index] = parsed_len + + return Episode(support, support_seq_len, query, query_seq_len, + metadata_support=metadata_support, + metadata_query=metadata_query) + + def get_episode_with_other_artists(self): + support = np.zeros( + (self.batch_size, self.support_size, self.max_len), + dtype=self.dtype) + support_seq_len = np.zeros( + (self.batch_size, self.support_size), dtype=self.dtype) + query = np.zeros( + (self.batch_size, self.query_size, self.max_len), dtype=self.dtype) + query_seq_len = np.zeros( + (self.batch_size, self.query_size), dtype=self.dtype) + other_query = np.zeros( + (self.batch_size, self.query_size, self.max_len), dtype=self.dtype) + other_query_seq_len = np.zeros( + (self.batch_size, self.query_size), dtype=self.dtype) + artists = self.random.choice( + self.dataset, size=self.batch_size, replace=False) + + metadata_support = {} + metadata_query = {} for batch_index, artist in enumerate(artists): query_songs, support_songs = self.sq_sampler.sample(artist) + metadata_support[artist.name] = support_songs.tolist() + metadata_query[artist.name] = query_songs.tolist() + for support_index, song in enumerate(support_songs): - parsed_song = self.dataset.load(artist.name, song) - support[batch_index,support_index,:] = parsed_song + parsed_song, parsed_len = self.dataset.load(artist.name, song) + support[batch_index, support_index, :] = parsed_song + support_seq_len[batch_index, support_index] = parsed_len for query_index, song in enumerate(query_songs): - parsed_song = self.dataset.load(artist.name, song) - query[batch_index,query_index,:] = parsed_song - return Episode(support, query) + parsed_song, parsed_len = self.dataset.load(artist.name, song) + query[batch_index, query_index, :] = parsed_song + query_seq_len[batch_index, query_index] = parsed_len + + other_artists = self.random.choice( + self.dataset, size=self.query_size + 1, replace=False) + other_artists = other_artists.tolist() + if artist in other_artists: + other_artists.remove(artist) + + other_artists = other_artists[:self.query_size] + other_songs = self.sq_sampler.sample_from_artists(other_artists, 1) + + other_artists_and_songs = zip(other_artists, other_songs) + for index, (other_artist, song) in enumerate(other_artists_and_songs): + parsed_song, parsed_len = self.dataset.load( + other_artist.name, song) + other_query[batch_index, index, :] = parsed_song + other_query_seq_len[batch_index, index] = parsed_len + + return Episode(support, support_seq_len, query, query_seq_len, + other_query, other_query_seq_len, + metadata_support, metadata_query) def get_num_unique_words(self): return self.dataset.loader.get_num_tokens() + def get_unk_token(self): + return self.dataset.loader.get_unk_token() + + def get_start_token(self): + return self.dataset.loader.get_start_token() + + def get_stop_token(self): + return self.dataset.loader.get_stop_token() + def detokenize(self, numpy_data): return self.dataset.loader.detokenize(numpy_data) -def load_sampler_from_config(config): - """Create an EpisodeSampler from a yaml config.""" + +def create_split(root, loader, metadata, seed, + split_proportions=(8, 1, 1), persist=True): + train_exists = metadata.exists('train.csv') + val_exists = metadata.exists('test.csv') + test_exists = metadata.exists('val.csv') + + if train_exists and val_exists and test_exists: + artist_splits = {} + for split in ['train', 'test', 'val']: + artist_splits[split] = [] + for line in metadata.lines('%s.csv' % split): + artist_splits[split].append(line.rstrip('\n')) + + return artist_splits + + all_artists = [] + for artist in os.listdir(root): + if os.path.isdir(os.path.join(root, artist)): + songs = os.listdir(os.path.join(root, artist)) + songs = [s for s in songs if loader.is_song(s)] + if len(songs) > 0: + all_artists.append(artist) + + train_count = int(float(split_proportions[0]) / + sum(split_proportions) * len(all_artists)) + val_count = int(float(split_proportions[1]) / + sum(split_proportions) * len(all_artists)) + + # Use RandomState(seed) so that shuffles with the same set of + # artists will result in the same shuffle on different computers. + np.random.RandomState(seed).shuffle(all_artists) + split_train = all_artists[:train_count] + split_val = all_artists[train_count:train_count + val_count] + split_test = all_artists[train_count + val_count:] + + return { + 'train': split_train, + 'val': split_val, + 'test': split_test + } + + +def load_all_samplers_from_config(config): if isinstance(config, str): config = yaml.load(open(config, 'r')) elif isinstance(config, dict): @@ -92,53 +273,87 @@ def load_sampler_from_config(config): 'query_size', 'support_size', 'batch_size', + 'min_len', 'max_len', 'dataset', - 'split' - ] - optional_keys = [ - 'train_proportion', - 'val_proportion', - 'test_proportion', - 'persist', - 'cache', - 'seed', - 'dataset_seed' ] for key in required_keys: if key not in config: raise RuntimeError('required config key "%s" not found' % key) - props = ( - config.get('train_proportion', 8), - config.get('val_proportion', 1), - config.get('test_proportion', 1) - ) + root = config['dataset_path'] if not os.path.isdir(root): raise RuntimeError('required data directory %s does not exist' % root) - metadata_dir = 'few_shot_metadata_%s_%s' % (config['dataset'], config['max_len']) + metadata_dir = 'few_shot_metadata_%s_%s' % (config['dataset'], + config['max_len']) metadata = Metadata(root, metadata_dir) + if config['dataset'] == 'lyrics': - loader = LyricsLoader(config['max_len'], metadata=metadata) - parallel = False + loader = LyricsLoader( + config['min_len'], config['max_len'], config['max_unk_percent'], + metadata=metadata, persist=False) elif config['dataset'] == 'midi': loader = MIDILoader(config['max_len']) - parallel = False else: raise RuntimeError('unknown dataset "%s"' % config['dataset']) + artist_splits = create_split( + root, loader, metadata, seed=config.get('dataset_seed', 0)) + + if not loader.read_vocab(): + log.info("Building vocabulary using train") + config['split'] = 'train' + config['validate'] = True + config['persist'] = False + config['cache'] = False + load_sampler_from_config(config, metadata, artist_splits, loader) + loader.prune(config['word_min_times']) + log.info("Vocabulary pruned!") + + loader.persist = True + episode_sampler = {} + for split in config['splits']: + log.info("Encoding %s split using pruned vocabulary" % split) + config['split'] = split + config['validate'] = True + config['persist'] = True + config['cache'] = True + episode_sampler[split] = load_sampler_from_config( + config, metadata, artist_splits, loader) + + return episode_sampler + + +def load_sampler_from_config(config, metadata, artist_splits, loader=None): + """Create an EpisodeSampler from a yaml config.""" + # Force batch_size of 1 for evaluation + if config['split'] in ['val', 'test']: + config['batch_size'] = 1 + + root = config['dataset_path'] + if not os.path.isdir(root): + raise RuntimeError('required data directory %s does not exist' % root) + + if loader is None: + if config['dataset'] == 'lyrics': + loader = LyricsLoader(config['max_len'], metadata=metadata) + elif config['dataset'] == 'midi': + loader = MIDILoader(config['max_len']) + else: + raise RuntimeError('unknown dataset "%s"' % config['dataset']) + dataset = Dataset( root, - config['split'], loader, metadata, - split_proportions=props, + artist_splits[config['split']], + seed=config.get('seed', None), cache=config.get('cache', True), persist=config.get('persist', True), validate=config.get('validate', True), - min_songs=config['support_size']+config['query_size'], - parallel=parallel, - seed=config.get('dataset_seed', 0) + min_songs=config['support_size'] + config['query_size'], + artists_file='%s.csv' % config['split'], + valid_songs_file='valid_songs_%s.csv' % config['split'] ) return EpisodeSampler( dataset, @@ -147,10 +362,3 @@ def load_sampler_from_config(config): config['query_size'], config['max_len'], seed=config.get('seed', None)) - - -def get_random(seed): - if seed is not None: - return RandomState(seed) - else: - return np.random diff --git a/src/data/lib.py b/src/data/lib.py new file mode 100644 index 0000000..4c8bfad --- /dev/null +++ b/src/data/lib.py @@ -0,0 +1,9 @@ +import numpy as np +from numpy.random import RandomState + + +def get_random(seed): + if seed is not None: + return RandomState(seed) + else: + return np.random diff --git a/src/data/lyrics_loader.py b/src/data/lyrics_loader.py index 338b493..cdf319e 100644 --- a/src/data/lyrics_loader.py +++ b/src/data/lyrics_loader.py @@ -7,6 +7,7 @@ import nltk import numpy as np +import re from data.base_loader import Loader @@ -14,6 +15,35 @@ log = logging.getLogger("few-shot") +def custom_tokenizer(str): + """Return list of stanzas where each stanza contains list of tokens.""" + str = str.lower() + # Remove lines indicating anything in brackets, like "[chorus:]" + str = re.sub("[\[].*?[\]]", "", str) + # Handle windows new spaces + str = str.replace("\r", "\n") + # Remove apostrophe's at end and beginning of words, like "fallin'" + str = re.sub(r"'([^A-Za-z])", r"\1", str) + str = re.sub(r"([^A-Za-z])'", r"\1", str) + + tokens = [] + for stanza in str.split("\n\n"): + t = [] + for sent in stanza.split("\n"): + s = nltk.word_tokenize(sent) + if len(s) > 0: + if s[-1] in [',', '.']: + del s[-1] + + t.extend(s) + t.append("\n".encode('unicode_escape')) + + if len(t) > 0: + tokens.append(t) + + return tokens + + class LyricsLoader(Loader): """Objects of this class parse lyrics files and persist word IDs. @@ -28,16 +58,50 @@ class LyricsLoader(Loader): to a file. If the file already exists, the tokenizer will bootstrap from the file. """ - def __init__(self, max_len, metadata, tokenizer=nltk.word_tokenize, - persist=True, dtype=np.int32): - super(LyricsLoader, self).__init__(max_len, dtype=dtype) + + def __init__(self, min_len, max_len, max_unk_percent, + metadata, tokenizer=custom_tokenizer, + persist=True, dtype=np.int32): + super(LyricsLoader, self).__init__( + min_len, max_len, max_unk_percent, dtype=dtype) self.tokenizer = tokenizer self.metadata = metadata + self.persist = persist + self.pruned = False + self.word_to_id = {} + self.word_to_cnt = {} self.id_to_word = {} self.highest_word_id = -1 - # read persisted word ids - if persist: + + # read persisted word ids (if they exist) + self.read_vocab() + + def is_song(self, filepath): + return filepath.endswith('.txt') + + def prune(self, n): + log.info('Vocab size before pruning: %d' % len(self.word_to_cnt.keys())) + # Delete words that have less than n frequency + for word in self.word_to_cnt.keys(): + if self.word_to_cnt[word] < n: + del self.word_to_cnt[word] + + # Recreate word_to_id and id_to_word based on pruned dictionary + self.word_to_id = {} + self.id_to_word = {} + for i, word in enumerate(self.word_to_cnt.keys()): + self.word_to_id[word] = i + self.id_to_word[i] = word + + self.highest_word_id = len(self.word_to_cnt.keys()) - 1 + self.pruned = True + + log.info('Vocab size after pruning: %d' % len(self.word_to_cnt.keys())) + self.write_vocab() + + def read_vocab(self): + if self.metadata.exists('word_ids.csv'): log.info('Loading lyrics metadata...') for line in self.metadata.lines('word_ids.csv'): row = line.rstrip('\n').split(',', 1) @@ -47,8 +111,18 @@ def __init__(self, max_len, metadata, tokenizer=nltk.word_tokenize, if word_id > self.highest_word_id: self.highest_word_id = word_id - def is_song(self, filepath): - return filepath.endswith('.txt') + self.pruned = True + return True + else: + return False + + def write_vocab(self): + log.info('writing lyrics metadata...') + for word in self.word_to_id: + self.metadata.write( + 'word_ids.csv', + '%s,%s\n' % (self.word_to_id[word], word) + ) def read(self, filepath): """Read a file. @@ -57,39 +131,84 @@ def read(self, filepath): filepath (str): path to the lyrics file. e.g. "/home/user/lyrics_data/tool/lateralus.txt" """ - return ''.join(codecs.open(filepath, 'r', errors='ignore').readlines()) + return ''.join(codecs.open(filepath, 'U', errors='ignore').readlines()) def get_num_tokens(self): + # +3 because: + # (self.highest_word_id + 1) is unknown token + # (self.highest_word_id + 2) is stop token + # (self.highest_word_id + 3) is start token (but we don't include this + # because we never predict start token) + return self.highest_word_id + 3 + + def get_unk_token(self): + # unk token is self.highest_word_id + 1 return self.highest_word_id + 1 + def get_stop_token(self): + # stop token is self.highest_word_id + 2 + return self.highest_word_id + 2 + + def get_start_token(self): + # start token is self.highest_word_id + 3 + return self.highest_word_id + 3 + + def get_unk_percent(self, list_of_tokens): + unk_token = self.get_unk_token() + num_unk = len(list(filter(lambda x: x == unk_token, list_of_tokens))) + return float(num_unk) / len(list_of_tokens) + def tokenize(self, raw_lyrics): - """Turns a string of lyrics data into a numpy array of int "word" IDs. + """Turn a string of lyrics data into a numpy array of int "word" IDs. Arguments: raw_lyrics (str): Stringified lyrics data """ - tokens = [] - for token in self.tokenizer(raw_lyrics): - if token not in self.word_to_id: - self.highest_word_id += 1 - self.word_to_id[token] = self.highest_word_id - self.id_to_word[self.highest_word_id] = token - if self.persist: - self.metadata.write( - 'word_ids.csv', - '%s,%s\n' % (self.highest_word_id, token) - ) - tokens.append(self.word_to_id[token]) - return tokens + all_tokens = [] + for stanza in self.tokenizer(raw_lyrics): + + stanza_tokens = [] + for token in stanza: + if not self.pruned: + if token not in self.word_to_id: + self.highest_word_id += 1 + self.word_to_id[token] = self.highest_word_id + self.word_to_cnt[token] = 0 + self.id_to_word[self.highest_word_id] = token + else: + self.word_to_cnt[token] += 1 + + token_value = self.word_to_id[token] + else: + if token not in self.word_to_id: + token_value = self.get_unk_token() + else: + token_value = self.word_to_id[token] + + stanza_tokens.append(token_value) + + all_tokens.append(stanza_tokens) + + return all_tokens def detokenize(self, numpy_data): ret = '' for token in numpy_data: - word = self.id_to_word[token] - if word == "n't": - ret += word - elif word not in string.punctuation and not word.startswith("'"): - ret += " " + word + if token == self.get_stop_token(): + ret += " [stop]" + break + elif token == self.get_start_token(): + ret += "[start]" + elif token == self.get_unk_token(): + ret += " [unk]" else: - ret += word + word = self.id_to_word[token] + if word == "n't": + ret += word + elif word not in string.punctuation and not word.startswith("'"): + ret += " " + word + elif word == "\n".encode('unicode_escape'): + ret += "\n" + else: + ret += word return "".join(ret).strip() diff --git a/src/evaluation/sampler.py b/src/evaluation/sampler.py new file mode 100644 index 0000000..1adb7ec --- /dev/null +++ b/src/evaluation/sampler.py @@ -0,0 +1,15 @@ +import numpy as np + + +class Sampler(object): + + def __init__(self, _type='random'): + self._type = _type + + def sample(self, p): + if self._type == 'random': + return np.random.choice(len(p), p=p) + elif self._type == 'argmax': + return np.argmax(p) + else: + raise ValueError('Sample type %s not recognized' % self._type) diff --git a/src/models/base_model.py b/src/models/base_model.py index b553ae3..eb68de9 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -60,27 +60,113 @@ def flatten_first_two_dims(token_array): return np.reshape(token_array, (shape[0] * shape[1], shape[2])) -def convert_tokens_to_input_and_target(token_array, start_word=None): - """Convert token_array to input and target to use for model for +def convert_tokens_to_input_and_target(token_array, seq_len_array, + start_word, end_word, + flatten_batch=True, x_and_y=True): + if token_array.ndim == 3 and flatten_batch: + X = flatten_first_two_dims(token_array) + else: + X = token_array + + if x_and_y: + if X.ndim == 2: + # Y is X + end_word but have to add an extra column to Y + # make sure end_word correctly ends each sequence + # in case of maximal sequence length + Y = np.copy(X) + n_rows = np.shape(X)[0] + extra_column = np.zeros((n_rows, 1)) + Y = np.concatenate([Y, extra_column], axis=1) + Y[range(n_rows), seq_len_array.flatten()] = end_word + + # X_new is start_word + X + start_word_column = np.full( + shape=[n_rows, 1], fill_value=start_word) + X_new = np.concatenate([start_word_column, X], axis=1) + elif X.ndim == 3: + Y = np.copy(X) + n_rows, n_cols = np.shape(X)[0], np.shape(X)[1] + extra_column = np.zeros((n_rows, n_cols, 1)) + Y = np.concatenate([Y, extra_column], axis=2) + indices_arr = np.indices((n_rows, n_cols)) + seq_len_array = seq_len_array.flatten() + idx1_flat = indices_arr[0].flatten() + idx2_flat = indices_arr[1].flatten() + Y[idx1_flat, idx2_flat, seq_len_array] = end_word + + start_word_column = np.full( + shape=[n_rows, n_cols, 1], fill_value=start_word) + X_new = np.concatenate([start_word_column, X], axis=2) + + return X_new, Y + """ + else: + if X.ndim == 3: + n_rows, n_cols = np.shape(X)[0], np.shape(X)[1] + start_word_column = np.full( + shape=[n_rows, n_cols, 1], fill_value=start_word) + X_new = np.concatenate([start_word_column, X], axis=2) + + Y = np.copy(X_new) + x_last_column = np.expand_dims(X[:, :, -1], 2) + # print(Y.shape) + # print(x_last_column.shape) + Y = np.concatenate([Y, x_last_column], axis=2) + + n_rows, n_cols = np.shape(X)[0], np.shape(X)[1] + indices_arr = np.indices((n_rows, n_cols)) + seq_len_array = (seq_len_array + 1).flatten() + idx1_flat = indices_arr[0].flatten() + idx2_flat = indices_arr[1].flatten() + Y[idx1_flat, idx2_flat, seq_len_array] = end_word + + return Y + """ + +""" +def convert_tokens_to_input_and_target(token_array, start_word=None, + flatten_batch=True): + Convert token_array to input and target to use for model for sequence generation. If start_word is given, add to start of each sequence of tokens. Input is token_array without last item; Target is token_array without first item. Arguments: - token_array (numpy int array): tokens array of size [B,S,N] where - B is batch_size, S is number of songs, N is size of the song + token_array (numpy int array): tokens array of size [B,S,N] or [S, N] + where B is batch_size, S is number of songs, N is size of the song start_word (int): token to use for start word - """ - X = flatten_first_two_dims(token_array) - - if start_word is None: - Y = np.copy(X[:, 1:]) - X_new = X[:, :-1] + flatten_batch (boolean): whether to flatten along the batch dimension + Returns: + X_new (numpy int array): input tokens of size either + a. [B*S,N] if token_array is [B,S,N] and flatten_batch=True + b. [B,S,N] if token_array is [B,S,N] and flatten_batch=False + c. [S,N] if token_array is [S,N] + Y (numpy int array): output tokens of same size as X_new + + if token_array.ndim == 3 and flatten_batch: + X = flatten_first_two_dims(token_array) else: - Y = np.copy(X) - start_word_column = np.full( - shape=[np.shape(X)[0], 1], fill_value=start_word) - X_new = np.concatenate([start_word_column, X[:, :-1]], axis=1) + X = token_array + + if X.ndim == 2: + if start_word is None: + Y = np.copy(X[:, 1:]) + X_new = X[:, :-1] + else: + Y = np.copy(X) + start_word_column = np.full( + shape=[np.shape(X)[0], 1], fill_value=start_word) + X_new = np.concatenate([start_word_column, X[:, :-1]], axis=1) + elif X.ndim == 3: + if start_word is None: + Y = np.copy(X[:, :, 1:]) + X_new = X[:, :, :-1] + else: + Y = np.copy(X) + start_word_column = np.full( + shape=[np.shape(X)[0], np.shape(X)[1], 1], fill_value=start_word) + X_new = np.concatenate([start_word_column, X[:, :, :-1]], axis=2) return X_new, Y +""" diff --git a/src/models/fast_weights_lstm.py b/src/models/fast_weights_lstm.py new file mode 100644 index 0000000..07873f3 --- /dev/null +++ b/src/models/fast_weights_lstm.py @@ -0,0 +1,134 @@ +import tensorflow as tf +import numpy as np + +from models.tf_model import TFModel +from models.lstm_cell import LSTMCell, sep + + +class FastWeightsLSTM(TFModel): + """Defines functions for lstm-models that work via a fast weights mechanism.""" + + def _define_placeholders(self): + self._support_batch_size = tf.placeholder(tf.int32, shape=()) + self._query_batch_size = tf.placeholder(tf.int32, shape=()) + self._support_seq_length = tf.placeholder(tf.int32, [None]) + self._query_seq_length = tf.placeholder(tf.int32, [None]) + self._supportX = tf.placeholder( + tf.int32, [None, None, self._time_steps]) + self._supportY = tf.placeholder( + tf.int32, [None, None, self._time_steps]) + self._queryX = tf.placeholder( + tf.int32, [None, None, self._time_steps]) + self._queryY = tf.placeholder( + tf.int32, [None, None, self._time_steps]) + + def _build_weights(self): + """Contruct and return all weights for LSTM.""" + weights = {} + + embedding = tf.get_variable( + self._embedding_var_name, [self._input_size, self._embd_size]) + weights[self._embedding_var_name] = embedding + + for i in range(self._n_layers): + weights.update(self._build_cell_weights(i)) + + softmax_w = tf.get_variable( + 'softmax_w', [self._hidden_size, self._input_size]) + softmax_b = tf.get_variable('softmax_b', [self._input_size]) + weights['softmax_w'] = softmax_w + weights['softmax_b'] = softmax_b + return weights + + def _build_cell_weights(self, n): + """Construct and return all weights for single LSTM cell.""" + weights = {} + n = str(n) + + vocabulary_size = self._embd_size + n_units = self._hidden_size + weights[sep(n, 'kernel')] = tf.get_variable( + sep(n, 'kernel'), shape=[vocabulary_size + n_units, 4 * n_units]) + weights[sep(n, 'bias')] = tf.get_variable( + sep(n, 'bias'), + shape=[4 * n_units], + initializer=tf.constant_initializer(0.0)) + + return weights + + def _model(self, weights, X, batch_size, seq_length, initial_state=None): + """LSTM model that accepts dynamic weights. + + Arguments: + weights: (tensor) weights for LSTM + X: input sequences for LSTM of size [B, N] where B is batch_size + and N is length of sequences + batch_size: batch_size of input + seq_length: list of size batch_size indicating sequence length of + each sequence + initial_state: (optional) intial state to begin LSTM processing + Returns: + logits: (tensor) output logits of size [B, N, C] where C is number + of outputs classes + hidden_states: (tensor) hidden states of LSTM of size [B, N, H] + where H is hidden state size + final_state: (tensor) final hidden state of LSTM of size [B, H] + """ + def make_cell(n): + return LSTMCell( + n, self._embd_size, self._hidden_size, weights=weights) + + embedding = weights[self._embedding_var_name] + inputs = tf.nn.embedding_lookup(embedding, X) + + cell = tf.contrib.rnn.MultiRNNCell( + [make_cell(i) for i in range(self._n_layers)]) + + if initial_state is None: + initial_state = cell.zero_state(batch_size, dtype=tf.float32) + + # tf.nn.static_rnn not working so dynamic_rnn + hidden_states, final_state = tf.nn.dynamic_rnn( + cell, inputs, + initial_state=initial_state, sequence_length=seq_length + ) + output = tf.reshape(hidden_states, [-1, self._hidden_size]) + + # Reshape logits to be a 3-D tensor for sequence loss + softmax_w = weights['softmax_w'] + softmax_b = weights['softmax_b'] + logits = tf.nn.xw_plus_b(output, softmax_w, softmax_b) + logits = tf.reshape( + logits, [batch_size, self._time_steps, self._input_size]) + + return logits, hidden_states, final_state + + def _loss_fxn(self, logits, Y): + """Sequence loss function for logits and target Y.""" + return tf.contrib.seq2seq.sequence_loss( + logits, + Y, + tf.ones_like(Y, dtype=tf.float32), + average_across_timesteps=True, + average_across_batch=True) + + def _get_grads(self, loss, weights): + """Get gradient for loss w/r/t weights.""" + grads = tf.gradients(loss, list(weights.values())) + grads, _ = tf.clip_by_global_norm(grads, self._max_grad_norm) + if self._stop_grad: + grads = [tf.stop_gradient(grad) for grad in grads] + + return dict(zip(weights.keys(), grads)) + + def _get_update(self, weights, gradients, lr): + """Update weights using gradients and lr.""" + weight_updates = [weights[key] - lr * gradients[key] + for key in weights.keys()] + return dict(zip(weights.keys(), weight_updates)) + + def train(self, episode): + raise NotImplementedError() + + def eval(self, episode): + raise NotImplementedError() diff --git a/src/models/lstm_baseline.py b/src/models/lstm_baseline.py index 03da629..4320866 100644 --- a/src/models/lstm_baseline.py +++ b/src/models/lstm_baseline.py @@ -2,11 +2,13 @@ import tensorflow as tf from models.tf_model import TFModel +from models.nn_lib import LSTM, make_cell, get_sentinel_prob, num_stable_log,\ + seq_loss, get_ndcg from models.base_model import convert_tokens_to_input_and_target class LSTMBaseline(TFModel): - """LSTM language model + """LSTM language model. Trained on songs from the meta-training set. During evaluation, ignore each episode's support set and evaluate only on query set. @@ -15,18 +17,13 @@ class LSTMBaseline(TFModel): def __init__(self, config): super(LSTMBaseline, self).__init__(config) - def _define_placedholders(self): - # Add start word that starts every song - # Adding start word increases the size of vocabulary by 1 - self._start_word = self._config['input_size'] - self._input_size = self._config['input_size'] + 1 - - self._time_steps = self._config['max_len'] + def _define_placeholders(self): self._embd_size = self._config['embedding_size'] self._hidden_size = self._config['hidden_size'] self._n_layers = self._config['n_layers'] self._lr = self._config['lr'] self._max_grad_norm = self._config['max_grad_norm'] + self._embedding_var_name = 'embedding' self._batch_size = tf.placeholder(tf.int32, shape=()) self._seq_length = tf.placeholder(tf.int32, [None]) @@ -35,50 +32,59 @@ def _define_placedholders(self): self._target = tf.placeholder( tf.int32, [None, self._time_steps]) - def _build_graph(self): - embedding = tf.get_variable( - 'embedding', [self._input_size, self._embd_size]) - inputs = tf.nn.embedding_lookup(embedding, self._words) - inputs = tf.unstack(inputs, axis=1) - - def make_cell(): - return tf.contrib.rnn.BasicLSTMCell( - self._hidden_size, forget_bias=1., state_is_tuple=True) + self._is_training = tf.placeholder_with_default( + True, shape=(), name='is_training') + # self._max_token_len = tf.placeholder(tf.int32, shape=()) + def _build_lstm(self): + embedding = tf.get_variable( + self._embedding_var_name, [self._input_size, self._embd_size]) self._cell = tf.contrib.rnn.MultiRNNCell( - [make_cell() for _ in range(self._n_layers)]) + [make_cell(0, self._embd_size, self._hidden_size)]) self._initial_state = self._cell.zero_state( self._batch_size, dtype=tf.float32) - outputs, state = tf.nn.static_rnn( - self._cell, inputs, initial_state=self._initial_state, - sequence_length=self._seq_length) - self._state = state - - output = tf.concat(outputs, 1) - self._output = tf.reshape(output, [-1, self._hidden_size]) - - softmax_w = tf.get_variable( - 'softmax_w', [self._hidden_size, self._input_size]) - softmax_b = tf.get_variable('softmax_b', [self._input_size]) - # Reshape logits to be a 3-D tensor for sequence loss - logits = tf.nn.xw_plus_b(self._output, softmax_w, softmax_b) + + # outputs: [batch_size, time_step, hidden_size] + # state: [batch_size, hidden_size] + self._hidden_states_orig, self._final_state = LSTM( + self._cell, self._words, embedding, + self._seq_length, self._batch_size, self._initial_state + ) + + # [batch_size * time_step, hidden_size] + self._hidden_states = tf.reshape( + self._hidden_states_orig, [-1, self._hidden_size]) + logits = tf.matmul(self._hidden_states, embedding, transpose_b=True) + + # [batch_size, time_step, input_size] logits = tf.reshape( logits, [self._batch_size, self._time_steps, self._input_size]) - self._logits = logits - self._prob = tf.nn.softmax(self._logits) - - self._avg_neg_log = tf.contrib.seq2seq.sequence_loss( - logits, - self._target, - tf.ones([self._batch_size, self._time_steps], dtype=tf.float32), - average_across_timesteps=True, - average_across_batch=True) - - lr = tf.train.exponential_decay( - self._lr, - self._global_step, - self._config['n_decay'], 0.5, staircase=False - ) + if not self._config['use_sentinel']: + self._logits = logits + self._prob = tf.nn.softmax(self._logits) + else: + prob_vocab = tf.nn.softmax(logits) + g, prob_cache = get_sentinel_prob( + self._target, self._hidden_states, self._batch_size, + self._time_steps, self._hidden_size, self._input_size) + self._prob = tf.multiply(g, prob_vocab) + prob_cache + self._logits = num_stable_log(self._prob) + """ + max_token_len = tf.tile([self._max_token_len], [self._batch_size]) + self._neg_log = seq_loss( + self._logits, self._target, + tf.minimum(max_token_len, self._seq_length), self._time_steps, + avg_batch=False) + """ + self._neg_log = seq_loss( + self._logits, self._target, + self._seq_length, self._time_steps, + avg_batch=False) + self._avg_neg_log = tf.reduce_mean(self._neg_log) + + def _build_graph(self): + self._build_lstm() + lr = self._lr optimizer = tf.train.AdamOptimizer(lr) grads, _ = tf.clip_by_global_norm(tf.gradients(self._avg_neg_log, self.get_vars()), @@ -89,21 +95,28 @@ def make_cell(): def train(self, episode): """Concatenate query and support sets to train.""" X, Y = convert_tokens_to_input_and_target( - episode.support, self._start_word) + episode.support, episode.support_seq_len, + self._start_word, self._end_word) X2, Y2 = convert_tokens_to_input_and_target( - episode.query, self._start_word) + episode.query, episode.query_seq_len, + self._start_word, self._end_word) X = np.concatenate([X, X2]) Y = np.concatenate([Y, Y2]) + support_seq_len = episode.support_seq_len.flatten() + query_seq_len = episode.query_seq_len.flatten() + seq_len = np.concatenate([support_seq_len, query_seq_len]) feed_dict = {} feed_dict[self._words] = X feed_dict[self._target] = Y feed_dict[self._batch_size] = np.shape(X)[0] - feed_dict[self._seq_length] = [np.shape(X)[1]] * np.shape(X)[0] + # adding stop word adds +1 to sequence length + feed_dict[self._seq_length] = seq_len + 1 _, loss = self._sess.run([self._train_op, self._avg_neg_log], feed_dict=feed_dict) - if self._summary_writer: + + if self._summary_writer is not None: summary = tf.Summary(value=[ tf.Summary.Value(tag='Train/loss', simple_value=loss)]) @@ -112,17 +125,78 @@ def train(self, episode): return loss + def eval_ndcg(self, episode): + # Evaluate NDCG ranking metric + + if np.shape(episode.support)[0] > 1: + episode.support = episode.support[0:1, :, :] + episode.query = episode.query[0:1, :, :] + episode.other_query = episode.other_query[0:1, :, :] + episode.support_seq_len = episode.support_seq_len[0:1, :] + episode.query_seq_len = episode.query_seq_len[0:1, :] + episode.other_query_seq_len = episode.other_query_seq_len[0:1, :] + + # Ignore support set and evaluate only on query set. + X, Y = convert_tokens_to_input_and_target( + episode.query, episode.query_seq_len, + self._start_word, self._end_word) + X_other, Y_other = convert_tokens_to_input_and_target( + episode.other_query, episode.other_query_seq_len, + self._start_word, self._end_word) + + query_seq_len = episode.query_seq_len.flatten() + feed_dict = {} + feed_dict[self._words] = X + feed_dict[self._target] = Y + feed_dict[self._batch_size] = np.shape(X)[0] + # adding stop word makes sequences longer by +1 + feed_dict[self._seq_length] = query_seq_len + 1 + feed_dict[self._is_training] = False + # feed_dict[self._max_token_len] = self._config['eval_len'] + nll, avg_nll = self._sess.run( + [self._neg_log, self._avg_neg_log], feed_dict=feed_dict) + + other_query_seq_len = episode.other_query_seq_len.flatten() + feed_dict = {} + feed_dict[self._words] = X_other + feed_dict[self._target] = Y_other + feed_dict[self._batch_size] = np.shape(X_other)[0] + # adding stop word makes sequences longer by +1 + feed_dict[self._seq_length] = other_query_seq_len + 1 + feed_dict[self._is_training] = False + # feed_dict[self._max_token_len] = self._config['eval_len'] + nll_other, _ = self._sess.run( + [self._neg_log, self._avg_neg_log], feed_dict=feed_dict) + + rel_scores = np.ones(shape=np.shape(nll)) + rel_scores_other = np.zeros(shape=np.shape(nll_other)) + + ndcg = get_ndcg( + np.concatenate([rel_scores, rel_scores_other]), + np.concatenate([nll, nll_other]), + rank_position=np.shape(nll)[0]) + + return ndcg + def eval(self, episode): - """Ignore support set and evaluate only on query set.""" + # Ignore support set and evaluate only on query set. + X, Y = convert_tokens_to_input_and_target( - episode.query, self._start_word) + episode.query, episode.query_seq_len, + self._start_word, self._end_word) + query_seq_len = episode.query_seq_len.flatten() feed_dict = {} feed_dict[self._words] = X feed_dict[self._target] = Y feed_dict[self._batch_size] = np.shape(X)[0] - feed_dict[self._seq_length] = [np.shape(X)[1]] * np.shape(X)[0] - avg_neg_log = self._sess.run(self._avg_neg_log, feed_dict=feed_dict) + # adding stop word makes sequences longer by +1 + feed_dict[self._seq_length] = query_seq_len + 1 + feed_dict[self._is_training] = False + # feed_dict[self._max_token_len] = self._config['max_len'] + avg_neg_log = self._sess.run( + self._avg_neg_log, feed_dict=feed_dict) + if self._summary_writer is not None: summary = tf.Summary(value=[ tf.Summary.Value(tag='Eval/Avg_NLL', @@ -147,10 +221,10 @@ def sample(self, support_set, num): feed_dict[self._seq_length] = [1] feed_dict[self._initial_state] = state - probs, state = self._sess.run([self._prob, self._state], + probs, state = self._sess.run([self._prob, self._final_state], feed_dict=feed_dict) p = probs[0][0] - word = np.argmax(p) + word = self._sampler.sample(p) pred_words.append(word) return pred_words diff --git a/src/models/lstm_cell.py b/src/models/lstm_cell.py new file mode 100644 index 0000000..88d6f56 --- /dev/null +++ b/src/models/lstm_cell.py @@ -0,0 +1,68 @@ +import tensorflow as tf + +from tensorflow.python.ops.rnn_cell_impl import LSTMStateTuple + + +def sep(*args): + return '/'.join(args) + + +class LSTMCell(tf.contrib.rnn.BasicLSTMCell): + + def __init__(self, n, input_size, num_units, weights=None, + state_is_tuple=True, forget_bias=1): + self._n = str(n) + self._num_units = num_units = num_units + self._state_is_tuple = state_is_tuple + self._forget_bias = forget_bias + + if weights is None: + self._kernel = tf.get_variable( + sep(self._n, "kernel"), + shape=[input_size + num_units, 4 * num_units]) + self._bias = tf.get_variable( + sep(self._n, "bias"), + shape=[4 * num_units], + initializer=tf.constant_initializer(0.0)) + else: + self._kernel = weights[sep(self._n, "kernel")] + self._bias = weights[sep(self._n, "bias")] + + self._forget_bias_tensor = tf.constant( + self._forget_bias, dtype=tf.float32) + + def __call__(self, inputs, state): + """Borrowed from + https://github.com/tensorflow/tensorflow/blob/r1.9/tensorflow/python/ops/rnn_cell_impl.py#L614""" + + if self._state_is_tuple: + c, h = state + else: + c, h = tf.split(value=state, num_or_size_splits=2, axis=1) + + i = inputs + o = h + gate_inputs = tf.matmul( + tf.concat([i, o], 1), self._kernel) + gate_inputs = tf.nn.bias_add(gate_inputs, self._bias) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = tf.split( + value=gate_inputs, num_or_size_splits=4, axis=1) + + # Note that using `add` and `multiply` instead of `+` and `*` gives a + # performance improvement. So using those at the cost of readability. + add = tf.add + multiply = tf.multiply + sigmoid = tf.sigmoid + activation = tf.tanh + new_c = add(multiply(c, sigmoid(add(f, self._forget_bias_tensor))), + multiply(sigmoid(i), activation(j))) + new_h = multiply(activation(new_c), sigmoid(o)) + + if self._state_is_tuple: + new_state = LSTMStateTuple(new_c, new_h) + else: + new_state = tf.concat([new_c, new_h], 1) + + return new_h, new_state diff --git a/src/models/lstm_cont_cache.py b/src/models/lstm_cont_cache.py new file mode 100644 index 0000000..e260c4f --- /dev/null +++ b/src/models/lstm_cont_cache.py @@ -0,0 +1,168 @@ +import numpy as np +import tensorflow as tf + +from models.fast_weights_lstm import FastWeightsLSTM +from models.base_model import convert_tokens_to_input_and_target + + +def _log(probs, eps=1e-7): + _epsilon = eps + return tf.log(tf.clip_by_value(probs, _epsilon, 1. - _epsilon)) + + +class LSTMContCache(FastWeightsLSTM): + + def __init__(self, config): + super(LSTMContCache, self).__init__(config) + + def _define_placeholders(self): + self._embd_size = self._config['embedding_size'] + self._hidden_size = self._config['hidden_size'] + self._n_layers = self._config['n_layers'] + self._lr = self._config['lr'] + self._max_grad_norm = self._config['max_grad_norm'] + self._embedding_var_name = 'embedding' + + # Hyperparameters Lambda and Theta as defined in continuous cache paper + self._lambda = self._config['lambda'] + self._theta = self._config['theta'] + + super(LSTMContCache, self)._define_placeholders() + + def _build_graph(self): + self.weights = self._build_weights() + elems = (self._supportX, self._supportY, self._queryX, self._queryY) + self._test_avg_neg_log = tf.map_fn( + self._use_cont_cache, elems=elems, dtype=tf.float32) + self._test_avg_neg_log = tf.reduce_mean(self._test_avg_neg_log) + lr = self._config['lr'] + optimizer = tf.train.AdamOptimizer(lr) + self._gvs = gvs = optimizer.compute_gradients(self._test_avg_neg_log) + self._train_op = optimizer.apply_gradients(gvs, self._global_step) + + def _use_cont_cache(self, _input): + """Use continuous cache computed on support set for query set output. + + Compute cache on hidden states of LSTM processed on support set + and use cache to produce output distribution for query set. + Arguments: + supportX: support set input of size [S,N] where S is number of songs + and N is size of songs + supportY: support set target of size [S,N] + queryX: query set input of size [S',N] where S' is number of songs + and N is size of songs + queryY: query set target of size [S',N] + Returns: + loss on query set using output distribution that is mixture of + regular LSTM distribution and cache distribution + + """ + supportX, supportY, queryX, queryY = _input + + _, all_hidden_train, _ = self._model( + self.weights, supportX, + self._support_batch_size, self._support_seq_length) + # convert train hidden states from + # [batch_size, time_step, n_hidden] + # => [n_hidden, support_batch_size * time_step] + all_hidden_train = tf.reshape(all_hidden_train, [-1, self._hidden_size]) + all_hidden_train = tf.transpose(all_hidden_train) + # convert from [support_batch_size, time_step] + # => [support_batch_size * time_step, input_size] + supportY_one_hot = tf.one_hot( + tf.reshape(supportY, [-1]), self._input_size) + + logits, all_hidden_test, _ = self._model( + self.weights, queryX, + self._query_batch_size, self._query_seq_length) + # distribution according to LSTM + lstm_prob = tf.nn.softmax(logits) + # test hidden states: [query_batch_size * time_step, n_hidden] + all_hidden_test = tf.reshape(all_hidden_test, [-1, self._hidden_size]) + + # [query_batch_size * time_step, support_batch_size * time_step] + sim_scores = tf.matmul( + all_hidden_test, all_hidden_train) + # [query_batch_size * time_step, support_batch_size * time_step] + p = tf.nn.softmax(self._theta * sim_scores) + # [query_batch_size * time_step, input_size] + p = tf.matmul(p, supportY_one_hot) + # [query_batch_size, time_step, input_size] + cache_prob = tf.reshape( + p, [self._query_batch_size, self._time_steps, self._input_size]) + + # final distribution is mixture of LSTM and cache distributions + prob = (1. - self._lambda) * lstm_prob + self._lambda * cache_prob + + # convert prob distribution to logits for loss function + return self._loss_fxn(_log(prob), queryY) + """ + return tf.contrib.seq2seq.sequence_loss( + _log(prob), + queryY, + tf.ones_like(queryY, dtype=tf.float32), + average_across_timesteps=True, + average_across_batch=True) + """ + + def train(self, episode): + """Use support set to compute cache and train on loss on query set.""" + feed_dict = {} + support_batch_size = np.shape(episode.support)[1] + query_batch_size = np.shape(episode.query)[1] + + supportX, supportY = convert_tokens_to_input_and_target( + episode.support, self._start_word, flatten_batch=False) + queryX, queryY = convert_tokens_to_input_and_target( + episode.query, self._start_word, flatten_batch=False) + feed_dict[self._supportX] = supportX + feed_dict[self._supportY] = supportY + feed_dict[self._queryX] = queryX + feed_dict[self._queryY] = queryY + feed_dict[self._support_batch_size] = support_batch_size + feed_dict[self._query_batch_size] = query_batch_size + feed_dict[self._support_seq_length] = [np.shape(supportX)[2]] * np.shape(supportX)[1] + feed_dict[self._query_seq_length] = [np.shape(queryX)[2]] * np.shape(queryX)[1] + + _, loss = self._sess.run( + [self._train_op, self._test_avg_neg_log], feed_dict=feed_dict) + + if self._summary_writer: + summary = tf.Summary(value=[ + tf.Summary.Value(tag='Train/loss', + simple_value=loss)]) + self._summary_writer.add_summary(summary, self._train_calls) + self._train_calls += 1 + + return loss + + def eval(self, episode): + """Use support set to compute cache and evaluate on query set.""" + feed_dict = {} + support_batch_size = np.shape(episode.support)[1] + query_batch_size = np.shape(episode.query)[1] + + supportX, supportY = convert_tokens_to_input_and_target( + episode.support, self._start_word, flatten_batch=False) + queryX, queryY = convert_tokens_to_input_and_target( + episode.query, self._start_word, flatten_batch=False) + feed_dict[self._supportX] = supportX + feed_dict[self._supportY] = supportY + feed_dict[self._queryX] = queryX + feed_dict[self._queryY] = queryY + feed_dict[self._support_batch_size] = support_batch_size + feed_dict[self._query_batch_size] = query_batch_size + feed_dict[self._support_seq_length] = [np.shape(supportX)[2]] * np.shape(supportX)[1] + feed_dict[self._query_seq_length] = [np.shape(queryX)[2]] * np.shape(queryX)[1] + + avg_neg_log = self._sess.run( + self._test_avg_neg_log, feed_dict=feed_dict) + + if self._summary_writer: + summary = tf.Summary(value=[ + tf.Summary.Value(tag='Eval/Avg_NLL', + simple_value=avg_neg_log)]) + self._summary_writer.add_summary(summary, self._eval_calls) + self._eval_calls += 1 + + return avg_neg_log diff --git a/src/models/lstm_dynamic_eval.py b/src/models/lstm_dynamic_eval.py new file mode 100644 index 0000000..b204e08 --- /dev/null +++ b/src/models/lstm_dynamic_eval.py @@ -0,0 +1,113 @@ +import numpy as np +import tensorflow as tf + +from models.fast_weights_lstm import FastWeightsLSTM +from models.base_model import convert_tokens_to_input_and_target + + +def convert_to_subsequences(sequences, subseq_len): + """Break sequence into bunch of subsequences based on desired subseq_len. + + Arguments: + sequences: (numpy int array) of size [S,N], where S is number of + sequences and N is size of each sequence + subseq_len: (int) size of subsequences desired to break sequence into + Returns: + sequences: (numpy int array) of size [S,n_subseq,subseq_len], where S is + number of sequences, n_subseq is number of subsequences that make + up whole sequence, and subseq_len is the size of each subsequence + """ + n_songs, n_dim = np.shape(sequences)[0], np.shape(sequences)[1] + + n_subseq = n_dim // subseq_len + sequences = np.reshape( + sequences, [n_songs, n_subseq, subseq_len]) + return sequences + + +class LSTMDynamicEval(FastWeightsLSTM): + + def __init__(self, config): + super(LSTMDynamicEval, self).__init__(config) + + def _define_placeholders(self): + # Overwrite time steps as we are operating on subsequences + assert_msg = 'Sequence length %d is not divisible by subsequence Length %d'\ + % (self._config['max_len'], self._config['subseq_len']) + assert self._config['max_len'] % self._time_steps == 0, assert_msg + self._time_steps = self._config['subseq_len'] + self._n_subseq = self._config['max_len'] // self._time_steps + + self._embd_size = self._config['embedding_size'] + self._hidden_size = self._config['hidden_size'] + self._n_layers = self._config['n_layers'] + self._lr = self._config['lr'] + self._max_grad_norm = self._config['max_grad_norm'] + self._stop_grad = True + self._embedding_var_name = 'embedding' + + self._seq_length = tf.placeholder(tf.int32, [None]) + self._words = tf.placeholder( + tf.int32, [None, None, self._time_steps]) + self._target = tf.placeholder( + tf.int32, [None, None, self._time_steps]) + + def _build_graph(self): + self.weights = self._build_weights() + elems = (self._words, self._target) + self._avg_neg_log = tf.map_fn( + self._dynamic_eval, elems=elems, dtype=tf.float32) + self._avg_neg_log = tf.reduce_mean(self._avg_neg_log) + + def _dynamic_eval(self, _input): + """Perform dynamic evaluation on a single sequence. + + Iterate through each subsequence and update parameters on loss of + each subsequence. Loss is computed on each subsequence before update + step. + Arguments: + _input: tuple containing words & target, where + words: (tensor) of size (n_subseq, subseq_len) + target: (tensor) of size (n_subseq, subseq_len) + Returns: + tensor of losses on each subsequence + + """ + words, target = _input + train_losses = [] + + fast_weights = self.weights + final_state = None + for i in range(self._n_subseq): + X = tf.slice(words, begin=[i, 0], size=[1, -1]) + Y = tf.slice(target, begin=[i, 0], size=[1, -1]) + + logits, _, final_state = self._model( + fast_weights, X, 1, self._seq_length, final_state) + loss_train = self._loss_fxn(logits, Y) + train_losses.append(loss_train) + + grads = self._get_grads(loss_train, fast_weights) + fast_weights = self._get_update(fast_weights, grads, self._lr) + + return tf.stack(loss_train) + + def train(self, episode): + raise NotImplementedError() + + def eval(self, episode): + """Ignore support set and perform dynamic evaluation on query set.""" + X, Y = convert_tokens_to_input_and_target( + episode.query, self._start_word, flatten_batch=True) + subseq_len = self._config['subseq_len'] + X = convert_to_subsequences(X, subseq_len) + Y = convert_to_subsequences(Y, subseq_len) + + feed_dict = {} + feed_dict[self._words] = X + feed_dict[self._target] = Y + feed_dict[self._seq_length] = [np.shape(X)[2]] + avg_neg_log = self._sess.run( + self._avg_neg_log, feed_dict=feed_dict) + + return avg_neg_log diff --git a/src/models/lstm_enc_dec.py b/src/models/lstm_enc_dec.py new file mode 100644 index 0000000..b10f23c --- /dev/null +++ b/src/models/lstm_enc_dec.py @@ -0,0 +1,391 @@ +import tensorflow as tf +import numpy as np +from tensorflow.contrib.layers import fully_connected +from models.nn_lib import LSTM, make_cell, get_sentinel_prob, num_stable_log,\ + seq_loss, get_ndcg, make_cell_film, LSTMFilm + +from models.tf_model import TFModel +from models.base_model import convert_tokens_to_input_and_target + + +def word_dropout(tokens, unknown_token, dropout_p): + """With probability dropout_p, replace tokens with unknown token. + + Args: + tokens: np array of size [B,N,S] + unknown_token: int + dropout_p: float + """ + if dropout_p > 0: + original_shape = tokens.shape + temp = tokens.flatten() + bernoulli_sample = np.random.binomial(n=1, p=dropout_p, size=temp.size) + idxs = np.where(bernoulli_sample == 1) + temp[idxs] = unknown_token + temp = temp.reshape(original_shape) + return temp + else: + return tokens + + +class LSTMEncDec(TFModel): + """LSTM language model which conditions on support set to evaluate query set. + + Trained on episodes from the meta-training set. During evaluation, + use each episode's support set to produce encoding that is then + used to condition when evaluating on query set. + """ + + def __init__(self, config): + super(LSTMEncDec, self).__init__(config) + + def _define_placeholders(self): + self._embd_size = self._config['embedding_size'] + self._hidden_size = self._config['hidden_size'] + self._n_layers = self._config['n_layers'] + self._lr = self._config['lr'] + self._max_grad_norm = self._config['max_grad_norm'] + self._embedding_var_name = 'embedding' + self._enc_size = self._config['enc_size'] + + self._support_size = tf.placeholder(tf.int32, shape=()) + self._query_size = tf.placeholder(tf.int32, shape=()) + self._support_seq_length = tf.placeholder(tf.int32, [None, None]) + self._query_seq_length = tf.placeholder(tf.int32, [None, None]) + + self._supportX = tf.placeholder( + tf.int32, [None, None, self._time_steps]) + self._supportY = tf.placeholder( + tf.int32, [None, None, self._time_steps]) + self._queryX = tf.placeholder( + tf.int32, [None, None, self._time_steps]) + self._queryY = tf.placeholder( + tf.int32, [None, None, self._time_steps]) + + self._is_training = tf.placeholder_with_default( + True, shape=(), name='is_training') + # self._max_token_len = tf.placeholder(tf.int32, shape=()) + + def _build_graph(self): + elems = (self._supportX, self._supportY, self._support_seq_length, + self._queryX, self._queryY, self._query_seq_length) + self._all_neg_log, self._query_neg_log, self._prob, self._enc \ + = tf.map_fn(self._train_episode, elems=elems, + dtype=(tf.float32, tf.float32, tf.float32, tf.float32)) + self._all_avg_neg_log = tf.reduce_mean(self._all_neg_log) + self._query_avg_neg_log = tf.reduce_mean(self._query_neg_log) + optimizer = tf.train.AdamOptimizer(self._lr) + self._gvs = gvs = optimizer.compute_gradients(self._all_avg_neg_log) + self._train_op = optimizer.apply_gradients(gvs, self._global_step) + + def _train_episode(self, _input): + supportX, supportY, support_seq_length,\ + queryX, queryY, query_seq_length = _input + + enc = self._encode( + supportY, + self._support_size, support_seq_length) + + # Option of whether to decode only query OR support & query using + # support encoding + if self._config['decode_support']: + X = tf.concat([supportX, queryX], axis=0) + Y = tf.concat([supportY, queryY], axis=0) + size = self._support_size + self._query_size + seq_length = tf.concat( + [support_seq_length, query_seq_length], axis=0) + + logits_both = self._decode( + X, + Y, + size, + seq_length, + enc + ) + loss_both = seq_loss( + logits_both, Y, + seq_length, self._time_steps, + avg_batch=False) + + logits_query = tf.slice( + logits_both, [self._support_size, 0, 0], [-1, -1, -1]) + loss_query = tf.contrib.seq2seq.sequence_loss( + logits_query, + queryY, + query_seq_length, + average_across_timesteps=True, + average_across_batch=True) + prob_query = tf.nn.softmax(logits_query) + + return loss_both, loss_query, prob_query, enc + else: + X = queryX + Y = queryY + size = self._query_size + seq_length = query_seq_length + + logits = self._decode( + X, + Y, + size, + seq_length, + enc + ) + """ + max_token_len = tf.tile([self._max_token_len], [self._query_size]) + loss = seq_loss( + logits, Y, + tf.minimum(max_token_len, query_seq_length), self._time_steps, + avg_batch=False) + """ + loss = seq_loss( + logits, Y, + seq_length, self._time_steps, + avg_batch=False) + prob_query = tf.nn.softmax(logits) + return loss, loss, prob_query, enc + + def _encode(self, X, size, seq_length): + with tf.variable_scope('shared'): + embedding = tf.get_variable( + self._embedding_var_name, [self._input_size, self._embd_size]) + + with tf.variable_scope('lstm1'): + cell = tf.contrib.rnn.MultiRNNCell( + [make_cell(0, self._embd_size, self._hidden_size)]) + initial_state = cell.zero_state( + self._support_size, dtype=tf.float32) + + # [support_size, time_steps, hidden_size] + _, final_state = LSTM( + cell, X, embedding, + seq_length, size, initial_state, scope='lstm1' + ) + final_state = final_state[0] + # [support_size, 2*hidden_size] + hidden_concat = tf.concat([final_state.c, final_state.h], axis=-1) + pool_fxn = tf.reduce_mean + + """ + # Pool + FC + # [1, 2*hidden_size] + # mean_hidden = tf.reduce_mean(hidden_concat, axis=0, keep_dims=True) + pool_hidden = pool_fxn(hidden_concat, axis=0, keep_dims=True) + # [1, enc_size] + # enc = fully_connected( + # mean_hidden, self._enc_size, activation_fn=tf.nn.tanh) + enc = fully_connected( + pool_hidden, self._enc_size, activation_fn=tf.nn.tanh) + """ + + # FC + Pool (seems to work better) + # [support_size, enc_size] + enc = fully_connected( + hidden_concat, self._enc_size) + # [1, enc_size] + enc = pool_fxn(enc, axis=0, keepdims=True) + + return enc + + def _decode(self, X, Y, size, seq_length, enc): + enc = tf.tile(enc, [size, 1]) + initial_state_h = fully_connected( + enc, self._hidden_size, activation_fn=tf.nn.tanh) + initial_state_c = fully_connected( + enc, self._hidden_size, activation_fn=tf.nn.tanh) + initial_state = (tf.contrib.rnn.LSTMStateTuple( + initial_state_c, initial_state_h),) + enc = tf.tile(tf.expand_dims(enc, axis=1), [1, self._time_steps, 1]) + + with tf.variable_scope('shared', reuse=True): + embedding = tf.get_variable( + self._embedding_var_name, [self._input_size, self._embd_size]) + + LSTMclass = None + with tf.variable_scope('lstm2'): + if self._config['use_film']: + list_of_cells = [make_cell_film( + 0, self._embd_size, self._enc_size, self._hidden_size)] + LSTMclass = LSTMFilm + else: + list_of_cells = [make_cell( + 0, self._embd_size + self._enc_size, self._hidden_size)] + LSTMclass = LSTM + + cell = tf.contrib.rnn.MultiRNNCell(list_of_cells) + + # [n_query, time_step, n_hidden] + hidden_states, _ = LSTMclass( + cell, X, embedding, + seq_length, size, initial_state, enc, scope='lstm2' + ) + + # [n_query*time_step, n_hidden] + hidden_states = tf.reshape( + hidden_states, [-1, self._hidden_size]) + logits = tf.matmul(hidden_states, embedding, transpose_b=True) + logits = tf.reshape( + logits, [size, self._time_steps, self._input_size]) + if not self._config['use_sentinel']: + return logits + + prob_vocab = tf.nn.softmax(logits) + g, prob_cache = get_sentinel_prob( + Y, hidden_states, size, + self._time_steps, self._hidden_size, self._input_size) + prob = tf.multiply(g, prob_vocab) + prob_cache + return num_stable_log(prob) + + def train(self, episode): + """MAML training objective involves input of support and query sets.""" + feed_dict = {} + support_size = np.shape(episode.support)[1] + query_size = np.shape(episode.query)[1] + support_seq_len = episode.support_seq_len + query_seq_len = episode.query_seq_len + + supportX, supportY = convert_tokens_to_input_and_target( + episode.support, episode.support_seq_len, + self._start_word, self._end_word, + flatten_batch=False) + queryX, queryY = convert_tokens_to_input_and_target( + episode.query, episode.query_seq_len, + self._start_word, self._end_word, + flatten_batch=False) + + feed_dict[self._supportX] = supportX + feed_dict[self._supportY] = supportY + feed_dict[self._queryX] = queryX + feed_dict[self._queryY] = queryY + feed_dict[self._support_size] = support_size + feed_dict[self._query_size] = query_size + feed_dict[self._support_seq_length] = support_seq_len + 1 + feed_dict[self._query_seq_length] = query_seq_len + 1 + # feed_dict[self._max_token_len] = self._config['max_len'] + + _, loss = self._sess.run( + [self._train_op, self._test_avg_neg_log], feed_dict=feed_dict) + + if self._summary_writer: + summary = tf.Summary(value=[ + tf.Summary.Value(tag='Train/loss', + simple_value=loss)]) + self._summary_writer.add_summary(summary, self._train_calls) + self._train_calls += 1 + + return loss + + def eval_ndcg(self, episode): + # Evaluate NDCG ranking metric + + if np.shape(episode.support)[0] > 1: + episode.support = episode.support[0:1, :, :] + episode.query = episode.query[0:1, :, :] + episode.other_query = episode.other_query[0:1, :, :] + episode.support_seq_len = episode.support_seq_len[0:1, :] + episode.query_seq_len = episode.query_seq_len[0:1, :] + episode.other_query_seq_len = episode.other_query_seq_len[0:1, :] + + feed_dict = {} + support_size = np.shape(episode.support)[1] + query_size = np.shape(episode.query)[1] + support_seq_len = episode.support_seq_len + query_seq_len = episode.query_seq_len + other_query_seq_len = episode.other_query_seq_len + + supportX, supportY = convert_tokens_to_input_and_target( + episode.support, episode.support_seq_len, + self._start_word, self._end_word, + flatten_batch=False) + queryX, queryY = convert_tokens_to_input_and_target( + episode.query, episode.query_seq_len, + self._start_word, self._end_word, + flatten_batch=False) + queryX_other, queryY_other = convert_tokens_to_input_and_target( + episode.other_query, episode.other_query_seq_len, + self._start_word, self._end_word, + flatten_batch=False) + + feed_dict[self._supportX] = supportX + feed_dict[self._supportY] = supportY + feed_dict[self._queryX] = queryX + feed_dict[self._queryY] = queryY + feed_dict[self._support_size] = support_size + feed_dict[self._query_size] = query_size + feed_dict[self._support_seq_length] = support_seq_len + 1 + feed_dict[self._query_seq_length] = query_seq_len + 1 + feed_dict[self._is_training] = False + # feed_dict[self._max_token_len] = self._config['eval_len'] + + nll, avg_nll = self._sess.run( + [self._query_neg_log, self._query_avg_neg_log], feed_dict=feed_dict) + + feed_dict[self._supportX] = supportX + feed_dict[self._supportY] = supportY + feed_dict[self._queryX] = queryX_other + feed_dict[self._queryY] = queryY_other + feed_dict[self._support_size] = support_size + feed_dict[self._query_size] = query_size + feed_dict[self._support_seq_length] = support_seq_len + 1 + feed_dict[self._query_seq_length] = other_query_seq_len + 1 + feed_dict[self._is_training] = False + # feed_dict[self._max_token_len] = self._config['eval_len'] + + nll_other, _ = self._sess.run( + [self._query_neg_log, self._query_avg_neg_log], feed_dict=feed_dict) + + nll = nll.flatten() + nll_other = nll_other.flatten() + rel_scores = np.ones(shape=np.shape(nll)) + rel_scores_neg_songs = np.zeros(shape=np.shape(nll_other)) + + ndcg = get_ndcg( + np.concatenate([rel_scores, rel_scores_neg_songs]), + np.concatenate([nll, nll_other]), + rank_position=np.shape(nll)[0]) + + return ndcg + + def eval(self, episode): + # Use support set to produce encoding that is then used to condition + # LSTM when evaluating corresponding support set + feed_dict = {} + support_size = np.shape(episode.support)[1] + query_size = np.shape(episode.query)[1] + support_seq_len = episode.support_seq_len + query_seq_len = episode.query_seq_len + + supportX, supportY = convert_tokens_to_input_and_target( + episode.support, episode.support_seq_len, + self._start_word, self._end_word, + flatten_batch=False) + queryX, queryY = convert_tokens_to_input_and_target( + episode.query, episode.query_seq_len, + self._start_word, self._end_word, + flatten_batch=False) + + feed_dict[self._supportX] = supportX + feed_dict[self._supportY] = supportY + feed_dict[self._queryX] = queryX + feed_dict[self._queryY] = queryY + feed_dict[self._support_size] = support_size + feed_dict[self._query_size] = query_size + feed_dict[self._support_seq_length] = support_seq_len + 1 + feed_dict[self._query_seq_length] = query_seq_len + 1 + feed_dict[self._is_training] = False + # feed_dict[self._max_token_len] = self._config['max_len'] + + avg_neg_log = self._sess.run( + self._query_avg_neg_log, feed_dict=feed_dict) + + if self._summary_writer: + summary = tf.Summary(value=[ + tf.Summary.Value(tag='Eval/Avg_NLL', + simple_value=avg_neg_log)]) + self._summary_writer.add_summary(summary, self._eval_calls) + self._eval_calls += 1 + + return avg_neg_log + + def sample(self, support_set, num): + raise NotImplementedError() diff --git a/src/models/lstm_film_cell.py b/src/models/lstm_film_cell.py new file mode 100644 index 0000000..7dabf5e --- /dev/null +++ b/src/models/lstm_film_cell.py @@ -0,0 +1,69 @@ +import tensorflow as tf + +from tensorflow.python.ops.rnn_cell_impl import LSTMStateTuple +from models.lstm_cell import LSTMCell + + +def sep(*args): + return '/'.join(args) + + +class LSTMFilmCell(LSTMCell): + """Compute LSTM cell with FILM conditioning. + + Reference: + https://arxiv.org/abs/1709.07871 + """ + + def __init__(self, n, input_size, cond_size, num_units, weights=None, + state_is_tuple=True, forget_bias=1): + super(LSTMFilmCell, self).__init__( + n, input_size, num_units, weights, state_is_tuple, forget_bias) + self._cond_scale_kernel = tf.get_variable( + sep(self._n, "cond_scale"), + shape=[cond_size, 4 * num_units]) + self._cond_shift_kernel = tf.get_variable( + sep(self._n, "cond_shift"), + shape=[cond_size, 4 * num_units]) + + def __call__(self, inputs, state): + """Modified from + https://github.com/tensorflow/tensorflow/blob/r1.9/tensorflow/python/ops/rnn_cell_impl.py#L614""" + + if self._state_is_tuple: + c, h = state + else: + c, h = tf.split(value=state, num_or_size_splits=2, axis=1) + + # Note that using `add` and `multiply` instead of `+` and `*` gives a + # performance improvement. So using those at the cost of readability. + add = tf.add + multiply = tf.multiply + sigmoid = tf.sigmoid + activation = tf.tanh + + i, cond = inputs + o = h + gate_inputs = tf.matmul( + tf.concat([i, o], 1), self._kernel) + gate_inputs = tf.nn.bias_add(gate_inputs, self._bias) + cond_scale = tf.matmul( + cond, self._cond_scale_kernel) + 1 + cond_shift = tf.matmul( + cond, self._cond_shift_kernel) + gate_inputs = add(multiply(cond_scale, gate_inputs), cond_shift) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = tf.split( + value=gate_inputs, num_or_size_splits=4, axis=1) + + new_c = add(multiply(c, sigmoid(add(f, self._forget_bias_tensor))), + multiply(sigmoid(i), activation(j))) + new_h = multiply(activation(new_c), sigmoid(o)) + + if self._state_is_tuple: + new_state = LSTMStateTuple(new_c, new_h) + else: + new_state = tf.concat([new_c, new_h], 1) + + return new_h, new_state diff --git a/src/models/lstm_maml.py b/src/models/lstm_maml.py new file mode 100644 index 0000000..41185d8 --- /dev/null +++ b/src/models/lstm_maml.py @@ -0,0 +1,125 @@ +import tensorflow as tf +import numpy as np + +from models.fast_weights_lstm import FastWeightsLSTM +from models.base_model import convert_tokens_to_input_and_target + + +class LSTMMAML(FastWeightsLSTM): + + def __init__(self, config): + super(LSTMMAML, self).__init__(config) + + def _define_placeholders(self): + self._embd_size = self._config['embedding_size'] + self._hidden_size = self._config['hidden_size'] + self._n_layers = self._config['n_layers'] + self._lr = self._config['lr'] + self._meta_lr = self._config['meta_lr'] + self._max_grad_norm = self._config['max_grad_norm'] + self._stop_grad = self._config['stop_grad'] + self._n_update = self._config['n_update'] + self._embedding_var_name = 'embedding' + + super(LSTMMAML, self)._define_placeholders() + + def _build_graph(self): + self.weights = self._build_weights() + elems = (self._supportX, self._supportY, self._queryX, self._queryY) + self._test_avg_neg_log = tf.map_fn( + self._train_episode, elems=elems, dtype=tf.float32) + self._test_avg_neg_log = tf.reduce_mean(self._test_avg_neg_log) + optimizer = tf.train.AdamOptimizer(self._meta_lr) + self._gvs = gvs = optimizer.compute_gradients(self._test_avg_neg_log) + self._train_op = optimizer.apply_gradients(gvs, self._global_step) + + def _train_episode(self, _input): + supportX, supportY, queryX, queryY = _input + train_losses = [] + + fast_weights = self.weights + for i in range(self._n_update): + logits, _, _ = self._model( + fast_weights, supportX, + self._support_batch_size, self._support_seq_length + ) + loss_train = self._loss_fxn(logits, supportY) + train_losses.append(loss_train) + + grads = self._get_grads(loss_train, fast_weights) + fast_weights = self._get_update(fast_weights, grads, self._lr) + + logits, _, _ = self._model( + fast_weights, supportX, + self._support_batch_size, self._support_seq_length) + loss_train = self._loss_fxn(logits, supportY) + train_losses.append(loss_train) + + logits, _, _ = self._model( + fast_weights, queryX, + self._query_batch_size, self._query_seq_length) + test_loss = self._loss_fxn(logits, queryY) + + return test_loss + + def train(self, episode): + """MAML training objective involves input of support and query sets.""" + feed_dict = {} + support_batch_size = np.shape(episode.support)[1] + query_batch_size = np.shape(episode.query)[1] + + supportX, supportY = convert_tokens_to_input_and_target( + episode.support, self._start_word, flatten_batch=False) + queryX, queryY = convert_tokens_to_input_and_target( + episode.query, self._start_word, flatten_batch=False) + feed_dict[self._supportX] = supportX + feed_dict[self._supportY] = supportY + feed_dict[self._queryX] = queryX + feed_dict[self._queryY] = queryY + feed_dict[self._support_batch_size] = support_batch_size + feed_dict[self._query_batch_size] = query_batch_size + feed_dict[self._support_seq_length] = [np.shape(supportX)[2]] * np.shape(supportX)[1] + feed_dict[self._query_seq_length] = [np.shape(queryX)[2]] * np.shape(queryX)[1] + + _, loss = self._sess.run( + [self._train_op, self._test_avg_neg_log], feed_dict=feed_dict) + + if self._summary_writer: + summary = tf.Summary(value=[ + tf.Summary.Value(tag='Train/loss', + simple_value=loss)]) + self._summary_writer.add_summary(summary, self._train_calls) + self._train_calls += 1 + + return loss + + def eval(self, episode): + """Perform gradients steps on support set and evaluate on query set.""" + feed_dict = {} + support_batch_size = np.shape(episode.support)[1] + query_batch_size = np.shape(episode.query)[1] + + supportX, supportY = convert_tokens_to_input_and_target( + episode.support, self._start_word, flatten_batch=False) + queryX, queryY = convert_tokens_to_input_and_target( + episode.query, self._start_word, flatten_batch=False) + feed_dict[self._supportX] = supportX + feed_dict[self._supportY] = supportY + feed_dict[self._queryX] = queryX + feed_dict[self._queryY] = queryY + feed_dict[self._support_batch_size] = support_batch_size + feed_dict[self._query_batch_size] = query_batch_size + feed_dict[self._support_seq_length] = [np.shape(supportX)[2]] * np.shape(supportX)[1] + feed_dict[self._query_seq_length] = [np.shape(queryX)[2]] * np.shape(queryX)[1] + + avg_neg_log = self._sess.run( + self._test_avg_neg_log, feed_dict=feed_dict) + + if self._summary_writer: + summary = tf.Summary(value=[ + tf.Summary.Value(tag='Eval/Avg_NLL', + simple_value=avg_neg_log)]) + self._summary_writer.add_summary(summary, self._eval_calls) + self._eval_calls += 1 + + return avg_neg_log diff --git a/src/models/nn_lib.py b/src/models/nn_lib.py new file mode 100644 index 0000000..4df9fd2 --- /dev/null +++ b/src/models/nn_lib.py @@ -0,0 +1,228 @@ +import tensorflow as tf +import numpy as np +from models.lstm_cell import LSTMCell +from models.lstm_film_cell import LSTMFilmCell + + +def num_stable_log(probs, eps=1e-7): + _epsilon = eps + return tf.log(tf.clip_by_value(probs, _epsilon, 1. - _epsilon)) + + +def make_cell(n, embd_size, hidden_size): + """Get LSTM cell.""" + """ + # Use Tensorflow Cell + return tf.contrib.rnn.BasicLSTMCell( + self._hidden_size, forget_bias=1., state_is_tuple=True) + """ + return LSTMCell(n, embd_size, hidden_size) + + +def make_cell_film(n, embd_size, cond_size, hidden_size): + """Get LSTM FILM-based cell.""" + return LSTMFilmCell(n, embd_size, cond_size, hidden_size) + + +def seq_loss(logits, Y, seq_length, time_steps, + avg_timesteps=True, avg_batch=True): + return tf.contrib.seq2seq.sequence_loss( + logits, + Y, + tf.sequence_mask(seq_length, time_steps, dtype=tf.float32), + average_across_timesteps=avg_timesteps, + average_across_batch=avg_batch) + + +def LSTM(cell, + onehot_X, + embedding, + seq_length, + batch_size, + initial_state, + enc=None, + scope="", + reuse=False): + + # [batch_size, time_step, embd_size] + inputs = tf.nn.embedding_lookup(embedding, onehot_X) + + if enc is not None: + inputs = tf.concat([inputs, enc], axis=-1) + + # outputs: [batch_size, time_step, hidden_size] + # state: [batch_size, hidden_size] + with tf.variable_scope(scope, reuse=reuse): + outputs, state = tf.nn.dynamic_rnn( + cell, inputs, initial_state=initial_state, + sequence_length=seq_length + ) + + return outputs, state + + +def LSTMFilm(cell, + onehot_X, + embedding, + seq_length, + batch_size, + initial_state, + cond, + scope="", + reuse=False): + + # [batch_size, time_step, embd_size] + inputs = tf.nn.embedding_lookup(embedding, onehot_X) + + # outputs: [batch_size, time_step, hidden_size] + # state: [batch_size, hidden_size] + with tf.variable_scope(scope, reuse=reuse): + outputs, state = tf.nn.dynamic_rnn( + cell, (inputs, cond), initial_state=initial_state, + sequence_length=seq_length + ) + + return outputs, state + + +def get_logits(hidden_states, emb_matrix, hidden_size): + hidden_states = tf.reshape(hidden_states, [-1, hidden_size]) + return tf.matmul(hidden_states, emb_matrix, transpose_b=True) + + +def masked_softmax(logits, mask): + """Masked softmax over dim 1. + + Args: + logits: (N, L) + mask: (N, L) + Returns: + probabilities (N, L) + """ + indices = tf.where(mask) + values = tf.gather_nd(logits, indices) + denseShape = tf.cast(tf.shape(logits), tf.int64) + sparseResult = tf.sparse_softmax( + tf.SparseTensor(indices, values, denseShape)) + result = tf.scatter_nd( + sparseResult.indices, sparseResult.values, sparseResult.dense_shape) + result.set_shape(logits.shape) + + return result + + +def get_seq_mask(size, time_steps): + """Get mask that doesn't allow items at curr time to be dependent on future. + + Args: + size: 1st dimension of requested output + time_steps: number of time steps being considered + Returns: + M: [size, time_steps, time_steps] tensor where for each of size + matrices, the t^th row is only True for the previous (t-1) entries + """ + # [time_steps] + lengths = tf.range(time_steps, dtype=tf.int32) + # [time_steps, time_steps] + M = tf.sequence_mask(lengths, time_steps) + # [size, time_steps, time_steps] + M = tf.tile(tf.expand_dims(M, 0), [size, 1, 1]) + + # [size*time_steps, time_steps] + M = tf.reshape(M, [-1, time_steps]) + return M + + +def get_sentinel_prob(X, hidden_states, size, + time_steps, hidden_size, input_size, + embd_size=None): + """Get probability according to pointer-sentinel mixture model. + + Args: + X: [size, time_steps] tensor + hidden_states: [size*time_steps, hidden_size] tensor + size: size of 1st dimension of X + time_steps: size of 2nd dimension of X + hidden_size: size of hidden units + input_size: size of output units + cond: [size, embd_size] + Returns: + prob_b: [size, time_steps, 1] tensor + prob_sentinel: [size, time_steps, input_size] tensor + Reference: + https://arxiv.org/abs/1609.07843 + """ + queryW = tf.get_variable( + 'queryW', [hidden_size, hidden_size]) + queryb = tf.get_variable( + 'queryb', [hidden_size]) + queryS = tf.get_variable( + 'queryS', [hidden_size, 1]) + + # [size, time_steps, hidden_size] + hidden_states_r = tf.reshape( + hidden_states, [size, time_steps, hidden_size]) + # [size, time_steps, hidden_size] + queryW = tf.tile(tf.expand_dims(queryW, 0), [size, 1, 1]) + query = tf.nn.tanh(tf.matmul(hidden_states_r, queryW) + queryb) + # [size, time_steps, hidden_size] * [size, hidden_size, time_steps] + # = [size, time_steps, time_steps] + alpha = tf.matmul(query, tf.transpose(hidden_states_r, [0, 2, 1])) + # [size*time_steps, time_steps] + alpha = tf.reshape(alpha, [-1, time_steps]) + # [size*time_steps, hidden_size] + query = tf.reshape(query, [-1, hidden_size]) + + # [size*time_steps, 1] + g = tf.matmul(query, queryS) + # [size*time_steps, time_steps + 1] + alpha_with_g = tf.concat([alpha, g], axis=-1) + + # [size*time_steps, time_steps] + mask = get_seq_mask(size, time_steps) + # [size*time_steps, time_steps + 1] + mask_with_g = tf.concat( + [mask, tf.ones([size * time_steps, 1], tf.bool)], + axis=-1) + # [size*time_steps, time_steps + 1] + prob_with_g = masked_softmax( + alpha_with_g, + mask_with_g + ) + + # [size*time_steps, time_steps] + prob_ptr = tf.slice( + prob_with_g, [0, 0], [-1, time_steps]) + # [size*time_steps, 1] + prob_g = tf.slice( + prob_with_g, [0, time_steps], [-1, 1]) + prob_g = tf.reshape(prob_g, [size, time_steps, 1]) + prob_ptr = tf.reshape( + prob_ptr, [size, time_steps, time_steps]) + + # [size, time_steps, time_steps] * [size, time_steps, input_size] + # [size, time_steps, input_size] + onehot_X = tf.one_hot(X, input_size) + prob_ptr = tf.matmul(prob_ptr, onehot_X) + + return prob_g, prob_ptr + + +def get_ndcg(rel_scores, neg_logs, rank_position): + """Compute NDCG metric for neg log rankings of songs by a model. + + Given relevancy scores of songs, model negative log likelihoods of songs, + and the rank position at which to evaluate, compute NDCG. + [https://en.wikipedia.org/wiki/Discounted_cumulative_gain] + """ + p = rank_position + _, sorted_rel_scores = ( + list(t) for t in zip(*sorted(zip(neg_logs, rel_scores)))) + + idxs = np.array(range(1, len(sorted_rel_scores) + 1)) + dcg = np.sum(sorted_rel_scores[:p] / np.log2(idxs[:p] + 1)) + + ideal_scores = sorted(rel_scores, reverse=True) + idcg = np.sum(ideal_scores[:p] / np.log2(idxs[:p] + 1)) + + return dcg / idcg diff --git a/src/models/tf_model.py b/src/models/tf_model.py index 60997ff..c663b93 100644 --- a/src/models/tf_model.py +++ b/src/models/tf_model.py @@ -3,6 +3,7 @@ import pprint from models.base_model import BaseModel +from evaluation.sampler import Sampler PP = pprint.PrettyPrinter(depth=6) @@ -78,23 +79,40 @@ def optimistic_restore(session, save_file, only_load_trainable_vars=False, class TFModel(BaseModel): def __init__(self, config): - tf.set_random_seed(config['seed']) + if 'seed' in config: + tf.set_random_seed(config['seed']) super(TFModel, self).__init__(config) + # Set up checkpoint directory self._summary_writer = None if 'checkpt_dir' in config: self._summary_writer = tf.summary.FileWriter(config['checkpt_dir']) self._train_calls = 0 self._eval_calls = 0 - self._sess = start_session() + # Set up which sampler to use + if 'sampler_type' in config: + self._sampler = Sampler(config['sampler_type']) + else: + self._sampler = Sampler() + + # Have start word that starts every song + # Have end word that ends every song + self._start_word = self._config['start_token'] + self._end_word = self._config['stop_token'] + self._input_size = self._config['input_size'] + + # Adding end word increase sequence size by +1 + self._time_steps = self._config['max_len'] + 1 + + self._sess = start_session() with tf.variable_scope(self.name): self._global_step = tf.Variable(0, trainable=False) - self._define_placedholders() + self._define_placeholders() self._build_graph() self._saver = tf.train.Saver(self.get_vars(only_trainable=False), - max_to_keep=10) + max_to_keep=5) def get_vars(self, name=None, only_trainable=True): name = name or self.name diff --git a/src/models/unigram_model.py b/src/models/unigram_model.py index f04e428..f65d0fc 100644 --- a/src/models/unigram_model.py +++ b/src/models/unigram_model.py @@ -15,7 +15,7 @@ class UnigramModel(TFModel): def __init__(self, config): super(UnigramModel, self).__init__(config) - def _define_placedholders(self): + def _define_placeholders(self): self._input_size = self._config['input_size'] self._time_steps = self._config['max_len'] @@ -71,8 +71,8 @@ def sample(self, support_set, num): pred_words = [] for i in range(num): - prob = self._sess.run(self._prob_all) - word = np.argmax(prob) + p = self._sess.run(self._prob_all) + word = self._sampler.sample(p) pred_words.append(word) return pred_words diff --git a/src/train/train.py b/src/train/train.py index 30f2bee..b0239eb 100644 --- a/src/train/train.py +++ b/src/train/train.py @@ -1,12 +1,15 @@ import os +import sys import pprint import argparse import yaml +from tqdm import tqdm from importlib import import_module -from data.episode import load_sampler_from_config +from data.episode import load_all_samplers_from_config PP = pprint.PrettyPrinter(depth=6) +LOG_FILE = 'status.log' def load_model_from_config(config): @@ -24,7 +27,7 @@ def write_seq(seq, dir, name): seq.write(os.path.join(dir, name + '.mid')) -def evaluate(model, episode_sampler, n_episodes): +def evaluate_nll(model, episode_sampler, n_episodes): avg_nll = 0. for i in range(n_episodes): episode = episode_sampler.get_episode() @@ -33,33 +36,81 @@ def evaluate(model, episode_sampler, n_episodes): return avg_nll / n_episodes +def evaluate_ndcg(model, episode_sampler, n_episodes): + avg_ndcg = 0. + for i in range(n_episodes): + episode = episode_sampler.get_episode_with_other_artists() + avg_ndcg += model.eval_ndcg(episode) + + return avg_ndcg / n_episodes + + +def write_samples(model, episode_sampler, samples_dir, n_samples, max_len): + if not os.path.exists(samples_dir): + os.makedirs(samples_dir) + + for i in range(n_samples): + curr_sample_dir = os.path.join(samples_dir, 'sample_%d' % i) + os.makedirs(curr_sample_dir) + + episode = episode_sampler.get_episode() + support_set = episode.support[0] + sample = model.sample(support_set, max_len) + + for j in range(support_set.shape[0]): + write_seq(episode_sampler.detokenize(support_set[j]), + curr_sample_dir, 'support_%d' % j) + + write_seq(episode_sampler.detokenize(sample), curr_sample_dir, + 'model_sample') + + +def create_log(checkpt_dir): + if checkpt_dir != '': + if not os.path.exists(checkpt_dir): + os.makedirs(checkpt_dir) + + sys.stdout = open(os.path.join(checkpt_dir, LOG_FILE), 'w', 0) + + parser = argparse.ArgumentParser(description='Train a model.') parser.add_argument('--data', dest='data', default='') parser.add_argument('--model', dest='model', default='') parser.add_argument('--task', dest='task', default='') +# parser.add_argument( +# '--use_negative_episodes', dest='use_negative_episodes', default=False) parser.add_argument('--checkpt_dir', dest='checkpt_dir', default='') parser.add_argument('--init_dir', dest='init_dir', default='') +parser.add_argument('--mode', dest='mode', default='train') args = parser.parse_args() def main(): + create_log(args.checkpt_dir) print('Args:') print(PP.pformat(vars(args))) config = yaml.load(open(args.data, 'r')) config.update(yaml.load(open(args.task, 'r'))) config.update(yaml.load(open(args.model, 'r'))) + # config['use_negative_episodes'] = args.use_negative_episodes config['dataset_path'] = os.path.abspath(config['dataset_path']) config['checkpt_dir'] = args.checkpt_dir print('Config:') print(PP.pformat(config)) + episode_sampler = load_all_samplers_from_config(config) + """ episode_sampler = {} for split in config['splits']: config['split'] = split episode_sampler[split] = load_sampler_from_config(config) + """ config['input_size'] = episode_sampler['train'].get_num_unique_words() + config['unk_token'] = episode_sampler['train'].get_unk_token() + config['start_token'] = episode_sampler['train'].get_start_token() + config['stop_token'] = episode_sampler['train'].get_stop_token() if not config['input_size'] > 0: raise RuntimeError( 'error reading data: %d unique tokens processed' % config['input_size']) @@ -73,57 +124,92 @@ def main(): n_samples = config['n_samples'] max_len = config['max_len'] + save_best_val = False + if 'patience_iters' in config: + save_best_val = True + patience_iters = config['patience_iters'] + model = load_model_from_config(config) model.recover_or_init(args.init_dir) - # Train model and evaluate - avg_nll = evaluate(model, episode_sampler['val'], n_val) - print("Iter: %d, val-nll: %.3e" % (0, avg_nll)) - - avg_loss = 0. - for i in range(1, n_train + 1): - episode = episode_sampler['train'].get_episode() - loss = model.train(episode) - avg_loss += loss - - if i % val_every_n == 0: - avg_nll = evaluate(model, episode_sampler['val'], n_val) - print("Iter: %d, val-nll: %.3e" % (i, avg_nll)) - - if args.checkpt_dir != '': - model.save(args.checkpt_dir) - - if i % print_every_n == 0: - print("Iter: %d, loss: %.3e" % (i, avg_loss / print_every_n)) - avg_loss = 0. - - # Evaluate model after training on training, validation, and test sets - avg_nll = evaluate(model, episode_sampler['train'], n_test) + if args.mode == 'train': + # Train model and evaluate + avg_nll = evaluate_nll(model, episode_sampler['val'], n_val) + print("Iter: %d, val-nll: %.3e" % (0, avg_nll)) + episode_sampler['val'].reset_seed() + avg_ndcg = evaluate_ndcg(model, episode_sampler['val'], n_val) + print("Iter: %d, val-ndcg: %.3e" % (0, avg_ndcg)) + episode_sampler['test'].reset_seed() + + avg_loss = 0. + best_val_nll = sys.float_info.max + for i in tqdm(range(1, n_train + 1)): + episode = episode_sampler['train'].get_episode() + loss = model.train(episode) + avg_loss += loss + + if i % print_every_n == 0: + print("Iter: %d, loss: %.3e" % (i, avg_loss / print_every_n)) + avg_loss = 0. + + if i % val_every_n == 0: + avg_nll = evaluate_nll(model, episode_sampler['val'], n_val) + print("Iter: %d, val-nll: %.3e" % (i, avg_nll)) + episode_sampler['val'].reset_seed() + + if save_best_val: + if avg_nll < best_val_nll: + best_val_nll = avg_nll + print("=> Found winner on validation set: %.3e" % best_val_nll) + if args.checkpt_dir != '': + model.save(args.checkpt_dir) + + # reset patience + patience_iters = config['patience_iters'] + # patience_iters = min( + # config['patience_iters'], patience_iters + 1) + else: + patience_iters -= 1 + print("=> Decreasing patience: %d" % patience_iters) + if patience_iters == 0: + print("=> Patience exhausted - ending training...") + break + else: + if args.checkpt_dir != '': + model.save(args.checkpt_dir) + + # Load best model + if args.checkpt_dir != '': + print("=> Loading winner on validation set") + model.recover_or_init(args.checkpt_dir) + + # Evaluate model NLL on training, validation, and test sets + episode_sampler['train'].reset_seed() + avg_nll = evaluate_nll(model, episode_sampler['train'], n_test) print("Train Avg NLL: %.3e" % (avg_nll)) - avg_nll = evaluate(model, episode_sampler['val'], n_test) + episode_sampler['val'].reset_seed() + avg_nll = evaluate_nll(model, episode_sampler['val'], n_test) print("Validation Avg NLL: %.3e" % (avg_nll)) - avg_nll = evaluate(model, episode_sampler['test'], n_test) + episode_sampler['test'].reset_seed() + avg_nll = evaluate_nll(model, episode_sampler['test'], n_test) print("Test Avg NLL: %.3e" % (avg_nll)) - # Generate samples from trained model for test episodes - samples_dir = os.path.join(args.checkpt_dir, 'samples') - if not os.path.exists(samples_dir): - os.makedirs(samples_dir) - - for i in range(n_samples): - curr_sample_dir = os.path.join(samples_dir, 'sample_%d' % i) - os.makedirs(curr_sample_dir) - - episode = episode_sampler['test'].get_episode() - support_set = episode.support[0] - sample = model.sample(support_set, max_len) + # Evaluate model ndcg + episode_sampler['train'].reset_seed() + avg_ndcg = evaluate_ndcg(model, episode_sampler['train'], n_test) + print("Train Avg ndcg: %.3e" % (avg_ndcg)) + episode_sampler['val'].reset_seed() + avg_ndcg = evaluate_ndcg(model, episode_sampler['val'], n_test) + print("Validation Avg ndcg: %.3e" % (avg_ndcg)) + episode_sampler['test'].reset_seed() + avg_ndcg = evaluate_ndcg(model, episode_sampler['test'], n_test) + print("Test Avg ndcg: %.3e" % (avg_ndcg)) - for j in range(support_set.shape[0]): - write_seq(episode_sampler['test'].detokenize(support_set[j]), - curr_sample_dir, 'support_%d' % j) - - write_seq(episode_sampler['test'].detokenize(sample), curr_sample_dir, - 'model_sample') + # Generate samples from trained model for test episodes + if args.checkpt_dir != '': + samples_dir = os.path.join(args.checkpt_dir, 'samples') + write_samples( + model, episode_sampler['test'], samples_dir, n_samples, max_len) if __name__ == '__main__':