Skip to content

Commit

Permalink
clean up nested types. Change member algorithm naming conventions.
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Feb 1, 2025
1 parent e4098de commit bf3fa12
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 40 deletions.
11 changes: 7 additions & 4 deletions RandLAPACK/comps/rl_preconditioners.hh
Original file line number Diff line number Diff line change
Expand Up @@ -288,20 +288,23 @@ RandBLAS::RNGState<RNG> nystrom_pc_data(
int64_t num_syps_passes = 3,
int64_t num_steps_power_iter_error_est = 10
) {
RandLAPACK::SYPS<T, RNG> SYPS(num_syps_passes, 1, false, false);
using SYPS_t = RandLAPACK::SYPS<T, RNG>;
using Orth_t = RandLAPACK::HQRQ<T>;
using SYRF_t = RandLAPACK::SYRF<SYPS_t, Orth_t>;
SYPS_t SYPS(num_syps_passes, 1, false, false);
// ^ Define a symmetric power sketch algorithm.
// (*) Stabilize power iteration with pivoted-LU after every
// mulitplication with A.
// (*) Do not check condition numbers or log to std::out.
RandLAPACK::HQRQ<T> Orth(false, false);
Orth_t Orth(false, false);
// ^ Define an orthogonalizer for a symmetric rangefinder.
// (*) Get a dense representation of Q from Householder QR.
// (*) Do not check condition numbers or log to std::out.
RandLAPACK::SYRF<RandLAPACK::SYPS<T, RNG>> SYRF(SYPS, Orth, false, false);
SYRF_t SYRF(SYPS, Orth, false, false);
// ^ Define the symmetric rangefinder algorithm.
// (*) Use power sketching followed by Householder orthogonalization.
// (*) Do not check condition numbers or log to std::out.
RandLAPACK::REVD2<RandLAPACK::SYRF<RandLAPACK::SYPS<T, RNG>>> NystromAlg(SYRF, num_steps_power_iter_error_est, false);
RandLAPACK::REVD2<SYRF_t> NystromAlg(SYRF, num_steps_power_iter_error_est, false);
// ^ Define the algorithm for low-rank approximation via Nystrom.
// (*) Handle accuracy requests by estimating ||A - V diag(eigvals) V'||
// with "num_steps_power_iter_error_est" steps of power iteration.
Expand Down
6 changes: 3 additions & 3 deletions RandLAPACK/comps/rl_qb.hh
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class QB : public QBalg<T, RNG> {
RandLAPACK::Stabilization<T> &orth_obj,
bool verb,
bool orth
) : RF_Obj(rf_obj), Orth_Obj(orth_obj) {
) : RF_Obj(rf_obj), orth(orth_obj) {
verbose = verb;
orth_check = orth;
}
Expand Down Expand Up @@ -124,7 +124,7 @@ class QB : public QBalg<T, RNG> {

public:
RandLAPACK::RangeFinder<T, RNG> &RF_Obj;
RandLAPACK::Stabilization<T> &Orth_Obj;
RandLAPACK::Stabilization<T> &orth;
bool verbose;
bool orth_check;
};
Expand Down Expand Up @@ -211,7 +211,7 @@ int QB<T, RNG>::call(
// Q_i = orth(Q_i - Q(Q'Q_i))
blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, curr_sz, b_sz, m, 1.0, Q, m, Q_i, m, 0.0, QtQi, next_sz);
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, b_sz, curr_sz, -1.0, Q, m, QtQi, next_sz, 1.0, Q_i, m);
this->Orth_Obj.call(m, b_sz, Q_i);
this->orth.call(m, b_sz, Q_i);
}

