Skip to content

Commit

Permalink
preposition activate function selection
Browse files Browse the repository at this point in the history
  • Loading branch information
caic99 committed Dec 26, 2024
1 parent 4645a43 commit e0e9df9
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,31 @@
class ActivationFn(torch.nn.Module):
def __init__(self, activation: Optional[str]) -> None:
super().__init__()
self.activation: str = activation if activation is not None else "linear"
self.activation = self.get_activation_fn(activation)

def get_activation_fn(self, activation: Optional[str]):
activation = activation.lower() if activation else "none"
if activation == "linear" or activation == "none":
return lambda x: x
elif activation == "relu":
return F.relu
elif activation == "gelu" or activation == "gelu_tf":
return lambda x: F.gelu(x, approximate="tanh")
elif activation == "tanh":
return F.tanh
elif activation == "relu6":
return F.relu6
elif activation == "softplus":
return F.softplus
elif activation == "sigmoid":
return F.sigmoid
else:
raise RuntimeError(f"activation function {activation} not supported")

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Returns the tensor after applying activation function corresponding to `activation`."""
# See jit supported types: https://pytorch.org/docs/stable/jit_language_reference.html#supported-type

if self.activation.lower() == "relu":
return F.relu(x)
elif self.activation.lower() == "gelu" or self.activation.lower() == "gelu_tf":
return F.gelu(x, approximate="tanh")
elif self.activation.lower() == "tanh":
return torch.tanh(x)
elif self.activation.lower() == "relu6":
return F.relu6(x)
elif self.activation.lower() == "softplus":
return F.softplus(x)
elif self.activation.lower() == "sigmoid":
return torch.sigmoid(x)
elif self.activation.lower() == "linear" or self.activation.lower() == "none":
return x
else:
raise RuntimeError(f"activation function {self.activation} not supported")
return self.activation(x)


@overload
Expand Down

0 comments on commit e0e9df9

Please sign in to comment.