-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunsup_multialign.py
198 lines (171 loc) · 7.55 KB
/
unsup_multialign.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (c) 2019-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import io, os, ot, argparse, random
import numpy as np
from utils import *
parser = argparse.ArgumentParser(description=' ')
parser.add_argument('--embdir', default='data/', type=str)
parser.add_argument('--outdir', default='output/', type=str)
parser.add_argument('--lglist', default='en-fr-es-it-pt-de-pl-ru-da-nl-cs', type=str,
help='list of languages. The first element is the pivot. Example: en-fr-es to align English, French and Spanish with English as the pivot.')
parser.add_argument('--maxload', default=20000, type=int, help='Max number of loaded vectors')
parser.add_argument('--uniform', action='store_true', help='switch to uniform probability of picking language pairs')
# optimization parameters for the square loss
parser.add_argument('--epoch', default=2, type=int, help='nb of epochs for square loss')
parser.add_argument('--niter', default=500, type=int, help='max number of iteration per epoch for square loss')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate for square loss')
parser.add_argument('--bsz', default=500, type=int, help='batch size for square loss')
# optimization parameters for the RCSLS loss
parser.add_argument('--altepoch', default=100, type=int, help='nb of epochs for RCSLS loss')
parser.add_argument('--altlr', default=25, type=float, help='learning rate for RCSLS loss')
parser.add_argument("--altbsz", type=int, default=1000, help="batch size for RCSLS")
args = parser.parse_args()
###### SPECIFIC FUNCTIONS ######
def getknn(sc, x, y, k=10):
sidx = np.argpartition(sc, -k, axis=1)[:, -k:]
ytopk = y[sidx.flatten(), :]
ytopk = ytopk.reshape(sidx.shape[0], sidx.shape[1], y.shape[1])
f = np.sum(sc[np.arange(sc.shape[0])[:, None], sidx])
df = np.dot(ytopk.sum(1).T, x)
return f / k, df / k
def rcsls(Xi, Xj, Zi, Zj, R, knn=10):
X_trans = np.dot(Xi, R.T)
f = 2 * np.sum(X_trans * Xj)
df = 2 * np.dot(Xj.T, Xi)
fk0, dfk0 = getknn(np.dot(X_trans, Zj.T), Xi, Zj, knn)
fk1, dfk1 = getknn(np.dot(np.dot(Zi, R.T), Xj.T).T, Xj, Zi, knn)
f = f - fk0 -fk1
df = df - dfk0 - dfk1.T
return -f / Xi.shape[0], -df.T / Xi.shape[0]
def GWmatrix(emb0):
N = np.shape(emb0)[0]
N2 = .5* np.linalg.norm(emb0, axis=1).reshape(1, N)
C2 = np.tile(N2.transpose(), (1, N)) + np.tile(N2, (N, 1))
C2 -= np.dot(emb0,emb0.T)
return C2
def gromov_wasserstein(x_src, x_tgt, C2):
N = x_src.shape[0]
C1 = GWmatrix(x_src)
M = ot.gromov_wasserstein(C1,C2,np.ones(N),np.ones(N),'square_loss',epsilon=0.55,max_iter=100,tol=1e-4)
return procrustes(np.dot(M,x_tgt), x_src)
def align(EMB, TRANS, lglist, args):
nmax, l = args.maxload, len(lglist)
# create a list of language pairs to sample from
# (default == higher probability to pick a language pair contianing the pivot)
# if --uniform: uniform probability of picking a language pair
samples = []
for i in range(l):
for j in range(l):
if j == i :
continue
if j > 0 and args.uniform == False:
samples.append((0,j))
if i > 0 and args.uniform == False:
samples.append((i,0))
samples.append((i,j))
# optimization of the l2 loss
print('start optimizing L2 loss')
lr0, bsz, nepoch, niter = args.lr, args.bsz, args.epoch, args.niter
for epoch in range(nepoch):
print("start epoch %d / %d"%(epoch+1, nepoch))
ones = np.ones(bsz)
f, fold, nb, lr = 0.0, 0.0, 0.0, lr0
for it in range(niter):
if it > 1 and f > fold + 1e-3:
lr /= 2
if lr < .05:
break
fold = f
f, nb = 0.0, 0.0
for k in range(100 * (l-1)):
(i,j) = random.choice(samples)
embi = EMB[i][np.random.permutation(nmax)[:bsz], :]
embj = EMB[j][np.random.permutation(nmax)[:bsz], :]
perm = ot.sinkhorn(ones, ones, np.linalg.multi_dot([embi, -TRANS[i], TRANS[j].T,embj.T]), reg = 0.025, stopThr = 1e-3)
grad = np.linalg.multi_dot([embi.T, perm, embj])
f -= np.trace(np.linalg.multi_dot([TRANS[i].T, grad, TRANS[j]])) / embi.shape[0]
nb += 1
if i > 0:
TRANS[i] = proj_ortho(TRANS[i] + lr * np.dot(grad, TRANS[j]))
if j > 0:
TRANS[j] = proj_ortho(TRANS[j] + lr * np.dot(grad.transpose(), TRANS[i]))
print("iter %d / %d - epoch %d - loss: %.5f lr: %.4f" % (it, niter, epoch+1, f / nb , lr))
print("end of epoch %d - loss: %.5f - lr: %.4f" % (epoch+1, f / max(nb,1), lr))
niter, bsz = max(int(niter/2),2), min(1000, bsz * 2)
#end for epoch in range(nepoch):
# optimization of the RCSLS loss
print('start optimizing RCSLS loss')
f, fold, nb, lr = 0.0, 0.0, 0.0, args.altlr
for epoch in range(args.altepoch):
if epoch > 1 and f-fold > -1e-4 * abs(fold):
lr/= 2
if lr < 1e-1:
break
fold = f
f, nb = 0.0, 0.0
for k in range(round(nmax / args.altbsz) * 10 * (l-1)):
(i,j) = random.choice(samples)
sgdidx = np.random.choice(nmax, size=args.altbsz, replace=False)
embi = EMB[i][sgdidx, :]
embj = EMB[j][:nmax, :]
# crude alignment approximation:
T = np.dot(TRANS[i], TRANS[j].T)
scores = np.linalg.multi_dot([embi, T, embj.T])
perm = np.zeros_like(scores)
perm[np.arange(len(scores)), scores.argmax(1)] = 1
embj = np.dot(perm, embj)
# normalization over a subset of embeddings for speed up
fi, grad = rcsls(embi, embj, embi, embj, T.T)
f += fi
nb += 1
if i > 0:
TRANS[i] = proj_ortho(TRANS[i] - lr * np.dot(grad, TRANS[j]))
if j > 0:
TRANS[j] = proj_ortho(TRANS[j] - lr * np.dot(grad.transpose(), TRANS[i]))
print("epoch %d - loss: %.5f - lr: %.4f" % (epoch+1, f / max(nb,1), lr))
#end for epoch in range(args.altepoch):
return TRANS
def convex_init(X, Y, niter=100, reg=0.05, apply_sqrt=False):
n, d = X.shape
K_X, K_Y = np.dot(X, X.T), np.dot(Y, Y.T)
K_Y *= np.linalg.norm(K_X) / np.linalg.norm(K_Y)
K2_X, K2_Y = np.dot(K_X, K_X), np.dot(K_Y, K_Y)
P = np.ones([n, n]) / float(n)
for it in range(1, niter + 1):
G = np.dot(P, K2_X) + np.dot(K2_Y, P) - 2 * np.dot(K_Y, np.dot(P, K_X))
q = ot.sinkhorn(np.ones(n), np.ones(n), G, reg, stopThr=1e-3)
alpha = 2.0 / float(2.0 + it)
P = alpha * q + (1.0 - alpha) * P
return procrustes(np.dot(P, X), Y).T
###### MAIN ######
lglist = args.lglist.split('-')
l = len(lglist)
# embs:
EMB = {}
for i in range(l):
fn = args.embdir + '/wiki.' + lglist[i] + '.vec'
_, vecs = load_vectors(fn, maxload=args.maxload)
EMB[i] = vecs
#init
print("Computing initial bilingual apping with Gromov-Wasserstein...")
TRANS={}
maxinit = 2000
emb0 = EMB[0][:maxinit,:]
C0 = GWmatrix(emb0)
TRANS[0] = np.eye(300)
for i in range(1, l):
print("init "+lglist[i])
embi = EMB[i][:maxinit,:]
TRANS[i] = gromov_wasserstein(embi, emb0, C0)
# align
align(EMB, TRANS, lglist, args)
print('saving matrices in ' + args.outdir)
languages=''.join(lglist)
for i in range(l):
save_matrix(args.outdir + '/W-' + languages + '-' + lglist[i], TRANS[i])