Skip to content

Commit

Permalink
refactored ekf
Browse files Browse the repository at this point in the history
  • Loading branch information
EirikKolas committed Mar 10, 2024
1 parent f4eef65 commit d373ba5
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 59 deletions.
92 changes: 33 additions & 59 deletions vortex-filtering/include/vortex_filtering/filters/ekf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <tuple>
#include <vortex_filtering/models/dynamic_model_interfaces.hpp>
#include <vortex_filtering/models/sensor_model_interfaces.hpp>
#include <vortex_filtering/models/type_aliases.hpp>
#include <vortex_filtering/probability/multi_var_gauss.hpp>

namespace vortex::filter {
Expand All @@ -24,41 +25,12 @@ namespace vortex::filter {
template <models::concepts::DynamicModelLTV DynModT, models::concepts::SensorModelLTV SensModT> class EKF {
public:
static constexpr int N_DIM_x = DynModT::DynModI::N_DIM_x;
static constexpr int N_DIM_u = DynModT::DynModI::N_DIM_u;
static constexpr int N_DIM_z = SensModT::SensModI::N_DIM_z;
static constexpr int N_DIM_u = DynModT::DynModI::N_DIM_u;
static constexpr int N_DIM_v = DynModT::DynModI::N_DIM_v;
static constexpr int N_DIM_w = SensModT::SensModI::N_DIM_w;

using Vec_x = Eigen::Vector<double, N_DIM_x>;
using Vec_z = Eigen::Vector<double, N_DIM_z>;
using Vec_u = Eigen::Vector<double, N_DIM_u>;
using Vec_v = Eigen::Vector<double, N_DIM_v>;
using Vec_w = Eigen::Vector<double, N_DIM_w>;

using Mat_xx = Eigen::Matrix<double, N_DIM_x, N_DIM_x>;
using Mat_xz = Eigen::Matrix<double, N_DIM_x, N_DIM_z>;
using Mat_xv = Eigen::Matrix<double, N_DIM_x, N_DIM_v>;
using Mat_xw = Eigen::Matrix<double, N_DIM_x, N_DIM_w>;

using Mat_zx = Eigen::Matrix<double, N_DIM_z, N_DIM_x>;
using Mat_zz = Eigen::Matrix<double, N_DIM_z, N_DIM_z>;
using Mat_zw = Eigen::Matrix<double, N_DIM_z, N_DIM_w>;

using Mat_vx = Eigen::Matrix<double, N_DIM_v, N_DIM_x>;
using Mat_vv = Eigen::Matrix<double, N_DIM_v, N_DIM_v>;
using Mat_vw = Eigen::Matrix<double, N_DIM_v, N_DIM_w>;

using Mat_wx = Eigen::Matrix<double, N_DIM_w, N_DIM_x>;
using Mat_wv = Eigen::Matrix<double, N_DIM_w, N_DIM_v>;
using Mat_ww = Eigen::Matrix<double, N_DIM_w, N_DIM_w>;

using Gauss_x = prob::MultiVarGauss<N_DIM_x>;
using Gauss_z = prob::MultiVarGauss<N_DIM_z>;
using Gauss_v = prob::MultiVarGauss<N_DIM_v>;
using Gauss_w = prob::MultiVarGauss<N_DIM_w>;

using DynModTPtr = std::shared_ptr<DynModT>;
using SensModTPtr = std::shared_ptr<SensModT>;
using T = Types_xzuvw<N_DIM_x, N_DIM_z, N_DIM_u, N_DIM_v, N_DIM_w>;

EKF() = delete;

Expand All @@ -67,41 +39,41 @@ template <models::concepts::DynamicModelLTV DynModT, models::concepts::SensorMod
* @param sens_mod Sensor model
* @param dt Time step
* @param x_est_prev Previous state estimate
* @param u Vec_x Input. Not used, set to zero.
* @return std::pair<Gauss_x, Gauss_z> Predicted state, predicted measurement
* @param u T::Vec_x Input. Not used, set to zero.
* @return std::pair<T::Gauss_x, T::Gauss_z> Predicted state, predicted measurement
* @throws std::runtime_error if dyn_mod or sens_mod are not of the DynamicModelT or SensorModelT type
*/
static std::pair<Gauss_x, Gauss_z> predict(const DynModT &dyn_mod, const SensModT &sens_mod, double dt, const Gauss_x &x_est_prev,
const Vec_u &u = Vec_u::Zero())
static std::pair<typename T::Gauss_x, typename T::Gauss_z> predict(const DynModT &dyn_mod, const SensModT &sens_mod, double dt, const T::Gauss_x &x_est_prev,
const T::Vec_u &u = T::Vec_u::Zero())
{
Gauss_x x_est_pred = dyn_mod.pred_from_est(dt, x_est_prev, u);
Gauss_z z_est_pred = sens_mod.pred_from_est(x_est_pred);
typename T::Gauss_x x_est_pred = dyn_mod.pred_from_est(dt, x_est_prev, u);
typename T::Gauss_z z_est_pred = sens_mod.pred_from_est(x_est_pred);
return {x_est_pred, z_est_pred};
}

/** Perform one EKF update step
* @param sens_mod Sensor model
* @param x_est_pred Predicted state
* @param z_est_pred Predicted measurement
* @param z_meas Vec_z Measurement
* @param z_meas T::Vec_z Measurement
* @return MultivarGauss Updated state
* @throws std::runtime_error ifsens_mod is not of the SensorModelT type
*/
static Gauss_x update(const SensModT &sens_mod, const Gauss_x &x_est_pred, const Gauss_z &z_est_pred, const Vec_z &z_meas)
static T::Gauss_x update(const SensModT &sens_mod, const T::Gauss_x &x_est_pred, const T::Gauss_z &z_est_pred, const T::Vec_z &z_meas)
{
Mat_zx C = sens_mod.C(x_est_pred.mean()); // Measurement matrix
Mat_ww R = sens_mod.R(x_est_pred.mean()); // Measurement noise covariance
Mat_zw H = sens_mod.H(x_est_pred.mean()); // Measurement noise cross-covariance
Mat_xx P = x_est_pred.cov(); // State covariance
Mat_zz S_inv = z_est_pred.cov_inv(); // Inverse of the predicted measurement covariance
Mat_xx I = Mat_xx::Identity(N_DIM_x, N_DIM_x);
typename T::Mat_zx C = sens_mod.C(x_est_pred.mean()); // Measurement matrix
typename T::Mat_ww R = sens_mod.R(x_est_pred.mean()); // Measurement noise covariance
typename T::Mat_zw H = sens_mod.H(x_est_pred.mean()); // Measurement noise cross-covariance
typename T::Mat_xx P = x_est_pred.cov(); // State covariance
typename T::Mat_zz S_inv = z_est_pred.cov_inv(); // Inverse of the predicted measurement covariance
typename T::Mat_xx I = T::Mat_xx::Identity(N_DIM_x, N_DIM_x);

Mat_xz W = P * C.transpose() * S_inv; // Kalman gain
Vec_z innovation = z_meas - z_est_pred.mean();
typename T::Mat_xz W = P * C.transpose() * S_inv; // Kalman gain
typename T::Vec_z innovation = z_meas - z_est_pred.mean();

Vec_x state_upd_mean = x_est_pred.mean() + W * innovation;
typename T::Vec_x state_upd_mean = x_est_pred.mean() + W * innovation;
// Use the Joseph form of the covariance update to ensure positive definiteness
Mat_xx state_upd_cov = (I - W * C) * P * (I - W * C).transpose() + W * H * R * H.transpose() * W.transpose();
typename T::Mat_xx state_upd_cov = (I - W * C) * P * (I - W * C).transpose() + W * H * R * H.transpose() * W.transpose();

return {state_upd_mean, state_upd_cov};
}
Expand All @@ -111,32 +83,34 @@ template <models::concepts::DynamicModelLTV DynModT, models::concepts::SensorMod
* @param sens_mod Sensor model
* @param dt Time step
* @param x_est_prev Previous state estimate
* @param z_meas Vec_z Measurement
* @param u Vec_x Input
* @param z_meas T::Vec_z Measurement
* @param u T::Vec_x Input
* @return Updated state, predicted state, predicted measurement
*/
static std::tuple<Gauss_x, Gauss_x, Gauss_z> step(const DynModT &dyn_mod, const SensModT &sens_mod, double dt, const Gauss_x &x_est_prev, const Vec_z &z_meas,
const Vec_u &u = Vec_u::Zero())
static std::tuple<typename T::Gauss_x, typename T::Gauss_x, typename T::Gauss_z>
step(const DynModT &dyn_mod, const SensModT &sens_mod, double dt, const T::Gauss_x &x_est_prev, const T::Vec_z &z_meas, const T::Vec_u &u = T::Vec_u::Zero())
{
auto [x_est_pred, z_est_pred] = predict(dyn_mod, sens_mod, dt, x_est_prev, u);

Gauss_x x_est_upd = update(sens_mod, x_est_pred, z_est_pred, z_meas);
typename T::Gauss_x x_est_upd = update(sens_mod, x_est_pred, z_est_pred, z_meas);
return {x_est_upd, x_est_pred, z_est_pred};
}

[[deprecated("use const DynModT& and const SensModT&")]] static std::pair<Gauss_x, Gauss_z>
predict(DynModTPtr dyn_mod, SensModTPtr sens_mod, double dt, const Gauss_x &x_est_prev, const Vec_u &u = Vec_u::Zero())
[[deprecated("use const DynModT& and const SensModT&")]] static std::pair<typename T::Gauss_x, typename T::Gauss_z>
predict(std::shared_ptr<DynModT> dyn_mod, std::shared_ptr<SensModT> sens_mod, double dt, const T::Gauss_x &x_est_prev, const T::Vec_u &u = T::Vec_u::Zero())
{
return predict(*dyn_mod, *sens_mod, dt, x_est_prev, u);
}

[[deprecated("use const SensModT&")]] static Gauss_x update(SensModTPtr sens_mod, const Gauss_x &x_est_pred, const Gauss_z &z_est_pred, const Vec_z &z_meas)
[[deprecated("use const SensModT&")]] static T::Gauss_x update(std::shared_ptr<SensModT> sens_mod, const T::Gauss_x &x_est_pred, const T::Gauss_z &z_est_pred,
const T::Vec_z &z_meas)
{
return update(*sens_mod, x_est_pred, z_est_pred, z_meas);
}

[[deprecated("use const DynModT& and const SensModT&")]] static std::tuple<Gauss_x, Gauss_x, Gauss_z>
step(DynModTPtr dyn_mod, SensModTPtr sens_mod, double dt, const Gauss_x &x_est_prev, const Vec_z &z_meas, const Vec_u &u = Vec_u::Zero())
[[deprecated("use const DynModT& and const SensModT&")]] static std::tuple<typename T::Gauss_x, typename T::Gauss_x, typename T::Gauss_z>
step(std::shared_ptr<DynModT> dyn_mod, std::shared_ptr<SensModT> sens_mod, double dt, const T::Gauss_x &x_est_prev, const T::Vec_z &z_meas,
const T::Vec_u &u = T::Vec_u::Zero())
{
return step(*dyn_mod, *sens_mod, dt, x_est_prev, z_meas, u);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ template <size_t n_dim_x, size_t n_dim_z, size_t n_dim_u, size_t n_dim_v, size_t

} // namespace vortex


// Don't want you to use these macros outside of this file :)
#undef VORTEX_TYPES_1
#undef VORTEX_TYPES_2
#undef VOXTEX_TYPES_3
Expand Down

0 comments on commit d373ba5

Please sign in to comment.