Skip to content

Commit

Permalink
started making State type independent on a specific enum
Browse files Browse the repository at this point in the history
  • Loading branch information
EirikKolas committed Jul 14, 2024
1 parent 28fe538 commit 011476b
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ConstantPosition : public interface::DynamicModelLTV<2, UNUSED, 2> {
using T = Types_xuv<N_DIM_x, N_DIM_u, N_DIM_v>;

using S = StateName;
using StateT = State<S::position, S::position>;
using StateT = State<S, S::position, S::position>;

/** Constant Position Model in 2D
* x = [x, y]
Expand Down Expand Up @@ -119,7 +119,7 @@ class ConstantVelocity : public interface::DynamicModelLTV<4, UNUSED, 2> {
using T = vortex::Types_xuv<N_DIM_x, N_DIM_u, N_DIM_v>;

using S = StateName;
using StateT = State<S::position, S::position, S::velocity, S::velocity>;
using StateT = State<S, S::position, S::position, S::velocity, S::velocity>;


using Vec_s = Eigen::Matrix<double, N_SPATIAL_DIM, 1>;
Expand Down Expand Up @@ -194,7 +194,7 @@ class ConstantAcceleration : public interface::DynamicModelLTV<3 * 2, UNUSED, 2
using T = vortex::Types_xv<N_STATES, N_DIM_v>;

using S = StateName;
using StateT = State<S::position, S::position, S::velocity, S::velocity, S::acceleration, S::acceleration>;
using StateT = State<S, S::position, S::position, S::velocity, S::velocity, S::acceleration, S::acceleration>;

using Vec_s = Eigen::Matrix<double, N_SPATIAL_DIM, 1>;
using Mat_ss = Eigen::Matrix<double, N_SPATIAL_DIM, N_SPATIAL_DIM>;
Expand Down Expand Up @@ -270,7 +270,7 @@ class CoordinatedTurn : public interface::DynamicModelCTLTV<5, UNUSED, 3> {
using T = vortex::Types_xv<N_DIM_x, N_DIM_v>;

using S = StateName;
using StateT = State<S::position, S::position, S::velocity, S::velocity, S::turn_rate>;
using StateT = State<S, S::position, S::position, S::velocity, S::velocity, S::turn_rate>;

/** (Nearly) Coordinated Turn Model in 2D. (Nearly constant speed, nearly constant turn rate)
* State = [x, y, x_dot, y_dot, omega]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ template <vortex::concepts::DynamicModelWithDefinedSizes... DynModels> class Imm
static constexpr bool MIN_DIM_x = std::min(N_DIMS_x);
static constexpr size_t N_MODELS = sizeof...(DynModels);

using StateName = decltype(DynModels::StateT::STATE_NAMES);
using StateNames = std::tuple<std::array<StateName, DynModels::N_DIM_x>...>;

static constexpr StateNames ALL_STATE_NAMES = {{DynModels::StateT::STATE_NAMES}...};


using DynModTuple = std::tuple<DynModels...>;
using GaussTuple_x = std::tuple<typename Types_x<DynModels::N_DIM_x>::Gauss_x...>;
using Vec_n = Eigen::Vector<double, N_MODELS>;
Expand Down
27 changes: 18 additions & 9 deletions vortex-filtering/include/vortex_filtering/models/state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <array>
#include <iostream>
#include <map>
#include <type_traits>
#include <vortex_filtering/probability/multi_var_gauss.hpp>
#include <vortex_filtering/types/type_aliases.hpp>

Expand All @@ -17,8 +18,13 @@ struct StateMinMax {
double min;
double max;
};

template <typename StateName>
requires std::is_integral_v<StateName> || std::is_enum_v<StateName>
using StateMap = std::map<StateName, StateMinMax>;

template <typename StateName>
requires std::is_integral_v<StateName> || std::is_enum_v<StateName>
struct StateLocation {
StateName name;
size_t start_index;
Expand All @@ -27,21 +33,24 @@ struct StateLocation {
bool operator==(const StateLocation &other) const { return name == other.name && start_index == other.start_index && end_index == other.end_index; }
};

template <typename R> constexpr auto index_of(const R &range, StateName needle)
template <typename T, typename R> constexpr auto index_of(const R &range, T needle)
{
auto it = std::ranges::find(range, needle);
if (it == std::ranges::end(range))
throw std::logic_error("Element not found!");
return std::ranges::distance(std::ranges::begin(range), it);
}

template <StateName... Sn> class State : public vortex::prob::MultiVarGauss<sizeof...(Sn)> {
template <typename StateName, StateName... Sn>
requires std::is_integral_v<StateName> || std::is_enum_v<StateName>
class State : public vortex::prob::MultiVarGauss<sizeof...(Sn)> {
public:
static constexpr size_t N_STATES = sizeof...(Sn);

static constexpr std::array<StateName, N_STATES> STATE_NAMES = {Sn...};

using T = vortex::Types_n<N_STATES>;
using StateLoc = StateLocation<StateName>;

State(const T::Vec_n &mean, const T::Mat_nn &cov)
: vortex::prob::MultiVarGauss<N_STATES>(mean, cov)
Expand Down Expand Up @@ -88,8 +97,8 @@ template <StateName... Sn> class State : public vortex::prob::MultiVarGauss<size
return unique_state_names;
}();

static constexpr std::array<StateLocation, UNIQUE_STATES_COUNT> STATE_MAP = []() {
std::array<StateLocation, UNIQUE_STATES_COUNT> state_map = {};
static constexpr std::array<StateLoc, UNIQUE_STATES_COUNT> STATE_MAP = []() {
std::array<StateLoc, UNIQUE_STATES_COUNT> state_map = {};

size_t start_index = 0;
size_t map_index = 0;
Expand All @@ -107,7 +116,7 @@ template <StateName... Sn> class State : public vortex::prob::MultiVarGauss<size
static constexpr bool has_state_name(StateName S) { return std::find(UNIQUE_STATE_NAMES.begin(), UNIQUE_STATE_NAMES.end(), S) != UNIQUE_STATE_NAMES.end(); }

public:
static constexpr StateLocation state_loc(StateName S) { return STATE_MAP[index_of(UNIQUE_STATE_NAMES, S)]; }
static constexpr StateLoc state_loc(StateName S) { return STATE_MAP[index_of(UNIQUE_STATE_NAMES, S)]; }

template <StateName S>
requires(has_state_name(S))
Expand All @@ -117,31 +126,31 @@ template <StateName... Sn> class State : public vortex::prob::MultiVarGauss<size
requires(has_state_name(S))
T_n<S>::Vec_n mean_of() const
{
constexpr StateLocation sm = state_loc(S);
constexpr StateLoc sm = state_loc(S);
return this->mean().template segment<sm.size()>(sm.start_index);
}

template <StateName S>
requires(has_state_name(S))
void set_mean_of(const T_n<S>::Vec_n &mean)
{
constexpr StateLocation sm = state_loc(S);
constexpr StateLoc sm = state_loc(S);
this->mean().template segment<sm.size()>(sm.start_index) = mean;
}

template <StateName S>
requires(has_state_name(S))
T_n<S>::Mat_nn cov_of() const
{
constexpr StateLocation sm = state_loc(S);
constexpr StateLoc sm = state_loc(S);
return this->cov().template block<sm.size(), sm.size()>(sm.start_index, sm.start_index);
}

template <StateName S>
requires(has_state_name(S))
void set_cov_of(const T_n<S>::Mat_nn &cov)
{
constexpr StateLocation sm = state_loc(S);
constexpr StateLoc sm = state_loc(S);
this->cov().template block<sm.size(), sm.size()>(sm.start_index, sm.start_index) = cov;
}

Expand Down
15 changes: 8 additions & 7 deletions vortex-filtering/test/state_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ TEST(State, typeChecks)
using namespace vortex;

using S = StateName;
using StateT = State<S::position, S::velocity, S::velocity, S::acceleration>;
using StateT = State<S, S::position, S::velocity, S::velocity, S::acceleration>;

ASSERT_EQ(StateT::N_STATES, 4);
ASSERT_EQ(StateT::UNIQUE_STATES_COUNT, 3);
Expand All @@ -31,7 +31,7 @@ TEST(State, init)
using namespace vortex;

using S = StateName;
using StateT = State<S::position, S::velocity, S::velocity, S::acceleration>;
using StateT = State<S, S::position, S::velocity, S::velocity, S::acceleration>;

auto x = prob::Gauss4d::Standard();
StateT state(x);
Expand All @@ -46,7 +46,7 @@ TEST(State, getMean)
using namespace vortex;

using S = StateName;
using StateT = State<S::position, S::velocity, S::velocity, S::acceleration>;
using StateT = State<S, S::position, S::velocity, S::velocity, S::acceleration>;

auto x = prob::Gauss4d::Standard();
StateT state(x);
Expand All @@ -61,7 +61,7 @@ TEST(State, getCov)
using namespace vortex;

using S = StateName;
using StateT = State<S::position, S::velocity, S::velocity, S::acceleration>;
using StateT = State<S, S::position, S::velocity, S::velocity, S::acceleration>;

auto x = prob::Gauss4d::Standard();
StateT state(x);
Expand All @@ -76,14 +76,15 @@ TEST(State, setMean)
using namespace vortex;

using S = StateName;
using StateT = State<S::position, S::velocity, S::velocity, S::acceleration>;
using StateT = State<S, S::position, S::velocity, S::velocity, S::acceleration>;

auto x = prob::Gauss4d::Standard();
StateT state(x);

StateT::T::Vec_n mean = StateT::T::Vec_n::Random();
StateT::T::Mat_nn cov = StateT::T::Mat_nn::Random();
cov = 0.5 * (cov + cov.transpose()).eval();

cov = 0.5 * (cov + cov.transpose()).eval();
cov += StateT::T::Mat_nn::Identity() * StateT::N_STATES;

StateT::T::Gauss_n x_new = {mean, cov};
Expand All @@ -99,7 +100,7 @@ TEST(State, setCov)
using namespace vortex;

using S = StateName;
using StateT = State<S::position, S::velocity, S::velocity, S::acceleration>;
using StateT = State<S, S::position, S::velocity, S::velocity, S::acceleration>;

auto x = prob::Gauss4d::Standard();
StateT state(x);
Expand Down
2 changes: 1 addition & 1 deletion vortex-filtering/test/types_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ TEST(Concepts, MultiVarGaussLike)
static_assert(!vortex::concepts::MultiVarGaussLike<Gauss2d, 3>);

using S = vortex::StateName;
using StateT = vortex::State<S::position, S::position, S::velocity, S::velocity>;
using StateT = vortex::State<S, S::position, S::position, S::velocity, S::velocity>;
static_assert(vortex::concepts::MultiVarGaussLike<StateT, StateT::N_STATES>);

ASSERT_TRUE(true);
Expand Down

0 comments on commit 011476b

Please sign in to comment.