Skip to content

Commit

Permalink
[trajoptlib] Upgrade Sleipnir for faster autodiff and less memory usa…
Browse files Browse the repository at this point in the history
…ge (#1136)
  • Loading branch information
calcmogul authored Jan 9, 2025
1 parent 64d2b7d commit f9108be
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 15 deletions.
4 changes: 2 additions & 2 deletions trajoptlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ set(BUILD_EXAMPLES OFF)
fetchcontent_declare(
Sleipnir
GIT_REPOSITORY https://github.com/SleipnirGroup/Sleipnir
# main on 2025-01-07
GIT_TAG fb0c46d547cb7c7801cbbbb6bddd5d7f86061a8f
# main on 2025-01-08
GIT_TAG 3b3df2fe485743ab73f395158d0225ad87d9e67c
PATCH_COMMAND
git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/0001-Downgrade-to-C-20.patch
UPDATE_DISCONNECTED 1
Expand Down
255 changes: 242 additions & 13 deletions trajoptlib/cmake/0001-Downgrade-to-C-20.patch
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Tyler Veness <[email protected]>
Date: Thu, 20 Jun 2024 12:30:54 -0700
Date: Wed, 8 Jan 2025 23:39:34 -0800
Subject: [PATCH] Downgrade to C++20

---
.styleguide | 1 +
CMakeLists.txt | 10 ++++++++++
cmake/modules/CompilerFlags.cmake | 2 +-
include/.styleguide | 1 +
include/sleipnir/util/Print.hpp | 27 ++++++++++++++-------------
5 files changed, 27 insertions(+), 14 deletions(-)
.styleguide | 1 +
CMakeLists.txt | 10 +++++
cmake/modules/CompilerFlags.cmake | 2 +-
include/.styleguide | 1 +
include/sleipnir/autodiff/Expression.hpp | 57 +++++++++++-------------
include/sleipnir/util/Print.hpp | 27 +++++------
6 files changed, 53 insertions(+), 45 deletions(-)

diff --git a/.styleguide b/.styleguide
index fc51b044a61dd842dca50d013c52baed00612545..2476200a8763368007c3d4029bf326619a7bf0fa 100644
index 2cf272a115102886134455a87ea5c7dcf7188283..76e25473e77ab7838509eaeeefa7170a06507f8c 100644
--- a/.styleguide
+++ b/.styleguide
@@ -17,6 +17,7 @@ modifiableFileExclude {
Expand All @@ -24,7 +25,7 @@ index fc51b044a61dd842dca50d013c52baed00612545..2476200a8763368007c3d4029bf32661
^sleipnir/
}
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e9668b53d6a0317cf2367b212f5013cbb2aa7ad5..1db191b6492a5f10b657c42c7959a73ef72724e3 100644
index 04567acbda381aa07914214cbca0a4419fc065ed..5efa843f6e38195fe0717e971ca99cfce0137b68 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -114,6 +114,16 @@ endif()
Expand All @@ -45,7 +46,7 @@ index e9668b53d6a0317cf2367b212f5013cbb2aa7ad5..1db191b6492a5f10b657c42c7959a73e
# Catch2 dependency
fetchcontent_declare(
diff --git a/cmake/modules/CompilerFlags.cmake b/cmake/modules/CompilerFlags.cmake
index 9f3fad1ce359897d6e87d5c37efae51d55bcf160..4331baf9394a7f27bde51906c9dd14b8543e4af4 100644
index bab9de4de2c79abb3195cdd8b3a8b56da665e900..8caa4b1d7c7f4be463b2541fba91561428193d51 100644
--- a/cmake/modules/CompilerFlags.cmake
+++ b/cmake/modules/CompilerFlags.cmake
@@ -22,7 +22,7 @@ macro(compiler_flags target)
Expand All @@ -67,12 +68,240 @@ index 8fb61fdf9cc5ceff633d3126f0579eef25a1326f..6a7f8ed28f9cb037c9746a7e0ef5e110
^Eigen/
+ ^fmt/
}
diff --git a/include/sleipnir/autodiff/Expression.hpp b/include/sleipnir/autodiff/Expression.hpp
index d418bdb201aa01a6ee39c3b2442b717d5e21b137..11ca0149e914b3d194b5a207bea5a72273320669 100644
--- a/include/sleipnir/autodiff/Expression.hpp
+++ b/include/sleipnir/autodiff/Expression.hpp
@@ -371,8 +371,8 @@ struct BinaryMinusExpression : Expression {
* @param lhs Binary operator's left operand.
* @param rhs Binary operator's right operand.
*/
- constexpr BinaryMinusExpression(ExpressionType type, ExpressionPtr lhs,
- ExpressionPtr rhs)
+ BinaryMinusExpression(ExpressionType type, ExpressionPtr lhs,
+ ExpressionPtr rhs)
: Expression{type, std::move(lhs), std::move(rhs)} {
value = BinaryMinusExpression::Value(args[0]->value, args[1]->value);
}
@@ -406,8 +406,8 @@ struct BinaryPlusExpression : Expression {
* @param lhs Binary operator's left operand.
* @param rhs Binary operator's right operand.
*/
- constexpr BinaryPlusExpression(ExpressionType type, ExpressionPtr lhs,
- ExpressionPtr rhs)
+ BinaryPlusExpression(ExpressionType type, ExpressionPtr lhs,
+ ExpressionPtr rhs)
: Expression{type, std::move(lhs), std::move(rhs)} {
value = BinaryPlusExpression::Value(args[0]->value, args[1]->value);
}
@@ -475,8 +475,7 @@ struct DivExpression : Expression {
* @param lhs Binary operator's left operand.
* @param rhs Binary operator's right operand.
*/
- constexpr DivExpression(ExpressionType type, ExpressionPtr lhs,
- ExpressionPtr rhs)
+ DivExpression(ExpressionType type, ExpressionPtr lhs, ExpressionPtr rhs)
: Expression{type, std::move(lhs), std::move(rhs)} {
value = DivExpression::Value(args[0]->value, args[1]->value);
}
@@ -512,8 +511,7 @@ struct MultExpression : Expression {
* @param lhs Binary operator's left operand.
* @param rhs Binary operator's right operand.
*/
- constexpr MultExpression(ExpressionType type, ExpressionPtr lhs,
- ExpressionPtr rhs)
+ MultExpression(ExpressionType type, ExpressionPtr lhs, ExpressionPtr rhs)
: Expression{type, std::move(lhs), std::move(rhs)} {
value = MultExpression::Value(args[0]->value, args[1]->value);
}
@@ -550,7 +548,7 @@ struct UnaryMinusExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr UnaryMinusExpression(ExpressionType type, ExpressionPtr lhs)
+ UnaryMinusExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = UnaryMinusExpression::Value(args[0]->value, 0.0);
}
@@ -628,7 +626,7 @@ struct AbsExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr AbsExpression(ExpressionType type, ExpressionPtr lhs)
+ AbsExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = AbsExpression::Value(args[0]->value, 0.0);
}
@@ -688,7 +686,7 @@ struct AcosExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr AcosExpression(ExpressionType type, ExpressionPtr lhs)
+ AcosExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = AcosExpression::Value(args[0]->value, 0.0);
}
@@ -736,7 +734,7 @@ struct AsinExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr AsinExpression(ExpressionType type, ExpressionPtr lhs)
+ AsinExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = AsinExpression::Value(args[0]->value, 0.0);
}
@@ -784,7 +782,7 @@ struct AtanExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr AtanExpression(ExpressionType type, ExpressionPtr lhs)
+ AtanExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = AtanExpression::Value(args[0]->value, 0.0);
}
@@ -832,8 +830,7 @@ struct Atan2Expression : Expression {
* @param lhs Binary operator's left operand.
* @param rhs Binary operator's right operand.
*/
- constexpr Atan2Expression(ExpressionType type, ExpressionPtr lhs,
- ExpressionPtr rhs)
+ Atan2Expression(ExpressionType type, ExpressionPtr lhs, ExpressionPtr rhs)
: Expression{type, std::move(lhs), std::move(rhs)} {
value = Atan2Expression::Value(args[0]->value, args[1]->value);
}
@@ -893,7 +890,7 @@ struct CosExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr CosExpression(ExpressionType type, ExpressionPtr lhs)
+ CosExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = CosExpression::Value(args[0]->value, 0.0);
}
@@ -939,7 +936,7 @@ struct CoshExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr CoshExpression(ExpressionType type, ExpressionPtr lhs)
+ CoshExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = CoshExpression::Value(args[0]->value, 0.0);
}
@@ -985,7 +982,7 @@ struct ErfExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr ErfExpression(ExpressionType type, ExpressionPtr lhs)
+ ErfExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = ErfExpression::Value(args[0]->value, 0.0);
}
@@ -1034,7 +1031,7 @@ struct ExpExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr ExpExpression(ExpressionType type, ExpressionPtr lhs)
+ ExpExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = ExpExpression::Value(args[0]->value, 0.0);
}
@@ -1083,8 +1080,7 @@ struct HypotExpression : Expression {
* @param lhs Binary operator's left operand.
* @param rhs Binary operator's right operand.
*/
- constexpr HypotExpression(ExpressionType type, ExpressionPtr lhs,
- ExpressionPtr rhs)
+ HypotExpression(ExpressionType type, ExpressionPtr lhs, ExpressionPtr rhs)
: Expression{type, std::move(lhs), std::move(rhs)} {
value = HypotExpression::Value(args[0]->value, args[1]->value);
}
@@ -1143,7 +1139,7 @@ struct LogExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr LogExpression(ExpressionType type, ExpressionPtr lhs)
+ LogExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = LogExpression::Value(args[0]->value, 0.0);
}
@@ -1190,7 +1186,7 @@ struct Log10Expression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr Log10Expression(ExpressionType type, ExpressionPtr lhs)
+ Log10Expression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = Log10Expression::Value(args[0]->value, 0.0);
}
@@ -1241,8 +1237,7 @@ struct PowExpression : Expression {
* @param lhs Binary operator's left operand.
* @param rhs Binary operator's right operand.
*/
- constexpr PowExpression(ExpressionType type, ExpressionPtr lhs,
- ExpressionPtr rhs)
+ PowExpression(ExpressionType type, ExpressionPtr lhs, ExpressionPtr rhs)
: Expression{type, std::move(lhs), std::move(rhs)} {
value = PowExpression::Value(args[0]->value, args[1]->value);
}
@@ -1332,7 +1327,7 @@ struct SignExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr SignExpression(ExpressionType type, ExpressionPtr lhs)
+ SignExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = SignExpression::Value(args[0]->value, 0.0);
}
@@ -1386,7 +1381,7 @@ struct SinExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr SinExpression(ExpressionType type, ExpressionPtr lhs)
+ SinExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = SinExpression::Value(args[0]->value, 0.0);
}
@@ -1433,7 +1428,7 @@ struct SinhExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr SinhExpression(ExpressionType type, ExpressionPtr lhs)
+ SinhExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = SinhExpression::Value(args[0]->value, 0.0);
}
@@ -1480,7 +1475,7 @@ struct SqrtExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr SqrtExpression(ExpressionType type, ExpressionPtr lhs)
+ SqrtExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = SqrtExpression::Value(args[0]->value, 0.0);
}
@@ -1529,7 +1524,7 @@ struct TanExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr TanExpression(ExpressionType type, ExpressionPtr lhs)
+ TanExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = TanExpression::Value(args[0]->value, 0.0);
}
@@ -1577,7 +1572,7 @@ struct TanhExpression : Expression {
* @param type The expression's type.
* @param lhs Unary operator's operand.
*/
- constexpr TanhExpression(ExpressionType type, ExpressionPtr lhs)
+ TanhExpression(ExpressionType type, ExpressionPtr lhs)
: Expression{type, std::move(lhs)} {
value = TanhExpression::Value(args[0]->value, 0.0);
}
diff --git a/include/sleipnir/util/Print.hpp b/include/sleipnir/util/Print.hpp
index 339320bce6d017ca85025060ba445b2f025bb225..fcf2e69bfb5a081cd915bdded3caa80cd9c38518 100644
index a746cb77b70f095bb15f4c493295cb925bc74cd3..484d1b2bec7148c5b9affccbf554c7df2b954cc0 100644
--- a/include/sleipnir/util/Print.hpp
+++ b/include/sleipnir/util/Print.hpp
@@ -2,53 +2,54 @@

@@ -3,52 +3,53 @@
#pragma once

#include <cstdio>
Expand Down

0 comments on commit f9108be

Please sign in to comment.