Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoon Kim committed Apr 7, 2019
1 parent 456b1e6 commit 36f8cf3
Show file tree
Hide file tree
Showing 16 changed files with 4,730 additions and 1 deletion.
20 changes: 20 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
*.pt
*.amat
*.mat
*.out
*.out~
*.pyc
*.pt~
.gitignore~
*.out~
*.sh
*.sh~
*.py~
*.json
*.json~
*.model
*.h5
*.tar.gz
*.hdf5
*.dict
*.pkl
66 changes: 66 additions & 0 deletions COLLINS.prm
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
##------------------------------------------##
## Debug mode ##
## 0: No debugging ##
## 1: print data for individual sentence ##
##------------------------------------------##
DEBUG 0

##------------------------------------------##
## MAX error ##
## Number of error to stop the process. ##
## This is useful if there could be ##
## tokanization error. ##
## The process will stop when this number##
## of errors are accumulated. ##
##------------------------------------------##
MAX_ERROR 10

##------------------------------------------##
## Cut-off length for statistics ##
## At the end of evaluation, the ##
## statistics for the senetnces of length##
## less than or equal to this number will##
## be shown, on top of the statistics ##
## for all the sentences ##
##------------------------------------------##
CUTOFF_LEN 10

##------------------------------------------##
## unlabeled or labeled bracketing ##
## 0: unlabeled bracketing ##
## 1: labeled bracketing ##
##------------------------------------------##
LABELED 0

##------------------------------------------##
## Delete labels ##
## list of labels to be ignored. ##
## If it is a pre-terminal label, delete ##
## the word along with the brackets. ##
## If it is a non-terminal label, just ##
## delete the brackets (don't delete ##
## deildrens). ##
##------------------------------------------##
DELETE_LABEL TOP
DELETE_LABEL -NONE-
DELETE_LABEL ,
DELETE_LABEL :
DELETE_LABEL ``
DELETE_LABEL ''
DELETE_LABEL .

##------------------------------------------##
## Delete labels for length calculation ##
## list of labels to be ignored for ##
## length calculation purpose ##
##------------------------------------------##
DELETE_LABEL_FOR_LENGTH -NONE-

##------------------------------------------##
## Equivalent labels, words ##
## the pairs are considered equivalent ##
## This is non-directional. ##
##------------------------------------------##
EQ_LABEL ADVP PRT

# EQ_WORD Example example
99 changes: 98 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,98 @@
# urnng
# Unsupervised Recurrent Neural Network Grammars

