Skip to content

Commit

Permalink
Add missing flags.
Browse files Browse the repository at this point in the history
  • Loading branch information
mrdrozdov committed Aug 11, 2019
1 parent 277709e commit 0c398fd
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pytorch/diora/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def make_batch_iterator(options, dset, shuffle=True, include_partial=False, filt
vocab_size = len(word2idx)

negative_sampler = None
if options.reconstruct_mode == 'margin':
if options.reconstruct_mode in ('margin', 'softmax'):
freq_dist = calculate_freq_dist(sentences, vocab_size)
negative_sampler = NegativeSampler(freq_dist=freq_dist, dist_power=options.freq_dist_power)
vocab_lst = [w for w, _ in sorted(word2idx.items(), key=lambda x: x[1])]
Expand Down
6 changes: 5 additions & 1 deletion pytorch/diora/net/diora.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,11 @@ def initialize_outside_root(self):
h = torch.matmul(self.inside_h[:, -1:], self.root_mat_out)
else:
h = self.root_vector_out_h.view(1, 1, D).expand(B, 1, D)
c = self.root_vector_out_c.view(1, 1, D).expand(B, 1, D)
if self.root_vector_out_c is None:
device = torch.cuda.current_device() if self.is_cuda else None
c = torch.full(h.shape, 0, dtype=torch.float32, device=device)
else:
c = self.root_vector_out_c.view(1, 1, D).expand(B, 1, D)

h = normalize_func(h)
c = normalize_func(c)
Expand Down
4 changes: 2 additions & 2 deletions pytorch/diora/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,15 @@ def argument_parser():
parser.add_argument('--validation_filter_length', default=0, type=int)

# Model.
parser.add_argument('--arch', default='treelstm', choices=('treelstm',))
parser.add_argument('--arch', default='treelstm', choices=('treelstm', 'mlp', 'mlp-shared'))
parser.add_argument('--hidden_dim', default=10, type=int)
parser.add_argument('--normalize', default='unit', choices=('none', 'unit'))
parser.add_argument('--compress', action='store_true',
help='If true, then copy root from inside chart for outside. ' + \
'Otherwise, learn outside root as bias.')

# Model (Objective).
parser.add_argument('--reconstruct_mode', default='margin', choices=('margin',))
parser.add_argument('--reconstruct_mode', default='margin', choices=('margin', 'softmax'))

# Model (Embeddings).
parser.add_argument('--emb', default='w2v', choices=('w2v', 'elmo', 'both'))
Expand Down

0 comments on commit 0c398fd

Please sign in to comment.