Skip to content

Commit

Permalink
Refactor compiletime derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
dschwen committed Jan 8, 2025
1 parent cf0cd9d commit fd40509
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 50 deletions.
113 changes: 65 additions & 48 deletions framework/include/utils/CompileTimeDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,13 @@ class CTArrayRef : public CTBase
{
public:
CTArrayRef(const T & arr, const I & idx) : _arr(arr), _idx(idx) {}
auto operator()() const { return _arr[_idx]; }

// get the value type returned by operator[]
using ResultType = CTCleanType<decltype(std::declval<T>()[std::declval<I>()])>;
static_assert(!std::is_same_v<ResultType, void>,
"Instantiation of CTArrayRef was attempted for a non-subscriptable type.");
ResultType operator()() const { return _arr[_idx]; }

std::string print() const { return "[a" + printTag<tag>() + "[" + Moose::stringify(_idx) + "]]"; }

template <CTTag dtag>
Expand All @@ -412,11 +418,6 @@ class CTArrayRef : public CTBase
return CTNull<ResultType>();
}

// get the value type returned by operator[]
typedef CTCleanType<decltype((static_cast<T>(0))[0])> ResultType;
static_assert(!std::is_same_v<ResultType, void>,
"Instantiation of CTArrayRef was attempted for a non-subscriptable type.");

