Skip to content

Commit

Permalink
fix implementation, Q and R is shared across iterations
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 5, 2020
1 parent 633afff commit 10fad38
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
12 changes: 9 additions & 3 deletions mogrifier/mogrifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ class Mogrifier(nn.Module):
def __init__(self, dim, iters = 5, factorize_k = None):
super().__init__()
self.dim = dim
self.weights = nn.ModuleList([weight(dim, dim, factorize_k) for _ in range(iters)])
self.iters = iters

self.Q = weight(dim, dim, factorize_k)
self.R = weight(dim, dim, factorize_k) if iters > 1 else None

def forward(self, x, h):
shape = x.shape
Expand All @@ -25,8 +28,11 @@ def forward(self, x, h):

x, h = map(lambda t: t.reshape(-1, dim), (x, h))

for ind, W in enumerate(self.weights):
if (ind % 2) == 0:
for ind in range(1, self.iters + 1):
is_odd = (ind % 2) == 1
W = self.Q if is_odd else self.R

if is_odd:
x = 2 * W(h).sigmoid() * x
else:
h = 2 * W(x).sigmoid() * h
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'mogrifier',
packages = find_packages(),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'Implementation of Mogrifier circuit from Deepmind',
author = 'Phil Wang',
Expand Down

0 comments on commit 10fad38

Please sign in to comment.