From 011476b590cd00a0c7cb8119021068525764444d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eirik=20Kol=C3=A5s?= Date: Sat, 13 Jul 2024 10:05:55 +0200 Subject: [PATCH] started making State type independent on a specific enum --- .../models/dynamic_models.hpp | 8 +++--- .../vortex_filtering/models/imm_model.hpp | 2 +- .../include/vortex_filtering/models/state.hpp | 27 ++++++++++++------- vortex-filtering/test/state_test.cpp | 15 ++++++----- vortex-filtering/test/types_test.cpp | 2 +- 5 files changed, 32 insertions(+), 22 deletions(-) diff --git a/vortex-filtering/include/vortex_filtering/models/dynamic_models.hpp b/vortex-filtering/include/vortex_filtering/models/dynamic_models.hpp index 50de10d3..d5217ce4 100644 --- a/vortex-filtering/include/vortex_filtering/models/dynamic_models.hpp +++ b/vortex-filtering/include/vortex_filtering/models/dynamic_models.hpp @@ -59,7 +59,7 @@ class ConstantPosition : public interface::DynamicModelLTV<2, UNUSED, 2> { using T = Types_xuv; using S = StateName; - using StateT = State; + using StateT = State; /** Constant Position Model in 2D * x = [x, y] @@ -119,7 +119,7 @@ class ConstantVelocity : public interface::DynamicModelLTV<4, UNUSED, 2> { using T = vortex::Types_xuv; using S = StateName; - using StateT = State; + using StateT = State; using Vec_s = Eigen::Matrix; @@ -194,7 +194,7 @@ class ConstantAcceleration : public interface::DynamicModelLTV<3 * 2, UNUSED, 2 using T = vortex::Types_xv; using S = StateName; - using StateT = State; + using StateT = State; using Vec_s = Eigen::Matrix; using Mat_ss = Eigen::Matrix; @@ -270,7 +270,7 @@ class CoordinatedTurn : public interface::DynamicModelCTLTV<5, UNUSED, 3> { using T = vortex::Types_xv; using S = StateName; - using StateT = State; + using StateT = State; /** (Nearly) Coordinated Turn Model in 2D. (Nearly constant speed, nearly constant turn rate) * State = [x, y, x_dot, y_dot, omega] diff --git a/vortex-filtering/include/vortex_filtering/models/imm_model.hpp b/vortex-filtering/include/vortex_filtering/models/imm_model.hpp index a14a7251..cbc6f057 100644 --- a/vortex-filtering/include/vortex_filtering/models/imm_model.hpp +++ b/vortex-filtering/include/vortex_filtering/models/imm_model.hpp @@ -47,11 +47,11 @@ template class Imm static constexpr bool MIN_DIM_x = std::min(N_DIMS_x); static constexpr size_t N_MODELS = sizeof...(DynModels); + using StateName = decltype(DynModels::StateT::STATE_NAMES); using StateNames = std::tuple...>; static constexpr StateNames ALL_STATE_NAMES = {{DynModels::StateT::STATE_NAMES}...}; - using DynModTuple = std::tuple; using GaussTuple_x = std::tuple::Gauss_x...>; using Vec_n = Eigen::Vector; diff --git a/vortex-filtering/include/vortex_filtering/models/state.hpp b/vortex-filtering/include/vortex_filtering/models/state.hpp index b018984f..b7777661 100644 --- a/vortex-filtering/include/vortex_filtering/models/state.hpp +++ b/vortex-filtering/include/vortex_filtering/models/state.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -17,8 +18,13 @@ struct StateMinMax { double min; double max; }; + +template + requires std::is_integral_v || std::is_enum_v using StateMap = std::map; +template + requires std::is_integral_v || std::is_enum_v struct StateLocation { StateName name; size_t start_index; @@ -27,7 +33,7 @@ struct StateLocation { bool operator==(const StateLocation &other) const { return name == other.name && start_index == other.start_index && end_index == other.end_index; } }; -template constexpr auto index_of(const R &range, StateName needle) +template constexpr auto index_of(const R &range, T needle) { auto it = std::ranges::find(range, needle); if (it == std::ranges::end(range)) @@ -35,13 +41,16 @@ template constexpr auto index_of(const R &range, StateName needle) return std::ranges::distance(std::ranges::begin(range), it); } -template class State : public vortex::prob::MultiVarGauss { +template + requires std::is_integral_v || std::is_enum_v +class State : public vortex::prob::MultiVarGauss { public: static constexpr size_t N_STATES = sizeof...(Sn); static constexpr std::array STATE_NAMES = {Sn...}; using T = vortex::Types_n; + using StateLoc = StateLocation; State(const T::Vec_n &mean, const T::Mat_nn &cov) : vortex::prob::MultiVarGauss(mean, cov) @@ -88,8 +97,8 @@ template class State : public vortex::prob::MultiVarGauss STATE_MAP = []() { - std::array state_map = {}; + static constexpr std::array STATE_MAP = []() { + std::array state_map = {}; size_t start_index = 0; size_t map_index = 0; @@ -107,7 +116,7 @@ template class State : public vortex::prob::MultiVarGauss requires(has_state_name(S)) @@ -117,7 +126,7 @@ template class State : public vortex::prob::MultiVarGauss::Vec_n mean_of() const { - constexpr StateLocation sm = state_loc(S); + constexpr StateLoc sm = state_loc(S); return this->mean().template segment(sm.start_index); } @@ -125,7 +134,7 @@ template class State : public vortex::prob::MultiVarGauss::Vec_n &mean) { - constexpr StateLocation sm = state_loc(S); + constexpr StateLoc sm = state_loc(S); this->mean().template segment(sm.start_index) = mean; } @@ -133,7 +142,7 @@ template class State : public vortex::prob::MultiVarGauss::Mat_nn cov_of() const { - constexpr StateLocation sm = state_loc(S); + constexpr StateLoc sm = state_loc(S); return this->cov().template block(sm.start_index, sm.start_index); } @@ -141,7 +150,7 @@ template class State : public vortex::prob::MultiVarGauss::Mat_nn &cov) { - constexpr StateLocation sm = state_loc(S); + constexpr StateLoc sm = state_loc(S); this->cov().template block(sm.start_index, sm.start_index) = cov; } diff --git a/vortex-filtering/test/state_test.cpp b/vortex-filtering/test/state_test.cpp index 29e5d558..728673d1 100644 --- a/vortex-filtering/test/state_test.cpp +++ b/vortex-filtering/test/state_test.cpp @@ -8,7 +8,7 @@ TEST(State, typeChecks) using namespace vortex; using S = StateName; - using StateT = State; + using StateT = State; ASSERT_EQ(StateT::N_STATES, 4); ASSERT_EQ(StateT::UNIQUE_STATES_COUNT, 3); @@ -31,7 +31,7 @@ TEST(State, init) using namespace vortex; using S = StateName; - using StateT = State; + using StateT = State; auto x = prob::Gauss4d::Standard(); StateT state(x); @@ -46,7 +46,7 @@ TEST(State, getMean) using namespace vortex; using S = StateName; - using StateT = State; + using StateT = State; auto x = prob::Gauss4d::Standard(); StateT state(x); @@ -61,7 +61,7 @@ TEST(State, getCov) using namespace vortex; using S = StateName; - using StateT = State; + using StateT = State; auto x = prob::Gauss4d::Standard(); StateT state(x); @@ -76,14 +76,15 @@ TEST(State, setMean) using namespace vortex; using S = StateName; - using StateT = State; + using StateT = State; auto x = prob::Gauss4d::Standard(); StateT state(x); StateT::T::Vec_n mean = StateT::T::Vec_n::Random(); StateT::T::Mat_nn cov = StateT::T::Mat_nn::Random(); - cov = 0.5 * (cov + cov.transpose()).eval(); + + cov = 0.5 * (cov + cov.transpose()).eval(); cov += StateT::T::Mat_nn::Identity() * StateT::N_STATES; StateT::T::Gauss_n x_new = {mean, cov}; @@ -99,7 +100,7 @@ TEST(State, setCov) using namespace vortex; using S = StateName; - using StateT = State; + using StateT = State; auto x = prob::Gauss4d::Standard(); StateT state(x); diff --git a/vortex-filtering/test/types_test.cpp b/vortex-filtering/test/types_test.cpp index 427404cc..695689d0 100644 --- a/vortex-filtering/test/types_test.cpp +++ b/vortex-filtering/test/types_test.cpp @@ -40,7 +40,7 @@ TEST(Concepts, MultiVarGaussLike) static_assert(!vortex::concepts::MultiVarGaussLike); using S = vortex::StateName; - using StateT = vortex::State; + using StateT = vortex::State; static_assert(vortex::concepts::MultiVarGaussLike); ASSERT_TRUE(true);