diff --git a/framework/include/utils/CompileTimeDerivatives.h b/framework/include/utils/CompileTimeDerivatives.h index 17d19b733cca..6131d2ae000b 100644 --- a/framework/include/utils/CompileTimeDerivatives.h +++ b/framework/include/utils/CompileTimeDerivatives.h @@ -400,7 +400,13 @@ class CTArrayRef : public CTBase { public: CTArrayRef(const T & arr, const I & idx) : _arr(arr), _idx(idx) {} - auto operator()() const { return _arr[_idx]; } + + // get the value type returned by operator[] + using ResultType = CTCleanType()[std::declval()])>; + static_assert(!std::is_same_v, + "Instantiation of CTArrayRef was attempted for a non-subscriptable type."); + ResultType operator()() const { return _arr[_idx]; } + std::string print() const { return "[a" + printTag() + "[" + Moose::stringify(_idx) + "]]"; } template @@ -412,11 +418,6 @@ class CTArrayRef : public CTBase return CTNull(); } - // get the value type returned by operator[] - typedef CTCleanType(0))[0])> ResultType; - static_assert(!std::is_same_v, - "Instantiation of CTArrayRef was attempted for a non-subscriptable type."); - protected: const T & _arr; const I & _idx; @@ -440,23 +441,20 @@ class CTAdd : public CTBinary { public: CTAdd(L left, R right) : CTBinary(left, right) {} - using typename CTBinary::ResultType; - ResultType operator()() const + auto operator()() const { // compile time optimization to skip null terms - if constexpr (std::is_base_of::value && std::is_base_of::value) - return ResultType(0); - if constexpr (std::is_base_of::value) return _right(); - if constexpr (std::is_base_of::value) + else if constexpr (std::is_base_of::value) return _left(); else return _left() + _right(); } + std::string print() const { return this->printParens(this, "+"); } constexpr static int precedence() { return 6; } @@ -478,17 +476,16 @@ class CTSub : public CTBinary { public: CTSub(L left, R right) : CTBinary(left, right) {} - using typename CTBinary::ResultType; - ResultType operator()() const + auto operator()() const { if constexpr (std::is_base_of::value && std::is_base_of::value) - return ResultType(0); + return decltype(_left() - _right())(0); - if constexpr (std::is_base_of::value) + else if constexpr (std::is_base_of::value) return -_right(); - if constexpr (std::is_base_of::value) + else if constexpr (std::is_base_of::value) return _left(); else @@ -516,20 +513,19 @@ class CTMul : public CTBinary { public: CTMul(L left, R right) : CTBinary(left, right) {} - using typename CTBinary::ResultType; - ResultType operator()() const + auto operator()() const { if constexpr (std::is_base_of::value || std::is_base_of::value) - return ResultType(0); + return decltype(_left() * _right())(0); - if constexpr (std::is_base_of::value && std::is_base_of::value) - return ResultType(1); + else if constexpr (std::is_base_of::value && std::is_base_of::value) + return decltype(_left() * _right())(1); - if constexpr (std::is_base_of::value) + else if constexpr (std::is_base_of::value) return _right(); - if constexpr (std::is_base_of::value) + else if constexpr (std::is_base_of::value) return _left(); else @@ -556,17 +552,18 @@ class CTDiv : public CTBinary { public: CTDiv(L left, R right) : CTBinary(left, right) {} - using typename CTBinary::ResultType; - ResultType operator()() const + auto operator()() const { if constexpr (std::is_base_of::value) return _left(); - if constexpr (std::is_base_of::value && !std::is_base_of::value) - return ResultType(0); + else if constexpr (std::is_base_of::value && + !std::is_base_of::value) + return decltype(_left() / _right())(0); - return _left() / _right(); + else + return _left() / _right(); } std::string print() const { return this->printParens(this, "/"); } constexpr static int precedence() { return 5; } @@ -601,9 +598,8 @@ class CTCompare : public CTBinary { public: CTCompare(L left, R right) : CTBinary(left, right) {} - typedef bool ResultType; - ResultType operator()() const + auto operator()() const { if constexpr (C == CTComparisonEnum::Less) return _left() < _right(); @@ -639,7 +635,7 @@ class CTCompare : public CTBinary template auto D() const { - return CTNull(); + return CTNull>())>(); } using CTBinary::_left; @@ -668,20 +664,21 @@ class CTPow : public CTBinary { public: CTPow(L left, R right) : CTBinary(left, right) {} - using typename CTBinary::ResultType; - ResultType operator()() const + auto operator()() const { if constexpr (std::is_base_of::value) - return ResultType(0); + return decltype(std::pow(_left(), _right()))(0); - if constexpr (std::is_base_of::value || std::is_base_of::value) - return ResultType(1); + else if constexpr (std::is_base_of::value || + std::is_base_of::value) + return decltype(std::pow(_left(), _right()))(1); - if constexpr (std::is_base_of::value) + else if constexpr (std::is_base_of::value) return _left(); - return std::pow(_left(), _right()); + else + return std::pow(_left(), _right()); } std::string print() const { return "pow(" + _left.print() + "," + _right.print() + ")"; } @@ -690,7 +687,7 @@ class CTPow : public CTBinary { if constexpr (std::is_base_of())>::value && std::is_base_of())>::value) - return CTNull(); + return CTNull(); else if constexpr (std::is_base_of())>::value) return pow(_left, _right) * _right.template D() * log(_left); @@ -732,15 +729,14 @@ class CTIPow : public CTUnary { public: CTIPow(B base) : CTUnary(base) {} - using typename CTUnary::ResultType; - ResultType operator()() const + auto operator()() const { if constexpr (std::is_base_of::value) - return ResultType(0); + return decltype(libMesh::Utility::pow(_arg()))(0); else if constexpr (std::is_base_of::value || E == 0) - return ResultType(1); + return decltype(libMesh::Utility::pow(_arg()))(1); else if constexpr (E == 1) return _arg(); @@ -760,7 +756,7 @@ class CTIPow : public CTUnary return _arg.template D(); else if constexpr (E == 0) - return CTNull(); + return CTNull(_arg()))>(); else return pow(_arg) * E * _arg.template D(); @@ -835,7 +831,6 @@ CT_OPERATOR_BINARY(!=, CTCompareUnequal) } \ std::string print() const { return #name "(" + _arg.print() + ")"; } \ constexpr static int precedence() { return 2; } \ - using typename CTUnary::ResultType; \ using CTUnary::_arg; \ }; \ template \ @@ -880,7 +875,6 @@ CT_SIMPLE_UNARY_FUNCTION(atan, 1.0 / (pow<2>(_arg) + 1.0) * _arg.template D::ResultType; \ using CTBinary::_left; \ using CTBinary::_right; \ }; @@ -937,7 +931,7 @@ class CTStandardDeviation : public CTBase protected: template - ResultType rowMul(std::index_sequence, const std::array & d) const + auto rowMul(std::index_sequence, const std::array & d) const { return ((_covariance(R, Is) * d[Is]) + ...); } @@ -973,4 +967,27 @@ makeStandardDeviation(const T & f, const CTMatrix covariance) covariance); } +/** + * Helper function to take a derivative + */ +template +auto +diff(T result) +{ + if constexpr (sizeof...(dtags) == 0) + return result.template D(); + else + return diff(result.template D()); +} + +/** + * Helper function to evaluate a result + */ +template +auto +eval(T result) +{ + return result(); +} + } // namespace CompileTimeDerivatives diff --git a/unit/src/CompileTimeDerivativesTest.C b/unit/src/CompileTimeDerivativesTest.C index bc71bfccd762..7ffdfd6dcbea 100644 --- a/unit/src/CompileTimeDerivativesTest.C +++ b/unit/src/CompileTimeDerivativesTest.C @@ -207,13 +207,18 @@ TEST(CompileTimeDerivativesTest, variable_reference) const auto X = makeRef(x); const auto result = X * X + 100.0; - x = 5; + x = 5.0; EXPECT_EQ(result(), 125.0); EXPECT_EQ(result.D()(), 10.0); - x = 3; + x = 3.0; EXPECT_EQ(result(), 109.0); + EXPECT_EQ(eval(result), 109.0); // alternative syntax for evaluation + EXPECT_EQ(result.D()(), 6.0); + EXPECT_EQ(eval(diff(result)), 6.0); // alternative derivative syntax + + EXPECT_EQ(eval(diff(result)), 2.0); } TEST(CompileTimeDerivativesTest, vector_reference) @@ -323,3 +328,13 @@ TEST(CompileTimeDerivativesTest, conditional) EXPECT_EQ(result(), 5 * vx); } } + +TEST(CompileTimeDerivativesTest, typeRefs) +{ + const RealVectorValue va(1,2,3), vb(4,6,8); + const auto [a, b] = makeRefs<30>(va, vb); + + // matching order + EXPECT_EQ(&va, &a()); + EXPECT_EQ(&vb, &b()); +}