Skip to content

Commit

Permalink
added deriv calc for comparisons. Even though completly useless
Browse files Browse the repository at this point in the history
  • Loading branch information
konrad.kraemer committed Jun 4, 2024
1 parent 85afc5e commit 08107fe
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 34 deletions.
72 changes: 72 additions & 0 deletions include/etr_bits/Core/Concepts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,78 @@ concept IsTimes = requires(T t) {
TimesTrait>::value;
};

template <typename T>
concept IsEqual = requires(T t) {
typename std::remove_reference<decltype(t)>::type::CaseTrait;
typename std::remove_reference<decltype(t)>::type::TypeTrait;
requires std::is_same<
typename std::remove_reference<decltype(t)>::type::CaseTrait,
BinaryTrait>::value;
requires std::is_same<
typename std::remove_reference<decltype(t)>::type::TypeTrait,
EqualTrait>::value;
};

template <typename T>
concept IsSmaller = requires(T t) {
typename std::remove_reference<decltype(t)>::type::CaseTrait;
typename std::remove_reference<decltype(t)>::type::TypeTrait;
requires std::is_same<
typename std::remove_reference<decltype(t)>::type::CaseTrait,
BinaryTrait>::value;
requires std::is_same<
typename std::remove_reference<decltype(t)>::type::TypeTrait,
SmallerTrait>::value;
};

template <typename T>
concept IsLarger = requires(T t) {
typename std::remove_reference<decltype(t)>::type::CaseTrait;
typename std::remove_reference<decltype(t)>::type::TypeTrait;
requires std::is_same<
typename std::remove_reference<decltype(t)>::type::CaseTrait,
BinaryTrait>::value;
requires std::is_same<
typename std::remove_reference<decltype(t)>::type::TypeTrait,
LargerTrait>::value;
};

template <typename T>
concept IsLargerEqual = requires(T t) {
typename std::remove_reference<decltype(t)>::type::CaseTrait;
typename std::remove_reference<decltype(t)>::type::TypeTrait;
requires std::is_same<
typename std::remove_reference<decltype(t)>::type::CaseTrait,
BinaryTrait>::value;
requires std::is_same<
typename std::remove_reference<decltype(t)>::type::TypeTrait,
LargerEqualTrait>::value;
};

template <typename T>
concept IsSmallerEqual = requires(T t) {
typename std::remove_reference<decltype(t)>::type::CaseTrait;
typename std::remove_reference<decltype(t)>::type::TypeTrait;
requires std::is_same<
typename std::remove_reference<decltype(t)>::type::CaseTrait,
BinaryTrait>::value;
requires std::is_same<
typename std::remove_reference<decltype(t)>::type::TypeTrait,
SmallerEqualTrait>::value;
};

template <typename T>
concept IsUnequal = requires(T t) {
typename std::remove_reference<decltype(t)>::type::CaseTrait;
typename std::remove_reference<decltype(t)>::type::TypeTrait;
requires std::is_same<
typename std::remove_reference<decltype(t)>::type::CaseTrait,
BinaryTrait>::value;
requires std::is_same<
typename std::remove_reference<decltype(t)>::type::TypeTrait,
UnEqualTrait>::value;
};