//B_i' = A' * Q_i'
Expand Down
6 changes: 3 additions & 3 deletions RandLAPACK/comps/rl_rf.hh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class RF : public RangeFinder<T, RNG> {
RandLAPACK::Stabilization<T> &orth_obj,
bool verb,
bool cond
) : RS_Obj(rs_obj), Orth_Obj(orth_obj) {
) : RS_Obj(rs_obj), orth(orth_obj) {
verbose = verb;
cond_check = cond;
}
Expand Down Expand Up @@ -94,7 +94,7 @@ class RF : public RangeFinder<T, RNG> {
public:
// Instantiated in the constructor
RandLAPACK::RowSketcher<T, RNG> &RS_Obj;
RandLAPACK::Stabilization<T> &Orth_Obj;
RandLAPACK::Stabilization<T> &orth;
bool verbose;
bool cond_check;

Expand Down Expand Up @@ -127,7 +127,7 @@ int RF<T, RNG>::call(
// Writes into this->cond_nums
this->cond_nums.push_back(util::cond_num_check(m, k, Q, this->verbose));

if(this->Orth_Obj.call(m, k, Q))
if(this->orth.call(m, k, Q))
return 2; // Orthogonalization failed

// Normal termination
Expand Down
16 changes: 8 additions & 8 deletions RandLAPACK/comps/rl_syrf.hh
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ concept SymmetricRangeFinderConcept =
};


template <typename SYPS_t>
template <typename SYPS_t, typename Orth_t>
class SYRF {
public:
using T = typename SYPS_t::scalar_t;
using T = typename SYPS_t::scalar_t;
using RNG = typename SYPS_t::RNG_t;
SYPS_t &SYPS_Obj;
RandLAPACK::Stabilization<T> &Orth_Obj;
SYPS_t &syps;
Orth_t &orth;
bool verbose;
bool cond_check;
std::vector<T> cond_work_mat;
Expand All @@ -44,10 +44,10 @@ class SYRF {

SYRF(
SYPS_t &syps_obj,
RandLAPACK::Stabilization<T> &orth_obj,
Orth_t &orth_obj,
bool verb = false,
bool cond = false
) : SYPS_Obj(syps_obj), Orth_Obj(orth_obj) {
) : syps(syps_obj), orth(orth_obj) {
verbose = verb;
cond_check = cond;
}
Expand Down Expand Up @@ -110,7 +110,7 @@ class SYRF {
RandBLAS::util::safe_scal(m * k, (T) 0.0, work_buff, 1);

T* Q_dat = util::upsize(m * k, Q);
SYPS_Obj.call(A, k, state, work_buff, Q_dat);
syps.call(A, k, state, work_buff, Q_dat);

// Q = orth(A * Omega)
A(Layout::ColMajor, k, (T) 1.0, work_buff, m, (T) 0.0, Q_dat, m);
Expand All @@ -121,7 +121,7 @@ class SYRF {
util::cond_num_check(m, k, Q.data(), this->verbose)
);
}
if(this->Orth_Obj.call(m, k, Q.data()))
if(this->orth.call(m, k, Q.data()))
throw std::runtime_error("Orthogonalization failed.");

if (!callers_work_buff)
Expand Down
6 changes: 3 additions & 3 deletions RandLAPACK/drivers/rl_revd2.hh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class REVD2 {
public:
using T = typename SYRF_t::T;
using RNG = typename SYRF_t::RNG;
SYRF_t &SYRF_Obj;
SYRF_t &syrf;
int error_est_p;
bool verbose;

Expand All @@ -90,7 +90,7 @@ class REVD2 {
SYRF_t &syrf_obj,
int error_est_power_iters,
bool verb = false
) : SYRF_Obj(syrf_obj) {
) : syrf(syrf_obj) {
error_est_p = error_est_power_iters;
verbose = verb;
}
Expand Down Expand Up @@ -165,7 +165,7 @@ class REVD2 {

// Construnct a sketching operator
// If CholeskyQR is used for stab/orth here, RF can fail
this->SYRF_Obj.call(A, k, this->Omega, state, symrf_work_dat);
this->syrf.call(A, k, this->Omega, state, symrf_work_dat);

// Y = A * Omega
A(Layout::ColMajor, k, 1.0, Omega_dat, m, 0.0, Y_dat, m);
Expand Down
17 changes: 10 additions & 7 deletions test/comps/test_syrf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,22 @@ class TestSYRF : public ::testing::Test

template <typename T, typename RNG>
struct algorithm_objects {
RandLAPACK::SYPS<T, RNG> SYPS;
RandLAPACK::HQRQ<T> Orth_RF;
RandLAPACK::SYRF<RandLAPACK::SYPS<T, RNG>> SYRF;
using SYPS_t = RandLAPACK::SYPS<T, RNG>;
using Orth_t = RandLAPACK::HQRQ<T>;
using SYRF_t = RandLAPACK::SYRF<SYPS_t, Orth_t>;
SYPS_t syps;
Orth_t orth;
SYRF_t syrf;

algorithm_objects(
bool verbose,
bool cond_check,
int64_t p,
int64_t passes_per_iteration
) :
SYPS(p, passes_per_iteration, verbose, cond_check),
Orth_RF(cond_check, verbose),
SYRF(SYPS, Orth_RF, verbose, cond_check)
syps(p, passes_per_iteration, verbose, cond_check),
orth(cond_check, verbose),
syrf(syps, orth, verbose, cond_check)
{}
};

Expand Down Expand Up @@ -96,7 +99,7 @@ class TestSYRF : public ::testing::Test
auto m = all_data.row;
auto k = all_data.rank;

all_algs.SYRF.call(Uplo::Upper, m, all_data.A.data(), k, all_data.Q, state, NULL);
all_algs.syrf.call(Uplo::Upper, m, all_data.A.data(), k, all_data.Q, state, NULL);

// Reassing pointers because Q, B have been resized
T* Q_dat = all_data.Q.data();
Expand Down
24 changes: 12 additions & 12 deletions test/drivers/test_revd2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ class TestREVD2 : public ::testing::Test
template <typename T, typename RNG>
struct algorithm_objects {
using SYPS_t = RandLAPACK::SYPS<T, RNG>;
using SYRF_t = RandLAPACK::SYRF<SYPS_t>;
using Orth_t = RandLAPACK::HQRQ<T>;
SYPS_t SYPS;
Orth_t Orth;
SYRF_t SYRF;
RandLAPACK::REVD2<SYRF_t> REVD2;
using SYRF_t = RandLAPACK::SYRF<SYPS_t, Orth_t>;
SYPS_t syps;
Orth_t orth;
SYRF_t syrf;
RandLAPACK::REVD2<SYRF_t> revd2;


algorithm_objects(
Expand All @@ -92,10 +92,10 @@ class TestREVD2 : public ::testing::Test
int64_t passes_per_syps_stabilization,
int64_t num_steps_power_iter_error_est
) :
SYPS(num_syps_passes, passes_per_syps_stabilization, verbose, cond_check),
Orth(cond_check, verbose),
SYRF(SYPS, Orth, verbose, cond_check),
REVD2(SYRF, num_steps_power_iter_error_est, verbose)
syps(num_syps_passes, passes_per_syps_stabilization, verbose, cond_check),
orth(cond_check, verbose),
syrf(syps, orth, verbose, cond_check),
revd2(syrf, num_steps_power_iter_error_est, verbose)
{}
};

Expand Down Expand Up @@ -152,7 +152,7 @@ class TestREVD2 : public ::testing::Test
auto m = all_data.dim;

int64_t k = k_start;
all_algs.REVD2.call(blas::Uplo::Upper, m, all_data.A.data(), k, tol, all_data.V, all_data.eigvals, state);
all_algs.revd2.call(blas::Uplo::Upper, m, all_data.A.data(), k, tol, all_data.V, all_data.eigvals, state);

T* E_dat = RandLAPACK::util::upsize(k * k, all_data.E);
T* Buf_dat = RandLAPACK::util::upsize(m * k, all_data.Buf);
Expand Down Expand Up @@ -190,8 +190,8 @@ class TestREVD2 : public ::testing::Test
auto m = all_data.dim;

int64_t k = k_start;
all_algs.REVD2.call(blas::Uplo::Upper, m, all_data.A_u.data(), k, tol, all_data.V_u, all_data.eigvals_u, state);
all_algs.REVD2.call(blas::Uplo::Lower, m, all_data.A_l.data(), k, tol, all_data.V_l, all_data.eigvals_l, state);
all_algs.revd2.call(blas::Uplo::Upper, m, all_data.A_u.data(), k, tol, all_data.V_u, all_data.eigvals_u, state);
all_algs.revd2.call(blas::Uplo::Lower, m, all_data.A_l.data(), k, tol, all_data.V_l, all_data.eigvals_l, state);

T* E_u_dat = RandLAPACK::util::upsize(k * k, all_data.E_u);
T* E_l_dat = RandLAPACK::util::upsize(k * k, all_data.E_l);
Expand Down

0 comments on commit bf3fa12

Please sign in to comment.