-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathnoise.py
136 lines (110 loc) · 4.39 KB
/
noise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import numpy as np
import re
import shutil
from tqdm import tqdm
from fairseq.tokenizer import tokenize_line
class NoiseInjector(object):
def __init__(self, corpus, shuffle_sigma=0.5,
replace_mean=0.1, replace_std=0.03,
delete_mean=0.1, delete_std=0.03,
add_mean=0.1, add_std=0.03):
# READ-ONLY, do not modify
self.corpus = corpus
self.shuffle_sigma = shuffle_sigma
self.replace_a, self.replace_b = self._solve_ab_given_mean_var(replace_mean, replace_std**2)
self.delete_a, self.delete_b = self._solve_ab_given_mean_var(delete_mean, delete_std**2)
self.add_a, self.add_b = self._solve_ab_given_mean_var(add_mean, add_std**2)
@staticmethod
def _solve_ab_given_mean_var(mean, var):
a = mean * mean * (1. - mean) / var - mean
b = (1. - mean) * (mean * (1. - mean) / var - 1.)
return a, b
def _shuffle_func(self, tgt):
if self.shuffle_sigma < 1e-6:
return tgt
shuffle_key = [i + np.random.normal(loc=0, scale=self.shuffle_sigma) for i in range(len(tgt))]
new_idx = np.argsort(shuffle_key)
res = [tgt[i] for i in new_idx]
return res
def _replace_func(self, tgt):
replace_ratio = np.random.beta(self.replace_a, self.replace_b)
ret = []
rnd = np.random.random(len(tgt))
for i, p in enumerate(tgt):
if rnd[i] < replace_ratio:
rnd_ex = self.corpus[np.random.randint(len(self.corpus))]
rnd_word = rnd_ex[np.random.randint(len(rnd_ex))]
ret.append((-1, rnd_word))
else:
ret.append(p)
return ret
def _delete_func(self, tgt):
delete_ratio = np.random.beta(self.delete_a, self.delete_b)
ret = []
rnd = np.random.random(len(tgt))
for i, p in enumerate(tgt):
if rnd[i] < delete_ratio:
continue
ret.append(p)
return ret
def _add_func(self, tgt):
add_ratio = np.random.beta(self.add_a, self.add_b)
ret = []
rnd = np.random.random(len(tgt))
for i, p in enumerate(tgt):
if rnd[i] < add_ratio:
rnd_ex = self.corpus[np.random.randint(len(self.corpus))]
rnd_word = rnd_ex[np.random.randint(len(rnd_ex))]
ret.append((-1, rnd_word))
ret.append(p)
return ret
def _parse(self, pairs):
align = []
art = []
for si in range(len(pairs)):
ti = pairs[si][0]
w = pairs[si][1]
art.append(w)
if ti >= 0:
align.append('{}-{}'.format(si, ti))
return art, align
def inject_noise(self, tokens):
# tgt is a vector of integers
funcs = [self._add_func, self._shuffle_func, self._replace_func, self._delete_func]
np.random.shuffle(funcs)
pairs = [(i, w) for (i, w) in enumerate(tokens)]
for f in funcs:
pairs = f(pairs)
art, align = self._parse(pairs)
return self._parse(pairs)
def save_file(filename, contents):
with open(filename, 'w') as ofile:
for content in contents:
ofile.write(' '.join(content) + '\n')
# make noise from filename
def noise(filename, ofile_suffix):
lines = open(filename, encoding='utf-8', errors='ignore').readlines()
tgts = [tokenize_line(line.strip()) for line in lines]
noise_injector = NoiseInjector(tgts)
srcs = []
aligns = []
for tgt in tqdm(tgts):
src, align = noise_injector.inject_noise(tgt)
srcs.append(src)
aligns.append(align)
save_file('{}.src'.format(ofile_suffix), srcs)
save_file('{}.tgt'.format(ofile_suffix), tgts)
save_file('{}.forward'.format(ofile_suffix), aligns)
import argparse
parser=argparse.ArgumentParser()
parser.add_argument('-c', '--corpus', type=str, default='data/train_1b.tgt')
parser.add_argument('-o', '--output-dir', type=str, default='data_art')
parser.add_argument('-e', '--epoch', type=int, default=10)
parser.add_argument('-s', '--seed', type=int, default=2468)
args = parser.parse_args()
np.random.seed(args.seed)
if __name__ == '__main__':
print("epoch={}, seed={}".format(args.epoch, args.seed))
filename = args.corpus.split('/')[-1].split('.')[0]
ofile_suffix = f'{args.output_dir}/{filename}_{args.epoch}'
noise(args.corpus, ofile_suffix)