template <typename T>
concept IsSinus = requires(T t) {
typename std::remove_reference<decltype(t)>::type::CaseTrait;
Expand Down
31 changes: 25 additions & 6 deletions include/etr_bits/Core/Traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,18 @@ struct EqualDerivTrait {
template <typename L = BaseType, typename R = BaseType>
static inline bool
f(L a,
R b) { // issue: add this to documentationion for package authors
R b) { // TODO: add this to documentationion for package authors
if (fabs(a - b) < 1E-3) {
return true;
} else {
return false;
}
}

static inline bool fDeriv() { return false; }
template <typename L = BaseType, typename R = BaseType>
static inline std::common_type<L, R>::type fDeriv(L l, R r) {
return false;
}
};
struct SmallerDerivTrait {
template <typename L = BaseType, typename R = BaseType>
Expand All @@ -162,7 +165,10 @@ struct SmallerDerivTrait {
return false;
}
}
static inline bool fDeriv() { return false; }
template <typename L = BaseType, typename R = BaseType>
static inline std::common_type<L, R>::type fDeriv(L l, R r) {
return false;
}
};
struct SmallerEqualDerivTrait {
template <typename L = BaseType, typename R = BaseType>
Expand All @@ -173,7 +179,10 @@ struct SmallerEqualDerivTrait {
return false;
}
}
static inline bool fDeriv() { return false; }
template <typename L = BaseType, typename R = BaseType>
static inline std::common_type<L, R>::type fDeriv(L l, R r) {
return false;
}
};
struct LargerDerivTrait {
template <typename L = BaseType, typename R = BaseType>
Expand All @@ -184,7 +193,10 @@ struct LargerDerivTrait {
return false;
}
}
static inline bool fDeriv() { return false; }
template <typename L = BaseType, typename R = BaseType>
static inline std::common_type<L, R>::type fDeriv(L l, R r) {
return false;
}
};
struct LargerEqualDerivTrait {
template <typename L = BaseType, typename R = BaseType>
Expand All @@ -195,6 +207,10 @@ struct LargerEqualDerivTrait {
return false;
}
}
template <typename L = BaseType, typename R = BaseType>
static inline std::common_type<L, R>::type fDeriv(L l, R r) {
return false;
}
};
struct UnEqualDerivTrait {
template <typename L = BaseType, typename R = BaseType>
Expand All @@ -205,7 +221,10 @@ struct UnEqualDerivTrait {
return false;
}
}
static inline bool fDeriv() { return false; }
template <typename L = BaseType, typename R = BaseType>
static inline std::common_type<L, R>::type fDeriv(L l, R r) {
return false;
}
};

