Skip to content

Commit

Permalink
add Python interface; fix a few bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mreineck committed Feb 18, 2025
1 parent 1180e27 commit e7bd659
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 13 deletions.
3 changes: 3 additions & 0 deletions include/finufft/finufft_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ template<typename TF> struct FINUFFT_PLAN_T { // the main plan class, fully C++
// Remaining actions (not create/delete) in guru interface are now methods...
int setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF *s, TF *t, TF *u);
int execute(std::complex<TF> *cj, std::complex<TF> *fk, bool adjoint = false) const;
int execute_adjoint(std::complex<TF> *cj, std::complex<TF> *fk) const {
return execute(cj, fk, true);
}
};

void finufft_default_opts_t(finufft_opts *o);
Expand Down
2 changes: 2 additions & 0 deletions include/finufft_eitherprec.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ FINUFFT_EXPORT int FINUFFT_CDECL FINUFFTIFY(_setpts)(
FINUFFT_FLT *zj, FINUFFT_BIGINT N, FINUFFT_FLT *s, FINUFFT_FLT *t, FINUFFT_FLT *u);
FINUFFT_EXPORT int FINUFFT_CDECL FINUFFTIFY(_execute)(
FINUFFT_PLAN plan, FINUFFT_CPX *weights, FINUFFT_CPX *result);
FINUFFT_EXPORT int FINUFFT_CDECL FINUFFTIFY(_execute_adjoint)(
FINUFFT_PLAN plan, FINUFFT_CPX *weights, FINUFFT_CPX *result);
FINUFFT_EXPORT int FINUFFT_CDECL FINUFFTIFY(_destroy)(FINUFFT_PLAN plan);

// ----------------- the 18 simple interfaces -------------------------------
Expand Down
2 changes: 1 addition & 1 deletion python/finufft/examples/guru1d1f.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
strt = time.time()

#plan
plan = fp.Plan(1,(N,),dtype='single')
plan = fp.Plan(1,(N,),dtype='complex64')

#set pts
plan.setpts(x)
Expand Down
2 changes: 1 addition & 1 deletion python/finufft/examples/guru2d1f.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

# instantiate the plan (note n_trans must be set here), also setting tolerance:
t0 = time.time()
plan = finufft.Plan(nufft_type, (N1, N2), eps=1e-4, n_trans=K, dtype='float32')
plan = finufft.Plan(nufft_type, (N1, N2), eps=1e-4, n_trans=K, dtype='complex64')

# set the nonuniform points
plan.setpts(x, y)
Expand Down
8 changes: 8 additions & 0 deletions python/finufft/finufft/_finufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ class FinufftOpts(ctypes.Structure):
_executef.argtypes = [c_void_p, c_void_p, c_void_p]
_executef.restype = c_int

_execute_adjoint = lib.finufft_execute_adjoint
_execute_adjoint.argtypes = [c_void_p, c_void_p, c_void_p]
_execute_adjoint.restype = c_int

_execute_adjointf = lib.finufftf_execute_adjoint
_execute_adjointf.argtypes = [c_void_p, c_void_p, c_void_p]
_execute_adjointf.restype = c_int

_destroy = lib.finufft_destroy
_destroy.argtypes = [c_void_p]
_destroy.restype = c_int
Expand Down
58 changes: 58 additions & 0 deletions python/finufft/finufft/_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,13 @@ def __init__(self,nufft_type,n_modes_or_dim,n_trans=1,eps=1e-6,isign=None,dtype=
self._makeplan = _finufft._makeplanf
self._setpts = _finufft._setptsf
self._execute = _finufft._executef
self._execute_adjoint = _finufft._execute_adjointf
self._destroy = _finufft._destroyf
else:
self._makeplan = _finufft._makeplan
self._setpts = _finufft._setpts
self._execute = _finufft._execute
self._execute_adjoint = _finufft._execute_adjoint
self._destroy = _finufft._destroy

ier = self._makeplan(nufft_type, dim, n_modes, isign, n_trans, eps,
Expand Down Expand Up @@ -305,6 +307,62 @@ def execute(self,data,out=None):

return _out

### execute_adjoint
def execute_adjoint(self,data,out=None):
_data = _ensure_array_type(data, "data", self._dtype)
_out = _ensure_array_type(out, "out", self._dtype, output=True)

tp = self._type
n_trans = self._n_trans
nj = self._nj
nk = self._nk
dim = self._dim

if tp==1 or tp==2:
ms, mt, mu = [*self._n_modes, *([1]*(3-len(self._n_modes)))]

# input shape and size check
if tp==1:
valid_fshape(data.shape,n_trans,dim,ms,mt,mu,None,2)
if tp==2:
valid_cshape(data.shape,nj,n_trans)
if tp==3:
valid_cshape(data.shape,nk,n_trans)

# out shape and size check
if out is not None:
if tp==1:
valid_cshape(out.shape,nj,n_trans)
if tp==2:
valid_fshape(out.shape,n_trans,dim,ms,mt,mu,None,1)
if tp==3:
valid_cshape(out.shape,nj,n_trans)

# allocate out if None
if out is None:
if tp==1:
_out = np.ones([*data.shape[:-dim], nj], dtype=self._dtype, order='C')
if tp==2:
_out = 2*np.ones([*data.shape[:-1], *self._n_modes[::-1]], dtype=self._dtype, order='C')
if tp==3:
_out = 3*np.ones([*data.shape[:-1], nj], dtype=self._dtype, order='C')

# call execute based on type and precision type
if tp==1 or tp==3:
ier = self._execute_adjoint(self._inner_plan,
_out.ctypes.data_as(c_void_p),
_data.ctypes.data_as(c_void_p))
elif tp==2:
ier = self._execute_adjoint(self._inner_plan,
_data.ctypes.data_as(c_void_p),
_out.ctypes.data_as(c_void_p))

# check error
if ier != 0:
err_handler(ier)

return _out


def __del__(self):
destroy(self)
Expand Down
6 changes: 6 additions & 0 deletions src/c_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ int finufft_execute(finufft_plan p, c128 *cj, c128 *fk) {
int finufftf_execute(finufftf_plan p, c64 *cj, c64 *fk) {
return reinterpret_cast<FINUFFT_PLAN_T<f32> *>(p)->execute(cj, fk);
}
int finufft_execute_adjoint(finufft_plan p, c128 *cj, c128 *fk) {
return reinterpret_cast<FINUFFT_PLAN_T<f64> *>(p)->execute_adjoint(cj, fk);
}
int finufftf_execute_adjoint(finufftf_plan p, c64 *cj, c64 *fk) {
return reinterpret_cast<FINUFFT_PLAN_T<f32> *>(p)->execute_adjoint(cj, fk);
}

int finufft_destroy(finufft_plan p)
// Free everything we allocated inside of finufft_plan pointed to by p.
Expand Down
25 changes: 14 additions & 11 deletions src/finufft_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,8 @@ static int spreadinterpSortedBatch(int batchSize, const FINUFFT_PLAN_T<T> &p,

template<typename T>
static int deconvolveBatch(int batchSize, const FINUFFT_PLAN_T<T> &p,
std::complex<T> *fkBatch, std::complex<T> *fwBatch)
std::complex<T> *fkBatch, std::complex<T> *fwBatch,
bool adjoint)
/*
Type 1: deconvolves (amplifies) from each interior fw array in fwBatch
into each output array fk in fkBatch.
Expand All @@ -474,6 +475,8 @@ static int deconvolveBatch(int batchSize, const FINUFFT_PLAN_T<T> &p,
*/
{
// since deconvolveshuffle?d are single-thread, omp par seems to help here...
int dir = p.spopts.spread_direction;
if (adjoint) dir = 3 - dir;
#pragma omp parallel for num_threads(batchSize)
for (int i = 0; i < batchSize; i++) {
std::complex<T> *fwi = fwBatch + i * p.nf(); // start of i'th fw array in
Expand All @@ -482,16 +485,15 @@ static int deconvolveBatch(int batchSize, const FINUFFT_PLAN_T<T> &p,

// pick dim-specific routine from above; note prefactors hardcoded to 1.0...
if (p.dim == 1)
deconvolveshuffle1d(p.spopts.spread_direction, T(1), p.phiHat[0], p.mstu[0],
(T *)fki, p.nfdim[0], fwi, p.opts.modeord);
else if (p.dim == 2)
deconvolveshuffle2d(p.spopts.spread_direction, T(1), p.phiHat[0], p.phiHat[1],
p.mstu[0], p.mstu[1], (T *)fki, p.nfdim[0], p.nfdim[1], fwi,
deconvolveshuffle1d(dir, T(1), p.phiHat[0], p.mstu[0], (T *)fki, p.nfdim[0], fwi,
p.opts.modeord);
else if (p.dim == 2)
deconvolveshuffle2d(dir, T(1), p.phiHat[0], p.phiHat[1], p.mstu[0], p.mstu[1],
(T *)fki, p.nfdim[0], p.nfdim[1], fwi, p.opts.modeord);
else
deconvolveshuffle3d(p.spopts.spread_direction, T(1), p.phiHat[0], p.phiHat[1],
p.phiHat[2], p.mstu[0], p.mstu[1], p.mstu[2], (T *)fki,
p.nfdim[0], p.nfdim[1], p.nfdim[2], fwi, p.opts.modeord);
deconvolveshuffle3d(dir, T(1), p.phiHat[0], p.phiHat[1], p.phiHat[2], p.mstu[0],
p.mstu[1], p.mstu[2], (T *)fki, p.nfdim[0], p.nfdim[1],
p.nfdim[2], fwi, p.opts.modeord);
}
return 0;
}
Expand Down Expand Up @@ -1004,7 +1006,7 @@ int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk,
continue;
} else if (!opts.spreadinterponly) {
// amplify Fourier coeffs fk into 0-padded fw
deconvolveBatch<TF>(thisBatchSize, *this, fkb, fwBatch.data());
deconvolveBatch<TF>(thisBatchSize, *this, fkb, fwBatch.data(), adjoint);
t_deconv += timer.elapsedsec();
}
if (!opts.spreadinterponly) { // Do FFT unless spread/interp only...
Expand All @@ -1018,7 +1020,7 @@ int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk,
// STEP 3: (varies by type)
timer.restart();
if ((type == 1) != adjoint) { // deconvolve (amplify) fw and shuffle to fk
deconvolveBatch<TF>(thisBatchSize, *this, fkb, fwBatch.data());
deconvolveBatch<TF>(thisBatchSize, *this, fkb, fwBatch.data(), adjoint);
t_deconv += timer.elapsedsec();
} else { // interpolate unif fw grid to NU target pts
spreadinterpSortedBatch<TF>(thisBatchSize, *this, fwBatch_or_fkb, cjb, adjoint);
Expand Down Expand Up @@ -1085,6 +1087,7 @@ int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk,
// STEP 2: type 2 NUFFT from fw batch to user output fk array batch...
timer.restart();
// illegal possible shrink of ntrans *after* plan for smaller last batch:
// MR FIXME: this breaks immutability!
innerT2plan->ntrans = thisBatchSize; // do not try this at home!
/* (alarming that FFT not shrunk, but safe, because t2's fwBatch array
still the same size, as Andrea explained; just wastes a few flops) */
Expand Down

0 comments on commit e7bd659

Please sign in to comment.