diff --git a/RandLAPACK/misc/rl_linops.hh b/RandLAPACK/misc/rl_linops.hh index 43f21fe3..db22cd6e 100644 --- a/RandLAPACK/misc/rl_linops.hh +++ b/RandLAPACK/misc/rl_linops.hh @@ -115,16 +115,20 @@ struct RegExplicitSymLinOp : public SymmetricLinearOperator { using scalar_t = T; RegExplicitSymLinOp( - int64_t m, const T* A_buff, int64_t lda, vector &arg_regs + int64_t m, const T* A_buff, int64_t lda, T* arg_regs, int64_t arg_num_regs ) : SymmetricLinearOperator(m), A_buff(A_buff), lda(lda) { randblas_require(lda >= m); _eval_includes_reg = false; - num_regs = arg_regs.size(); + num_regs = arg_num_regs; num_regs = std::max(num_regs, (int64_t) 1); regs = new T[num_regs]{}; - std::copy(arg_regs.begin(), arg_regs.end(), regs); + std::copy(arg_regs, arg_regs, regs); } + RegExplicitSymLinOp( + int64_t m, const T* A_buff, int64_t lda, vector &arg_regs + ) : RegExplicitSymLinOp(m, A_buff, lda, arg_regs.data(), static_cast(arg_regs.size())) {} + ~RegExplicitSymLinOp() { if (regs != nullptr) delete [] regs; } @@ -244,11 +248,7 @@ struct SpectralPrecond { num_regs = arg_num_regs; } - void prep(vector &eigvecs, vector &eigvals, vector &mus, int64_t arg_num_rhs) { - // assume eigvals are positive numbers sorted in decreasing order. - int64_t arg_num_regs = mus.size(); - int64_t arg_dim_pre = eigvals.size(); - reset_owned_buffers(arg_dim_pre, arg_num_rhs, arg_num_regs); + void set_D_from_eigs_and_regs(T* eigvals, T* mus) { for (int64_t r = 0; r < num_regs; ++r) { T mu_r = mus[r]; T* D_r = D + r*dim_pre; @@ -257,6 +257,15 @@ struct SpectralPrecond { D_r[i] = (numerator / (eigvals[i] + mu_r)) - 1.0; } } + return; + } + + void prep(vector &eigvecs, vector &eigvals, vector &mus, int64_t arg_num_rhs) { + // assume eigvals are positive numbers sorted in decreasing order. + int64_t arg_num_regs = mus.size(); + int64_t arg_dim_pre = eigvals.size(); + reset_owned_buffers(arg_dim_pre, arg_num_rhs, arg_num_regs); + set_D_from_eigs_and_regs(eigvals.data(), mus.data()); std::copy(eigvecs.begin(), eigvecs.end(), V); return; }