diff --git a/cmd/dwirecon.cpp b/cmd/dwirecon.cpp new file mode 100644 index 0000000000..8207e7f648 --- /dev/null +++ b/cmd/dwirecon.cpp @@ -0,0 +1,372 @@ +/* Copyright (c) 2017-2019 Daan Christiaens + * + * MRtrix and this add-on module are distributed in the hope + * that it will be useful, but WITHOUT ANY WARRANTY; without + * even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. + */ + +#include +#include + +#include "adapter/extract.h" +#include "command.h" +#include "dwi/gradient.h" +#include "dwi/shells.h" +#include "file/matrix.h" +#include "image.h" +#include "math/SH.h" +#include "phase_encoding.h" + +#include "dwi/svr/qspacebasis.h" +#include "dwi/svr/recon.h" + +constexpr int DEFAULT_LMAX = 4; +constexpr float DEFAULT_SSPW = 1.0F; +constexpr float DEFAULT_REG = 1e-3F; +constexpr float DEFAULT_ZREG = 1e-3F; +constexpr float DEFAULT_TOL = 1e-4F; +constexpr int DEFAULT_MAXITER = 10; + +using namespace MR; +using namespace App; + +// clang-format off +void usage () +{ + AUTHOR = "Daan Christiaens (daan.christiaens@kcl.ac.uk)"; + + SYNOPSIS = "Reconstruct DWI signal from a series of scattered slices with associated motion parameters."; + + DESCRIPTION + + ""; + + ARGUMENTS + + Argument ("DWI", "the input DWI image.").type_image_in() + + Argument ("SH", "the output spherical harmonics coefficients image.").type_image_out(); + + + OPTIONS + + Option ("motion", "The motion parameters associated with input slices or volumes. " + "These are supplied as a matrix of 6 columns encoding the rigid " + "transformations w.r.t. scanner space in se(3) Lie algebra." ) + + Argument ("file").type_file_in() + + + Option ("rf", "Basis functions for the radial (multi-shell) domain, provided as matrices in which " + "rows correspond with shells and columns with SH harmonic bands.").allow_multiple() + + Argument ("b").type_file_in() + + + Option ("lmax", "The maximum harmonic order for the output series. (default = " + str(DEFAULT_LMAX) + ")") + + Argument ("order").type_integer(0, 30) + + + Option ("weights", "Slice weights, provided as a matrix of dimensions Nslices x Nvols.") + + Argument ("W").type_file_in() + + + Option ("voxweights", "Voxel weights, provided as an image of same dimensions as dMRI data.") + + Argument ("W").type_image_in() + + + Option ("ssp", "Slice sensitivity profile, either as text file or as a scalar slice thickness for a " + "Gaussian SSP, relative to the voxel size. (default = " + str(DEFAULT_SSPW) + ")") + + Argument ("w").type_text() + + + Option ("reg", "Isotropic Laplacian regularization. (default = " + str(DEFAULT_REG) + ")") + + Argument ("l").type_float() + + + Option ("zreg", "Regularization in the slice direction. (default = " + str(DEFAULT_ZREG) + ")") + + Argument ("l").type_float() + + + Option ("template", "Template header to determine the reconstruction grid.") + + Argument ("header").type_image_in() + + + DWI::GradImportOptions() + + + PhaseEncoding::ImportOptions + + + DWI::ShellsOption + + + OptionGroup ("Output options") + + + Option ("spred", + "output source prediction of all scattered slices. (useful for diagnostics)") + + Argument ("out").type_image_out() + + + Option ("padding", "zero-padding output coefficients to given dimension.") + + Argument ("rank").type_integer(0) + + + Option ("complete", "complete (zero-filled) source prediction.") + + + OptionGroup ("CG Optimization options") + + + Option ("tolerance", "the tolerance on the conjugate gradient solver. (default = " + str(DEFAULT_TOL) + ")") + + Argument ("t").type_float(0.0, 1.0) + + + Option ("maxiter", + "the maximum number of iterations of the conjugate gradient solver. (default = " + str(DEFAULT_MAXITER) + ")") + + Argument ("n").type_integer(1) + + + Option ("init", + "initial guess of the reconstruction parameters.") + + Argument ("img").type_image_in(); + +} +// clang-format on + +typedef float value_type; + +void run() { + // Load input data + auto dwi = Image::open(argument[0]).with_direct_io({1, 2, 3, 4}); + + // Read motion parameters + auto opt = get_options("motion"); + Eigen::MatrixXf motion; + if (!opt.empty()) + motion = File::Matrix::load_matrix(opt[0][0]); + else + motion = Eigen::MatrixXf::Zero(dwi.size(3), 6); + + // Check dimensions + if ((motion.size() != 0) && (motion.cols() != 6)) + throw Exception("No. columns in motion parameters must equal 6."); + if ((motion.size() != 0) && (((dwi.size(3) * dwi.size(2)) % motion.rows()) != 0)) + throw Exception("No. rows in motion parameters does not match image dimensions."); + + // Select shells + auto grad = DWI::get_DW_scheme(dwi); + DWI::Shells shells(grad); + shells.select_shells(false, false, false); + + // Read multi-shell basis + int lmax = 0; + std::vector rf; + opt = get_options("rf"); + for (size_t k = 0; k < opt.size(); k++) { + Eigen::MatrixXf const t = File::Matrix::load_matrix(opt[k][0]); + if (t.rows() != shells.count()) + throw Exception("No. shells does not match no. rows in basis function " + opt[k][0] + "."); + lmax = std::max(2 * (static_cast(t.cols()) - 1), lmax); + rf.push_back(t); + } + + // Read slice weights + Eigen::MatrixXf W = Eigen::MatrixXf::Ones(dwi.size(2), dwi.size(3)); + opt = get_options("weights"); + if (!opt.empty()) { + W = File::Matrix::load_matrix(opt[0][0]); + if (W.rows() != dwi.size(2) || W.cols() != dwi.size(3)) + throw Exception("Weights matrix dimensions don't match image dimensions."); + } + + // Get volume indices + std::vector idx; + if (rf.empty()) { + idx = shells.largest().get_volumes(); + } else { + for (size_t k = 0; k < shells.count(); k++) + idx.insert(idx.end(), shells[k].get_volumes().begin(), shells[k].get_volumes().end()); + std::sort(idx.begin(), idx.end()); + } + + // Select subset + auto dwisub = Adapter::make(dwi, 3, container_cast>(idx)); + + Eigen::MatrixXf gradsub(idx.size(), grad.cols()); + for (ssize_t i = 0; i < idx.size(); i++) + gradsub.row(i) = grad.row(static_cast(idx[i])).template cast(); + + ssize_t const ne = motion.rows() / dwi.size(3); + Eigen::MatrixXf motionsub(ne * idx.size(), 6); + for (ssize_t i = 0; i < idx.size(); i++) + for (ssize_t j = 0; j < ne; j++) + motionsub.row(i * ne + j) = motion.row(static_cast(idx[i]) * ne + j); + + Eigen::MatrixXf Wsub(W.rows(), idx.size()); + for (ssize_t i = 0; i < idx.size(); i++) + Wsub.col(i) = W.col(static_cast(idx[i])); + + // SSP + DWI::SVR::SSP ssp(DEFAULT_SSPW); + opt = get_options("ssp"); + if (!opt.empty()) { + std::string const t = opt[0][0]; + try { + ssp = DWI::SVR::SSP(std::stof(t)); + } catch (std::invalid_argument &e) { + try { + Eigen::VectorXf const v = File::Matrix::load_vector(t); + ssp = DWI::SVR::SSP(v); + } catch (...) { + throw Exception("Invalid argument for SSP."); + } + } + } + + // Read voxel weights + Eigen::VectorXf Wvox = Eigen::VectorXf::Ones(dwisub.size(0) * dwisub.size(1) * dwisub.size(2) * dwisub.size(3)); + opt = get_options("voxweights"); + if (!opt.empty()) { + auto voxweights = Image::open(opt[0][0]); + check_dimensions(dwisub, voxweights, 0, 4); + ssize_t j = 0; + for (auto l = Loop("loading voxel weights data", {0, 1, 2, 3})(voxweights); l; l++, j++) { + Wvox[j] = voxweights.value(); + } + } + + // Other parameters + if (rf.empty()) + lmax = get_option_value("lmax", DEFAULT_LMAX); + else + lmax = std::min(lmax, static_cast(get_option_value("lmax", lmax))); + + float const reg = get_option_value("reg", DEFAULT_REG); + float const zreg = get_option_value("zreg", DEFAULT_ZREG); + + value_type const tol = get_option_value("tolerance", DEFAULT_TOL); + ssize_t const maxiter = get_option_value("maxiter", DEFAULT_MAXITER); + + DWI::SVR::QSpaceBasis const qbasis{gradsub, lmax, rf, motionsub}; + + ssize_t const ncoefs = static_cast(qbasis.get_ncoefs()); + size_t const padding = get_option_value("padding", Math::SH::NforL(lmax)); + if (padding < Math::SH::NforL(lmax)) + throw Exception("user-provided padding too small."); + + // Create source header - needed due to stride handling + Header srchdr(dwisub); + Stride::set(srchdr, {1, 2, 3, 4}); + DWI::set_DW_scheme(srchdr, gradsub); + srchdr.datatype() = DataType::Float32; + srchdr.sanitise(); + + // Create recon header + Header rechdr(dwisub); + opt = get_options("template"); + if (!opt.empty()) { + rechdr = Header::open(opt[0][0]); + } + rechdr.ndim() = 4; + rechdr.size(3) = ncoefs; + Stride::set(rechdr, {2, 3, 4, 1}); + rechdr.datatype() = DataType::Float32; + rechdr.sanitise(); + + // Create mapping + DWI::SVR::ReconMapping const map(rechdr, srchdr, qbasis, motionsub, ssp); + + // Set up scattered data matrix + INFO("initialise reconstruction matrix"); + DWI::SVR::ReconMatrix R(map, reg, zreg); + R.setWeights(Wsub); + + R.setVoxelWeights(Wvox); + + // Read input data to vector (this enforces positive strides!) + Eigen::VectorXf y(R.rows()); + y.setZero(); + ssize_t j = 0; + for (auto lv = Loop("loading image data", {0, 1, 2, 3})(dwisub); lv; lv++, j++) { + float const w = Wsub(dwisub.index(2), dwisub.index(3)) * Wvox[j]; + y[j] = std::sqrt(w) * dwisub.value(); + } + + // Fit scattered data in basis... + INFO("initialise conjugate gradient solver"); + + Eigen::LeastSquaresConjugateGradient cg; + cg.compute(R); + + cg.setTolerance(tol); + cg.setMaxIterations(maxiter); + + // Solve y = M x + Eigen::VectorXf x(R.cols()); + opt = get_options("init"); + if (!opt.empty()) { + // load initialisation + auto init = Image::open(opt[0][0]).with_direct_io({3, 4, 5, 2, 1}); + check_dimensions(rechdr, init, 0, 3); + if ((init.size(3) != shells.count()) || (init.size(4) < Math::SH::NforL(lmax))) + throw Exception("dimensions of init image don't match."); + // init vector + Eigen::VectorXf x0(R.cols()); + // convert from mssh + Eigen::VectorXf c(shells.count() * Math::SH::NforL(lmax)); + Eigen::MatrixXf x2mssh(c.size(), ncoefs); + x2mssh.setZero(); + for (int k = 0; k < shells.count(); k++) + x2mssh.middleRows(static_cast(k * Math::SH::NforL(lmax)), static_cast(Math::SH::NforL(lmax))) = qbasis.getShellBasis(k).transpose(); + auto mssh2x = x2mssh.fullPivHouseholderQr(); + ssize_t j = 0; + ssize_t k = 0; + for (auto l = Loop("loading initialisation", {0, 1, 2})(init); l; l++, j += ncoefs) { + k = 0; + for (auto l2 = Loop(3)(init); l2; l2++) { + for (init.index(4) = 0; init.index(4) < Math::SH::NforL(lmax); init.index(4)++) + c[k++] = std::isfinite(static_cast(init.value())) ? init.value() : 0.0F; + } + x0.segment(j, ncoefs) = mssh2x.solve(c); + } + INFO("solve from given starting point"); + x = cg.solveWithGuess(y, x0); + } else { + INFO("solve from zero starting point"); + x = cg.solve(y); + } + + CONSOLE("CG: #iterations: " + str(cg.iterations())); + CONSOLE("CG: estimated error: " + str(cg.error())); + + // Write result to output file + Header msshhdr(rechdr); + msshhdr.ndim() = 5; + msshhdr.size(3) = static_cast(shells.count()); + msshhdr.size(4) = static_cast(padding); + Stride::set_from_command_line(msshhdr, {3, 4, 5, 2, 1}); + msshhdr.datatype() = DataType::from_command_line(DataType::Float32); + PhaseEncoding::set_scheme(msshhdr, Eigen::MatrixXf()); + // store b-values and counts + { + std::stringstream ss; + for (auto b : shells.get_bvalues()) + ss << b << ","; + std::string const key = "shells"; + std::string val = ss.str(); + val.erase(val.length() - 1); + msshhdr.keyval()[key] = val; + } + { + std::stringstream ss; + for (auto c : shells.get_counts()) + ss << c << ","; + std::string const key = "shellcounts"; + std::string val = ss.str(); + val.erase(val.length() - 1); + msshhdr.keyval()[key] = val; + } + + auto out = Image::create(argument[1], msshhdr); + + j = 0; + Eigen::VectorXf c(ncoefs); + Eigen::VectorXf sh(padding); + sh.setZero(); + for (auto l = Loop("writing result to image", {0, 1, 2})(out); l; l++, j += ncoefs) { + c = x.segment(j, ncoefs); + for (int k = 0; k < shells.count(); k++) { + out.index(3) = k; + sh.head(Math::SH::NforL(lmax)) = qbasis.getShellBasis(k).transpose() * c; + out.row(4) = sh; + } + } + + // Output source prediction + bool const complete = !get_options("complete").empty(); + opt = get_options("spred"); + if (!opt.empty()) { + srchdr.size(3) = (complete) ? dwi.size(3) : dwisub.size(3); + auto spred = Image::create(opt[0][0], srchdr); + auto recon = ImageView(rechdr, x.data()); + map.x2y(recon, spred); + } +} diff --git a/cmd/dwislicealign.cpp b/cmd/dwislicealign.cpp new file mode 100644 index 0000000000..03460078e7 --- /dev/null +++ b/cmd/dwislicealign.cpp @@ -0,0 +1,131 @@ +/* Copyright (c) 2017-2019 Daan Christiaens + * + * MRtrix and this add-on module are distributed in the hope + * that it will be useful, but WITHOUT ANY WARRANTY; without + * even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. + */ + +#include "command.h" +#include "dwi/gradient.h" +#include "file/matrix.h" +#include "image.h" +#include "thread_queue.h" + +#include "dwi/svr/psf.h" +#include "dwi/svr/register.h" + +constexpr float DEFAULT_SSPW = 1.0F; + +using namespace MR; +using namespace App; + +// clang-format off +void usage () +{ + AUTHOR = "Daan Christiaens (daan.christiaens@kcl.ac.uk)"; + + SYNOPSIS = "Register multi-shell spherical harmonics image to DWI slices or volumes."; + + DESCRIPTION + + "This command takes DWI data and a multi-shell spherical harmonics (MSSH) signal " + "prediction to estimate subject motion parameters with volume-to-slice registration."; + + ARGUMENTS + + Argument ("data", "the input DWI data.").type_image_in() + + + Argument ("mssh", "the input MSSH prediction.").type_image_in() + + + Argument ("out", "the output motion parameters.").type_file_out(); + + OPTIONS + + Option ("mask", "image mask") + + Argument ("m").type_image_in() + + + Option ("mb", "multiband factor. (default = 0; v2v registration)") + + Argument ("factor").type_integer(0) + + + Option ("ssp", "SSP vector or slice thickness in voxel units (default = 1).") + + Argument ("w").type_text() + + + Option ("init", "motion initialisation") + + Argument ("motion").type_file_in() + + + Option ("maxiter", "maximum no. iterations for the registration") + + Argument ("n").type_integer(0) + + + DWI::GradImportOptions(); + +} +// clang-format on + +using value_type = float; + +void run() { + // input data + auto data = Image::open(argument[0]); + auto grad = DWI::get_DW_scheme(data); + + // input template + auto mssh = Image::open(argument[1]); + if (mssh.ndim() != 5) + throw Exception("5-D MSSH image expected."); + + // index shells + auto bvals = parse_floats(mssh.keyval().find("shells")->second); + + // mask + auto mask = Image(); + auto opt = get_options("mask"); + if (!opt.empty()) { + mask = Image::open(opt[0][0]); + check_dimensions(data, mask, 0, 3); + } + + // multiband factor + size_t mb = get_option_value("mb", 0); + if (mb == 0 || mb == data.size(2)) { + mb = data.size(2); + INFO("volume-to-volume registration."); + } else { + if (data.size(2) % mb != 0) + throw Exception("multiband factor invalid."); + } + + // SSP + DWI::SVR::SSP ssp(DEFAULT_SSPW); + opt = get_options("ssp"); + if (!opt.empty()) { + std::string const t = opt[0][0]; + try { + ssp = DWI::SVR::SSP(std::stof(t)); + } catch (std::invalid_argument &e) { + try { + Eigen::VectorXf const v = File::Matrix::load_vector(t); + ssp = DWI::SVR::SSP(v); + } catch (...) { + throw Exception("Invalid argument for SSP."); + } + } + } + + // settings and initialisation + size_t const niter = get_option_value("maxiter", 0); + Eigen::MatrixXf init(data.size(3), 6); + init.setZero(); + opt = get_options("init"); + if (!opt.empty()) { + init = File::Matrix::load_matrix(opt[0][0]); + if ((init.cols() != 6) || (((data.size(3) * data.size(2)) % init.rows()) != 0)) + throw Exception("dimension mismatch in motion initialisaton."); + } + + // run registration + DWI::SVR::SliceAlignSource source(data.size(3), data.size(2), mb, grad, bvals, init); + DWI::SVR::SliceAlignPipe pipe(data, mssh, mask, mb, niter, ssp); + DWI::SVR::SliceAlignSink sink(data.size(3), data.size(2), mb); + Thread::run_queue(source, DWI::SVR::SliceIdx(), Thread::multi(pipe), DWI::SVR::SliceIdx(), sink); + + // output + File::Matrix::save_matrix(sink.get_motion(), argument[2]); +} diff --git a/cmd/dwisliceoutliergmm.cpp b/cmd/dwisliceoutliergmm.cpp new file mode 100644 index 0000000000..c47fa47f23 --- /dev/null +++ b/cmd/dwisliceoutliergmm.cpp @@ -0,0 +1,295 @@ +/* Copyright (c) 2017-2019 Daan Christiaens + * + * MRtrix and this add-on module are distributed in the hope + * that it will be useful, but WITHOUT ANY WARRANTY; without + * even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. + */ + +#include "algo/threaded_loop.h" +#include "command.h" +#include "dwi/gradient.h" +#include "dwi/shells.h" +#include "file/matrix.h" +#include "image.h" +#include "interp/nearest.h" +#include "math/median.h" +#include "math/rng.h" + +#include "dwi/svr/param.h" + +using namespace MR; +using namespace App; + +// clang-format off +void usage () +{ + AUTHOR = "Daan Christiaens (daan.christiaens@kcl.ac.uk)"; + + SYNOPSIS = "Detect and reweigh outlier slices in DWI image."; + + DESCRIPTION + + "This command takes DWI data and a signal prediction to calculate " + "slice inlier probabilities using Bayesian GMM modelling."; + + ARGUMENTS + + Argument ("in", "the input DWI data.").type_image_in() + + Argument ("pred", "the input signal prediction").type_image_in() + + Argument ("out", "the output slice weights.").type_file_out(); + + OPTIONS + + Option ("mb", "multiband factor (default = 1)") + + Argument ("f").type_integer(1) + + + Option ("mask", "image mask") + + Argument ("m").type_image_in() + + + Option ("motion", "rigid motion parameters (used for masking)") + + Argument ("param").type_file_in() + + + Option ("export_error", "export RMSE matrix, scaled by the median error in each shell.") + + Argument ("E").type_file_out() + + + DWI::GradImportOptions(); + +} +// clang-format on + +using value_type = float; + +/** + * @brief RMSE Functor + */ +class RMSErrorFunctor { +public: + RMSErrorFunctor(const Image &in, const Image &mask, const Eigen::MatrixXf &mot, const int mb = 1) + : nv(in.size(3)), + nz(in.size(2)), + ne(nz / mb), + T0(in), + mask(mask, false), + motion(mot), + E(new Eigen::MatrixXf(nz, nv)), + N(new Eigen::MatrixXi(nz, nv)) { + E->setZero(); + N->setZero(); + } + + void operator()(Image &data, Image &pred) { + ssize_t const v = data.get_index(3); + ssize_t const z = data.get_index(2); + // Get transformation for masking. Note that the MB-factor of the motion table and the OR settings can be different. + ssize_t const ne_mot = motion.rows() / nv; + transform_type const T{DWI::SVR::se3exp(motion.row(v * ne_mot + z % ne_mot)).cast()}; + // Calculate slice error + value_type e = 0.0; + int n = 0; + Eigen::Vector3d pos; + for (auto l = Loop(0, 2)(data, pred); l; l++) { + if (mask.valid()) { + assign_pos_of(data, 0, 3).to(pos); + mask.scanner(T * T0.voxel2scanner * pos); + if (!mask.value()) + continue; + } + value_type const d = data.value() - pred.value(); + e += d * d; + n++; + } + (*E)(z, v) = e; + (*N)(z, v) = n; + } + + Eigen::MatrixXf result() const { + Eigen::MatrixXf Emb(ne, nv); + Emb.setZero(); + Eigen::MatrixXi Nmb(ne, nv); + Nmb.setZero(); + for (ssize_t b = 0; b < nz / ne; b++) { + Emb += E->block(b * ne, 0, ne, nv); + Nmb += N->block(b * ne, 0, ne, nv); + } + Eigen::MatrixXf const R = (Nmb.array() > 0).select(Emb.cwiseQuotient(Nmb.cast()), Eigen::MatrixXf::Zero(ne, nv)); + return R.cwiseSqrt(); + } + +private: + const ssize_t nv, nz, ne; + const Transform T0; + Interp::Nearest> mask; + const Eigen::MatrixXf motion; + + std::shared_ptr E; + std::shared_ptr N; +}; + +/** + * @brief 2-component Gaussian Mixture Model + */ +template class GMModel { +public: + using float_t = T; + using VecType = Eigen::Matrix; + + GMModel(const int max_iters = 50, const float_t eps = 1e-3, const float_t reg_covar = 1e-6) + : niter(max_iters), tol(eps), reg(reg_covar) {} + + /** + * Fit GMM to vector x. + */ + void fit(const VecType &x) { + // initialise + init(x); + float_t ll; + float_t ll0 = -std::numeric_limits::infinity(); + // Expectation-Maximization + for (int n = 0; n < niter; n++) { + ll = e_step(x); + m_step(x); + // check convergence + if (std::fabs(ll - ll0) < tol) + break; + ll0 = ll; + } + } + + /** + * Get posterior probability of the fit. + */ + VecType posterior() const { return Rin.array().exp(); } + +private: + const int niter; + const float_t tol, reg; + + float_t Min, Mout; + float_t Sin, Sout; + float_t Pin, Pout; + VecType Rin, Rout; + + // initialise inlier and outlier classes. + inline void init(const VecType &x) { + float_t med = median(x); + float_t mad = median(Eigen::abs(x.array() - med)) * 1.4826; + Min = med; + Mout = med + 1.0; // shift +1 only valid for log-Gaussians, + Sin = mad; + Sout = mad + 1.0; // corresp. to approx. x 3 med/mad error. + Pin = 0.9; + Pout = 0.1; + } + + // E-step: update sample log-responsabilies and return log-likelohood. + inline float_t e_step(const VecType &x) { + Rin = log_gaussian(x, Min, Sin); + Rin = Rin.array() + std::log(Pin); + Rout = log_gaussian(x, Mout, Sout); + Rout = Rout.array() + std::log(Pout); + VecType log_prob_norm = Eigen::log(Rin.array().exp() + Rout.array().exp()); + Rin -= log_prob_norm; + Rout -= log_prob_norm; + return log_prob_norm.mean(); + } + + // M-step: update component mean and variance. + inline void m_step(const VecType &x) { + VecType w1 = Rin.array().exp() + std::numeric_limits::epsilon(); + VecType w2 = Rout.array().exp() + std::numeric_limits::epsilon(); + Pin = w1.mean(); + Pout = w2.mean(); + Min = average(x, w1); + Mout = average(x, w2); + Sin = std::sqrt(average((x.array() - Min).square(), w1) + reg); + Sout = std::sqrt(average((x.array() - Min).square(), w2) + reg); + } + + inline VecType log_gaussian(const VecType &x, float_t mu = 0., float_t sigma = 1.) const { + VecType resp = x.array() - mu; + resp /= sigma; + resp = resp.array().square() + std::log(2 * M_PI); + resp *= -0.5f; + resp = resp.array() - std::log(sigma); + return resp; + } + + inline float_t median(const VecType &x) const { + std::vector vec(x.size()); + for (size_t i = 0; i < x.size(); i++) + vec[i] = x[i]; + return Math::median(vec); + } + + inline float_t average(const VecType &x, const VecType &w) const { return x.dot(w) / w.sum(); } +}; + +void run() { + auto data = Image::open(argument[0]); + auto pred = Image::open(argument[1]); + check_dimensions(data, pred, 0, 4); + + auto mask = Image(); + auto opt = get_options("mask"); + if (!opt.empty()) { + mask = Image::open(opt[0][0]); + check_dimensions(data, mask, 0, 3); + } else { + throw Exception("mask is required."); + } + + Eigen::MatrixXf motion(data.size(3), 6); + motion.setZero(); + opt = get_options("motion"); + if (!opt.empty()) { + motion = File::Matrix::load_matrix(opt[0][0]); + if (motion.cols() != 6 || ((data.size(3) * data.size(2)) % motion.rows())) + throw Exception("dimension mismatch in motion initialisaton."); + } + + int mb = get_option_value("mb", 1); + if (data.size(2) % mb) + throw Exception("Multiband factor incompatible with image dimensions."); + + auto grad = DWI::get_DW_scheme(data); + DWI::Shells shells(grad); + + // Compute RMSE of each slice + RMSErrorFunctor rmse(data, mask, motion, mb); + ThreadedLoop("Computing root-mean-squared error", data, 2, 4).run(rmse, data, pred); + Eigen::MatrixXf E = rmse.result(); + + opt = get_options("export_error"); + if (!opt.empty()) { + File::Matrix::save_matrix(E.replicate(mb, 1), opt[0][0]); + } + + // Compute weights + Eigen::MatrixXf W = Eigen::MatrixXf::Ones(E.rows(), E.cols()); + GMModel gmm; + for (size_t s = 0; s < shells.count(); s++) { + // Log-residuals + int k = 0; + Eigen::VectorXf res(E.rows() * shells[s].count()); + for (size_t v : shells[s].get_volumes()) + res.segment(E.rows() * (k++), E.rows()) = E.col(v); + // clip at non-zero minimum + value_type nzmin = res.redux([](value_type a, value_type b) { + if (a > 0) + return (b > 0) ? std::min(a, b) : a; + else + return (b > 0) ? b : std::numeric_limits::infinity(); + }); + Eigen::VectorXf logres = res.array().max(nzmin).log(); + // Fit GMM + gmm.fit(logres); + // Save posterior probabilities + Eigen::VectorXf p = gmm.posterior(); + k = 0; + for (size_t v : shells[s].get_volumes()) + W.col(v) = p.segment(E.rows() * (k++), E.rows()); + } + + // Output + Eigen::ArrayXXf Wfull = 1e6 * W.replicate(mb, 1); + Wfull = 1e-6 * Wfull.round(); + File::Matrix::save_matrix(Wfull, argument[2]); +} diff --git a/cmd/mrfieldunwarp.cpp b/cmd/mrfieldunwarp.cpp new file mode 100644 index 0000000000..5e1089f989 --- /dev/null +++ b/cmd/mrfieldunwarp.cpp @@ -0,0 +1,181 @@ +/* Copyright (c) 2017-2019 Daan Christiaens + * + * MRtrix and this add-on module are distributed in the hope + * that it will be useful, but WITHOUT ANY WARRANTY; without + * even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. + */ + +#include + +#include "command.h" +#include "file/matrix.h" +#include "image.h" +#include "interp/cubic.h" +#include "interp/linear.h" +#include "phase_encoding.h" +#include "transform.h" + +#include "dwi/svr/param.h" + +using namespace MR; +using namespace App; + +// clang-format off +void usage () +{ + AUTHOR = "Daan Christiaens (daan.christiaens@kcl.ac.uk)"; + + SYNOPSIS = "Unwarp an EPI image according to its susceptibility field."; + + DESCRIPTION + + "This command takes EPI data and a field map in Hz, and inverts the distortion introduced " + "by the B0 field inhomogeneity. The command can also take motion parameters for each volume " + "or slice, but does not invert the motion. The motion parameters are only used to align the " + "field with the moving subject." + + + "If the field map is estimated using FSL Topup, make sure to use the --fmap output " + "(the field map in Hz) instead of the spline coefficient representation saved by default."; + + ARGUMENTS + + Argument ("input", + "the input image.").type_image_in () + + Argument ("field", + "the B0 field map in Hz.").type_file_in () + + Argument ("output", + "the field-unwrapped image.").type_image_out (); + + OPTIONS + + Option ("motion", + "rigid motion parameters per volume or slice, applied to the field.") + + Argument("T").type_file_in() + + + Option ("fidx", + "index of the input volume to which the field is aligned. (default = none)") + + Argument("vol").type_integer(0) + + + Option ("nomodulation", "disable Jacobian intensity modulation") + + + PhaseEncoding::ImportOptions + + + DataType::options(); +} +// clang-format on + +using value_type = float; + +class FieldUnwarp { +public: + FieldUnwarp(const Image &data, + const Image &field, + const Eigen::MatrixXd &petable, + const Eigen::MatrixXd &motion, + const int fidx = -1, + const bool nomod = false) + : dinterp(data, 0.0f), + finterp(field, 0.0f), + PE(petable.leftCols<3>()), + motion(motion.leftCols<6>()), + T0(data), + nv(data.size(3)), + nz(data.size(2)), + ne(motion.rows() / nv), + nomod(nomod) { + PE.array().colwise() *= petable.col(3).array(); + if ((nv * nz) % motion.rows()) + throw Exception("Motion parameters incompatible with data dimensions."); + Tf = Transform(field).scanner2voxel * T0.voxel2scanner; + if (fidx >= 0 && fidx < nv) + Tf = Tf * get_Ts2r_avg(fidx).inverse(); + } + + void operator()(Image &out) { + size_t v = out.index(3), z = out.index(2); + transform_type Ts2r = Tf * get_Ts2r(v, z); + dinterp.index(3) = v; + Eigen::Vector3d vox, pos, RdB0; + Eigen::RowVector3f dB0; + value_type B0, jac = 1.0; + for (auto l = Loop(0, 2)(out); l; l++) { + assign_pos_of(out).to(vox); + finterp.voxel(Ts2r * vox); + finterp.value_and_gradient(B0, dB0); + RdB0 = Ts2r.rotation().transpose() * dB0.transpose().cast(); + pos = vox + B0 * PE.row(v).transpose(); + dinterp.voxel(pos); + if (!nomod) + jac = 1.0 + 2. * PE.row(v) * RdB0; + out.value() = jac * dinterp.value(); + } + } + +private: + Interp::Linear> dinterp; + Interp::LinearInterp, Interp::LinearInterpProcessingType::ValueAndDerivative> finterp; + Eigen::Matrix PE; + Eigen::Matrix motion; + Transform T0; + transform_type Tf; + size_t nv, nz, ne; + bool nomod; + + inline transform_type get_transform(const Eigen::VectorXd &p) const { + transform_type T(DWI::SVR::se3exp(p).cast()); + return T; + } + + inline transform_type get_Ts2r(const size_t v, const size_t z) const { + transform_type Ts2r = T0.scanner2voxel * get_transform(motion.row(v * ne + z % ne)) * T0.voxel2scanner; + return Ts2r; + } + + inline transform_type get_Ts2r_avg(const size_t v) const { + transform_type Ts2r = + T0.scanner2voxel * get_transform(motion.block(v * ne, 0, ne, 6).colwise().mean()) * T0.voxel2scanner; + return Ts2r; + } +}; + +void run() { + auto data = Image::open(argument[0]); + auto field = Image::open(argument[1]); + if (not voxel_grids_match_in_scanner_space(data, field)) { + WARN("Field map voxel grid does not match the input data. " + "If the field map was estimated using FSL TOPUP, make sure to use the --fmap output " + "(the field map in Hz) instead of the spline coefficient representation."); + } + + auto petable = PhaseEncoding::get_scheme(data); + if (petable.rows() != data.size(3)) + throw Exception("Invalid PE table."); + // ----------------------- // TODO: Eddy uses a reverse LR axis for storing + petable.col(0) *= -1; // the PE table, akin to the gradient table. + // ----------------------- // Fix in the eddy import/export functions in core. + + // Apply rigid rotation to field. + auto opt = get_options("motion"); + Eigen::MatrixXd motion; + if (!opt.empty()) + motion = File::Matrix::load_matrix(opt[0][0]); + else + motion = Eigen::MatrixXd::Zero(data.size(3), 6); + + // field alignment + int fidx = get_option_value("fidx", -1); + if (fidx >= data.size(3)) + throw Exception("field index invalid."); + + // other options + opt = get_options("nomodulation"); + bool nomod = !opt.empty(); + + // Save output + Header header(data); + header.datatype() = DataType::from_command_line(DataType::Float32); + + auto out = Image::create(argument[2], header); + + // Loop through shells + FieldUnwarp func(data, field, petable, motion, fidx, nomod); + ThreadedLoop("unwarping field", out, {2, 3}).run(func, out); +} diff --git a/cmd/mssh2amp.cpp b/cmd/mssh2amp.cpp new file mode 100644 index 0000000000..72cd216130 --- /dev/null +++ b/cmd/mssh2amp.cpp @@ -0,0 +1,139 @@ +/* Copyright (c) 2017-2019 Daan Christiaens + * + * MRtrix and this add-on module are distributed in the hope + * that it will be useful, but WITHOUT ANY WARRANTY; without + * even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. + */ + +#include + +#include "command.h" +#include "dwi/gradient.h" +#include "file/matrix.h" +#include "image.h" +#include "math/SH.h" + +using namespace MR; +using namespace App; + +// clang-format off +void usage () +{ + AUTHOR = "Daan Christiaens (daan.christiaens@kcl.ac.uk) and " + "David Raffelt (david.raffelt@florey.edu.au)"; + + SYNOPSIS = "Evaluate the amplitude of a 5-D image of multi-shell " + "spherical harmonic functions along specified directions."; + + ARGUMENTS + + Argument ("input", + "the input image consisting of spherical harmonic (SH) " + "coefficients.").type_image_in () + + Argument ("gradient", + "the gradient encoding along which the SH functions will " + "be sampled (directions + shells)").type_file_in () + + Argument ("output", + "the output image consisting of the amplitude of the SH " + "functions along the specified directions.").type_image_out (); + + OPTIONS + + Option ("transform", + "rigid transformation, applied to the gradient table.") + + Argument("T").type_file_in() + + + Option ("nonnegative", + "cap all negative amplitudes to zero") + + + Stride::Options + + DataType::options(); +} +// clang-format on + +using value_type = float; + +class MSSH2Amp { +public: + template + MSSH2Amp(const MatrixType &dirs, const size_t lmax, const std::vector &idx, bool nonneg) + : SHT(Math::SH::init_transform_cart(dirs.template cast(), lmax)), + bidx(idx), + nonnegative(nonneg), + sh(SHT.cols()), + amp(SHT.rows()) {} + + void operator()(Image &in, Image &out) { + sh = in.row(4); + amp = SHT * sh; + if (nonnegative) + amp = amp.cwiseMax(value_type(0.0)); + for (size_t j = 0; j < amp.size(); j++) { + out.index(3) = bidx[j]; + out.value() = amp[j]; + } + } + +private: + const Eigen::Matrix SHT; + const std::vector &bidx; + const bool nonnegative; + Eigen::Matrix sh, amp; +}; + +template inline std::vector get_indices(const VectorType &blist, const value_type bval) { + std::vector indices; + for (size_t j = 0; j < blist.size(); j++) + if ((blist[j] > bval - DWI_SHELLS_EPSILON) && (blist[j] < bval + DWI_SHELLS_EPSILON)) + indices.push_back(j); + return indices; +} + +void run() { + auto mssh = Image::open(argument[0]); + if (mssh.ndim() != 5) + throw Exception("5-D MSSH image expected."); + + Header header(mssh); + auto bvals = parse_floats(header.keyval().find("shells")->second); + + Eigen::Matrix grad; + grad = File::Matrix::load_matrix(argument[1]).leftCols<4>(); + // copied from core/dwi/gradient.cpp; refactor upon merge + Eigen::Array squared_norms = grad.leftCols(3).rowwise().squaredNorm(); + for (ssize_t row = 0; row != grad.rows(); ++row) { + if (squared_norms[row]) + grad.row(row).template head<3>().array() /= std::sqrt(squared_norms[row]); + } + + // Apply rigid rotation to gradient table. + transform_type T; + T.setIdentity(); + auto opt = get_options("transform"); + if (!opt.empty()) + T = File::Matrix::load_transform(opt[0][0]); + + grad.leftCols<3>() = grad.leftCols<3>() * T.rotation().transpose(); + + // Save output + header.ndim() = 4; + header.size(3) = grad.rows(); + DWI::set_DW_scheme(header, grad); + Stride::set_from_command_line(header, Stride::contiguous_along_axis(3)); + header.datatype() = DataType::from_command_line(DataType::Float32); + + auto amp_data = Image::create(argument[2], header); + + // Loop through shells + for (size_t k = 0; k < bvals.size(); k++) { + mssh.index(3) = k; + auto idx = get_indices(grad.col(3), bvals[k]); + if (idx.empty()) + continue; + Eigen::MatrixXd directions(idx.size(), 3); + for (size_t i = 0; i < idx.size(); i++) { + directions.row(i) = grad.row(idx[i]).template head<3>(); + } + MSSH2Amp mssh2amp(directions, Math::SH::LforN(mssh.size(4)), idx, get_options("nonnegative").size()); + ThreadedLoop("computing amplitudes", mssh, 0, 3, 2).run(mssh2amp, mssh, amp_data); + } +} diff --git a/cmd/msshsvd.cpp b/cmd/msshsvd.cpp new file mode 100644 index 0000000000..7f65510ee9 --- /dev/null +++ b/cmd/msshsvd.cpp @@ -0,0 +1,189 @@ +/* Copyright (c) 2017-2019 Daan Christiaens + * + * MRtrix and this add-on module are distributed in the hope + * that it will be useful, but WITHOUT ANY WARRANTY; without + * even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. + */ + +#include "command.h" +#include "file/matrix.h" +#include "image.h" +#include "math/SH.h" + +#include +#include + +using namespace MR; +using namespace App; + +// clang-format off +void usage () +{ + AUTHOR = "Daan Christiaens (daan.christiaens@kcl.ac.uk)"; + + SYNOPSIS = "SH-SVD decomposition of multi-shell SH data."; + + DESCRIPTION + + "This command expects a 5-D MSSH image (shells on 4th dimension; " + "SH coefficients in the 5th). The command will compute the optimal " + "orthonormal basis for representing the input data using the singular " + "value decomposition across shells and SH frequency bands (SH-SVD)." + + + "Optionally, the command can output the low-rank projection of the input. " + "The rank is set with the parameter -lmax. For lmax=4,2,0 (default), the data " + "is projected onto components of order 4, 2, and 0, leading to a rank = 22."; + + ARGUMENTS + + Argument ("in", "the input MSSH data.").type_image_in() + + + Argument ("rf", "the output basis functions.").type_file_out().allow_multiple(); + + OPTIONS + + Option ("mask", "image mask") + + Argument ("m").type_file_in() + + + Option ("lmax", "maximum SH order per component (default = 4,2,0)") + + Argument ("order").type_sequence_int() + + + Option ("lbreg", "Laplace-Beltrami regularisation weight (default = 0)") + + Argument ("lambda").type_float(0.0) + + + Option ("weights", "vector of weights per shell (default = ones)") + + Argument ("w").type_file_in() + + + Option ("proj", "output low-rank MSSH projection") + + Argument ("mssh").type_image_out(); + +} +// clang-format on + +using value_type = float; + +class SHSVDProject { +public: + SHSVDProject(const int l, const Eigen::MatrixXf &P) : l(l), P(P) {} + + void operator()(Image &in, Image &out) { + for (size_t k = Math::SH::NforL(l - 2); k < Math::SH::NforL(l); k++) { + in.index(4) = k; + out.index(4) = k; + out.row(3) = P * Eigen::VectorXf(in.row(3)); + } + } + +private: + const int l; + const Eigen::MatrixXf &P; +}; + +void run() { + auto in = Image::open(argument[0]).with_direct_io({3, 4}); + + auto mask = Image(); + auto opt = get_options("mask"); + if (!opt.empty()) { + mask = Image::open(opt[0][0]); + check_dimensions(in, mask, 0, 3); + } + + int nshells = in.size(3); + size_t nvox = in.size(0) * in.size(1) * in.size(2); + + std::vector lmax; + opt = get_options("lmax"); + if (!opt.empty()) { + lmax = opt[0][0].as_sequence_int(); + } else { + lmax = {4, 2, 0}; + } + std::sort(lmax.begin(), lmax.end(), std::greater()); // sort in place + if (Math::SH::NforL(lmax[0]) > in.size(4)) + throw Exception("lmax too large for input dimension."); + + int nrf = lmax.size(); + if (argument.size() != nrf + 1) + throw Exception("no. output arguments does not match desired lmax."); + if (nrf > nshells) + throw Exception("no. basis functions can't exceed no. shells."); + std::vector rf; + for (int l : lmax) { + rf.push_back(Eigen::MatrixXf::Zero(nshells, l / 2 + 1)); + } + + Eigen::VectorXf W(nshells); + W.setOnes(); + auto key = in.keyval().find("shellcounts"); + opt = get_options("weights"); + if (!opt.empty()) { + W = File::Matrix::load_vector(opt[0][0]).cwiseSqrt(); + if (W.size() != nshells) + throw Exception("provided weights do not match the no. shells."); + } else if (key != in.keyval().end()) { + size_t i = 0; + for (auto w : parse_floats(key->second)) + W[i++] = std::sqrt(w); + } + W /= W.mean(); + + // LB regularization + float lam = get_option_value("lbreg", 0.0f); + Eigen::VectorXf lb(nshells); + lb.setZero(); + auto bvals = parse_floats(in.keyval().find("shells")->second); + for (size_t i = 0; i < nshells; i++) { + float b = bvals[i] / 1000.f; + lb[i] = (b < 1e-2) ? 0.0f : lam / (b * b); + } + + auto proj = Image(); + opt = get_options("proj"); + bool pout = !opt.empty(); + if (pout) { + Header header(in); + proj = Image::create(opt[0][0], header); + } + + // Compute SVD per SH order l + for (int l = 0; l <= lmax[0]; l += 2) { + // define LB filter + Eigen::VectorXf lbfilt = Math::pow2(l * (l + 1)) * lb; + lbfilt.array() += 1.0f; + lbfilt = lbfilt.cwiseInverse(); + // load data to matrix + Eigen::MatrixXf Sl(nshells, nvox * (2 * l + 1)); + Sl.setZero(); + auto loop = Loop(0, 3); + size_t i = 0; + for (auto v = loop(in); v; v++) { + if (mask.valid()) { + assign_pos_of(in).to(mask); + if (!mask.value()) + continue; + } + for (size_t j = Math::SH::NforL(l - 2); j < Math::SH::NforL(l); j++, i++) { + in.index(4) = j; + Sl.col(i) = lbfilt.asDiagonal() * Eigen::VectorXf(in.row(3)); + } + } + // low-rank SVD + Eigen::JacobiSVD svd(W.asDiagonal() * Sl.leftCols(i), Eigen::ComputeFullU); + int rank = std::upper_bound(lmax.begin(), lmax.end(), l, std::greater()) - lmax.begin(); + Eigen::MatrixXf U = svd.matrixU().leftCols(rank); + // save basis functions + for (int n = 0; n < rank; n++) { + rf[n].col(l / 2) = W.asDiagonal().inverse() * U.col(n); + } + // save low-rank projection + if (pout) { + Eigen::MatrixXf P = W.asDiagonal().inverse() * U * U.adjoint() * W.asDiagonal(); + SHSVDProject func(l, P); + ThreadedLoop(in, 0, 3).run(func, in, proj); + } + } + + // Write basis functions to file + for (int n = 0; n < nrf; n++) { + File::Matrix::save_matrix(rf[n], argument[n + 1]); + } +} diff --git a/python/mrtrix3/commands/dwimotioncorrect.py b/python/mrtrix3/commands/dwimotioncorrect.py new file mode 100755 index 0000000000..5f5e5817fd --- /dev/null +++ b/python/mrtrix3/commands/dwimotioncorrect.py @@ -0,0 +1,333 @@ +# Copyright (c) 2017-2019 Daan Christiaens +# +# MRtrix and this add-on module are distributed in the hope +# that it will be useful, but WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. +# +# MOTION CORRECTION FOR DWI VOLUME SERIES +# +# This script performs volume-to-series and slice-to-series registration +# of diffusion-weighted images for motion correction in the brain. +# +# Author: Daan Christiaens +# King's College London +# daan.christiaens@kcl.ac.uk +# + +from mrtrix3 import app, image, path, run, MRtrixError +import json + + +DEFAULT_CONFIG = """{ + "global": { + "rec-reg": 0.005, + "rec-zreg": 0.001, + "svr": true, + "rec-iter": 3, + "reg-iter": 10, + "reg-scale": 1.0, + "lbreg": 0.001 + }, + "epochs": [ + { + "svr": false, + "reg-scale": 3.0 + },{ + "svr": false, + "reg-scale": 2.4 + },{ + "reg-scale": 1.9 + },{ + "reg-scale": 1.5 + },{ + "reg-scale": 1.2 + },{ + "rec-iter": 10 + } + ] +}""" + + + +def usage(cmdline): #pylint: disable=unused-variable + # base + cmdline.set_author('Daan Christiaens (daan.christiaens@kcl.ac.uk)') + cmdline.set_synopsis('Perform motion correction in a dMRI dataset') + cmdline.add_description('Volume-level and/or slice-level motion correction for dMRI, ' + 'based on the SHARD representation for multi-shell data. ') + # arguments + cmdline.add_argument('input', type=app.Parser.ImageIn(), help='The input image series to be corrected') + cmdline.add_argument('output', type=app.Parser.ImageOut(), help='The output multi-shell SH coefficients') + # options + options = cmdline.add_argument_group('Options for the dwimotioncorrect script') + options.add_argument('-mask', type=app.Parser.ImageIn(), help='Manually provide a mask image for motion correction') + options.add_argument('-lmax', help='SH basis order per shell (default = maximum with at least 30% oversampling)') + options.add_argument('-rlmax', help='Reduced basis order per component for registration (default = 4,2,0 or lower if needed)') + options.add_argument('-reg', help='Regularization for dwirecon (overrides config; default = 0.005)') + options.add_argument('-zreg', help='Regularization for dwirecon (overrides config; default = 0.001)') + options.add_argument('-setup', type=app.Parser.FileIn(), help='Configuration setup file (json structured)') + options.add_argument('-mb', help='Multiband factor (default = 1)') + options.add_argument('-sorder', help='Slice order (default = 2,1, for odd-even)') + options.add_argument('-sspwidth', help='Slice thickness for Gaussian SSP (default = 1)') + options.add_argument('-sspfile', type=app.Parser.FileIn(), help='Slice sensitivity profile as vector') + options.add_argument('-fieldmap', type=app.Parser.ImageIn(), help='B0 field map for distortion correction') + options.add_argument('-fieldidx', help='Index of volume to which field map is aligned (default = 0)') + options.add_argument('-pe_table', type=app.Parser.FileIn(), help='Phase encoding table in MRtrix format') + options.add_argument('-pe_eddy', nargs=2, type=app.Parser.FileIn(), help='Phase encoding table in FSL acqp/index format') + options.add_argument('-priorweights', type=app.Parser.FileIn(), help='Import prior slice weights') + options.add_argument('-fixedweights', type=app.Parser.FileIn(), help='Import fixed slice weights') + options.add_argument('-fixedmotion', type=app.Parser.FileIn(), help='Import fixed motion traces') + options.add_argument('-driftfilter', action='store_true', help='Filter motion for drift') + options.add_argument('-voxelweights', type=app.Parser.ImageIn(), help='Import fixed voxel weights') + options.add_argument('-export_motion', type=app.Parser.FileOut(), help='Export rigid motion parameters') + options.add_argument('-export_weights', type=app.Parser.FileOut(), help='Export slice weights') + app.add_dwgrad_import_options(cmdline) + + + +def execute(): #pylint: disable=unused-variable + try: + import scipy + except ImportError: + raise MRtrixError('SciPy required: https://www.scipy.org/install.html') + # import metadata + grad_import_option = app.dwgrad_import_options() + pe_import_option = [] + if app.ARGS.pe_table: + pe_import_option = ['-import_pe_table', app.ARGS.pe_table] + elif app.ARGS.pe_eddy: + pe_import_option = ['-import_pe_eddy', app.ARGS.pe_eddy[0], app.ARGS.pe_eddy[1]] + # check output path - code no longer works; now done internally?? + #app.check_output_path(app.ARGS.output) + #if app.ARGS.export_motion: + # app.check_output_path(app.ARGS.export_motion) + #if app.ARGS.export_weights: + # app.check_output_path(app.ARGS.export_weights) + # prepare working directory + app.activate_scratch_dir() + run.command(['mrconvert', app.ARGS.input, 'in.mif'] + grad_import_option + pe_import_option) + if app.ARGS.mask: + run.command(['mrconvert', app.ARGS.mask, 'mask.mif']) + if app.ARGS.fieldmap: + run.command(['mrconvert', app.ARGS.fieldmap, 'field.mif']) + #app.goto_scratch_dir() + + # Make sure it's actually a DWI that's been passed + header = image.Header('in.mif') + dwi_sizes = header.size() + if len(dwi_sizes) != 4: + raise MRtrixError('Input image must be a 4D image') + DW_scheme = image.mrinfo('in.mif', 'dwgrad').split('\n') + if len(DW_scheme) != int(dwi_sizes[3]): + raise MRtrixError('Input image does not contain valid DW gradient scheme') + + # Check PE table if field map is passed. + if app.ARGS.fieldmap: + PE_scheme = image.mrinfo('in.mif', 'petable').split('\n') + if len(PE_scheme) != int(dwi_sizes[3]): + raise MRtrixError('Input image does not contain valid phase encoding scheme') + + # Generate a brain mask if required, or check the mask if provided + if app.ARGS.mask: + if not image.match(header, 'mask.mif', up_to_dim=3): + raise MRtrixError('Provided mask image does not match input DWI') + else: + run.command('dwi2mask in.mif mask.mif') + + # Image dimensions + dims = list(map(int, header.size())) + vox = list(map(float, header.spacing()[:3])) + vu = round((vox[0]+vox[1])/2., 1) + shells = [int(round(float(s))) for s in image.mrinfo('in.mif', 'shell_bvalues').split()] + shell_sizes = [int(s) for s in image.mrinfo('in.mif', 'shell_sizes').split()] + + # Set default lmax + def get_max_sh_degree(N, oversampling_factor=1.3): + for l in range(0,10,2): + if (l+3)*(l+4)/2 * oversampling_factor > N: + return l + return 8 + + lmax = [get_max_sh_degree(n) for b, n in zip(shells, shell_sizes)] + if app.ARGS.lmax: + lmax = [int(l) for l in app.ARGS.lmax.split(',')] + if len(lmax) != len(shells): + raise MRtrixError('No. lmax must match no. shells.') + + # Set rlmax + rlmax = [min(r,max(l-2,0)) for r, l in zip([4,2,0], sorted(lmax, reverse=True))] + if app.ARGS.rlmax: + rlmax = [int(l) for l in app.ARGS.rlmax.split(',')] + if len(rlmax) > len(lmax) or max(rlmax) > max(lmax): + raise MRtrixError('-rlmax invalid.') + + # Configuration + config = json.loads(DEFAULT_CONFIG) + if app.ARGS.setup: + with open(app.ARGS.setup) as f: + customconfig = json.load(f) + # use defaults for unspecified config options + customconfig['global'] = {**config['global'], **customconfig['global']} + # update config + config.update(customconfig) + + # Override regularization options if provided + if app.ARGS.reg: + config['global']['rec-reg'] = float(app.ARGS.reg) + if app.ARGS.zreg: + config['global']['rec-zreg'] = float(app.ARGS.zreg) + + # SSP option + ssp_option = '' + if app.ARGS.sspfile: + run.command(f'cp {app.ARGS.sspfile} ssp.txt') + ssp_option = ' -ssp ssp.txt' + elif app.ARGS.sspwidth: + ssp_option = ' -ssp ' + app.ARGS.sspwidth + + # Initialise radial basis with RF per shell. + rfs = [] + for k, l in enumerate(lmax): + fn = 'rf'+str(k+1)+'.txt' + with open(fn, 'w') as f: + for s in range(len(shells)): + i = '1' if k==s else '0' + f.write(' '.join([i,]*(l//2+1)) + '\n') + rfs += [fn,] + + redrfs = ['redrf'+str(k+1)+'.txt' for k in range(len(rlmax))] + + # Force max no. threads + nthr = '' + if app.ARGS.nthreads: + nthr = ' -nthreads ' + str(app.ARGS.nthreads) + + # Set multiband factor + mb = 1 + if app.ARGS.mb: + mb = int(app.ARGS.mb) + + # Slice order + motfilt_option = '' + if app.ARGS.sorder: + p,s = app.ARGS.sorder.split(',') + motfilt_option += ' -packs ' + p + ' -shift ' + s + if app.ARGS.driftfilter: + motfilt_option += ' -driftfilt' + + # Import fixed slice weights + if app.ARGS.priorweights: + run.command(f'cp {app.ARGS.priorweights} priorweights.txt') + if app.ARGS.fixedweights: + run.command(f'cp {app.ARGS.fixedweights} sliceweights.txt') + # Import fixed motion traces + if app.ARGS.fixedmotion: + run.command(f'cp {app.ARGS.fixedmotion} motion.txt') + + # Import voxel weights + if app.ARGS.voxelweights: + # mrcalc "hack" to check image dimensions & copy PE table + run.command(f'mrcalc in.mif 0 -mult {app.ARGS.voxelweights} -add voxelweights.mif') + + + # Variable input file name + global inputfn, voxweightsfn + inputfn = 'in.mif' + voxweightsfn = 'voxelweights.mif' + + + + # __________ Function definitions __________ + + def reconstep(k, conf): + rcmd = 'dwirecon {} recon-{}.mif -spred spred.mif'.format(inputfn, k) + rcmd += ' -maxiter {} -reg {} -zreg {}'.format(conf['rec-iter'], conf['rec-reg'], conf['rec-zreg']) + rcmd += ssp_option + ' -rf ' + ' -rf '.join(rfs) + if k>0: + rcmd += ' -motion motion.txt -weights sliceweights.txt -init recon-{}.mif'.format(k-1) + elif app.ARGS.priorweights: + rcmd += ' -weights priorweights.txt' + if app.ARGS.voxelweights: + rcmd += ' -voxweights ' + voxweightsfn + rcmd += ' -force' + nthr + run.command(rcmd) + + + def sliceweightstep(k): + if app.ARGS.fixedweights: + return; + mask_opt = ' -mask mask.mif' + (' -motion motion.txt' if k>0 else '') + run.command('dwisliceoutliergmm ' + inputfn + ' spred.mif -mb ' + str(mb) + mask_opt + ' sliceweights.txt -force') + if app.ARGS.priorweights: + import numpy as np + W0 = np.loadtxt('priorweights.txt') + W1 = np.loadtxt('sliceweights.txt') + np.savetxt('sliceweights.txt', W1 * W0) + + + def basisupdatestep(k, conf): + lboption = ' -lbreg {}'.format(conf['lbreg']) if conf['lbreg']>0 else '' + run.command('msshsvd recon-' + str(k) + '.mif -mask mask.mif -lmax ' + ','.join([str(l) for l in sorted(lmax)[::-1]]) + lboption + ' ' + ' '.join(rfs) + ' -force') + + + def rankreduxstep(k, conf): + lboption = ' -lbreg {}'.format(conf['lbreg']) if conf['lbreg']>0 else '' + fwhmoption = ' -fwhm {}'.format(round(vu*conf['reg-scale'], 2)) + run.command('mrfilter recon-' + str(k) + '.mif smooth -' + fwhmoption + ' | ' + + 'msshsvd - -mask mask.mif -lmax ' + ','.join([str(l) for l in rlmax]) + + lboption + ' ' + ' '.join(redrfs) + ' -proj rankredux.mif -force') + + + def registrationstep(k, conf): + if app.ARGS.fixedmotion: + return; + rcmd = 'dwislicealign {} rankredux.mif motion.txt -mask mask.mif -maxiter {}'.format(inputfn, conf['reg-iter']) + rcmd += ' -mb ' + (str(mb) if conf['svr'] else '0') + (' -init motion.txt' if k>0 else '') + rcmd += ssp_option + ' -force' + nthr + run.command(rcmd) + run.command('motionfilter motion.txt sliceweights.txt motion.txt -medfilt 5' + motfilt_option + ' -force') + + + def fieldalignstep(k): + global inputfn, voxweightsfn + if app.ARGS.fieldmap: + cmdopt = ' -motion motion.txt -fidx ' + ('0' if not app.ARGS.fieldidx else app.ARGS.fieldidx) + ' -force' if k > 0 else '' + run.command('mrfieldunwarp in.mif field.mif unwarped.mif' + cmdopt) + if app.ARGS.voxelweights: + run.command('mrfieldunwarp -nomodulation voxelweights.mif field.mif voxelweights-unwarped.mif' + cmdopt) + inputfn = 'unwarped.mif' + voxweightsfn = 'voxelweights-unwarped.mif' + + + # __________ Motion correction __________ + + nepochs = len(config['epochs']) - 1 + for k, epoch in enumerate(config['epochs']): + liveconfig = {**config['global']} + liveconfig.update(epoch) + # update template + fieldalignstep(k) + reconstep(k, liveconfig) + # end on last recon step + if k == nepochs: + break + # basis update and rank reduction + basisupdatestep(k, liveconfig) + rankreduxstep(k, liveconfig) + # slice weights + sliceweightstep(k) + # register template to volumes + registrationstep(k, liveconfig) + + + # __________ Copy outputs __________ + + run.command(['mrconvert', f'recon-{nepochs}.mif', app.ARGS.output], mrconvert_keyval=f'recon-{nepochs}.mif', force=app.FORCE_OVERWRITE) + if app.ARGS.export_motion: + run.command(f'cp motion.txt {app.ARGS.export_motion}') + if app.ARGS.export_weights: + run.command(f'cp sliceweights.txt {app.ARGS.export_weights}') + + + diff --git a/python/mrtrix3/commands/motionfilter.py b/python/mrtrix3/commands/motionfilter.py new file mode 100755 index 0000000000..e8e53a524e --- /dev/null +++ b/python/mrtrix3/commands/motionfilter.py @@ -0,0 +1,70 @@ +# Copyright (c) 2017-2019 Daan Christiaens +# +# MRtrix and this add-on module are distributed in the hope +# that it will be useful, but WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. +# +# Author: Daan Christiaens +# King's College London +# daan.christiaens@kcl.ac.uk +# + +import numpy as np +from scipy.linalg import logm, expm +from scipy.signal import medfilt + + +def getsliceorder(n, p=2, s=1): + return np.array([j for k in range(0,p) for j in range((k*s)%p,n,p)], dtype=int) + + +def usage(cmdline): #pylint: disable=unused-variable + from mrtrix3 import app #pylint: disable=no-name-in-module, import-outside-toplevel + cmdline.set_author('Daan Christiaens (daan.christiaens@kcl.ac.uk)') + cmdline.set_synopsis('Filtering a series of rigid motion parameters') + cmdline.add_description('This command applies a filter on a timeseries of rigid motion parameters.' + ' This is used in dwimotioncorrect to correct severe registration errors.') + cmdline.add_argument('input', + type=app.Parser.FileIn(), + help='The input motion file') + cmdline.add_argument('weights', + type=app.Parser.FileIn(), + help='The input weight matrix') + cmdline.add_argument('output', + type=app.Parser.FileOut(), + help='The output motion file') + cmdline.add_argument('-packs', type=int, default=2, help='no. slice packs') + cmdline.add_argument('-shift', type=int, default=1, help='slice shift') + cmdline.add_argument('-medfilt', type=int, default=1, help='median filtering kernel size (default = 1; disabled)') + cmdline.add_argument('-driftfilt', action='store_true', help='drift filter slice packs') + + +def execute(): #pylint: disable=unused-variable + from mrtrix3 import MRtrixError #pylint: disable=no-name-in-module, import-outside-toplevel + from mrtrix3 import app, image #pylint: disable=no-name-in-module, import-outside-toplevel + # read inputs + M = np.loadtxt(app.ARGS.input) + W = np.clip(np.loadtxt(app.ARGS.weights), 1e-6, None) + # set up slice order + nv = W.shape[1] + ne = M.shape[0]//nv + sliceorder = getsliceorder(ne, app.ARGS.packs, app.ARGS.shift) + isliceorder = np.argsort(sliceorder) + # reorder + M1 = np.reshape(M.reshape((nv,ne,6))[:,sliceorder,:], (-1,6)) + W1 = np.reshape(np.mean(W.reshape((-1, ne, nv)), axis=0)[sliceorder,:], (-1,1)); + # filter: t'_i = w * t_i + (1-w) * (t'_(i-1) + t'_(i+1)) / 2 + A = np.eye(nv*ne) - 0.5 * (1.-W1) * (np.eye(nv*ne, k=1) + np.eye(nv*ne, k=-1)) + A[0,1] = W1[0] - 1 + A[-1,-2] = W1[-1] - 1 + M2 = np.linalg.solve(A, W1 * M1) + # median filtering + if ne > 1: + M2 = medfilt(M2, (app.ARGS.medfilt, 1)) + if app.ARGS.driftfilt: + M2 = M2.reshape((nv,ne,6)) - np.median(M2.reshape((nv,ne,6)), 0) + # reorder output + M3 = np.reshape(M2.reshape((nv,ne,6))[:,isliceorder,:], (-1,6)) + np.savetxt(app.ARGS.output, M3, fmt='%.6f') + diff --git a/python/mrtrix3/commands/motionstats.py b/python/mrtrix3/commands/motionstats.py new file mode 100755 index 0000000000..787689d834 --- /dev/null +++ b/python/mrtrix3/commands/motionstats.py @@ -0,0 +1,117 @@ +# Copyright (c) 2017-2019 Daan Christiaens +# +# MRtrix and this add-on module are distributed in the hope +# that it will be useful, but WITHOUT ANY WARRANTY; without +# even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. +# +# Author: Daan Christiaens +# King's College London +# daan.christiaens@kcl.ac.uk +# + +import math +import numpy as np +from scipy.linalg import logm, expm +import matplotlib.pyplot as plt + + +def getsliceorder(n, p=2, s=1): + return np.array([j for k in range(0,p) for j in range((k*s)%p,n,p)], dtype=int) + + +def tr2lie(T): + r = np.zeros((6,)) + try: + L = logm(T) + R = 0.5 * (L - L.T) + r[:3] = L[:3,3] + r[3:] = R[2,1], R[0,2], R[1,0] + except: + pass + return r + + +def lie2tr(r): + L = np.zeros((4,4)) + L[:3,3] = r[:3] + L[2,1] = r[3] ; L[1,2] = -r[3] + L[0,2] = r[4] ; L[2,0] = -r[4] + L[1,0] = r[5] ; L[0,1] = -r[5] + T = expm(L) + return T + + +def tr2euler(T): + sy = math.hypot(T[0,0], T[1,0]) + if sy > 1e-6: + roll = math.atan2(T[2,1], T[2,2]) + pitch = math.atan2(-T[2,0], sy) + yaw = math.atan2(T[1,0], T[0,0]) + else: + roll = math.atan2(-T[1,2], T[1,1]) + pitch = math.atan2(-T[2,0], sy) + yaw = 0.0 + return np.array([T[0,3], T[1,3], T[2,3], yaw, pitch, roll]) + + +def usage(cmdline): #pylint: disable=unused-variable + from mrtrix3 import app #pylint: disable=no-name-in-module, import-outside-toplevel + cmdline.set_author('Daan Christiaens (daan.christiaens@kcl.ac.uk)') + cmdline.set_synopsis('Calculate motion and outlier statistics') + cmdline.add_description('This command calculates statistics of the subject translation and rotation' + ' and of the detected slice outliers.') + cmdline.add_argument('input', + type=app.Parser.FileIn(), + help='The input motion file') + cmdline.add_argument('weights', + type=app.Parser.FileIn(), + help='The input weight matrix') + cmdline.add_argument('-packs', type=int, default=2, help='no. slice packs') + cmdline.add_argument('-shift', type=int, default=1, help='slice shift') + cmdline.add_argument('-plot', action='store_true', help='plot motion trajectory') + cmdline.add_argument('-grad', type=app.Parser.FileIn(), help='dMRI gradient table') + cmdline.add_argument('-dispersion', type=app.Parser.FileIn(), help='output gradient dispersion to file') + + +def execute(): #pylint: disable=unused-variable + from mrtrix3 import MRtrixError #pylint: disable=no-name-in-module, import-outside-toplevel + from mrtrix3 import app, image #pylint: disable=no-name-in-module, import-outside-toplevel + # read inputs + M0 = np.loadtxt(app.ARGS.input) + W = np.loadtxt(app.ARGS.weights) + # set up slice order + nv = W.shape[1] + ne = M0.shape[0]//nv + sliceorder = getsliceorder(ne, app.ARGS.packs, app.ARGS.shift) + isliceorder = np.argsort(sliceorder) + # reorder + M = np.reshape(M0.reshape((nv,ne,6))[:,sliceorder,:], (-1,6)) + dM = np.diff(M, axis=0) + # motion stats + mtra = np.mean(np.sqrt(np.sum(dM[:,:3]**2, axis=1))) + mrot = np.mean(np.sqrt(np.sum(dM[:,3:]**2, axis=1))) * 180./np.pi + # outlier stats + orratio = 1.0 - np.sum(W) / (W.shape[0] * W.shape[1]) + # print stats + print('{:f} {:f} {:f}'.format(mtra, mrot, orratio)) + # plot trajectory + if app.ARGS.plot: + T = np.array([tr2euler(lie2tr(m)) for m in M]) + ax1 = plt.subplot(2,1,1); plt.plot(T[:,:3]); plt.ylabel('translation'); plt.legend(['x', 'y', 'z']); + ax2 = plt.subplot(2,1,2, sharex=ax1); plt.plot(T[:,3:]); plt.ylabel('rotation'); plt.legend(['yaw', 'pitch', 'roll']); + plt.xlim(0, nv*ne); plt.xlabel('time'); plt.tight_layout(); + plt.show(); + # intra-volume gradient scatter + if app.ARGS.dispersion: + if app.ARGS.grad: + grad = np.loadtxt(app.ARGS.grad) + bvec = grad[:,:3] / np.linalg.norm(grad[:,:3], axis=1)[:,np.newaxis] + r = np.array([[np.dot(lie2tr(m)[:3,:3], v) for m in mvol] for mvol, v in zip(M.reshape((nv,ne,6)), bvec)]) + rm = np.sum(r, axis=1); rm /= np.linalg.norm(rm, axis=1)[:,np.newaxis] + rd = np.einsum('vzi,vi->vz', r, rm) + dispersion = np.degrees(2*np.arccos(np.sqrt(np.mean(rd**2, axis=1)))) + np.savetxt(app.ARGS.dispersion, dispersion[np.newaxis,:], fmt='%.4f') + else: + raise MRtrixError('No diffusion gradient table provided') + diff --git a/src/dwi/svr/mapping.h b/src/dwi/svr/mapping.h new file mode 100644 index 0000000000..2908b924fb --- /dev/null +++ b/src/dwi/svr/mapping.h @@ -0,0 +1,276 @@ +/* Copyright (c) 2017-2019 Daan Christiaens + * + * MRtrix and this add-on module are distributed in the hope + * that it will be useful, but WITHOUT ANY WARRANTY; without + * even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. + */ + +#pragma once + +#include +#include + +#include "adapter/base.h" +#include "algo/threaded_loop.h" +#include "dwi/shells.h" +#include "header.h" +#include "interp/cubic.h" +#include "interp/linear.h" +#include "math/SH.h" +#include "transform.h" +#include "types.h" + +#include "dwi/svr/param.h" +#include "dwi/svr/psf.h" +#include "dwi/svr/qspacebasis.h" + +namespace MR::Interp { +template class LinearAdjoint : public Linear { +public: + using typename Linear::value_type; + using Linear::clamp; + using Linear::P; + using Linear::factors; + + LinearAdjoint(const ImageType &parent, value_type outofbounds = 0) : Linear(parent, outofbounds) {} + + //! Add value to local region by interpolation weights. + void adjoint_add(value_type val) { + if (Base::out_of_bounds) + return; + + ssize_t c[] = {ssize_t(std::floor(P[0])), ssize_t(std::floor(P[1])), ssize_t(std::floor(P[2]))}; + + size_t i(0); + for (ssize_t z = 0; z < 2; ++z) { + ImageType::index(2) = clamp(c[2] + z, ImageType::size(2)); + for (ssize_t y = 0; y < 2; ++y) { + ImageType::index(1) = clamp(c[1] + y, ImageType::size(1)); + for (ssize_t x = 0; x < 2; ++x) { + ImageType::index(0) = clamp(c[0] + x, ImageType::size(0)); + ImageType::adjoint_add(factors[i++] * val); + } + } + } + } +}; + +template class CubicAdjoint : public Cubic { +public: + using typename Cubic::value_type; + using Cubic::clamp; + using Cubic::P; + using Cubic::weights_vec; + + CubicAdjoint(const ImageType &parent, value_type outofbounds = 0) : Cubic(parent, outofbounds) {} + + //! Add value to local region by interpolation weights. + void adjoint_add(value_type val) { + if (Base::out_of_bounds) + return; + + ssize_t c[] = {ssize_t(std::floor(P[0]) - 1), ssize_t(std::floor(P[1]) - 1), ssize_t(std::floor(P[2]) - 1)}; + + size_t i(0); + for (ssize_t z = 0; z < 4; ++z) { + ImageType::index(2) = clamp(c[2] + z, ImageType::size(2)); + for (ssize_t y = 0; y < 4; ++y) { + ImageType::index(1) = clamp(c[1] + y, ImageType::size(1)); + for (ssize_t x = 0; x < 4; ++x) { + ImageType::index(0) = clamp(c[0] + x, ImageType::size(0)); + ImageType::adjoint_add(weights_vec[i++] * val); + } + } + } + } +}; +} // namespace MR::Interp + +namespace MR::DWI::SVR { +template class MotionMapping : public Adapter::Base, ImageType> { +public: + using base_type = Adapter::Base, ImageType>; + using value_type = typename ImageType::value_type; + using vector_type = typename Eigen::Matrix; + + using base_type::parent; + + MotionMapping(const ImageType &projection, const Header &source, const Eigen::MatrixXf &rigid, const SSP &ssp) + : base_type(projection), + interp(projection, 0.0f), + yhdr(source), + motion(rigid), + ssp(ssp), + Tr(projection), + Ts(source), + Ts2r(Ts.scanner2voxel * Tr.voxel2scanner) {} + + // Adapter attributes ----------------------------------------------- + size_t ndim() const { return interp.ndim(); } + int size(size_t axis) const { return (axis < 3) ? yhdr.size(axis) : interp.size(axis); } + default_type spacing(size_t axis) const { return (axis < 3) ? yhdr.spacing(axis) : interp.spacing(axis); } + const transform_type &transform() const { return yhdr.transform(); } + const std::string &name() const { return yhdr.name(); } + + ssize_t get_index(size_t axis) const { return (axis < 3) ? x[axis] : interp.index(axis); } + void move_index(size_t axis, ssize_t increment) { + if (axis < 3) + x[axis] += increment; + else + interp.index(axis) += increment; + } + void reset() { + x[0] = x[1] = x[2] = 0; + for (size_t n = 3; n < interp.ndim(); ++n) + interp.index(n) = 0; + } + // ------------------------------------------------------------------ + + value_type value() { + value_type res = 0; + for (int z = -ssp.size(); z <= ssp.size(); z++) { + Eigen::Vector3d pr = Ts2r * Eigen::Vector3d(x[0], x[1], x[2] + z); + for (int k = 0; k < 3; k++) + pr[k] = clampdim(pr[k], k); + interp.voxel(pr); + res += ssp(z) * interp.value(); + } + return res; + } + + void adjoint_add(value_type val) { + for (int z = -ssp.size(); z <= ssp.size(); z++) { + Eigen::Vector3d pr = Ts2r * Eigen::Vector3d(x[0], x[1], x[2] + z); + for (int k = 0; k < 3; k++) + pr[k] = clampdim(pr[k], k); + interp.voxel(pr); + interp.adjoint_add(ssp(z) * val); + } + } + + void set_shotidx(size_t idx) { + interp.set_shotidx(idx); + Ts2r = Tr.scanner2voxel * get_transform(motion.row(idx)) * Ts.voxel2scanner; + } + +private: + Interp::CubicAdjoint interp; + const Header &yhdr; + Eigen::MatrixXf motion; + SSP ssp; + ssize_t x[3]; + const Transform Tr, Ts; + transform_type Ts2r; // vox-to-vox transform, mapping vectors in source space to recon space + + FORCE_INLINE transform_type get_transform(const Eigen::VectorXf &p) const { + transform_type T(se3exp(p).cast()); + return T; + } + + FORCE_INLINE default_type clampdim(default_type r, size_t axis) const { + return (r < 0) ? 0 : (r > parent().size(axis) - 1) ? parent().size(axis) - 1 : r; + } +}; + +class ReconMapping { +public: + ReconMapping(const Header &recon, + const Header &source, + const QSpaceBasis &basis, + const Eigen::MatrixXf &rigid, + const SSP &ssp) + : xhdr(recon), + yhdr(source), + ne(rigid.rows() / source.size(3)), + outer_axes({2, 3}), + slice_axes({0, 1}), + qbasis(basis), + motion(rigid), + ssp(ssp) { + INFO("Multiband factor " + str(source.size(2) / ne) + " detected."); + } + + const Header &xheader() const { return xhdr; } + const Header &yheader() const { return yhdr; } + + size_t rows() const { return voxel_count(yhdr); } + size_t cols() const { return voxel_count(xhdr); } + + template void x2y(const ImageType1 &X, ImageType2 &Y) const { + // create adapters + auto qmap = Adapter::makecached(X, qbasis); + auto spatialmap = Adapter::make(qmap, yhdr, motion, ssp); + + // define per-slice mapping + struct MapSliceX2Y { + ImageType2 out; + decltype(spatialmap) pred; + size_t ne; + const std::vector &axouter; + const std::vector &axslice; + // define slice-wise operation + void operator()(Iterator &pos) { + size_t z = pos.index(2); + size_t v = pos.index(3); + if (z < ne) { + assign_pos_of(pos, axouter).to(out); + pred.set_shotidx(v * ne + z % ne); + for (size_t zz = z; zz < out.size(2); zz += ne) { + out.index(2) = pred.index(2) = zz; + for (auto i = Loop(axslice)(out, pred); i; ++i) + out.value() += pred.value(); + } + } + } + } func = {Y, spatialmap, ne, outer_axes, slice_axes}; + + // run across all slices + ThreadedLoop("forward projection", Y, outer_axes, slice_axes).run_outer(func); + } + + template void y2x(ImageType1 &X, const ImageType2 &Y) const { + // create adapters + auto qmap = Adapter::makecached_add(X, qbasis); + auto spatialmap = Adapter::make(qmap, yhdr, motion, ssp); + + // define per-slice mapping + struct MapSliceY2X { + ImageType2 in; + decltype(spatialmap) pred; + size_t ne; + const std::vector &axouter; + const std::vector &axslice; + // define slice-wise operation + void operator()(Iterator &pos) { + size_t z = pos.index(2); + size_t v = pos.index(3); + if (z < ne) { + assign_pos_of(pos, axouter).to(in); + pred.set_shotidx(v * ne + z % ne); + for (size_t zz = z; zz < in.size(2); zz += ne) { + in.index(2) = pred.index(2) = zz; + for (auto i = Loop(axslice)(in, pred); i; ++i) + pred.adjoint_add(in.value()); + } + pred.set_shotidx(0); // trigger delayed write back + } + } + } func = {Y, spatialmap, ne, outer_axes, slice_axes}; + + // run across all slices + ThreadedLoop("transpose projection", Y, outer_axes, slice_axes).run_outer(func); + } + +private: + const Header &xhdr, yhdr; + const size_t ne; + const std::vector outer_axes; + const std::vector slice_axes; + + const QSpaceBasis qbasis; + const Eigen::MatrixXf motion; + const SSP ssp; +}; + +} // namespace MR::DWI::SVR diff --git a/src/dwi/svr/param.h b/src/dwi/svr/param.h new file mode 100644 index 0000000000..8e303e4dcb --- /dev/null +++ b/src/dwi/svr/param.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2017-2019 Daan Christiaens + * + * MRtrix and this add-on module are distributed in the hope + * that it will be useful, but WITHOUT ANY WARRANTY; without + * even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. + */ + +#pragma once + +#include +#include + +namespace MR::DWI::SVR { +/* Exponential Lie mapping on SE(3). */ +template Eigen::Matrix4f se3exp(const VectorType &v) { + Eigen::Matrix4f A, T; + A.setZero(); + A(0, 3) = v[0]; + A(1, 3) = v[1]; + A(2, 3) = v[2]; + A(2, 1) = v[3]; + A(1, 2) = -v[3]; + A(0, 2) = v[4]; + A(2, 0) = -v[4]; + A(1, 0) = v[5]; + A(0, 1) = -v[5]; + T = A.exp(); + return T; +} + +/* Logarithmic Lie mapping on SE(3). */ +Eigen::Matrix se3log(const Eigen::Matrix4f &T) { + Eigen::Matrix4f A = T.log(); + Eigen::Matrix v; + v[0] = A(0, 3); + v[1] = A(1, 3); + v[2] = A(2, 3); + v[3] = (A(2, 1) - A(1, 2)) / 2; + v[4] = (A(0, 2) - A(2, 0)) / 2; + v[5] = (A(1, 0) - A(0, 1)) / 2; + return v; +} + +} // namespace MR::DWI::SVR diff --git a/src/dwi/svr/psf.h b/src/dwi/svr/psf.h new file mode 100644 index 0000000000..2d0498a08c --- /dev/null +++ b/src/dwi/svr/psf.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2017-2019 Daan Christiaens + * + * MRtrix and this add-on module are distributed in the hope + * that it will be useful, but WITHOUT ANY WARRANTY; without + * even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. + */ + +#pragma once + +#include + +#include "types.h" + +namespace MR::DWI::SVR { + +/** + * 1-D Slice Sensitivity Profile. + */ +template struct SSP { +public: + SSP(const T fwhm = 1) : n(std::floor(fwhm / scale)), values(2 * n + 1) { + for (int z = -n; z <= n; z++) + values[n + z] = gaussian(z, fwhm / scale); + normalise_values(); + } + + template SSP(const VectorType &vec) : n(vec.size() / 2), values(2 * n + 1) { + for (size_t i = 0; i < values.size(); i++) + values[i] = vec[i]; + normalise_values(); + } + + inline T operator()(const int z) const { return values[n + z]; } + + inline int size() const { return n; } + +private: + int n; + std::vector values; + static constexpr T scale = 2.35482; // 2.sqrt(2.ln(2)); + + inline T gaussian(T x, T sigma) const { + T y = x / sigma; + return std::exp(-0.5 * y * y); + } + + inline void normalise_values() { + T norm = 0; + for (int z = -n; z <= n; z++) + norm += values[n + z]; + for (int z = -n; z <= n; z++) + values[n + z] /= norm; + } +}; + +} // namespace MR::DWI::SVR diff --git a/src/dwi/svr/qspacebasis.h b/src/dwi/svr/qspacebasis.h new file mode 100644 index 0000000000..d6512307d9 --- /dev/null +++ b/src/dwi/svr/qspacebasis.h @@ -0,0 +1,280 @@ +/* Copyright (c) 2017-2019 Daan Christiaens + * + * MRtrix and this add-on module are distributed in the hope + * that it will be useful, but WITHOUT ANY WARRANTY; without + * even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. + */ + +#pragma once + +#include +#include +#include + +#include "adapter/base.h" +#include "algo/loop.h" +#include "dwi/shells.h" +#include "math/SH.h" +#include "types.h" + +#include "dwi/svr/param.h" + +namespace MR::Adapter { +template class ReadCache : public Adapter::Base, ImageType> { +public: + using base_type = Adapter::Base, ImageType>; + using value_type = typename ImageType::value_type; + + using base_type::parent; + + ReadCache(const ImageType &parent) : base_type(parent) { + Header hdr(parent); + buffer = Image::scratch(hdr, "temporary buffer"); + } + + ReadCache(const ReadCache &other) : ReadCache(other.parent()) {} + + FORCE_INLINE ssize_t get_index(size_t axis) const { return buffer.get_index(axis); } + FORCE_INLINE void move_index(size_t axis, ssize_t increment) { buffer.move_index(axis, increment); } + FORCE_INLINE void reset() { buffer.reset(); } + + void flush() { + // clear buffer + reset(); + std::fill_n(buffer.address(), voxel_count(buffer), value_type(NAN)); + } + + FORCE_INLINE value_type value() { + value_type val = *buffer.address(); + if (!std::isfinite(val)) + load(val); + return val; + } + + FORCE_INLINE void set_shotidx(size_t idx) { + flush(); + parent().set_shotidx(idx); + } + +private: + Image buffer; + + FORCE_INLINE void load(value_type &val) { + assign_pos_of(buffer).to(parent()); + *buffer.address() = val = parent().value(); + } +}; + +template class WriteCache : public Adapter::Base, ImageType> { +public: + using base_type = Adapter::Base, ImageType>; + using value_type = typename ImageType::value_type; + + using base_type::parent; + + WriteCache(const ImageType &parent) : base_type(parent) { + Header hdr(parent); + buffer = Image::scratch(hdr, "temporary buffer"); + // initialise lock image + static_assert(sizeof(std::atomic_flag) == sizeof(uint8_t), "std::atomic_flag expected to be 1 byte"); + lock = Image::scratch(hdr, "temporary buffer lock"); + } + + WriteCache(const WriteCache &other) : base_type(other.parent()), lock(other.lock) { + Header hdr(other.parent()); + buffer = Image::scratch(hdr, "temporary buffer"); + } + + FORCE_INLINE ssize_t get_index(size_t axis) const { return buffer.get_index(axis); } + FORCE_INLINE void move_index(size_t axis, ssize_t increment) { buffer.move_index(axis, increment); } + FORCE_INLINE void reset() { buffer.reset(); } + + void flush() { + // delayed write back + for (auto l = Loop()(buffer); l; l++) { + if (buffer.value()) { + assign_pos_of(buffer).to(parent(), lock); + std::atomic_flag *flag = reinterpret_cast(lock.address()); + while (flag->test_and_set(std::memory_order_acquire)) + ; + parent().adjoint_add(buffer.value()); + flag->clear(std::memory_order_release); + } + } + // clear buffer + reset(); + std::fill_n(buffer.address(), voxel_count(buffer), value_type(0)); + } + + FORCE_INLINE void adjoint_add(value_type val) { *buffer.address() += val; } + + FORCE_INLINE void set_shotidx(size_t idx) { + flush(); + parent().set_shotidx(idx); + } + +private: + Image buffer; + Image lock; +}; + +template