-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Yoon Kim
committed
Apr 7, 2019
1 parent
456b1e6
commit 36f8cf3
Showing
16 changed files
with
4,730 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
*.pt | ||
*.amat | ||
*.mat | ||
*.out | ||
*.out~ | ||
*.pyc | ||
*.pt~ | ||
.gitignore~ | ||
*.out~ | ||
*.sh | ||
*.sh~ | ||
*.py~ | ||
*.json | ||
*.json~ | ||
*.model | ||
*.h5 | ||
*.tar.gz | ||
*.hdf5 | ||
*.dict | ||
*.pkl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
##------------------------------------------## | ||
## Debug mode ## | ||
## 0: No debugging ## | ||
## 1: print data for individual sentence ## | ||
##------------------------------------------## | ||
DEBUG 0 | ||
|
||
##------------------------------------------## | ||
## MAX error ## | ||
## Number of error to stop the process. ## | ||
## This is useful if there could be ## | ||
## tokanization error. ## | ||
## The process will stop when this number## | ||
## of errors are accumulated. ## | ||
##------------------------------------------## | ||
MAX_ERROR 10 | ||
|
||
##------------------------------------------## | ||
## Cut-off length for statistics ## | ||
## At the end of evaluation, the ## | ||
## statistics for the senetnces of length## | ||
## less than or equal to this number will## | ||
## be shown, on top of the statistics ## | ||
## for all the sentences ## | ||
##------------------------------------------## | ||
CUTOFF_LEN 10 | ||
|
||
##------------------------------------------## | ||
## unlabeled or labeled bracketing ## | ||
## 0: unlabeled bracketing ## | ||
## 1: labeled bracketing ## | ||
##------------------------------------------## | ||
LABELED 0 | ||
|
||
##------------------------------------------## | ||
## Delete labels ## | ||
## list of labels to be ignored. ## | ||
## If it is a pre-terminal label, delete ## | ||
## the word along with the brackets. ## | ||
## If it is a non-terminal label, just ## | ||
## delete the brackets (don't delete ## | ||
## deildrens). ## | ||
##------------------------------------------## | ||
DELETE_LABEL TOP | ||
DELETE_LABEL -NONE- | ||
DELETE_LABEL , | ||
DELETE_LABEL : | ||
DELETE_LABEL `` | ||
DELETE_LABEL '' | ||
DELETE_LABEL . | ||
|
||
##------------------------------------------## | ||
## Delete labels for length calculation ## | ||
## list of labels to be ignored for ## | ||
## length calculation purpose ## | ||
##------------------------------------------## | ||
DELETE_LABEL_FOR_LENGTH -NONE- | ||
|
||
##------------------------------------------## | ||
## Equivalent labels, words ## | ||
## the pairs are considered equivalent ## | ||
## This is non-directional. ## | ||
##------------------------------------------## | ||
EQ_LABEL ADVP PRT | ||
|
||
# EQ_WORD Example example |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,98 @@ | ||
# urnng | ||
# Unsupervised Recurrent Neural Network Grammars | ||
|
||
This is an implementation of the paper: | ||
[Unsupervised Recurrent Neural Network Grammars](https://arxiv.org/pdf/1804.0000.pdf) | ||
Yoon Kim, Alexander Rush, Adhiguna Kuncoro, Chris Dyer, Gabor Melis | ||
NAACL 2019 | ||
|
||
## Dependencies | ||
The code was tested in `python 3.6` and `pytorch 1.0`. | ||
|
||
## Data | ||
Sample train/val/test data is in the `data/` folder. These are the standard datasets from PTB. | ||
First preprocess the data: | ||
``` | ||
python preprocess.py --trainfile data/train.txt --valfile data/valid.txt --testfile data/test.txt | ||
--outputfile data/ptb --vocabminfreq 1 --lowercase 0 --replace_num 0 --batchsize 16 | ||
``` | ||
Running this will save the following files in the `data/` folder: `ptb-train.pkl`, `ptb-val.pkl`, | ||
`ptb-test.pkl`, `ptb.dict`. Here `ptb.dict` is the word-idx mapping, and you can change the | ||
output folder/name by changing the argument to `--outputfile`. | ||
|
||
## Training | ||
To train the URNNG: | ||
``` | ||
python train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path urnng.pt | ||
--mode unsupervised --gpu 0 | ||
``` | ||
where `--save_path` is where you want to save the model, and `--gpu 0` is for using the first GPU | ||
in the cluster (the mapping from PyTorch GPU index to your cluster's GPU index may vary). | ||
Training should take 2 to 4 days depending on your setup. | ||
|
||
To train the RNNG: | ||
``` | ||
python train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path rnng.pt | ||
--mode supervised --train_q_epochs 18 --gpu 0 | ||
``` | ||
|
||
For fine-tuning: | ||
``` | ||
python train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path rnng-urnng.pt | ||
--mode unsupervised --lr 0.1 --train_q_epochs 10 --epochs 10 --gpu 0 --kl_warmup 0 | ||
``` | ||
|
||
To train the LM: | ||
``` | ||
python train_lm.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --test_file data/ptb-test.pkl --save_path lm.pt | ||
``` | ||
|
||
## Evaluation | ||
To evaluate perplexity with importance sampling on the test set: | ||
``` | ||
python eval_ppl.py --model_file urnng.pt --test_file data/ptb-test.pkl --samples 1000 | ||
--is_temp 2 --gpu 0 | ||
``` | ||
The argument `--samples` is for the number of importance weighted samples, and `--is_temp` is for | ||
flattening the inference network's distribution (footnote 14 in the paper). | ||
The same evalulation code will work for RNNG. | ||
|
||
For LM evaluation: | ||
``` | ||
python train_lm.py --train_from lm.pt --test_file data/ptb-test.pkl --test 1 | ||
``` | ||
|
||
To evaluate F1, first we need to parse the test set: | ||
``` | ||
python parse.py --model_file urnng.pt --data_file data/ptb-test.txt --out_file pred-parse.txt | ||
--gold_file gold-parse.txt --gpu 0 | ||
``` | ||
This will output the predicted parse trees into `pred-parse.txt`. We also output a version | ||
of the gold parse `gold-parse.txt` to be used as input for `evalb`, since sentences with only trivial spans are ignored by `parse.py`. Note that corpus/sentence F1 results printed here do not correspond to the results reported in the paper, since it does not ignore punctuation. | ||
|
||
Finally, download/install `evalb`, available (here)[https://nlp.cs.nyu.edu/evalb]. | ||
Then run: | ||
``` | ||
evalb -p COLLINS.prm gold-parse.txt test-parse.txt | ||
``` | ||
where `COLLINS.prm` is the parameter file (provided in this repo) that tells `evalb` to ignore | ||
punctuation and evaluate on unlabeled F1. | ||
|
||
## Note | ||
Note that some of the details regarding the preprocessing is slightly different from the original | ||
paper. In particular, in this implementation we replace singleton words a single `<unk>` token | ||
instead of using Berkeley parser's mapping rules. This results in slight lower perplexity | ||
for all models, since the vocabulary size is smaller. Here are the results I get | ||
in this setting: | ||
|
||
- RNNLM: 89.2 | ||
- RNNG: 83.7 | ||
- URNNG: 85.1 | ||
|
||
|
||
## Acknowledgements | ||
Some of our preprocessing and evaluation code is based on the following repositories: | ||
- [Recurrent Neural Network Grammars](https://github.com/clab/rnng) | ||
- [Parsing Reading Predict Network](https://github.com/yikangshen/PRPN) | ||
|
||
## License | ||
MIT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
import itertools | ||
import utils | ||
import random | ||
|
||
class ConstituencyTreeCRF(nn.Module): | ||
def __init__(self): | ||
super(ConstituencyTreeCRF, self).__init__() | ||
self.huge = 1e9 | ||
|
||
def logadd(self, x, y): | ||
d = torch.max(x,y) | ||
return torch.log(torch.exp(x-d) + torch.exp(y-d)) + d | ||
|
||
def logsumexp(self, x, dim=1): | ||
d = torch.max(x, dim)[0] | ||
return torch.log(torch.exp(x - d.unsqueeze(dim).expand_as(x)).sum(dim)) + d | ||
|
||
def _init_table(self, scores): | ||
# initialize dynamic programming table | ||
batch_size = scores.size(0) | ||
n = scores.size(1) | ||
self.alpha = [[scores.new(batch_size).fill_(-self.huge) for _ in range(n)] for _ in range(n)] | ||
|
||
def _forward(self, scores): | ||
#inside step | ||
batch_size = scores.size(0) | ||
n = scores.size(1) | ||
self._init_table(scores) | ||
for i in range(n): | ||
self.alpha[i][i] = scores[:, i, i] | ||
for k in np.arange(1, n+1): | ||
for s in range(n): | ||
t = s + k | ||
if t > n-1: | ||
break | ||
tmp = [self.alpha[s][u] + self.alpha[u+1][t] + scores[:, s, t] for u in np.arange(s,t)] | ||
tmp = torch.stack(tmp, 1) | ||
self.alpha[s][t] = self.logsumexp(tmp, 1) | ||
|
||
def _backward(self, scores): | ||
#outside step | ||
batch_size = scores.size(0) | ||
n = scores.size(1) | ||
self.beta = [[None for _ in range(n)] for _ in range(n)] | ||
self.beta[0][n-1] = scores.new(batch_size).fill_(0) | ||
for k in np.arange(n-1, 0, -1): | ||
for s in range(n): | ||
t = s + k | ||
if t > n-1: | ||
break | ||
for u in np.arange(s, t): | ||
if s < u+1: | ||
tmp = self.beta[s][t] + self.alpha[u+1][t] + scores[:, s, t] | ||
if self.beta[s][u] is None: | ||
self.beta[s][u] = tmp | ||
else: | ||
self.beta[s][u] = self.logadd(self.beta[s][u], tmp) | ||
if u+1 < t+1: | ||
tmp = self.beta[s][t] + self.alpha[s][u] + scores[:, s, t] | ||
if self.beta[u+1][t] is None: | ||
self.beta[u+1][t] = tmp | ||
else: | ||
self.beta[u+1][t] = self.logadd(self.beta[u+1][t], tmp) | ||
|
||
def _marginal(self, scores): | ||
batch_size = scores.size(0) | ||
n = scores.size(1) | ||
self.log_marginal = [[None for _ in range(n)] for _ in range(n)] | ||
log_Z = self.alpha[0][n-1] | ||
for s in range(n): | ||
for t in np.arange(s, n): | ||
self.log_marginal[s][t] = self.alpha[s][t] + self.beta[s][t] - log_Z | ||
|
||
def _entropy(self, scores): | ||
batch_size = scores.size(0) | ||
n = scores.size(1) | ||
self.entropy = [[None for _ in range(n)] for _ in range(n)] | ||
for i in range(n): | ||
self.entropy[i][i] = scores.new(batch_size).fill_(0) | ||
for k in np.arange(1, n+1): | ||
for s in range(n): | ||
t = s + k | ||
if t > n-1: | ||
break | ||
score = [] | ||
prev_ent = [] | ||
for u in np.arange(s, t): | ||
score.append(self.alpha[s][u] + self.alpha[u+1][t]) | ||
prev_ent.append(self.entropy[s][u] + self.entropy[u+1][t]) | ||
score = torch.stack(score, 1) | ||
prev_ent = torch.stack(prev_ent, 1) | ||
log_prob = F.log_softmax(score, dim = 1) | ||
prob = log_prob.exp() | ||
entropy = ((prev_ent - log_prob)*prob).sum(1) | ||
self.entropy[s][t] = entropy | ||
|
||
|
||
def _sample(self, scores, alpha = None, argmax = False): | ||
# sample from p(tree | sent) | ||
# also get the spans | ||
if alpha is None: | ||
self._forward(scores) | ||
alpha = self.alpha | ||
batch_size = scores.size(0) | ||
n = scores.size(1) | ||
tree = scores.new(batch_size, n, n).zero_() | ||
all_log_probs = [] | ||
tree_brackets = [] | ||
spans = [] | ||
for b in range(batch_size): | ||
sampled = [(0, n-1)] | ||
span = [(0, n-1)] | ||
queue = [(0, n-1)] #start, end | ||
log_probs = [] | ||
tree_str = get_span_str(0, n-1) | ||
while len(queue) > 0: | ||
node = queue.pop(0) | ||
start, end = node | ||
left_parent = get_span_str(start, None) | ||
right_parent = get_span_str(None, end) | ||
score = [] | ||
score_idx = [] | ||
for u in np.arange(start, end): | ||
score.append(alpha[start][u][b] + alpha[u+1][end][b]) | ||
score_idx.append([(start, u), (u+1, end)]) | ||
score = torch.stack(score, 0) | ||
log_prob = F.log_softmax(score, dim = 0) | ||
if argmax: | ||
sample = torch.max(log_prob, 0)[1] | ||
else: | ||
prob = log_prob.exp() | ||
sample = torch.multinomial(log_prob.exp(), 1) | ||
sample_idx = score_idx[sample.item()] | ||
log_probs.append(log_prob[sample.item()]) | ||
for idx in sample_idx: | ||
if idx[0] != idx[1]: | ||
queue.append(idx) | ||
span.append(idx) | ||
sampled.append(idx) | ||
left_child = '(' + get_span_str(sample_idx[0][0], sample_idx[0][1]) | ||
right_child = get_span_str(sample_idx[1][0], sample_idx[1][1]) + ')' | ||
if sample_idx[0][0] != sample_idx[0][1]: | ||
tree_str = tree_str.replace(left_parent, left_child) | ||
if sample_idx[1][0] != sample_idx[1][1]: | ||
tree_str = tree_str.replace(right_parent, right_child) | ||
all_log_probs.append(torch.stack(log_probs, 0).sum(0)) | ||
tree_brackets.append(tree_str) | ||
spans.append(span[::-1]) | ||
for idx in sampled: | ||
tree[b][idx[0]][idx[1]] = 1 | ||
|
||
all_log_probs = torch.stack(all_log_probs, 0) | ||
return tree, all_log_probs, tree_brackets, spans | ||
|
||
def _viterbi(self, scores): | ||
# cky algorithm | ||
batch_size = scores.size(0) | ||
n = scores.size(1) | ||
self.max_scores = scores.new(batch_size, n, n).fill_(-self.huge) | ||
self.bp = scores.new(batch_size, n, n).zero_() | ||
self.argmax = scores.new(batch_size, n, n).zero_() | ||
self.spans = [[] for _ in range(batch_size)] | ||
tmp = scores.new(batch_size, n).zero_() | ||
for i in range(n): | ||
self.max_scores[:, i, i] = scores[:, i, i] | ||
for k in np.arange(1, n): | ||
for s in np.arange(n): | ||
t = s + k | ||
if t > n-1: | ||
break | ||
for u in np.arange(s, t): | ||
tmp = self.max_scores[:, s, u] + self.max_scores[:, u+1, t] + scores[:, s, t] | ||
self.bp[:, s, t][self.max_scores[:, s, t] < tmp] = int(u) | ||
self.max_scores[:, s, t] = torch.max(self.max_scores[:, s, t], tmp) | ||
for b in range(batch_size): | ||
self._backtrack(b, 0, n-1) | ||
return self.max_scores[:, 0, n-1], self.argmax, self.spans | ||
|
||
def _backtrack(self, b, s, t): | ||
u = int(self.bp[b][s][t]) | ||
self.argmax[b][s][t] = 1 | ||
if s == t: | ||
return None | ||
else: | ||
self.spans[b].insert(0, (s,t)) | ||
self._backtrack(b, s, u) | ||
self._backtrack(b, u+1, t) | ||
return None | ||
|
||
def get_span_str(start = None, end = None): | ||
assert(start is not None or end is not None) | ||
if start is None: | ||
return ' ' + str(end) + ')' | ||
elif end is None: | ||
return '(' + str(start) + ' ' | ||
else: | ||
return ' (' + str(start) + ' ' + str(end) + ') ' |
Oops, something went wrong.