Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoon Kim committed Apr 9, 2019
1 parent 5e846a4 commit af949de
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 29 deletions.
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Unsupervised Recurrent Neural Network Grammars

This is an implementation of the paper:
[Unsupervised Recurrent Neural Network Grammars](https://arxiv.org/pdf/1804.0000.pdf)
[Unsupervised Recurrent Neural Network Grammars](https://arxiv.org/abs/1904.03746)
Yoon Kim, Alexander Rush, Lei Yu, Adhiguna Kuncoro, Chris Dyer, Gabor Melis
NAACL 2019

Expand Down Expand Up @@ -29,7 +29,7 @@ python train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --sa
```
where `save_path` is where you want to save the model, and `gpu 0` is for using the first GPU
in the cluster (the mapping from PyTorch GPU index to your cluster's GPU index may vary).
Training should take 2 to 4 days depending on your setup.
Training should take 2 to 3 days depending on your setup.

To train the RNNG:
```
Expand All @@ -41,7 +41,7 @@ For fine-tuning:
```
python train.py --train_from rnng.pt --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl
--save_path rnng-urnng.pt --mode unsupervised --lr 0.1 --train_q_epochs 10 --epochs 10
--gpu 0 --kl_warmup 0
--min_epochs 6 --gpu 0 --kl_warmup 0
```

To train the LM:
Expand All @@ -58,7 +58,7 @@ python eval_ppl.py --model_file urnng.pt --test_file data/ptb-test.pkl --samples
```
The argument `samples` is for the number of importance weighted samples, and `is_temp` is for
flattening the inference network's distribution (footnote 14 in the paper).
The same evalulation code will work for RNNG.
The same evaluation code will work for RNNG.

For LM evaluation:
```
Expand All @@ -85,13 +85,13 @@ punctuation and evaluate on unlabeled F1.
Note that some of the details regarding the preprocessing is slightly different from the original
paper. In particular, in this implementation we replace singleton words a single `<unk>` token
instead of using Berkeley parser's mapping rules. This results in slight lower perplexity
for all models, since the vocabulary size is smaller. Here are the results I get
for all models, since the vocabulary size is smaller. Here are the perplexty numbers I get
in this setting:

- RNNLM: 89.2
- RNNG: 83.7
- URNNG: 85.1 (F1: 38.4)

- RNNLM: 89.2
- RNNG: 83.7
- URNNG: 85.1 (F1: 38.4)
- RNNG --> URNNG: 82.5

## Acknowledgements
Some of our preprocessing and evaluation code is based on the following repositories:
Expand Down
2 changes: 2 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __getitem__(self, idx):
binary_tree = [d[3] for d in other_data]
spans = [d[5] for d in other_data]
batch_size = self.batch_size[idx].item()
# by default, we return sents with <s> </s> tokens
# hence we subtract 2 from length as these are (by default) not counted for evaluation
data_batch = [sents[:, :length], length-2, batch_size, actions,
spans, binary_tree, other_data]
return data_batch
3 changes: 2 additions & 1 deletion eval_ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
parser.add_argument('--test_file', default='data/ptb-test.pkl')
parser.add_argument('--model_file', default='')
parser.add_argument('--is_temp', default=2., type=float, help='divide scores by is_temp before CRF')
parser.add_argument('--samples', default=1000, type=int, help='samples for IWAE calculation')
parser.add_argument('--samples', default=1000, type=int, help='samples for IS calculation')
parser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL')
parser.add_argument('--gpu', default=2, type=int, help='which gpu to use')
parser.add_argument('--seed', default=3435, type=int)
Expand Down Expand Up @@ -57,6 +57,7 @@ def main(args):
for i in list(reversed(range(len(data)))):
sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i]
if length == 1:
# length 1 sents are ignored since our generative model requires sents of length >= 2
continue
if args.count_eos_ppl == 1:
length += 1
Expand Down
21 changes: 4 additions & 17 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, i_dim = 200,
h_dim = 0,
num_layers = 1,
dropout = 0):
super(SeqLSTM, self).__init__()
super(SeqLSTM, self).__init__()
self.i_dim = i_dim
self.h_dim = h_dim
self.num_layers = num_layers
Expand All @@ -60,8 +60,6 @@ def __init__(self, i_dim = 200,
self.dropout_layer = nn.Dropout(dropout)

def forward(self, x, prev_h = None):
#x = b x i_dim
#prev_h = [(h_l, c_l) for l layers]
if prev_h is None:
prev_h = [(x.new(x.size(0), self.h_dim).fill_(0),
x.new(x.size(0), self.h_dim).fill_(0)) for _ in range(self.num_layers)]
Expand All @@ -79,17 +77,12 @@ def forward(self, x, prev_h = None):
return curr_h

class TreeLSTM(nn.Module):
def __init__(self, dim = 200,
e_dim = 0,
dropout = 0):
def __init__(self, dim = 200):
super(TreeLSTM, self).__init__()
self.dim = dim
self.e_dim = e_dim
self.linear = nn.Linear(dim*2 + e_dim, dim*5)
self.linear = nn.Linear(dim*2, dim*5)

def forward(self, x1, x2, e=None):
#x = (h, c). h, c = b x dim. hidden/cell states of children
#e = b x e_dim. external information vector
if not isinstance(x1, tuple):
x1 = (x1, None)
h1, c1 = x1
Expand All @@ -102,13 +95,7 @@ def forward(self, x1, x2, e=None):
c1 = torch.zeros_like(h1)
if c2 is None:
c2 = torch.zeros_like(h2)
if self.e_dim == 0:
concat = torch.cat([h1, h2], 1)
else:
if e is None:
concat = torch.cat([h1, h2, torch.zeros_like(h1)], 1)
else:
concat = torch.cat([h1, h2, e], 1)
concat = torch.cat([h1, h2], 1)
all_sum = self.linear(concat)
i, f1, f2, o, g = all_sum.split(self.dim, 1)

Expand Down
Empty file added pred-test.txt
Empty file.
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def main(args):
kl_pen = min(1., kl_pen + kl_warmup_batch)
sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[i]
if length == 1:
# we ignore length 1 sents during training/eval since we work with binary trees only
continue
sents = sents.cuda()
b += 1
Expand Down Expand Up @@ -250,7 +251,7 @@ def eval(data, model, samples = 0, count_eos_ppl = 0):
with torch.no_grad():
for i in list(reversed(range(len(data)))):
sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i]
if length == 1:
if length == 1: # length 1 sents are ignored since URNNG needs at least length 2 sents
continue
if args.count_eos_ppl == 1:
tree_length = length
Expand Down
2 changes: 1 addition & 1 deletion train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def eval(data, model, count_eos_ppl = 0):
with torch.no_grad():
for i in list(reversed(range(len(data)))):
sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i]
if length == 1:
if length == 1: #we ignore length 1 sents in URNNG eval so do this for LM too
continue
if args.count_eos_ppl == 1:
length += 1
Expand Down

0 comments on commit af949de

Please sign in to comment.