struct SinusDerivTrait {
Expand Down
56 changes: 55 additions & 1 deletion include/etr_bits/Vector/DerivativeCalc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ static constexpr auto walkTD() {
return produceQuarternyType<decltype(LT), decltype(RT), decltype(LDeriv),
decltype(RDeriv), QuarternaryTrait,
DivideByConstantDerivTrait>();

// TODO: add case if L is a constant. This has to be added to all constants
} else {
return produceQuarternyType<decltype(LT), decltype(RT), decltype(LDeriv),
decltype(RDeriv), QuarternaryTrait,
Expand All @@ -99,6 +99,60 @@ static constexpr auto walkTD() {
etr::MinusDerivTrait>();
}

template <typename TD>
requires IsEqual<TD>
static constexpr auto walkTD() {
constexpr auto LDeriv = walkTD<typename TD::typeTraitL>();
constexpr auto RDeriv = walkTD<typename TD::typeTraitR>();
return produceBinaryType<decltype(LDeriv), decltype(RDeriv), BinaryTrait,
EqualDerivTrait>();
}

template <typename TD>
requires IsUnequal<TD>
static constexpr auto walkTD() {
constexpr auto LDeriv = walkTD<typename TD::typeTraitL>();
constexpr auto RDeriv = walkTD<typename TD::typeTraitR>();
return produceBinaryType<decltype(LDeriv), decltype(RDeriv), BinaryTrait,
UnEqualDerivTrait>();
}

template <typename TD>
requires IsSmaller<TD>
static constexpr auto walkTD() {
constexpr auto LDeriv = walkTD<typename TD::typeTraitL>();
constexpr auto RDeriv = walkTD<typename TD::typeTraitR>();
return produceBinaryType<decltype(LDeriv), decltype(RDeriv), BinaryTrait,
SmallerDerivTrait>();
}

template <typename TD>
requires IsLarger<TD>
static constexpr auto walkTD() {
constexpr auto LDeriv = walkTD<typename TD::typeTraitL>();
constexpr auto RDeriv = walkTD<typename TD::typeTraitR>();
return produceBinaryType<decltype(LDeriv), decltype(RDeriv), BinaryTrait,
LargerDerivTrait>();
}

template <typename TD>
requires IsSmallerEqual<TD>
static constexpr auto walkTD() {
constexpr auto LDeriv = walkTD<typename TD::typeTraitL>();
constexpr auto RDeriv = walkTD<typename TD::typeTraitR>();
return produceBinaryType<decltype(LDeriv), decltype(RDeriv), BinaryTrait,
SmallerEqualDerivTrait>();
}

template <typename TD>
requires IsLargerEqual<TD>
static constexpr auto walkTD() {
constexpr auto LDeriv = walkTD<typename TD::typeTraitL>();
constexpr auto RDeriv = walkTD<typename TD::typeTraitR>();
return produceBinaryType<decltype(LDeriv), decltype(RDeriv), BinaryTrait,
LargerEqualDerivTrait>();
}

template <typename TD>
requires IsSinus<TD>
static constexpr auto walkTD() {
Expand Down
116 changes: 89 additions & 27 deletions tests/Derivatives_Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,31 +95,93 @@ int main() {
print(get_derivs(b));
}

/*
std::cout << "\n"
<< "a = a - b" << "\n"
<< "b = a * a" << std::endl;
a = a - b;
b = a * a;
print(a, av);
print(b, av);
print(get_derivs(a));
print(get_derivs(b));
std::cout << "\n"
<< "a = b" << std::endl;
a = b;
print(a, av);
print(b, av);
print(get_derivs(a));
print(get_derivs(b));
std::cout << "\n"
<< "a = a / b" << std::endl;
a = a / scalarDeriv<3>(av, 3.14);
print(a, av);
print(b, av);
print(get_derivs(a));
print(get_derivs(b));
*/
// NOTE: test comparisons
{
std::cout << "test equal" << std::endl;
AllVars<2, 0, 0, 3> av(0, 0);
Vec<double, VarPointer<decltype(av), 0, 0>, VariableTypeTrait> a(av);
Vec<double, VarPointer<decltype(av), 1, 0>, VariableTypeTrait> b(av);
a = coca<0>(av, 3, 6, 9, 12);
b = coca<1>(av, 4, 5, 6, 7);
b = a == b;
print(a, av);
print(b, av);
print(get_derivs(a));
print(get_derivs(b));
}

// NOTE: test unequal
{
std::cout << "test unequal" << std::endl;
AllVars<2, 0, 0, 3> av(0, 0);
Vec<double, VarPointer<decltype(av), 0, 0>, VariableTypeTrait> a(av);
Vec<double, VarPointer<decltype(av), 1, 0>, VariableTypeTrait> b(av);
a = coca<0>(av, 3, 6, 9, 12);
b = coca<1>(av, 4, 5, 6, 7);
b = a != b;
print(a, av);
print(b, av);
print(get_derivs(a));
print(get_derivs(b));
}

// NOTE: test smaller
{
std::cout << "test smaller" << std::endl;
AllVars<2, 0, 0, 3> av(0, 0);
Vec<double, VarPointer<decltype(av), 0, 0>, VariableTypeTrait> a(av);
Vec<double, VarPointer<decltype(av), 1, 0>, VariableTypeTrait> b(av);
a = coca<0>(av, 3, 6, 9, 12);
b = coca<1>(av, 4, 5, 6, 7);
b = a < b;
print(a, av);
print(b, av);
print(get_derivs(a));
print(get_derivs(b));
}

// NOTE: test larger
{
std::cout << "test larger" << std::endl;
AllVars<2, 0, 0, 3> av(0, 0);
Vec<double, VarPointer<decltype(av), 0, 0>, VariableTypeTrait> a(av);
Vec<double, VarPointer<decltype(av), 1, 0>, VariableTypeTrait> b(av);
a = coca<0>(av, 3, 6, 9, 12);
b = coca<1>(av, 4, 5, 6, 7);
b = a > b;
print(a, av);
print(b, av);
print(get_derivs(a));
print(get_derivs(b));
}

// NOTE: test larger equal
{
std::cout << "test larger equal" << std::endl;
AllVars<2, 0, 0, 3> av(0, 0);
Vec<double, VarPointer<decltype(av), 0, 0>, VariableTypeTrait> a(av);
Vec<double, VarPointer<decltype(av), 1, 0>, VariableTypeTrait> b(av);
a = coca<0>(av, 3, 6, 9, 12);
b = coca<1>(av, 4, 5, 6, 7);
b = a >= b;
print(a, av);
print(b, av);
print(get_derivs(a));
print(get_derivs(b));
}

// NOTE: test smaller equal
{
std::cout << "test smaller equal" << std::endl;
AllVars<2, 0, 0, 3> av(0, 0);
Vec<double, VarPointer<decltype(av), 0, 0>, VariableTypeTrait> a(av);
Vec<double, VarPointer<decltype(av), 1, 0>, VariableTypeTrait> b(av);
a = coca<0>(av, 3, 6, 9, 12);
b = coca<1>(av, 4, 5, 6, 7);
b = a <= b;
print(a, av);
print(b, av);
print(get_derivs(a));
print(get_derivs(b));
}
}

0 comments on commit 08107fe

Please sign in to comment.