diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 6ce4f5d6fc..de0f8bc86d 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -21,28 +21,31 @@ class ActivationFn(torch.nn.Module): def __init__(self, activation: Optional[str]) -> None: super().__init__() - self.activation: str = activation if activation is not None else "linear" + self.activation = self.get_activation_fn(activation) + + def get_activation_fn(self, activation: Optional[str]): + activation = activation.lower() if activation else "none" + if activation == "linear" or activation == "none": + return lambda x: x + elif activation == "relu": + return F.relu + elif activation == "gelu" or activation == "gelu_tf": + return lambda x: F.gelu(x, approximate="tanh") + elif activation == "tanh": + return F.tanh + elif activation == "relu6": + return F.relu6 + elif activation == "softplus": + return F.softplus + elif activation == "sigmoid": + return F.sigmoid + else: + raise RuntimeError(f"activation function {activation} not supported") def forward(self, x: torch.Tensor) -> torch.Tensor: """Returns the tensor after applying activation function corresponding to `activation`.""" # See jit supported types: https://pytorch.org/docs/stable/jit_language_reference.html#supported-type - - if self.activation.lower() == "relu": - return F.relu(x) - elif self.activation.lower() == "gelu" or self.activation.lower() == "gelu_tf": - return F.gelu(x, approximate="tanh") - elif self.activation.lower() == "tanh": - return torch.tanh(x) - elif self.activation.lower() == "relu6": - return F.relu6(x) - elif self.activation.lower() == "softplus": - return F.softplus(x) - elif self.activation.lower() == "sigmoid": - return torch.sigmoid(x) - elif self.activation.lower() == "linear" or self.activation.lower() == "none": - return x - else: - raise RuntimeError(f"activation function {self.activation} not supported") + return self.activation(x) @overload