protected:
const T & _arr;
const I & _idx;
Expand All @@ -440,23 +441,20 @@ class CTAdd : public CTBinary<L, R>
{
public:
CTAdd(L left, R right) : CTBinary<L, R>(left, right) {}
using typename CTBinary<L, R>::ResultType;

ResultType operator()() const
auto operator()() const
{
// compile time optimization to skip null terms
if constexpr (std::is_base_of<CTNullBase, L>::value && std::is_base_of<CTNullBase, R>::value)
return ResultType(0);

if constexpr (std::is_base_of<CTNullBase, L>::value)
return _right();

if constexpr (std::is_base_of<CTNullBase, R>::value)
else if constexpr (std::is_base_of<CTNullBase, R>::value)
return _left();

else
return _left() + _right();
}

std::string print() const { return this->printParens(this, "+"); }
constexpr static int precedence() { return 6; }

Expand All @@ -478,17 +476,16 @@ class CTSub : public CTBinary<L, R>
{
public:
CTSub(L left, R right) : CTBinary<L, R>(left, right) {}
using typename CTBinary<L, R>::ResultType;

ResultType operator()() const
auto operator()() const
{
if constexpr (std::is_base_of<CTNullBase, L>::value && std::is_base_of<CTNullBase, R>::value)
return ResultType(0);
return decltype(_left() - _right())(0);

if constexpr (std::is_base_of<CTNullBase, L>::value)
else if constexpr (std::is_base_of<CTNullBase, L>::value)
return -_right();

if constexpr (std::is_base_of<CTNullBase, R>::value)
else if constexpr (std::is_base_of<CTNullBase, R>::value)
return _left();

else
Expand Down Expand Up @@ -516,20 +513,19 @@ class CTMul : public CTBinary<L, R>
{
public:
CTMul(L left, R right) : CTBinary<L, R>(left, right) {}
using typename CTBinary<L, R>::ResultType;

ResultType operator()() const
auto operator()() const
{
if constexpr (std::is_base_of<CTNullBase, L>::value || std::is_base_of<CTNullBase, R>::value)
return ResultType(0);
return decltype(_left() * _right())(0);

if constexpr (std::is_base_of<CTOneBase, L>::value && std::is_base_of<CTOneBase, R>::value)
return ResultType(1);
else if constexpr (std::is_base_of<CTOneBase, L>::value && std::is_base_of<CTOneBase, R>::value)
return decltype(_left() * _right())(1);

if constexpr (std::is_base_of<CTOneBase, L>::value)
else if constexpr (std::is_base_of<CTOneBase, L>::value)
return _right();

if constexpr (std::is_base_of<CTOneBase, R>::value)
else if constexpr (std::is_base_of<CTOneBase, R>::value)
return _left();

else
Expand All @@ -556,17 +552,18 @@ class CTDiv : public CTBinary<L, R>
{
public:
CTDiv(L left, R right) : CTBinary<L, R>(left, right) {}
using typename CTBinary<L, R>::ResultType;

ResultType operator()() const
auto operator()() const
{
if constexpr (std::is_base_of<CTOneBase, R>::value)
return _left();

if constexpr (std::is_base_of<CTNullBase, L>::value && !std::is_base_of<CTNullBase, R>::value)
return ResultType(0);
else if constexpr (std::is_base_of<CTNullBase, L>::value &&
!std::is_base_of<CTNullBase, R>::value)
return decltype(_left() / _right())(0);

return _left() / _right();
else
return _left() / _right();
}
std::string print() const { return this->printParens(this, "/"); }
constexpr static int precedence() { return 5; }
Expand Down Expand Up @@ -601,9 +598,8 @@ class CTCompare : public CTBinary<L, R>
{
public:
CTCompare(L left, R right) : CTBinary<L, R>(left, right) {}
typedef bool ResultType;

ResultType operator()() const
auto operator()() const
{
if constexpr (C == CTComparisonEnum::Less)
return _left() < _right();
Expand Down Expand Up @@ -639,7 +635,7 @@ class CTCompare : public CTBinary<L, R>
template <CTTag dtag>
auto D() const
{
return CTNull<ResultType>();
return CTNull<decltype(std::declval<CTCompare<C, L, R>>())>();
}

using CTBinary<L, R>::_left;
Expand Down Expand Up @@ -668,20 +664,21 @@ class CTPow : public CTBinary<L, R>
{
public:
CTPow(L left, R right) : CTBinary<L, R>(left, right) {}
using typename CTBinary<L, R>::ResultType;

ResultType operator()() const
auto operator()() const
{
if constexpr (std::is_base_of<CTNullBase, L>::value)
return ResultType(0);
return decltype(std::pow(_left(), _right()))(0);

if constexpr (std::is_base_of<CTOneBase, L>::value || std::is_base_of<CTNullBase, R>::value)
return ResultType(1);
else if constexpr (std::is_base_of<CTOneBase, L>::value ||
std::is_base_of<CTNullBase, R>::value)
return decltype(std::pow(_left(), _right()))(1);

if constexpr (std::is_base_of<CTOneBase, R>::value)
else if constexpr (std::is_base_of<CTOneBase, R>::value)
return _left();

return std::pow(_left(), _right());
else
return std::pow(_left(), _right());
}
std::string print() const { return "pow(" + _left.print() + "," + _right.print() + ")"; }

Expand All @@ -690,7 +687,7 @@ class CTPow : public CTBinary<L, R>
{
if constexpr (std::is_base_of<CTNullBase, decltype(_left.template D<dtag>())>::value &&
std::is_base_of<CTNullBase, decltype(_right.template D<dtag>())>::value)
return CTNull<ResultType>();
return CTNull<decltype(std::pow(_left(), _right()))>();

else if constexpr (std::is_base_of<CTNullBase, decltype(_left.template D<dtag>())>::value)
return pow(_left, _right) * _right.template D<dtag>() * log(_left);
Expand Down Expand Up @@ -732,15 +729,14 @@ class CTIPow : public CTUnary<B>
{
public:
CTIPow(B base) : CTUnary<B>(base) {}
using typename CTUnary<B>::ResultType;

ResultType operator()() const
auto operator()() const
{
if constexpr (std::is_base_of<CTNullBase, B>::value)
return ResultType(0);
return decltype(libMesh::Utility::pow<E>(_arg()))(0);

else if constexpr (std::is_base_of<CTOneBase, B>::value || E == 0)
return ResultType(1);
return decltype(libMesh::Utility::pow<E>(_arg()))(1);

else if constexpr (E == 1)
return _arg();
Expand All @@ -760,7 +756,7 @@ class CTIPow : public CTUnary<B>
return _arg.template D<dtag>();

else if constexpr (E == 0)
return CTNull<ResultType>();
return CTNull<decltype(libMesh::Utility::pow<E>(_arg()))>();

else
return pow<E - 1>(_arg) * E * _arg.template D<dtag>();
Expand Down Expand Up @@ -835,7 +831,6 @@ CT_OPERATOR_BINARY(!=, CTCompareUnequal)
} \
std::string print() const { return #name "(" + _arg.print() + ")"; } \
constexpr static int precedence() { return 2; } \
using typename CTUnary<T>::ResultType; \
using CTUnary<T>::_arg; \
}; \
template <typename T> \
Expand Down Expand Up @@ -880,7 +875,6 @@ CT_SIMPLE_UNARY_FUNCTION(atan, 1.0 / (pow<2>(_arg) + 1.0) * _arg.template D<dtag
} \
std::string print() const { return #name "(" + _left.print() + ", " + _right.print() + ")"; } \
constexpr static int precedence() { return 2; } \
using typename CTBinary<L, R>::ResultType; \
using CTBinary<L, R>::_left; \
using CTBinary<L, R>::_right; \
};
Expand Down Expand Up @@ -937,7 +931,7 @@ class CTStandardDeviation : public CTBase

protected:
template <int R, std::size_t... Is>
ResultType rowMul(std::index_sequence<Is...>, const std::array<ResultType, N> & d) const
auto rowMul(std::index_sequence<Is...>, const std::array<ResultType, N> & d) const
{
return ((_covariance(R, Is) * d[Is]) + ...);
}
Expand Down Expand Up @@ -973,4 +967,27 @@ makeStandardDeviation(const T & f, const CTMatrix<Real, N, N> covariance)
covariance);
}

/**
* Helper function to take a derivative
*/
template <CTTag dtag,CTTag... dtags, typename T>
auto
diff(T result)
{
if constexpr (sizeof...(dtags) == 0)
return result.template D<dtag>();
else
return diff<dtags...>(result.template D<dtag>());
}

/**
* Helper function to evaluate a result
*/
template <typename T>
auto
eval(T result)
{
return result();
}

} // namespace CompileTimeDerivatives
19 changes: 17 additions & 2 deletions unit/src/CompileTimeDerivativesTest.C
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,18 @@ TEST(CompileTimeDerivativesTest, variable_reference)
const auto X = makeRef<dX>(x);
const auto result = X * X + 100.0;

x = 5;
x = 5.0;
EXPECT_EQ(result(), 125.0);
EXPECT_EQ(result.D<dX>()(), 10.0);

x = 3;
x = 3.0;
EXPECT_EQ(result(), 109.0);
EXPECT_EQ(eval(result), 109.0); // alternative syntax for evaluation

EXPECT_EQ(result.D<dX>()(), 6.0);
EXPECT_EQ(eval(diff<dX>(result)), 6.0); // alternative derivative syntax

EXPECT_EQ(eval(diff<dX, dX>(result)), 2.0);
}

TEST(CompileTimeDerivativesTest, vector_reference)
Expand Down Expand Up @@ -323,3 +328,13 @@ TEST(CompileTimeDerivativesTest, conditional)
EXPECT_EQ(result(), 5 * vx);
}
}

TEST(CompileTimeDerivativesTest, typeRefs)
{
const RealVectorValue va(1,2,3), vb(4,6,8);
const auto [a, b] = makeRefs<30>(va, vb);

// matching order
EXPECT_EQ(&va, &a());
EXPECT_EQ(&vb, &b());
}

0 comments on commit fd40509

Please sign in to comment.