From 4c98d0b41f38ee638a979064856ae06fc1aec8b6 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 26 Jul 2023 09:39:37 -1000 Subject: [PATCH] [MLP] Edit ParallelGatedMlp --- flash_attn/modules/mlp.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/flash_attn/modules/mlp.py b/flash_attn/modules/mlp.py index 0e74d19b6..d237253a7 100644 --- a/flash_attn/modules/mlp.py +++ b/flash_attn/modules/mlp.py @@ -11,10 +11,9 @@ ColumnParallelLinear, RowParallelLinear = None, None try: - from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP, ColumnParallelLinear, RowParallelLinear + from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP except ImportError: FusedMLP, ParallelFusedMLP = None, None - ColumnParallelLinear, RowParallelLinear = None, None class Mlp(nn.Module): @@ -87,25 +86,31 @@ def forward(self, x): return y if not self.return_residual else (y, x) -class ParallelGatedMlp(GatedMlp): +class ParallelGatedMlp(nn.Module): """ Parallel GatedMlp """ - def __init__(self, in_features, process_group, hidden_features=None, out_features=None, activation=F.sigmoid, - bias1=True, bias2=True, multiple_of=256, return_residual=False, + def __init__(self, in_features, process_group, hidden_features=None, out_features=None, + activation=F.sigmoid, bias1=True, bias2=True, multiple_of=256, sequence_parallel=True, device=None, dtype=None): factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__(in_features, hidden_features=hidden_features, out_features=out_features, activation=activation, - bias1=bias1, bias2=bias2, multiple_of=multiple_of, return_residual=return_residual, - device=device, dtype=dtype) + super().__init__() out_features = out_features or in_features hidden_features = hidden_features or int(8 * in_features / 3) hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of - if ColumnParallelLinear is None or RowParallelLinear is None: raise ImportError('fused_dense is not installed') - self.fc1 = ColumnParallelLinear(in_features, 2 * hidden_features, process_group, - bias=bias1, + self.fc1 = ColumnParallelLinear(in_features, 2 * hidden_features, process_group, bias=bias1, sequence_parallel=sequence_parallel, **factory_kwargs) - self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, - bias=bias2, + self.activation = activation + self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2, sequence_parallel=sequence_parallel, **factory_kwargs) + + def forward(self, x): + y = self.fc1(x) + if self.activation == F.sigmoid: # Special case for GLU + y = F.glu(y, dim=-1) + else: + y, gate = y.chunk(2, dim=-1) + y = y * self.activation(gate) + y = self.fc2(y) + return y