From 4c3ec34def59abde2651689da4d6c358eec2f8c7 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 15 May 2023 13:41:54 -0700 Subject: [PATCH] fix alibi with flash attention --- MEGABYTE_pytorch/attend.py | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/MEGABYTE_pytorch/attend.py b/MEGABYTE_pytorch/attend.py index a464e2f..ddfebb8 100644 --- a/MEGABYTE_pytorch/attend.py +++ b/MEGABYTE_pytorch/attend.py @@ -101,6 +101,7 @@ def flash_attn(self, q, k, v, mask = None, attn_bias = None): if exists(mask): attn_bias = attn_bias.masked_fill(~mask, mask_value) + mask = attn_bias causal = False # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale diff --git a/setup.py b/setup.py index 6e49c6b..dd77e0f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'MEGABYTE-pytorch', packages = find_packages(), - version = '0.0.4', + version = '0.0.5', license='MIT', description = 'MEGABYTE - Pytorch', long_description_content_type = 'text/markdown',