Skip to content

Commit

Permalink
further reduce dependence on std::vector
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Jan 20, 2025
1 parent e00afba commit b48dbc2
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions RandLAPACK/misc/rl_linops.hh
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,20 @@ struct RegExplicitSymLinOp : public SymmetricLinearOperator<T> {
using scalar_t = T;

RegExplicitSymLinOp(
int64_t m, const T* A_buff, int64_t lda, vector<T> &arg_regs
int64_t m, const T* A_buff, int64_t lda, T* arg_regs, int64_t arg_num_regs
) : SymmetricLinearOperator<T>(m), A_buff(A_buff), lda(lda) {
randblas_require(lda >= m);
_eval_includes_reg = false;
num_regs = arg_regs.size();
num_regs = arg_num_regs;
num_regs = std::max(num_regs, (int64_t) 1);
regs = new T[num_regs]{};
std::copy(arg_regs.begin(), arg_regs.end(), regs);
std::copy(arg_regs, arg_regs, regs);
}

RegExplicitSymLinOp(
int64_t m, const T* A_buff, int64_t lda, vector<T> &arg_regs
) : RegExplicitSymLinOp<T>(m, A_buff, lda, arg_regs.data(), static_cast<int64_t>(arg_regs.size())) {}

~RegExplicitSymLinOp() {
if (regs != nullptr) delete [] regs;
}
Expand Down Expand Up @@ -244,11 +248,7 @@ struct SpectralPrecond {
num_regs = arg_num_regs;
}

void prep(vector<T> &eigvecs, vector<T> &eigvals, vector<T> &mus, int64_t arg_num_rhs) {
// assume eigvals are positive numbers sorted in decreasing order.
int64_t arg_num_regs = mus.size();
int64_t arg_dim_pre = eigvals.size();
reset_owned_buffers(arg_dim_pre, arg_num_rhs, arg_num_regs);
void set_D_from_eigs_and_regs(T* eigvals, T* mus) {
for (int64_t r = 0; r < num_regs; ++r) {
T mu_r = mus[r];
T* D_r = D + r*dim_pre;
Expand All @@ -257,6 +257,15 @@ struct SpectralPrecond {
D_r[i] = (numerator / (eigvals[i] + mu_r)) - 1.0;
}
}
return;
}

void prep(vector<T> &eigvecs, vector<T> &eigvals, vector<T> &mus, int64_t arg_num_rhs) {
// assume eigvals are positive numbers sorted in decreasing order.
int64_t arg_num_regs = mus.size();
int64_t arg_dim_pre = eigvals.size();
reset_owned_buffers(arg_dim_pre, arg_num_rhs, arg_num_regs);
set_D_from_eigs_and_regs(eigvals.data(), mus.data());
std::copy(eigvecs.begin(), eigvecs.end(), V);
return;
}
Expand Down

0 comments on commit b48dbc2

Please sign in to comment.