diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 31162fe80e..22675d6163 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -8,6 +8,7 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from deepmd.pt.utils import ( env, @@ -202,18 +203,14 @@ def forward( ori_prec = xx.dtype if not env.DP_DTYPE_PROMOTION_STRICT: xx = xx.to(self.prec) - yy = ( - torch.matmul(xx, self.matrix) + self.bias - if self.bias is not None - else torch.matmul(xx, self.matrix) - ) - yy = self.activate(yy).clone() + yy = F.linear(xx, self.matrix.t(), self.bias) + yy = self.activate(yy) yy = yy * self.idt if self.idt is not None else yy if self.resnet: if xx.shape[-1] == yy.shape[-1]: - yy += xx + yy = yy + xx elif 2 * xx.shape[-1] == yy.shape[-1]: - yy += torch.concat([xx, xx], dim=-1) + yy = yy + torch.concat([xx, xx], dim=-1) else: yy = yy if not env.DP_DTYPE_PROMOTION_STRICT: