Skip to content

Commit

Permalink
first version; adjoint still completely untested
Browse files Browse the repository at this point in the history
  • Loading branch information
mreineck committed Feb 18, 2025
1 parent 04fbd67 commit 1180e27
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 159 deletions.
13 changes: 11 additions & 2 deletions include/finufft/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ template<> struct Finufft_FFT_plan<float> {
unlock();
}
void execute [[maybe_unused]] () { fftwf_execute(plan_); }
void execute [[maybe_unused]] (std::complex<float> *data) {
fftwf_execute_dft(plan_, reinterpret_cast<fftwf_complex *>(data),
reinterpret_cast<fftwf_complex *>(data));
}

static void forget_wisdom [[maybe_unused]] () { fftwf_forget_wisdom(); }
static void cleanup [[maybe_unused]] () { fftwf_cleanup(); }
Expand Down Expand Up @@ -152,6 +156,10 @@ template<> struct Finufft_FFT_plan<double> {
unlock();
}
void execute [[maybe_unused]] () { fftw_execute(plan_); }
void execute [[maybe_unused]] (std::complex<double> *data) {
fftw_execute_dft(plan_, reinterpret_cast<fftw_complex *>(data),
reinterpret_cast<fftw_complex *>(data));
}

static void forget_wisdom [[maybe_unused]] () { fftw_forget_wisdom(); }
static void cleanup [[maybe_unused]] () { fftw_cleanup(); }
Expand Down Expand Up @@ -179,7 +187,8 @@ static inline void finufft_fft_cleanup_threads [[maybe_unused]] () {
Finufft_FFT_plan<double>::cleanup_threads();
}
template<typename TF> struct FINUFFT_PLAN_T;
template<typename TF> std::vector<int> gridsize_for_fft(FINUFFT_PLAN_T<TF> *p);
template<typename TF> void do_fft(FINUFFT_PLAN_T<TF> *p);
template<typename TF> std::vector<int> gridsize_for_fft(const FINUFFT_PLAN_T<TF> &p);
template<typename TF>
void do_fft(const FINUFFT_PLAN_T<TF> &p, std::complex<TF> *fwBatch, bool adjoint);

#endif // FINUFFT_INCLUDE_FINUFFT_FFT_H
7 changes: 1 addition & 6 deletions include/finufft/finufft_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,6 @@ template<typename TF> struct FINUFFT_PLAN_T { // the main plan class, fully C++

std::array<std::vector<TF>, 3> phiHat; // FT of kernel in t1,2, on x,y,z-axis mode grid

// fwBatch: (batches of) fine working grid(s) for the FFT to plan & act on.
// Usually the largest internal array. Its allocator is 64-byte (cache-line) aligned:
std::vector<TC, xsimd::aligned_allocator<TC, 64>> fwBatch;

std::vector<BIGINT> sortIndices; // precomputed NU pt permutation, speeds spread/interp
bool didSort; // whether binsorting used (false: identity perm used)

Expand All @@ -187,7 +183,6 @@ template<typename TF> struct FINUFFT_PLAN_T { // the main plan class, fully C++
// arrays (no new allocs)
std::vector<TC> prephase; // pre-phase, for all input NU pts
std::vector<TC> deconv; // reciprocal of kernel FT, phase, all output NU pts
std::vector<TC> CpBatch; // working array of prephased strengths
std::array<std::vector<TF>, 3> XYZp; // internal primed NU points (x'_j, etc)
std::array<std::vector<TF>, 3> STUp; // internal primed targs (s'_k, etc)
type3params<TF> t3P; // groups together type 3 shift, scale, phase, parameters
Expand All @@ -201,7 +196,7 @@ template<typename TF> struct FINUFFT_PLAN_T { // the main plan class, fully C++

// Remaining actions (not create/delete) in guru interface are now methods...
int setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF *s, TF *t, TF *u);
int execute(std::complex<TF> *cj, std::complex<TF> *fk);
int execute(std::complex<TF> *cj, std::complex<TF> *fk, bool adjoint = false) const;
};

void finufft_default_opts_t(finufft_opts *o);
Expand Down
2 changes: 1 addition & 1 deletion include/finufft/spreadinterp.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ FINUFFT_EXPORT int FINUFFT_CDECL spreadinterpSorted(
const std::vector<BIGINT> &sort_indices, const UBIGINT N1, const UBIGINT N2,
const UBIGINT N3, T *data_uniform, const UBIGINT M, T *FINUFFT_RESTRICT kx,
T *FINUFFT_RESTRICT ky, T *FINUFFT_RESTRICT kz, T *FINUFFT_RESTRICT data_nonuniform,
const finufft_spread_opts &opts, int did_sort);
const finufft_spread_opts &opts, int did_sort, bool adjoint);
template<typename T>
FINUFFT_EXPORT T FINUFFT_CDECL evaluate_kernel(T x, const finufft_spread_opts &opts);
template<typename T>
Expand Down
108 changes: 59 additions & 49 deletions src/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,39 @@ using namespace std;
#include "ducc0/fft/fftnd_impl.h"
#endif

template<typename TF> std::vector<int> gridsize_for_fft(FINUFFT_PLAN_T<TF> *p) {
template<typename TF> std::vector<int> gridsize_for_fft(const FINUFFT_PLAN_T<TF> &p) {
// local helper func returns a new int array of length dim, extracted from
// the finufft plan, that fftw_plan_many_dft needs as its 2nd argument.
if (p->dim == 1) return {(int)p->nfdim[0]};
if (p->dim == 2) return {(int)p->nfdim[1], (int)p->nfdim[0]};
// if (p->dim == 3)
return {(int)p->nfdim[2], (int)p->nfdim[1], (int)p->nfdim[0]};
if (p.dim == 1) return {(int)p.nfdim[0]};
if (p.dim == 2) return {(int)p.nfdim[1], (int)p.nfdim[0]};
// if (p.dim == 3)
return {(int)p.nfdim[2], (int)p.nfdim[1], (int)p.nfdim[0]};
}
template std::vector<int> gridsize_for_fft<float>(FINUFFT_PLAN_T<float> *p);
template std::vector<int> gridsize_for_fft<double>(FINUFFT_PLAN_T<double> *p);
template std::vector<int> gridsize_for_fft<float>(const FINUFFT_PLAN_T<float> &p);
template std::vector<int> gridsize_for_fft<double>(const FINUFFT_PLAN_T<double> &p);

template<typename TF> void do_fft(FINUFFT_PLAN_T<TF> *p) {
template<typename TF>
void do_fft(const FINUFFT_PLAN_T<TF> &p, std::complex<TF> *fwBatch, bool adjoint) {
#ifdef FINUFFT_USE_DUCC0
size_t nthreads = min<size_t>(MY_OMP_GET_MAX_THREADS(), p->opts.nthreads);
size_t nthreads = min<size_t>(MY_OMP_GET_MAX_THREADS(), p.opts.nthreads);
const auto ns = gridsize_for_fft(p);
vector<size_t> arrdims, axes;
arrdims.push_back(size_t(p->batchSize));
// FIXME: use thisBatchsize if it is smaller than p.batchSize!
arrdims.push_back(size_t(p.batchSize));
arrdims.push_back(size_t(ns[0]));
axes.push_back(1);
if (p->dim >= 2) {
if (p.dim >= 2) {
arrdims.push_back(size_t(ns[1]));
axes.push_back(2);
}
if (p->dim >= 3) {
if (p.dim >= 3) {
arrdims.push_back(size_t(ns[2]));
axes.push_back(3);
}
ducc0::vfmav<std::complex<TF>> data(p->fwBatch.data(), arrdims);
bool forward = (p.fftSign < 0) != adjoint;
ducc0::vfmav<std::complex<TF>> data(fwBatch, arrdims); // FIXME
#ifdef FINUFFT_NO_DUCC0_TWEAKS
ducc0::c2c(data, data, axes, p->fftSign < 0, TF(1), nthreads);
ducc0::c2c(data, data, axes, forward, TF(1), nthreads);
#else
/* For type 1 NUFFTs, only the low-frequency parts of the output fine grid are
going to be used, and for type 2 NUFFTs, the high frequency parts of the
Expand All @@ -46,67 +49,74 @@ template<typename TF> void do_fft(FINUFFT_PLAN_T<TF> *p) {
second axis we need to do (roughly) a fraction of 1/oversampling_factor
of all 1D FFTs, and for the last remaining axis the factor is
1/oversampling_factor^2. */
if (p->dim == 1) // 1D: no chance for FFT shortcuts
ducc0::c2c(data, data, axes, p->fftSign < 0, TF(1), nthreads);
else if (p->dim == 2) { // 2D: do partial FFTs
if (p->mstu[0] < 2) // something is weird, do standard FFT
ducc0::c2c(data, data, axes, p->fftSign < 0, TF(1), nthreads);
if (p.dim == 1) // 1D: no chance for FFT shortcuts
ducc0::c2c(data, data, axes, forward, TF(1), nthreads);
else if (p.dim == 2) { // 2D: do partial FFTs
if (p.mstu[0] < 2) // something is weird, do standard FFT
ducc0::c2c(data, data, axes, forward, TF(1), nthreads);
else {
size_t y_lo = size_t((p->mstu[0] + 1) / 2);
size_t y_hi = size_t(ns[1] - p->mstu[0] / 2);
size_t y_lo = size_t((p.mstu[0] + 1) / 2);
size_t y_hi = size_t(ns[1] - p.mstu[0] / 2);
// the next line is analogous to the Python statement "sub1 = data[:, :, :y_lo]"
auto sub1 = ducc0::subarray(data, {{}, {}, {0, y_lo}});
// the next line is analogous to the Python statement "sub2 = data[:, :, y_hi:]"
auto sub2 = ducc0::subarray(data, {{}, {}, {y_hi, ducc0::MAXIDX}});
if (p->type == 1) // spreading, not all parts of the output array are needed
if (p.type == 1) // spreading, not all parts of the output array are needed
// do axis 2 in full
ducc0::c2c(data, data, {2}, p->fftSign < 0, TF(1), nthreads);
ducc0::c2c(data, data, {2}, forward, TF(1), nthreads);
// do only parts of axis 1
ducc0::c2c(sub1, sub1, {1}, p->fftSign < 0, TF(1), nthreads);
ducc0::c2c(sub2, sub2, {1}, p->fftSign < 0, TF(1), nthreads);
if (p->type == 2) // interpolation, parts of the input array are zero
ducc0::c2c(sub1, sub1, {1}, forward, TF(1), nthreads);
ducc0::c2c(sub2, sub2, {1}, forward, TF(1), nthreads);
if (p.type == 2) // interpolation, parts of the input array are zero
// do axis 2 in full
ducc0::c2c(data, data, {2}, p->fftSign < 0, TF(1), nthreads);
ducc0::c2c(data, data, {2}, forward, TF(1), nthreads);
}
} else { // 3D
if ((p->mstu[0] < 2) || (p->mstu[1] < 2)) // something is weird, do standard FFT
ducc0::c2c(data, data, axes, p->fftSign < 0, TF(1), nthreads);
} else { // 3D
if ((p.mstu[0] < 2) || (p.mstu[1] < 2)) // something is weird, do standard FFT
ducc0::c2c(data, data, axes, forward, TF(1), nthreads);
else {
size_t z_lo = size_t((p->mstu[0] + 1) / 2);
size_t z_hi = size_t(ns[2] - p->mstu[0] / 2);
size_t y_lo = size_t((p->mstu[1] + 1) / 2);
size_t y_hi = size_t(ns[1] - p->mstu[1] / 2);
size_t z_lo = size_t((p.mstu[0] + 1) / 2);
size_t z_hi = size_t(ns[2] - p.mstu[0] / 2);
size_t y_lo = size_t((p.mstu[1] + 1) / 2);
size_t y_hi = size_t(ns[1] - p.mstu[1] / 2);
auto sub1 = ducc0::subarray(data, {{}, {}, {}, {0, z_lo}});
auto sub2 = ducc0::subarray(data, {{}, {}, {}, {z_hi, ducc0::MAXIDX}});
auto sub3 = ducc0::subarray(sub1, {{}, {}, {0, y_lo}, {}});
auto sub4 = ducc0::subarray(sub1, {{}, {}, {y_hi, ducc0::MAXIDX}, {}});
auto sub5 = ducc0::subarray(sub2, {{}, {}, {0, y_lo}, {}});
auto sub6 = ducc0::subarray(sub2, {{}, {}, {y_hi, ducc0::MAXIDX}, {}});
if (p->type == 1) { // spreading, not all parts of the output array are needed
if (p.type == 1) { // spreading, not all parts of the output array are needed
// do axis 3 in full
ducc0::c2c(data, data, {3}, p->fftSign < 0, TF(1), nthreads);
ducc0::c2c(data, data, {3}, forward, TF(1), nthreads);
// do only parts of axis 2
ducc0::c2c(sub1, sub1, {2}, p->fftSign < 0, TF(1), nthreads);
ducc0::c2c(sub2, sub2, {2}, p->fftSign < 0, TF(1), nthreads);
ducc0::c2c(sub1, sub1, {2}, forward, TF(1), nthreads);
ducc0::c2c(sub2, sub2, {2}, forward, TF(1), nthreads);
}
// do even smaller parts of axis 1
ducc0::c2c(sub3, sub3, {1}, p->fftSign < 0, TF(1), nthreads);
ducc0::c2c(sub4, sub4, {1}, p->fftSign < 0, TF(1), nthreads);
ducc0::c2c(sub5, sub5, {1}, p->fftSign < 0, TF(1), nthreads);
ducc0::c2c(sub6, sub6, {1}, p->fftSign < 0, TF(1), nthreads);
if (p->type == 2) { // interpolation, parts of the input array are zero
ducc0::c2c(sub3, sub3, {1}, forward, TF(1), nthreads);
ducc0::c2c(sub4, sub4, {1}, forward, TF(1), nthreads);
ducc0::c2c(sub5, sub5, {1}, forward, TF(1), nthreads);
ducc0::c2c(sub6, sub6, {1}, forward, TF(1), nthreads);
if (p.type == 2) { // interpolation, parts of the input array are zero
// do only parts of axis 2
ducc0::c2c(sub1, sub1, {2}, p->fftSign < 0, TF(1), nthreads);
ducc0::c2c(sub2, sub2, {2}, p->fftSign < 0, TF(1), nthreads);
ducc0::c2c(sub1, sub1, {2}, forward, TF(1), nthreads);
ducc0::c2c(sub2, sub2, {2}, forward, TF(1), nthreads);
// do axis 3 in full
ducc0::c2c(data, data, {3}, p->fftSign < 0, TF(1), nthreads);
ducc0::c2c(data, data, {3}, forward, TF(1), nthreads);
}
}
}
#endif
#else
p->fftPlan->execute(); // if thisBatchSize<batchSize it wastes some flops
// FIXME: the "adjoint" emulation is a crude band-aid
if (adjoint)
for (BIGINT i = 0; i < p.batchSize * p.nf(); ++i) fwBatch[i] = conj(fwBatch[i]);
p.fftPlan->execute(fwBatch); // if thisBatchSize<batchSize it wastes some flops
if (adjoint)
for (BIGINT i = 0; i < p.batchSize * p.nf(); ++i) fwBatch[i] = conj(fwBatch[i]);
#endif
}
template void do_fft<float>(FINUFFT_PLAN_T<float> *p);
template void do_fft<double>(FINUFFT_PLAN_T<double> *p);
template void do_fft<float>(const FINUFFT_PLAN_T<float> &p, std::complex<float> *fwBatch,
bool adjoint);
template void do_fft<double>(const FINUFFT_PLAN_T<double> &p,
std::complex<double> *fwBatch, bool adjoint);
Loading

0 comments on commit 1180e27

Please sign in to comment.