This is an implementation of the paper:
[Unsupervised Recurrent Neural Network Grammars](https://arxiv.org/pdf/1804.0000.pdf)
Yoon Kim, Alexander Rush, Adhiguna Kuncoro, Chris Dyer, Gabor Melis
NAACL 2019

## Dependencies
The code was tested in `python 3.6` and `pytorch 1.0`.

## Data
Sample train/val/test data is in the `data/` folder. These are the standard datasets from PTB.
First preprocess the data:
```
python preprocess.py --trainfile data/train.txt --valfile data/valid.txt --testfile data/test.txt
--outputfile data/ptb --vocabminfreq 1 --lowercase 0 --replace_num 0 --batchsize 16
```
Running this will save the following files in the `data/` folder: `ptb-train.pkl`, `ptb-val.pkl`,
`ptb-test.pkl`, `ptb.dict`. Here `ptb.dict` is the word-idx mapping, and you can change the
output folder/name by changing the argument to `--outputfile`.

## Training
To train the URNNG:
```
python train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path urnng.pt
--mode unsupervised --gpu 0
```
where `--save_path` is where you want to save the model, and `--gpu 0` is for using the first GPU
in the cluster (the mapping from PyTorch GPU index to your cluster's GPU index may vary).
Training should take 2 to 4 days depending on your setup.

To train the RNNG:
```
python train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path rnng.pt
--mode supervised --train_q_epochs 18 --gpu 0
```

For fine-tuning:
```
python train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path rnng-urnng.pt
--mode unsupervised --lr 0.1 --train_q_epochs 10 --epochs 10 --gpu 0 --kl_warmup 0
```

To train the LM:
```
python train_lm.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --test_file data/ptb-test.pkl --save_path lm.pt
```

## Evaluation
To evaluate perplexity with importance sampling on the test set:
```
python eval_ppl.py --model_file urnng.pt --test_file data/ptb-test.pkl --samples 1000
--is_temp 2 --gpu 0
```
The argument `--samples` is for the number of importance weighted samples, and `--is_temp` is for
flattening the inference network's distribution (footnote 14 in the paper).
The same evalulation code will work for RNNG.

For LM evaluation:
```
python train_lm.py --train_from lm.pt --test_file data/ptb-test.pkl --test 1
```

To evaluate F1, first we need to parse the test set:
```
python parse.py --model_file urnng.pt --data_file data/ptb-test.txt --out_file pred-parse.txt
--gold_file gold-parse.txt --gpu 0
```
This will output the predicted parse trees into `pred-parse.txt`. We also output a version
of the gold parse `gold-parse.txt` to be used as input for `evalb`, since sentences with only trivial spans are ignored by `parse.py`. Note that corpus/sentence F1 results printed here do not correspond to the results reported in the paper, since it does not ignore punctuation.

Finally, download/install `evalb`, available (here)[https://nlp.cs.nyu.edu/evalb].
Then run:
```
evalb -p COLLINS.prm gold-parse.txt test-parse.txt
```
where `COLLINS.prm` is the parameter file (provided in this repo) that tells `evalb` to ignore
punctuation and evaluate on unlabeled F1.

## Note
Note that some of the details regarding the preprocessing is slightly different from the original
paper. In particular, in this implementation we replace singleton words a single `<unk>` token
instead of using Berkeley parser's mapping rules. This results in slight lower perplexity
for all models, since the vocabulary size is smaller. Here are the results I get
in this setting:

- RNNLM: 89.2
- RNNG: 83.7
- URNNG: 85.1


## Acknowledgements
Some of our preprocessing and evaluation code is based on the following repositories:
- [Recurrent Neural Network Grammars](https://github.com/clab/rnng)
- [Parsing Reading Predict Network](https://github.com/yikangshen/PRPN)

## License
MIT
203 changes: 203 additions & 0 deletions TreeCRF.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
#!/usr/bin/env python3

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import itertools
import utils
import random

class ConstituencyTreeCRF(nn.Module):
def __init__(self):
super(ConstituencyTreeCRF, self).__init__()
self.huge = 1e9

def logadd(self, x, y):
d = torch.max(x,y)
return torch.log(torch.exp(x-d) + torch.exp(y-d)) + d

def logsumexp(self, x, dim=1):
d = torch.max(x, dim)[0]
return torch.log(torch.exp(x - d.unsqueeze(dim).expand_as(x)).sum(dim)) + d

def _init_table(self, scores):
# initialize dynamic programming table
batch_size = scores.size(0)
n = scores.size(1)
self.alpha = [[scores.new(batch_size).fill_(-self.huge) for _ in range(n)] for _ in range(n)]

def _forward(self, scores):
#inside step
batch_size = scores.size(0)
n = scores.size(1)
self._init_table(scores)
for i in range(n):
self.alpha[i][i] = scores[:, i, i]
for k in np.arange(1, n+1):
for s in range(n):
t = s + k
if t > n-1:
break
tmp = [self.alpha[s][u] + self.alpha[u+1][t] + scores[:, s, t] for u in np.arange(s,t)]
tmp = torch.stack(tmp, 1)
self.alpha[s][t] = self.logsumexp(tmp, 1)

def _backward(self, scores):
#outside step
batch_size = scores.size(0)
n = scores.size(1)
self.beta = [[None for _ in range(n)] for _ in range(n)]
self.beta[0][n-1] = scores.new(batch_size).fill_(0)
for k in np.arange(n-1, 0, -1):
for s in range(n):
t = s + k
if t > n-1:
break
for u in np.arange(s, t):
if s < u+1:
tmp = self.beta[s][t] + self.alpha[u+1][t] + scores[:, s, t]
if self.beta[s][u] is None:
self.beta[s][u] = tmp
else:
self.beta[s][u] = self.logadd(self.beta[s][u], tmp)
if u+1 < t+1:
tmp = self.beta[s][t] + self.alpha[s][u] + scores[:, s, t]
if self.beta[u+1][t] is None:
self.beta[u+1][t] = tmp
else:
self.beta[u+1][t] = self.logadd(self.beta[u+1][t], tmp)

def _marginal(self, scores):
batch_size = scores.size(0)
n = scores.size(1)
self.log_marginal = [[None for _ in range(n)] for _ in range(n)]
log_Z = self.alpha[0][n-1]
for s in range(n):
for t in np.arange(s, n):
self.log_marginal[s][t] = self.alpha[s][t] + self.beta[s][t] - log_Z

def _entropy(self, scores):
batch_size = scores.size(0)
n = scores.size(1)
self.entropy = [[None for _ in range(n)] for _ in range(n)]
for i in range(n):
self.entropy[i][i] = scores.new(batch_size).fill_(0)
for k in np.arange(1, n+1):
for s in range(n):
t = s + k
if t > n-1:
break
score = []
prev_ent = []
for u in np.arange(s, t):
score.append(self.alpha[s][u] + self.alpha[u+1][t])
prev_ent.append(self.entropy[s][u] + self.entropy[u+1][t])
score = torch.stack(score, 1)
prev_ent = torch.stack(prev_ent, 1)
log_prob = F.log_softmax(score, dim = 1)
prob = log_prob.exp()
entropy = ((prev_ent - log_prob)*prob).sum(1)
self.entropy[s][t] = entropy


def _sample(self, scores, alpha = None, argmax = False):
# sample from p(tree | sent)
# also get the spans
if alpha is None:
self._forward(scores)
alpha = self.alpha
batch_size = scores.size(0)
n = scores.size(1)
tree = scores.new(batch_size, n, n).zero_()
all_log_probs = []
tree_brackets = []
spans = []
for b in range(batch_size):
sampled = [(0, n-1)]
span = [(0, n-1)]
queue = [(0, n-1)] #start, end
log_probs = []
tree_str = get_span_str(0, n-1)
while len(queue) > 0:
node = queue.pop(0)
start, end = node
left_parent = get_span_str(start, None)
right_parent = get_span_str(None, end)
score = []
score_idx = []
for u in np.arange(start, end):
score.append(alpha[start][u][b] + alpha[u+1][end][b])
score_idx.append([(start, u), (u+1, end)])
score = torch.stack(score, 0)
log_prob = F.log_softmax(score, dim = 0)
if argmax:
sample = torch.max(log_prob, 0)[1]
else:
prob = log_prob.exp()
sample = torch.multinomial(log_prob.exp(), 1)
sample_idx = score_idx[sample.item()]
log_probs.append(log_prob[sample.item()])
for idx in sample_idx:
if idx[0] != idx[1]:
queue.append(idx)
span.append(idx)
sampled.append(idx)
left_child = '(' + get_span_str(sample_idx[0][0], sample_idx[0][1])
right_child = get_span_str(sample_idx[1][0], sample_idx[1][1]) + ')'
if sample_idx[0][0] != sample_idx[0][1]:
tree_str = tree_str.replace(left_parent, left_child)
if sample_idx[1][0] != sample_idx[1][1]:
tree_str = tree_str.replace(right_parent, right_child)
all_log_probs.append(torch.stack(log_probs, 0).sum(0))
tree_brackets.append(tree_str)
spans.append(span[::-1])
for idx in sampled:
tree[b][idx[0]][idx[1]] = 1

all_log_probs = torch.stack(all_log_probs, 0)
return tree, all_log_probs, tree_brackets, spans

def _viterbi(self, scores):
# cky algorithm
batch_size = scores.size(0)
n = scores.size(1)
self.max_scores = scores.new(batch_size, n, n).fill_(-self.huge)
self.bp = scores.new(batch_size, n, n).zero_()
self.argmax = scores.new(batch_size, n, n).zero_()
self.spans = [[] for _ in range(batch_size)]
tmp = scores.new(batch_size, n).zero_()
for i in range(n):
self.max_scores[:, i, i] = scores[:, i, i]
for k in np.arange(1, n):
for s in np.arange(n):
t = s + k
if t > n-1:
break
for u in np.arange(s, t):
tmp = self.max_scores[:, s, u] + self.max_scores[:, u+1, t] + scores[:, s, t]
self.bp[:, s, t][self.max_scores[:, s, t] < tmp] = int(u)
self.max_scores[:, s, t] = torch.max(self.max_scores[:, s, t], tmp)
for b in range(batch_size):
self._backtrack(b, 0, n-1)
return self.max_scores[:, 0, n-1], self.argmax, self.spans

def _backtrack(self, b, s, t):
u = int(self.bp[b][s][t])
self.argmax[b][s][t] = 1
if s == t:
return None
else:
self.spans[b].insert(0, (s,t))
self._backtrack(b, s, u)
self._backtrack(b, u+1, t)
return None

def get_span_str(start = None, end = None):
assert(start is not None or end is not None)
if start is None:
return ' ' + str(end) + ')'
elif end is None:
return '(' + str(start) + ' '
else:
return ' (' + str(start) + ' ' + str(end) + ') '
Loading

0 comments on commit 36f8cf3

Please sign in to comment.