-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[trajoptlib] Upgrade Sleipnir for faster autodiff and less memory usa…
…ge (#1136)
- Loading branch information
Showing
2 changed files
with
244 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,19 @@ | ||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 | ||
From: Tyler Veness <[email protected]> | ||
Date: Thu, 20 Jun 2024 12:30:54 -0700 | ||
Date: Wed, 8 Jan 2025 23:39:34 -0800 | ||
Subject: [PATCH] Downgrade to C++20 | ||
|
||
--- | ||
.styleguide | 1 + | ||
CMakeLists.txt | 10 ++++++++++ | ||
cmake/modules/CompilerFlags.cmake | 2 +- | ||
include/.styleguide | 1 + | ||
include/sleipnir/util/Print.hpp | 27 ++++++++++++++------------- | ||
5 files changed, 27 insertions(+), 14 deletions(-) | ||
.styleguide | 1 + | ||
CMakeLists.txt | 10 +++++ | ||
cmake/modules/CompilerFlags.cmake | 2 +- | ||
include/.styleguide | 1 + | ||
include/sleipnir/autodiff/Expression.hpp | 57 +++++++++++------------- | ||
include/sleipnir/util/Print.hpp | 27 +++++------ | ||
6 files changed, 53 insertions(+), 45 deletions(-) | ||
|
||
diff --git a/.styleguide b/.styleguide | ||
index fc51b044a61dd842dca50d013c52baed00612545..2476200a8763368007c3d4029bf326619a7bf0fa 100644 | ||
index 2cf272a115102886134455a87ea5c7dcf7188283..76e25473e77ab7838509eaeeefa7170a06507f8c 100644 | ||
--- a/.styleguide | ||
+++ b/.styleguide | ||
@@ -17,6 +17,7 @@ modifiableFileExclude { | ||
|
@@ -24,7 +25,7 @@ index fc51b044a61dd842dca50d013c52baed00612545..2476200a8763368007c3d4029bf32661 | |
^sleipnir/ | ||
} | ||
diff --git a/CMakeLists.txt b/CMakeLists.txt | ||
index e9668b53d6a0317cf2367b212f5013cbb2aa7ad5..1db191b6492a5f10b657c42c7959a73ef72724e3 100644 | ||
index 04567acbda381aa07914214cbca0a4419fc065ed..5efa843f6e38195fe0717e971ca99cfce0137b68 100644 | ||
--- a/CMakeLists.txt | ||
+++ b/CMakeLists.txt | ||
@@ -114,6 +114,16 @@ endif() | ||
|
@@ -45,7 +46,7 @@ index e9668b53d6a0317cf2367b212f5013cbb2aa7ad5..1db191b6492a5f10b657c42c7959a73e | |
# Catch2 dependency | ||
fetchcontent_declare( | ||
diff --git a/cmake/modules/CompilerFlags.cmake b/cmake/modules/CompilerFlags.cmake | ||
index 9f3fad1ce359897d6e87d5c37efae51d55bcf160..4331baf9394a7f27bde51906c9dd14b8543e4af4 100644 | ||
index bab9de4de2c79abb3195cdd8b3a8b56da665e900..8caa4b1d7c7f4be463b2541fba91561428193d51 100644 | ||
--- a/cmake/modules/CompilerFlags.cmake | ||
+++ b/cmake/modules/CompilerFlags.cmake | ||
@@ -22,7 +22,7 @@ macro(compiler_flags target) | ||
|
@@ -67,12 +68,240 @@ index 8fb61fdf9cc5ceff633d3126f0579eef25a1326f..6a7f8ed28f9cb037c9746a7e0ef5e110 | |
^Eigen/ | ||
+ ^fmt/ | ||
} | ||
diff --git a/include/sleipnir/autodiff/Expression.hpp b/include/sleipnir/autodiff/Expression.hpp | ||
index d418bdb201aa01a6ee39c3b2442b717d5e21b137..11ca0149e914b3d194b5a207bea5a72273320669 100644 | ||
--- a/include/sleipnir/autodiff/Expression.hpp | ||
+++ b/include/sleipnir/autodiff/Expression.hpp | ||
@@ -371,8 +371,8 @@ struct BinaryMinusExpression : Expression { | ||
* @param lhs Binary operator's left operand. | ||
* @param rhs Binary operator's right operand. | ||
*/ | ||
- constexpr BinaryMinusExpression(ExpressionType type, ExpressionPtr lhs, | ||
- ExpressionPtr rhs) | ||
+ BinaryMinusExpression(ExpressionType type, ExpressionPtr lhs, | ||
+ ExpressionPtr rhs) | ||
: Expression{type, std::move(lhs), std::move(rhs)} { | ||
value = BinaryMinusExpression::Value(args[0]->value, args[1]->value); | ||
} | ||
@@ -406,8 +406,8 @@ struct BinaryPlusExpression : Expression { | ||
* @param lhs Binary operator's left operand. | ||
* @param rhs Binary operator's right operand. | ||
*/ | ||
- constexpr BinaryPlusExpression(ExpressionType type, ExpressionPtr lhs, | ||
- ExpressionPtr rhs) | ||
+ BinaryPlusExpression(ExpressionType type, ExpressionPtr lhs, | ||
+ ExpressionPtr rhs) | ||
: Expression{type, std::move(lhs), std::move(rhs)} { | ||
value = BinaryPlusExpression::Value(args[0]->value, args[1]->value); | ||
} | ||
@@ -475,8 +475,7 @@ struct DivExpression : Expression { | ||
* @param lhs Binary operator's left operand. | ||
* @param rhs Binary operator's right operand. | ||
*/ | ||
- constexpr DivExpression(ExpressionType type, ExpressionPtr lhs, | ||
- ExpressionPtr rhs) | ||
+ DivExpression(ExpressionType type, ExpressionPtr lhs, ExpressionPtr rhs) | ||
: Expression{type, std::move(lhs), std::move(rhs)} { | ||
value = DivExpression::Value(args[0]->value, args[1]->value); | ||
} | ||
@@ -512,8 +511,7 @@ struct MultExpression : Expression { | ||
* @param lhs Binary operator's left operand. | ||
* @param rhs Binary operator's right operand. | ||
*/ | ||
- constexpr MultExpression(ExpressionType type, ExpressionPtr lhs, | ||
- ExpressionPtr rhs) | ||
+ MultExpression(ExpressionType type, ExpressionPtr lhs, ExpressionPtr rhs) | ||
: Expression{type, std::move(lhs), std::move(rhs)} { | ||
value = MultExpression::Value(args[0]->value, args[1]->value); | ||
} | ||
@@ -550,7 +548,7 @@ struct UnaryMinusExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr UnaryMinusExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ UnaryMinusExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = UnaryMinusExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -628,7 +626,7 @@ struct AbsExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr AbsExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ AbsExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = AbsExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -688,7 +686,7 @@ struct AcosExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr AcosExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ AcosExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = AcosExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -736,7 +734,7 @@ struct AsinExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr AsinExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ AsinExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = AsinExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -784,7 +782,7 @@ struct AtanExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr AtanExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ AtanExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = AtanExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -832,8 +830,7 @@ struct Atan2Expression : Expression { | ||
* @param lhs Binary operator's left operand. | ||
* @param rhs Binary operator's right operand. | ||
*/ | ||
- constexpr Atan2Expression(ExpressionType type, ExpressionPtr lhs, | ||
- ExpressionPtr rhs) | ||
+ Atan2Expression(ExpressionType type, ExpressionPtr lhs, ExpressionPtr rhs) | ||
: Expression{type, std::move(lhs), std::move(rhs)} { | ||
value = Atan2Expression::Value(args[0]->value, args[1]->value); | ||
} | ||
@@ -893,7 +890,7 @@ struct CosExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr CosExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ CosExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = CosExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -939,7 +936,7 @@ struct CoshExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr CoshExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ CoshExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = CoshExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -985,7 +982,7 @@ struct ErfExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr ErfExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ ErfExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = ErfExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -1034,7 +1031,7 @@ struct ExpExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr ExpExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ ExpExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = ExpExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -1083,8 +1080,7 @@ struct HypotExpression : Expression { | ||
* @param lhs Binary operator's left operand. | ||
* @param rhs Binary operator's right operand. | ||
*/ | ||
- constexpr HypotExpression(ExpressionType type, ExpressionPtr lhs, | ||
- ExpressionPtr rhs) | ||
+ HypotExpression(ExpressionType type, ExpressionPtr lhs, ExpressionPtr rhs) | ||
: Expression{type, std::move(lhs), std::move(rhs)} { | ||
value = HypotExpression::Value(args[0]->value, args[1]->value); | ||
} | ||
@@ -1143,7 +1139,7 @@ struct LogExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr LogExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ LogExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = LogExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -1190,7 +1186,7 @@ struct Log10Expression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr Log10Expression(ExpressionType type, ExpressionPtr lhs) | ||
+ Log10Expression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = Log10Expression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -1241,8 +1237,7 @@ struct PowExpression : Expression { | ||
* @param lhs Binary operator's left operand. | ||
* @param rhs Binary operator's right operand. | ||
*/ | ||
- constexpr PowExpression(ExpressionType type, ExpressionPtr lhs, | ||
- ExpressionPtr rhs) | ||
+ PowExpression(ExpressionType type, ExpressionPtr lhs, ExpressionPtr rhs) | ||
: Expression{type, std::move(lhs), std::move(rhs)} { | ||
value = PowExpression::Value(args[0]->value, args[1]->value); | ||
} | ||
@@ -1332,7 +1327,7 @@ struct SignExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr SignExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ SignExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = SignExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -1386,7 +1381,7 @@ struct SinExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr SinExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ SinExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = SinExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -1433,7 +1428,7 @@ struct SinhExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr SinhExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ SinhExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = SinhExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -1480,7 +1475,7 @@ struct SqrtExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr SqrtExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ SqrtExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = SqrtExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -1529,7 +1524,7 @@ struct TanExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr TanExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ TanExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = TanExpression::Value(args[0]->value, 0.0); | ||
} | ||
@@ -1577,7 +1572,7 @@ struct TanhExpression : Expression { | ||
* @param type The expression's type. | ||
* @param lhs Unary operator's operand. | ||
*/ | ||
- constexpr TanhExpression(ExpressionType type, ExpressionPtr lhs) | ||
+ TanhExpression(ExpressionType type, ExpressionPtr lhs) | ||
: Expression{type, std::move(lhs)} { | ||
value = TanhExpression::Value(args[0]->value, 0.0); | ||
} | ||
diff --git a/include/sleipnir/util/Print.hpp b/include/sleipnir/util/Print.hpp | ||
index 339320bce6d017ca85025060ba445b2f025bb225..fcf2e69bfb5a081cd915bdded3caa80cd9c38518 100644 | ||
index a746cb77b70f095bb15f4c493295cb925bc74cd3..484d1b2bec7148c5b9affccbf554c7df2b954cc0 100644 | ||
--- a/include/sleipnir/util/Print.hpp | ||
+++ b/include/sleipnir/util/Print.hpp | ||
@@ -2,53 +2,54 @@ | ||
|
||
@@ -3,52 +3,53 @@ | ||
#pragma once | ||
|
||
#include <cstdio> | ||
|