From 10fad3828a6ecb999f2d7e256e29c3db1bb0d8fd Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 5 Jul 2020 15:59:38 -0700 Subject: [PATCH] fix implementation, Q and R is shared across iterations --- mogrifier/mogrifier.py | 12 +++++++++--- setup.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mogrifier/mogrifier.py b/mogrifier/mogrifier.py index c00068c..8723f49 100644 --- a/mogrifier/mogrifier.py +++ b/mogrifier/mogrifier.py @@ -16,7 +16,10 @@ class Mogrifier(nn.Module): def __init__(self, dim, iters = 5, factorize_k = None): super().__init__() self.dim = dim - self.weights = nn.ModuleList([weight(dim, dim, factorize_k) for _ in range(iters)]) + self.iters = iters + + self.Q = weight(dim, dim, factorize_k) + self.R = weight(dim, dim, factorize_k) if iters > 1 else None def forward(self, x, h): shape = x.shape @@ -25,8 +28,11 @@ def forward(self, x, h): x, h = map(lambda t: t.reshape(-1, dim), (x, h)) - for ind, W in enumerate(self.weights): - if (ind % 2) == 0: + for ind in range(1, self.iters + 1): + is_odd = (ind % 2) == 1 + W = self.Q if is_odd else self.R + + if is_odd: x = 2 * W(h).sigmoid() * x else: h = 2 * W(x).sigmoid() * h diff --git a/setup.py b/setup.py index 73f6499..abf2b74 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'mogrifier', packages = find_packages(), - version = '0.0.2', + version = '0.0.3', license='MIT', description = 'Implementation of Mogrifier circuit from Deepmind', author = 'Phil Wang',