From 1adb2eafbb554cc694a2ee8b4600b8894cb982fa Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 7 Aug 2022 12:31:34 -0700 Subject: [PATCH] softmax is already stable, also start enforcing it to be always float32 --- setup.py | 2 +- x_transformers/x_transformers.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index f6ed4bde..eb89099f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '0.32.2', + version = '0.32.3', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 474909e7..11c63d76 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -66,10 +66,6 @@ def l2norm(t, groups = 1): t = F.normalize(t, p = 2, dim = -1) return rearrange(t, '... g d -> ... (g d)') -def stable_softmax(t, dim = -1): - t = t - t.amax(dim = dim, keepdim = True).detach() - return F.softmax(t, dim = dim) - # init helpers def init_zero_(layer): @@ -573,7 +569,7 @@ def __init__( self.sparse_topk = sparse_topk # entmax - self.attn_fn = entmax15 if use_entmax15 else stable_softmax + self.attn_fn = entmax15 if use_entmax15 else partial(F.softmax, dtype = torch.float32) # add memory key / values self.num_mem_kv = num_mem_kv