-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathactivationFuncsFactory.cpp
52 lines (43 loc) · 1.39 KB
/
activationFuncsFactory.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <functional>
#include <string>
#include <unordered_map>
#include <utility>
using ActivationFunction = std::function<double(double)>;
using ActivationDerivative = std::function<double(double)>;
using ActivationPair = std::pair<ActivationFunction, ActivationDerivative>;
double sigmoid(double x) {
return 1.0 / (1.0 + exp(-x));
}
double sigmoid_derivative(double x) {
double s = sigmoid(x);
return s * (1 - s);
}
double relu(double x) {
return std::max(0.0, x);
}
double relu_derivative(double x) {
return x > 0 ? 1.0 : 0.0;
}
double tanhAct(double input) { // (e^(x) - e^(-x)) / (e^(x) + e^(-x))
return tanh(input);
}
double tanh_derivative(double input) {
// Tanh derivative: 1 - tanh^2(x)
double tanhValue = tanhAct(input);
return 1.0 - tanhValue * tanhValue;
}
ActivationPair getActivationFunctions(const std::string& name) {
static std::unordered_map<std::string, ActivationPair> activation_map = {
{"sigmoid", {sigmoid, sigmoid_derivative}},
{"relu", {relu, relu_derivative}},
{"tanh", {tanhAct, tanh_derivative}},
// ... Add other mappings as needed ...
};
auto it = activation_map.find(name);
if (it != activation_map.end()) {
return it->second;
} else {
throw std::invalid_argument("Unknown activation function: " + name);
}
}
// auto [actFunc, actDeriv] = getActivationFunctions("sigmoid");