Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor compiletime derivatives #29666

Draft
wants to merge 1 commit into
base: next
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 65 additions & 48 deletions framework/include/utils/CompileTimeDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,13 @@ class CTArrayRef : public CTBase
{
public:
CTArrayRef(const T & arr, const I & idx) : _arr(arr), _idx(idx) {}
auto operator()() const { return _arr[_idx]; }

// get the value type returned by operator[]
using ResultType = CTCleanType<decltype(std::declval<T>()[std::declval<I>()])>;
static_assert(!std::is_same_v<ResultType, void>,
"Instantiation of CTArrayRef was attempted for a non-subscriptable type.");
ResultType operator()() const { return _arr[_idx]; }

std::string print() const { return "[a" + printTag<tag>() + "[" + Moose::stringify(_idx) + "]]"; }

template <CTTag dtag>
Expand All @@ -412,11 +418,6 @@ class CTArrayRef : public CTBase
return CTNull<ResultType>();
}

// get the value type returned by operator[]
typedef CTCleanType<decltype((static_cast<T>(0))[0])> ResultType;
static_assert(!std::is_same_v<ResultType, void>,
"Instantiation of CTArrayRef was attempted for a non-subscriptable type.");

protected:
const T & _arr;
const I & _idx;
Expand All @@ -440,23 +441,20 @@ class CTAdd : public CTBinary<L, R>
{
public:
CTAdd(L left, R right) : CTBinary<L, R>(left, right) {}
using typename CTBinary<L, R>::ResultType;

ResultType operator()() const
auto operator()() const
{
// compile time optimization to skip null terms
if constexpr (std::is_base_of<CTNullBase, L>::value && std::is_base_of<CTNullBase, R>::value)
return ResultType(0);

if constexpr (std::is_base_of<CTNullBase, L>::value)
return _right();

if constexpr (std::is_base_of<CTNullBase, R>::value)
else if constexpr (std::is_base_of<CTNullBase, R>::value)
return _left();

else
return _left() + _right();
}

std::string print() const { return this->printParens(this, "+"); }
constexpr static int precedence() { return 6; }

Expand All @@ -478,17 +476,16 @@ class CTSub : public CTBinary<L, R>
{
public:
CTSub(L left, R right) : CTBinary<L, R>(left, right) {}
using typename CTBinary<L, R>::ResultType;

ResultType operator()() const
auto operator()() const
{
if constexpr (std::is_base_of<CTNullBase, L>::value && std::is_base_of<CTNullBase, R>::value)
return ResultType(0);
return decltype(_left() - _right())(0);

if constexpr (std::is_base_of<CTNullBase, L>::value)
else if constexpr (std::is_base_of<CTNullBase, L>::value)
return -_right();

if constexpr (std::is_base_of<CTNullBase, R>::value)
else if constexpr (std::is_base_of<CTNullBase, R>::value)
return _left();

else
Expand Down Expand Up @@ -516,20 +513,19 @@ class CTMul : public CTBinary<L, R>
{
public:
CTMul(L left, R right) : CTBinary<L, R>(left, right) {}
using typename CTBinary<L, R>::ResultType;

ResultType operator()() const
auto operator()() const
{
if constexpr (std::is_base_of<CTNullBase, L>::value || std::is_base_of<CTNullBase, R>::value)
return ResultType(0);
return decltype(_left() * _right())(0);

if constexpr (std::is_base_of<CTOneBase, L>::value && std::is_base_of<CTOneBase, R>::value)
return ResultType(1);
else if constexpr (std::is_base_of<CTOneBase, L>::value && std::is_base_of<CTOneBase, R>::value)
return decltype(_left() * _right())(1);

if constexpr (std::is_base_of<CTOneBase, L>::value)
else if constexpr (std::is_base_of<CTOneBase, L>::value)
return _right();

if constexpr (std::is_base_of<CTOneBase, R>::value)
else if constexpr (std::is_base_of<CTOneBase, R>::value)
return _left();

else
Expand All @@ -556,17 +552,18 @@ class CTDiv : public CTBinary<L, R>
{
public:
CTDiv(L left, R right) : CTBinary<L, R>(left, right) {}
using typename CTBinary<L, R>::ResultType;

ResultType operator()() const
auto operator()() const
{
if constexpr (std::is_base_of<CTOneBase, R>::value)
return _left();

if constexpr (std::is_base_of<CTNullBase, L>::value && !std::is_base_of<CTNullBase, R>::value)
return ResultType(0);
else if constexpr (std::is_base_of<CTNullBase, L>::value &&
!std::is_base_of<CTNullBase, R>::value)
return decltype(_left() / _right())(0);

return _left() / _right();
else
return _left() / _right();
}
std::string print() const { return this->printParens(this, "/"); }
constexpr static int precedence() { return 5; }
Expand Down Expand Up @@ -601,9 +598,8 @@ class CTCompare : public CTBinary<L, R>
{
public:
CTCompare(L left, R right) : CTBinary<L, R>(left, right) {}
typedef bool ResultType;

ResultType operator()() const
auto operator()() const
{
if constexpr (C == CTComparisonEnum::Less)
return _left() < _right();
Expand Down Expand Up @@ -639,7 +635,7 @@ class CTCompare : public CTBinary<L, R>
template <CTTag dtag>
auto D() const
{
return CTNull<ResultType>();
return CTNull<decltype(std::declval<CTCompare<C, L, R>>())>();
}

using CTBinary<L, R>::_left;
Expand Down Expand Up @@ -668,20 +664,21 @@ class CTPow : public CTBinary<L, R>
{
public:
CTPow(L left, R right) : CTBinary<L, R>(left, right) {}
using typename CTBinary<L, R>::ResultType;

ResultType operator()() const
auto operator()() const
{
if constexpr (std::is_base_of<CTNullBase, L>::value)
return ResultType(0);
return decltype(std::pow(_left(), _right()))(0);

if constexpr (std::is_base_of<CTOneBase, L>::value || std::is_base_of<CTNullBase, R>::value)
return ResultType(1);
else if constexpr (std::is_base_of<CTOneBase, L>::value ||
std::is_base_of<CTNullBase, R>::value)
return decltype(std::pow(_left(), _right()))(1);

if constexpr (std::is_base_of<CTOneBase, R>::value)
else if constexpr (std::is_base_of<CTOneBase, R>::value)
return _left();

return std::pow(_left(), _right());
else
return std::pow(_left(), _right());
}
std::string print() const { return "pow(" + _left.print() + "," + _right.print() + ")"; }

Expand All @@ -690,7 +687,7 @@ class CTPow : public CTBinary<L, R>
{
if constexpr (std::is_base_of<CTNullBase, decltype(_left.template D<dtag>())>::value &&
std::is_base_of<CTNullBase, decltype(_right.template D<dtag>())>::value)
return CTNull<ResultType>();
return CTNull<decltype(std::pow(_left(), _right()))>();

else if constexpr (std::is_base_of<CTNullBase, decltype(_left.template D<dtag>())>::value)
return pow(_left, _right) * _right.template D<dtag>() * log(_left);
Expand Down Expand Up @@ -732,15 +729,14 @@ class CTIPow : public CTUnary<B>
{
public:
CTIPow(B base) : CTUnary<B>(base) {}
using typename CTUnary<B>::ResultType;

ResultType operator()() const
auto operator()() const
{
if constexpr (std::is_base_of<CTNullBase, B>::value)
return ResultType(0);
return decltype(libMesh::Utility::pow<E>(_arg()))(0);

else if constexpr (std::is_base_of<CTOneBase, B>::value || E == 0)
return ResultType(1);
return decltype(libMesh::Utility::pow<E>(_arg()))(1);

else if constexpr (E == 1)
return _arg();
Expand All @@ -760,7 +756,7 @@ class CTIPow : public CTUnary<B>
return _arg.template D<dtag>();

else if constexpr (E == 0)
return CTNull<ResultType>();
return CTNull<decltype(libMesh::Utility::pow<E>(_arg()))>();

else
return pow<E - 1>(_arg) * E * _arg.template D<dtag>();
Expand Down Expand Up @@ -835,7 +831,6 @@ CT_OPERATOR_BINARY(!=, CTCompareUnequal)
} \
std::string print() const { return #name "(" + _arg.print() + ")"; } \
constexpr static int precedence() { return 2; } \
using typename CTUnary<T>::ResultType; \
using CTUnary<T>::_arg; \
}; \
template <typename T> \
Expand Down Expand Up @@ -880,7 +875,6 @@ CT_SIMPLE_UNARY_FUNCTION(atan, 1.0 / (pow<2>(_arg) + 1.0) * _arg.template D<dtag
} \
std::string print() const { return #name "(" + _left.print() + ", " + _right.print() + ")"; } \
constexpr static int precedence() { return 2; } \
using typename CTBinary<L, R>::ResultType; \
using CTBinary<L, R>::_left; \
using CTBinary<L, R>::_right; \
};
Expand Down Expand Up @@ -937,7 +931,7 @@ class CTStandardDeviation : public CTBase

protected:
template <int R, std::size_t... Is>
ResultType rowMul(std::index_sequence<Is...>, const std::array<ResultType, N> & d) const
auto rowMul(std::index_sequence<Is...>, const std::array<ResultType, N> & d) const
{
return ((_covariance(R, Is) * d[Is]) + ...);
}
Expand Down Expand Up @@ -973,4 +967,27 @@ makeStandardDeviation(const T & f, const CTMatrix<Real, N, N> covariance)
covariance);
}

/**
* Helper function to take a derivative
*/
template <CTTag dtag, CTTag... dtags, typename T>
auto
diff(T result)
{
if constexpr (sizeof...(dtags) == 0)
return result.template D<dtag>();
else
return diff<dtags...>(result.template D<dtag>());
}

/**
* Helper function to evaluate a result
*/
template <typename T>
auto
eval(T result)
{
return result();
}

} // namespace CompileTimeDerivatives
19 changes: 17 additions & 2 deletions unit/src/CompileTimeDerivativesTest.C
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,18 @@ TEST(CompileTimeDerivativesTest, variable_reference)
const auto X = makeRef<dX>(x);
const auto result = X * X + 100.0;

x = 5;
x = 5.0;
EXPECT_EQ(result(), 125.0);
EXPECT_EQ(result.D<dX>()(), 10.0);

x = 3;
x = 3.0;
EXPECT_EQ(result(), 109.0);
EXPECT_EQ(eval(result), 109.0); // alternative syntax for evaluation

EXPECT_EQ(result.D<dX>()(), 6.0);
EXPECT_EQ(eval(diff<dX>(result)), 6.0); // alternative derivative syntax

EXPECT_EQ(eval(diff<dX, dX>(result)), 2.0);
}

TEST(CompileTimeDerivativesTest, vector_reference)
Expand Down Expand Up @@ -323,3 +328,13 @@ TEST(CompileTimeDerivativesTest, conditional)
EXPECT_EQ(result(), 5 * vx);
}
}

TEST(CompileTimeDerivativesTest, typeRefs)
{
const RealVectorValue va(1, 2, 3), vb(4, 6, 8);
const auto [a, b] = makeRefs<30>(va, vb);

// matching order
EXPECT_EQ(&va, &a());
EXPECT_EQ(&vb, &b());
}