From df7e40d4f1384b99417c99b9f99c51431677900c Mon Sep 17 00:00:00 2001 From: Pieter Pas Date: Thu, 7 Dec 2023 13:37:36 +0100 Subject: [PATCH] =?UTF-8?q?Add=20prox=20of=20complex=20=E2=84=93=E2=82=81?= =?UTF-8?q?=20norm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../include/alpaqa/functions/l1-norm.hpp | 92 +++++++++++++++++++ src/alpaqa/include/alpaqa/util/lifetime.hpp | 49 ++++++++++ 2 files changed, 141 insertions(+) create mode 100644 src/alpaqa/include/alpaqa/util/lifetime.hpp diff --git a/src/alpaqa/include/alpaqa/functions/l1-norm.hpp b/src/alpaqa/include/alpaqa/functions/l1-norm.hpp index b87e7fd438..342159aacd 100644 --- a/src/alpaqa/include/alpaqa/functions/l1-norm.hpp +++ b/src/alpaqa/include/alpaqa/functions/l1-norm.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -75,4 +76,95 @@ struct L1Norm { } }; +/// ℓ₁-norm of complex numbers. +/// @ingroup grp_Functions +/// @tparam Weight +/// Type of weighting factors. Either scalar or vector. +template + requires(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) +struct L1NormComplex { + USING_ALPAQA_CONFIG(Conf); + using weight_t = Weight; + static constexpr bool scalar_weight = std::is_same_v; + + L1NormComplex(weight_t λ) : λ{std::move(λ)} { + const char *msg = "L1NormComplex::λ must be nonnegative"; + if constexpr (scalar_weight) { + if (λ < 0 || !std::isfinite(λ)) + throw std::invalid_argument(msg); + } else { + if ((λ.array() < 0).any() || !λ.allFinite()) + throw std::invalid_argument(msg); + } + } + + L1NormComplex() + requires(scalar_weight) + : λ{1} {} + L1NormComplex() + requires(!scalar_weight) + = default; + + weight_t λ; + + real_t prox(crcmat in, rcmat out, real_t γ = 1) { + assert(in.cols() == 1); + assert(out.cols() == 1); + assert(in.size() == out.size()); + const length_t n = in.size(); + if constexpr (scalar_weight) { + assert(λ >= 0); + if (λ == 0) { + out = in; + return 0; + } + auto soft_thres = [γλ{γ * λ}](cplx_t x) { + auto mag = std::abs(x), arg = std::arg(x); + return mag <= γλ ? 0 : std::polar(mag - γλ, arg); + }; + out = in.unaryExpr(soft_thres); + return λ * out.template lpNorm<1>(); + } else { + if constexpr (std::is_same_v) + if (λ.size() == 0) + λ = weight_t::Ones(n); + assert(λ.cols() == 1); + assert(in.size() == λ.size()); + assert((λ.array() >= 0).all()); + auto soft_thres = [γ](cplx_t x, real_t λ) { + real_t γλ = γ * λ; + auto mag = std::abs(x), arg = std::arg(x); + return mag <= γλ ? 0 : std::polar(mag - γλ, arg); + }; + out = in.binaryExpr(λ, soft_thres); + return out.cwiseProduct(λ).template lpNorm<1>(); + } + } + + /// Note: a complex vector in ℂⁿ is represented by a real vector in ℝ²ⁿ. + real_t prox(crmat in, rmat out, real_t γ = 1) { + assert(in.rows() % 2 == 0); + assert(out.rows() % 2 == 0); + cmcmat cplx_in{ + util::start_lifetime_as_array(in.data(), in.size() / 2), + in.rows() / 2, + in.cols(), + }; + mcmat cplx_out{ + util::start_lifetime_as_array(out.data(), out.size() / 2), + out.rows() / 2, + out.cols(), + }; + return prox(cplx_in, cplx_out, γ); + } + + friend real_t alpaqa_tag_invoke(tag_t, L1NormComplex &self, + crmat in, rmat out, real_t γ) { + return self.prox(std::move(in), std::move(out), γ); + } +}; + } // namespace alpaqa::functions diff --git a/src/alpaqa/include/alpaqa/util/lifetime.hpp b/src/alpaqa/include/alpaqa/util/lifetime.hpp new file mode 100644 index 0000000000..00d32726d7 --- /dev/null +++ b/src/alpaqa/include/alpaqa/util/lifetime.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include + +#if __cpp_lib_start_lifetime_as >= 202207L + +namespace alpaqa::util { +using std::start_lifetime_as; +using std::start_lifetime_as_array; +} // namespace alpaqa::util + +#else + +#include +#include +#include + +namespace alpaqa::util { +template + requires std::is_trivially_copyable_v +T *start_lifetime_as_array(void *p, size_t n) noexcept { +#if __cpp_lib_is_implicit_lifetime >= 202302L + static_assert(std::is_implicit_lifetime_v); +#endif + return std::launder(static_cast(std::memmove(p, p, n * sizeof(T)))); +} +template + requires std::is_trivially_copyable_v +const T *start_lifetime_as_array(const void *p, size_t n) noexcept { +#if __cpp_lib_is_implicit_lifetime >= 202302L + static_assert(std::is_implicit_lifetime_v); +#endif + static_cast(n); // TODO + // best we can do without compiler support + return std::launder(static_cast(p)); +} +template + requires std::is_trivially_copyable_v +T *start_lifetime_as(void *p) noexcept { + return start_lifetime_as_array(p, 1); +} +template + requires std::is_trivially_copyable_v +const T *start_lifetime_as(const void *p) noexcept { + return start_lifetime_as_array(p, 1); +} +} // namespace alpaqa::util + +#endif