Skip to content

Commit

Permalink
uses Eigen's internal multiindexing for multi indexing on vectors and…
Browse files Browse the repository at this point in the history
… matrices
  • Loading branch information
SteveBronder committed Jan 8, 2024
1 parent 706b751 commit e46a24b
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 95 deletions.
194 changes: 100 additions & 94 deletions src/stan/model/indexing/rvalue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,20 @@ inline auto rvalue(Vec&& v, const char* name, index_uni idx) {
* @throw std::invalid_argument If the value size isn't the same as
* the indexed size.
*/
template <typename EigVec, require_eigen_vector_t<EigVec>* = nullptr>
inline auto rvalue(EigVec&& v, const char* name, const index_multi& idx) {
template <typename EigVec, typename MultiIndex, require_eigen_vector_t<EigVec>* = nullptr,
require_same_t<MultiIndex, index_multi>* = nullptr>
inline auto rvalue(EigVec&& v, const char* name, MultiIndex&& idx) {
using fwd_t = decltype(stan::math::to_ref(std::forward<EigVec>(v)));
for (auto idx_i : idx.ns_) {
math::check_range("vector[multi] indexing", name, v.size(), idx_i);
}
return stan::math::make_holder(
[name, &idx](auto& v_ref) {
return plain_type_t<EigVec>::NullaryExpr(
idx.ns_.size(), [name, &idx, &v_ref](Eigen::Index i) {
math::check_range("vector[multi] indexing", name, v_ref.size(),
idx.ns_[i]);
return v_ref.coeff(idx.ns_[i] - 1);
});
},
stan::math::to_ref(v));
[name](auto&& v_ref, auto&& idx_inner) {
Eigen::Map<Eigen::Array<int, -1, 1>> idx2(idx_inner.ns_.data(), idx_inner.ns_.size());
return std::forward<decltype(v_ref)>(v_ref)(idx2 - 1);
},
std::forward<fwd_t>(stan::math::to_ref(std::forward<EigVec>(v))),
std::forward<MultiIndex>(idx));
}

/**
Expand Down Expand Up @@ -262,21 +264,19 @@ inline auto rvalue(Mat&& x, const char* name, index_uni idx) {
* @param[in] idx A multi index for selecting a set of rows.
* @throw std::out_of_range If any of the indices are out of bounds.
*/
template <typename EigMat, require_eigen_dense_dynamic_t<EigMat>* = nullptr>
inline plain_type_t<EigMat> rvalue(EigMat&& x, const char* name,
const index_multi& idx) {
template <typename EigMat, typename MultiIndex, require_eigen_dense_dynamic_t<EigMat>* = nullptr,
require_same_t<MultiIndex, index_multi>* = nullptr>
inline auto rvalue(EigMat&& x, const char* name,
MultiIndex&& idx) {
for (int i = 0; i < idx.ns_.size(); ++i) {
math::check_range("matrix[multi] row indexing", name, x.rows(), idx.ns_[i]);
}
return stan::math::make_holder(
[&idx](auto& x_ref) {
return plain_type_t<EigMat>::NullaryExpr(
idx.ns_.size(), x_ref.cols(),
[&idx, &x_ref](Eigen::Index i, Eigen::Index j) {
return x_ref.coeff(idx.ns_[i] - 1, j);
});
[&idx](auto&& x_ref, auto&& idx_inner) {
using vec_map = Eigen::Map<Eigen::Array<int, -1, 1>>;
return x_ref((vec_map(idx_inner.ns_.data(), idx_inner.ns_.size()) - 1), Eigen::all);
},
stan::math::to_ref(x));
stan::math::to_ref(x), std::forward<MultiIndex>(idx));
}

/**
Expand Down Expand Up @@ -425,23 +425,23 @@ inline auto rvalue(Mat&& x, const char* name, index_uni row_idx,
* @param[in] col_idx multi index for selecting cols.
* @throw std::out_of_range If any of the indices are out of bounds.
*/
template <typename EigMat, require_eigen_dense_dynamic_t<EigMat>* = nullptr>
inline Eigen::Matrix<value_type_t<EigMat>, 1, Eigen::Dynamic> rvalue(
EigMat&& x, const char* name, index_uni row_idx,
const index_multi& col_idx) {
template <typename EigMat, typename MultiIndex, require_eigen_dense_dynamic_t<EigMat>* = nullptr,
require_same_t<MultiIndex, index_multi>* = nullptr>
inline auto rvalue(EigMat&& x, const char* name, index_uni row_idx,
MultiIndex&& col_idx) {
math::check_range("matrix[uni, multi] row indexing", name, x.rows(),
row_idx.n_);
for (auto idx_i : col_idx.ns_) {
math::check_range("matrix[uni, multi] column indexing", name, x.cols(),
idx_i);
}
return stan::math::make_holder(
[name, row_idx, &col_idx](auto& x_ref) {
return Eigen::Matrix<value_type_t<EigMat>, 1, Eigen::Dynamic>::
NullaryExpr(col_idx.ns_.size(), [name, row_i = row_idx.n_ - 1,
&col_idx, &x_ref](Eigen::Index i) {
math::check_range("matrix[uni, multi] column indexing", name,
x_ref.cols(), col_idx.ns_[i]);
return x_ref.coeff(row_i, col_idx.ns_[i] - 1);
});
[name, row_idx](auto&& x_ref, auto&& col_idx_inner) {
using vec_map = Eigen::Map<Eigen::Array<int, -1, 1>>;
return x_ref(row_idx.n_ - 1, (vec_map(col_idx_inner.ns_.data(), col_idx_inner.ns_.size()) - 1));
},
stan::math::to_ref(x));
stan::math::to_ref(std::forward<EigMat>(x)),
std::forward<MultiIndex>(col_idx));
}

/**
Expand All @@ -457,25 +457,22 @@ inline Eigen::Matrix<value_type_t<EigMat>, 1, Eigen::Dynamic> rvalue(
* @param[in] col_idx uni index for selecting cols.
* @throw std::out_of_range If any of the indices are out of bounds.
*/
template <typename EigMat, require_eigen_dense_dynamic_t<EigMat>* = nullptr>
inline Eigen::Matrix<value_type_t<EigMat>, Eigen::Dynamic, 1> rvalue(
EigMat&& x, const char* name, const index_multi& row_idx,
template <typename EigMat, typename MultiIndex, require_eigen_dense_dynamic_t<EigMat>* = nullptr,
require_same_t<MultiIndex, index_multi>* = nullptr>
inline auto rvalue(EigMat&& x, const char* name, MultiIndex&& row_idx,
index_uni col_idx) {
math::check_range("matrix[multi, uni] column indexing", name, x.cols(),
col_idx.n_);

for (auto idx_i : row_idx.ns_) {
math::check_range("matrix[uni, multi] row indexing", name, x.rows(),
idx_i);
}
return stan::math::make_holder(
[name, &row_idx, col_idx](auto& x_ref) {
return Eigen::Matrix<value_type_t<EigMat>, Eigen::Dynamic, 1>::
NullaryExpr(row_idx.ns_.size(),
[name, &row_idx, col_i = col_idx.n_ - 1,
&x_ref](Eigen::Index i) {
math::check_range("matrix[multi, uni] row indexing",
name, x_ref.rows(), row_idx.ns_[i]);
return x_ref.coeff(row_idx.ns_[i] - 1, col_i);
});
[name, col_idx](auto&& x_ref, auto&& row_idx_inner) {
using vec_map = Eigen::Map<Eigen::Array<int, -1, 1>>;
return x_ref((vec_map(row_idx_inner.ns_.data(), row_idx_inner.ns_.size()) - 1), col_idx.n_ - 1);
},
stan::math::to_ref(x));
stan::math::to_ref(x), std::forward<MultiIndex>(row_idx));
}

/**
Expand All @@ -491,26 +488,30 @@ inline Eigen::Matrix<value_type_t<EigMat>, Eigen::Dynamic, 1> rvalue(
* @param[in] col_idx multi index for selecting cols.
* @return Result of indexing matrix.
*/
template <typename EigMat, require_eigen_dense_dynamic_t<EigMat>* = nullptr>
inline plain_type_t<EigMat> rvalue(EigMat&& x, const char* name,
const index_multi& row_idx,
const index_multi& col_idx) {
const auto& x_ref = stan::math::to_ref(x);
template <typename EigMat, typename RowIndexMulti, typename ColIndexMulti,
require_eigen_dense_dynamic_t<EigMat>* = nullptr,
require_all_same_t<RowIndexMulti, ColIndexMulti, index_multi> * = nullptr>
inline auto rvalue(EigMat&& x, const char* name, RowIndexMulti&& row_idx,
ColIndexMulti&& col_idx) {
const Eigen::Index rows = row_idx.ns_.size();
const Eigen::Index cols = col_idx.ns_.size();
plain_type_t<EigMat> x_ret(rows, cols);
for (Eigen::Index j = 0; j < cols; ++j) {
for (Eigen::Index i = 0; i < rows; ++i) {
const Eigen::Index m = row_idx.ns_[i];
const Eigen::Index n = col_idx.ns_[j];
math::check_range("matrix[multi,multi] row indexing", name, x_ref.rows(),
m);
math::check_range("matrix[multi,multi] column indexing", name,
x_ref.cols(), n);
x_ret.coeffRef(i, j) = x_ref.coeff(m - 1, n - 1);
}
for (auto idx_i : row_idx.ns_) {
math::check_range("matrix[uni, multi] row indexing", name, x.rows(),
idx_i);
}
return x_ret;
for (auto idx_j : col_idx.ns_) {
math::check_range("matrix[uni, multi] col indexing", name, x.cols(),
idx_j);
}
return stan::math::make_holder(
[name](auto&& x_ref, auto&& row_idx_inner, auto&& col_idx_inner) {
using vec_map = Eigen::Map<Eigen::Array<int, -1, 1>>;
return x_ref((vec_map(row_idx_inner.ns_.data(), row_idx_inner.ns_.size()) - 1),
(vec_map(col_idx_inner.ns_.data(), col_idx_inner.ns_.size()) - 1));
},
stan::math::to_ref(std::forward<EigMat>(x)),
std::forward<RowIndexMulti>(row_idx),
std::forward<ColIndexMulti>(col_idx));
}

/**
Expand Down Expand Up @@ -547,22 +548,27 @@ inline auto rvalue(Mat&& x, const char* name, const Idx& row_idx,
* @param[in] col_idx multi index for selecting cols.
* @return Result of indexing matrix.
*/
template <typename EigMat, typename Idx,
template <typename EigMat, typename Idx, typename MultiIndex,
require_eigen_dense_dynamic_t<EigMat>* = nullptr,
require_not_same_t<std::decay_t<Idx>, index_uni>* = nullptr>
inline plain_type_t<EigMat> rvalue(EigMat&& x, const char* name,
const Idx& row_idx,
const index_multi& col_idx) {
const auto& x_ref = stan::math::to_ref(x);
const int rows = rvalue_index_size(row_idx, x_ref.rows());
plain_type_t<EigMat> x_ret(rows, col_idx.ns_.size());
for (int j = 0; j < col_idx.ns_.size(); ++j) {
const Eigen::Index n = col_idx.ns_[j];
math::check_range("matrix[..., multi] column indexing", name, x_ref.cols(),
n);
x_ret.col(j) = rvalue(x_ref.col(n - 1), name, row_idx);
require_not_same_t<std::decay_t<Idx>, index_uni>* = nullptr,
require_same_t<MultiIndex, index_multi>* = nullptr>
inline auto rvalue(EigMat&& x, const char* name,
Idx&& row_idx,
MultiIndex&& col_idx) {
for (auto idx_j : col_idx.ns_) {
math::check_range("matrix[..., multi] column indexing", name, x.cols(),
idx_j);
}
return x_ret;
return stan::math::make_holder(
[name](auto&& x_ref, auto&& row_idx_inner, auto&& col_idx_inner) {
using vec_map = Eigen::Map<Eigen::Array<int, -1, 1>>;
return rvalue(x_ref(Eigen::all,
(vec_map(col_idx_inner.ns_.data(), col_idx_inner.ns_.size()) - 1)),
name, std::forward<decltype(row_idx_inner)>(row_idx_inner));
},
stan::math::to_ref(std::forward<EigMat>(x)),
std::forward<Idx>(row_idx),
std::forward<MultiIndex>(col_idx));
}

/**
Expand All @@ -578,9 +584,9 @@ inline plain_type_t<EigMat> rvalue(EigMat&& x, const char* name,
* @throw std::out_of_range If any of the indices are out of bounds.
*/
template <typename Mat, typename Idx, require_dense_dynamic_t<Mat>* = nullptr>
inline auto rvalue(Mat&& x, const char* name, const Idx& row_idx,
inline auto rvalue(Mat&& x, const char* name, Idx&& row_idx,
index_omni /*col_idx*/) {
return rvalue(std::forward<Mat>(x), name, row_idx);
return rvalue(std::forward<Mat>(x), name, std::forward<Idx>(row_idx));
}

/**
Expand All @@ -599,12 +605,12 @@ inline auto rvalue(Mat&& x, const char* name, const Idx& row_idx,
* @throw std::out_of_range If any of the indices are out of bounds.
*/
template <typename Mat, typename Idx, require_dense_dynamic_t<Mat>* = nullptr>
inline auto rvalue(Mat&& x, const char* name, const Idx& row_idx,
inline auto rvalue(Mat&& x, const char* name, Idx&& row_idx,
index_min col_idx) {
const Eigen::Index col_size = x.cols() - (col_idx.min_ - 1);
math::check_range("matrix[..., min] column indexing", name, x.cols(),
col_idx.min_);
return rvalue(x.rightCols(col_size), name, row_idx);
return rvalue(x.rightCols(col_size), name, std::forward<Idx>(row_idx));
}

/**
Expand All @@ -623,14 +629,14 @@ inline auto rvalue(Mat&& x, const char* name, const Idx& row_idx,
* @throw std::out_of_range If any of the indices are out of bounds.
*/
template <typename Mat, typename Idx, require_dense_dynamic_t<Mat>* = nullptr>
inline auto rvalue(Mat&& x, const char* name, const Idx& row_idx,
inline auto rvalue(Mat&& x, const char* name, Idx&& row_idx,
index_max col_idx) {
if (col_idx.max_ > 0) {
math::check_range("matrix[..., max] column indexing", name, x.cols(),
col_idx.max_);
return rvalue(x.leftCols(col_idx.max_), name, row_idx);
return rvalue(x.leftCols(col_idx.max_), name, std::forward<Idx>(row_idx));
} else {
return rvalue(x.leftCols(0), name, row_idx);
return rvalue(x.leftCols(0), name, std::forward<Idx>(row_idx));
}
}

Expand All @@ -650,7 +656,7 @@ inline auto rvalue(Mat&& x, const char* name, const Idx& row_idx,
* @return Result of indexing matrix.
*/
template <typename Mat, typename Idx, require_dense_dynamic_t<Mat>* = nullptr>
inline auto rvalue(Mat&& x, const char* name, const Idx& row_idx,
inline auto rvalue(Mat&& x, const char* name, Idx&& row_idx,
index_min_max col_idx) {
math::check_range("matrix[..., min_max] min column indexing", name, x.cols(),
col_idx.min_);
Expand All @@ -659,9 +665,9 @@ inline auto rvalue(Mat&& x, const char* name, const Idx& row_idx,
math::check_range("matrix[..., min_max] max column indexing", name,
x.cols(), col_idx.max_);
return rvalue(x.middleCols(col_start, col_idx.max_ - col_start), name,
row_idx);
std::forward<Idx>(row_idx));
} else {
return rvalue(x.middleCols(col_start, 0), name, row_idx);
return rvalue(x.middleCols(col_start, 0), name, std::forward<Idx>(row_idx));
}
}

Expand All @@ -685,16 +691,16 @@ template <typename StdVec, typename... Idxs,
require_std_vector_t<StdVec>* = nullptr,
require_not_t<std::is_lvalue_reference<StdVec&&>>* = nullptr>
inline auto rvalue(StdVec&& v, const char* name, index_uni idx1,
const Idxs&... idxs) {
Idxs&&... idxs) {
math::check_range("array[uni, ...] index", name, v.size(), idx1.n_);
return rvalue(std::move(v[idx1.n_ - 1]), name, idxs...);
return rvalue(std::move(v[idx1.n_ - 1]), name, std::forward<Idxs>(idxs)...);
}
template <typename StdVec, typename... Idxs,
require_std_vector_t<StdVec>* = nullptr>
inline auto rvalue(StdVec& v, const char* name, index_uni idx1,
const Idxs&... idxs) {
Idxs&&... idxs) {
math::check_range("array[uni, ...] index", name, v.size(), idx1.n_);
return rvalue(v[idx1.n_ - 1], name, idxs...);
return rvalue(v[idx1.n_ - 1], name, std::forward<Idxs>(idxs)...);
}

/**
Expand Down Expand Up @@ -749,7 +755,7 @@ template <typename StdVec, typename Idx1, typename... Idxs,
require_std_vector_t<StdVec>* = nullptr,
require_not_same_t<Idx1, index_uni>* = nullptr>
inline auto rvalue(StdVec&& v, const char* name, const Idx1& idx1,
const Idxs&... idxs) {
Idxs&&... idxs) {
using inner_type = plain_type_t<decltype(
rvalue(v[rvalue_at(0, idx1) - 1], name, idxs...))>;
const auto index_size = rvalue_index_size(idx1, v.size());
Expand Down
2 changes: 1 addition & 1 deletion src/test/unit/model/indexing/rvalue_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ void vector_multi_test() {
test_out_of_range(v, index_multi(ns));
}

TEST(ModelIndexing, rvalueVectorMulti) { vector_multi_test<Eigen::VectorXd>(); }
TEST(ModelIndexing, x) { vector_multi_test<Eigen::VectorXd>(); }

TEST(ModelIndexing, rvalueRowVectorMulti) {
vector_multi_test<Eigen::RowVectorXd>();
Expand Down

0 comments on commit e46a24b

Please sign in to comment.