Skip to content

Commit

Permalink
tested KRILL implementation for full kernel ridge regression problems
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Feb 10, 2025
1 parent 02acc22 commit 4f93c19
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 71 deletions.
53 changes: 12 additions & 41 deletions RandLAPACK/drivers/rl_krill.hh
Original file line number Diff line number Diff line change
Expand Up @@ -12,57 +12,44 @@
#include <limits>
#include <vector>

/**
*
* TODO:
* (1) finish and test krill_restricted_rpchol
* (2) write and test a krill_restricted function that accepts the centers as inputs
* in advance.
* (3) See also, rl_preconditioners.hh
*
*/

namespace RandLAPACK {

/**
* Fun thing about the name KRILLx:
*
* we can do KRILLrs for KRILL with lockstep PCG for regularization sweep.
*
* we can do KRILLb (?) for "random lifting + block" version.
*/


template <typename T, typename FUNC, typename SEMINORM, typename STATE>
STATE krill_full_rpchol(
int64_t n, FUNC &G, std::vector<T> &H, std::vector<T> &X, T tol,
int64_t n, FUNC &G, int64_t ell, const T* H, T* X, T tol,
STATE state, SEMINORM &seminorm, int64_t rpchol_block_size = -1, int64_t max_iters = 20, int64_t k = -1
) {
using std::vector;
int64_t mu_size = G.num_ops;
vector<T> mus(mu_size);
std::vector<T> mus(mu_size);
std::copy(G.regs, G.regs + mu_size, mus.data());
int64_t ell = ((int64_t) H.size()) / n;
randblas_require(ell * n == (int64_t) H.size());
randblas_require(mu_size == 1 || mu_size == ell);

if (rpchol_block_size < 0)
rpchol_block_size = std::min((int64_t) 64, n/4);
if (k < 0)
k = (int64_t) std::sqrt(n);

vector<T> V(n*k, 0.0);
vector<T> eigvals(k, 0.0);
std::vector<T> V(n*k, 0.0);
std::vector<T> eigvals(k, 0.0);
G.set_eval_includes_reg(false);
state = rpchol_pc_data(n, G, k, rpchol_block_size, V.data(), eigvals.data(), state);
linops::SpectralPrecond<T> invP(n);
invP.prep(V, eigvals, mus, ell);
G.set_eval_includes_reg(true);
pcg(G, H.data(), ell, seminorm, tol, max_iters, invP, X.data(), true);
pcg(G, H, ell, seminorm, tol, max_iters, invP, X, true);

return state;
}

/**
* TODO:
* (1) write and test krill_restricted_rpchol, documented below.
* (2) write and test a krill_restricted function that accepts the centers as inputs in advance.
*
*/

/**
* We start with a regularized kernel linear operator G and target data H.
* We use "K" to denote the unregularized version of G, which can be accessed
Expand Down Expand Up @@ -139,25 +126,9 @@ STATE krill_full_rpchol(
// //
// // That second identity can be written as MM' = G(inds, inds) for M = V(inds, :).
// //


// vector<T> M(k * k);
// _rpchol_impl::pack_selected_rows(blas::Layout::ColMajor, n, k, V.data(), inds, M.data());
// //
// //
// //

// linops::SpectralPrecond<T> invP(n);
// // invP.prep(V, eigvals, mus, ell);
// return state;
// }

// template <typename T, typename FUNC, typename STATE>
// STATE krill_block(
//
// ) {
//
// }


} // end namespace RandLAPACK
65 changes: 38 additions & 27 deletions RandLAPACK/misc/rl_pdkernels.hh
Original file line number Diff line number Diff line change
Expand Up @@ -185,50 +185,63 @@ void block_arrowhead_multiply(int64_t k, int64_t ell, int64_t n, const T* A, con

namespace linops {

using std::vector;

/***
* It might be practical to have one class that handles several different kinds of kernels.
* A representation of num_ops >= 1 regularized Radial Basis Function (RBF)
* kernel matrices, differing from one another only on their diagonals.
*
* Every constituent matrix represented by this object is dim-by-dim.
*
* For i != j, the (i,j)-th matrix entry is equal to
* exp(-||X(:,i) - X(:,j)||_2^2 / bandwidth),
* where X is a matrix with dim columns and bandwidth is some
* positive number.
*
* The diagonals of this object's constituent matrices are all
* proportional to the identity. For matrix k (0 <= k < num_ops),
* the constant of proportionality is (1+regs[k]).
*/
template <typename T>
struct RBFKernelMatrix {
// squared exp kernel linear operator
const int64_t dim;
const int64_t m;
/***
* X is a rows_x-by-dim matrix stored in column major format with
* leading dimension equal to rows_x. Each column of X is interpreted
* as a datapoint in "rows_x" dimensional space.
*/
const T* X;
const int64_t rows_x;
T bandwidth;
int64_t num_ops;
vector<T> regs;
T* regs;

vector<T> _sq_colnorms_x;
vector<T> _eval_work1;
vector<T> _eval_work2;
std::vector<T> _sq_colnorms_x;
std::vector<T> _eval_work1;
std::vector<T> _eval_work2;
bool _eval_includes_reg;
int64_t _eval_block_size;

using scalar_t = T;

RBFKernelMatrix(
int64_t dim, const T* X, int64_t rows_x, T bandwidth, vector<T> &regs
) : dim(dim), m(dim), X(X), rows_x(rows_x), bandwidth(bandwidth), regs(regs), _sq_colnorms_x(m), _eval_work1{}, _eval_work2{} {
num_ops = regs.size();
for (int64_t i = 0; i < m; ++i) {
int64_t dim, const T* X, int64_t rows_x, T bandwidth, std::vector<T> &argregs
) : dim(dim), X(X), rows_x(rows_x), bandwidth(bandwidth), regs(argregs.data()), _sq_colnorms_x(dim), _eval_work1{}, _eval_work2{} {
num_ops = argregs.size();
for (int64_t i = 0; i < dim; ++i) {
_sq_colnorms_x[i] = std::pow(blas::nrm2(rows_x, X + i*rows_x, 1), 2);
}
_eval_block_size = std::min(m / ((int64_t) 4), (int64_t) 512);
_eval_work1.resize(_eval_block_size * m);
_eval_block_size = std::min(dim / ((int64_t) 4), (int64_t) 512);
_eval_work1.resize(_eval_block_size * dim);
_eval_includes_reg = false;
return;
}

void _prep_eval_work1(int64_t rows_ksub, int64_t cols_ksub, int64_t ro_ksub, int64_t co_ksub) {
randblas_require(rows_ksub * cols_ksub <= (int64_t) _eval_work1.size());
squared_exp_kernel_submatrix(
rows_x, this->m, X, _sq_colnorms_x.data(),
rows_x, dim, X, _sq_colnorms_x.data(),
rows_ksub, cols_ksub, _eval_work1.data(), ro_ksub, co_ksub, bandwidth
);
num_ops = regs.size();
}

void set_eval_includes_reg(bool eir) {
Expand All @@ -237,15 +250,15 @@ struct RBFKernelMatrix {

void operator()(blas::Layout layout, int64_t n, T alpha, T* const B, int64_t ldb, T beta, T* C, int64_t ldc) {
randblas_require(layout == blas::Layout::ColMajor);
randblas_require(ldb >= this->m);
randblas_require(ldc >= this->m);
randblas_require(ldb >= dim);
randblas_require(ldc >= dim);

_eval_work2.resize(this->m * n);
_eval_work2.resize(dim * n);
for (int64_t i = 0; i < n; ++i) {
blas::scal(this->m, beta, C + i*ldc, 1);
blas::scal(dim, beta, C + i*ldc, 1);
}
int64_t done = 0;
int64_t todo = this->m;
int64_t todo = dim;
while (todo > 0) {
int64_t k = std::min(_eval_block_size, todo);
_prep_eval_work1(k, todo, done, done);
Expand All @@ -262,12 +275,10 @@ struct RBFKernelMatrix {
todo -= k;
}
if (_eval_includes_reg) {
int64_t num_regs = this->regs.size();
randblas_require(num_regs == 1 || n == num_regs);
T* regsp = regs.data();
randblas_require(num_ops == 1 || n == num_ops);
for (int64_t i = 0; i < n; ++i) {
T coeff = alpha * regsp[std::min(i, num_regs - 1)];
blas::axpy(this->m, coeff, B + i*ldb, 1, C + i*ldc, 1);
T coeff = alpha * regs[std::min(i, num_ops - 1)];
blas::axpy(dim, coeff, B + i*ldb, 1, C + i*ldc, 1);
}
}
return;
Expand All @@ -276,7 +287,7 @@ struct RBFKernelMatrix {
inline T operator()(int64_t i, int64_t j) {
T val = squared_exp_kernel(rows_x, X + i*rows_x, X + j*rows_x, bandwidth);
if (_eval_includes_reg && i == j) {
randblas_require(regs.size() == 1);
randblas_require(num_ops == 1);
val += regs[0];
}
return val;
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ if (GTest_FOUND)
comps/test_rf.cc
comps/test_syrf.cc
comps/test_rpchol.cc
drivers/test_krill.cc
drivers/test_rsvd.cc
drivers/test_cqrrpt.cc
drivers/test_bqrrp.cc
Expand Down
8 changes: 5 additions & 3 deletions test/drivers/test_krill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ class TestKrillIsh: public ::testing::Test {
void run_common(T mu_min, vector<T> &V, vector<T> &lambda, RegExplicitSymLinOp<T> &G_linop) {
RandLAPACK::linops::SpectralPrecond<T> invP(m);
vector<T> mus {mu_min, mu_min/10, mu_min/100};
G_linop.regs = mus;
G_linop.regs = new T[3];
G_linop.num_ops = 3;
std::copy(mus.begin(), mus.end(), G_linop.regs);
G_linop.set_eval_includes_reg(true);
invP.prep(V, lambda, mus, mus.size());
int64_t s = mus.size();
Expand Down Expand Up @@ -127,7 +129,7 @@ class TestKrillx: public ::testing::Test {
template <typename RELO>
void run_krill_separable(int key_index, RELO &G_linop, int64_t k) {
using T = typename RELO::scalar_t;
int64_t s = G_linop.num_regs;
int64_t s = G_linop.num_ops;

vector<T> X_star(m*s, 0.0);
vector<T> X_init(m*s, 0.0);
Expand All @@ -147,7 +149,7 @@ class TestKrillx: public ::testing::Test {
int64_t rpc_blocksize = 16;
RNGState state2(keys[key_index]);
RandLAPACK::krill_full_rpchol(
m, G_linop, H, X_init, tol, state2, seminorm, rpc_blocksize, max_iters, k
m, G_linop, s, H.data(), X_init.data(), tol, state2, seminorm, rpc_blocksize, max_iters, k
);
T tol_scale = std::sqrt((T)m);
T atol = tol_scale * std::pow(std::numeric_limits<T>::epsilon(), 0.5);
Expand Down

0 comments on commit 4f93c19

Please sign in to comment.