Skip to content

Commit

Permalink
Add prox of complex ℓ₁ norm
Browse files Browse the repository at this point in the history
  • Loading branch information
tttapa committed Dec 7, 2023
1 parent 445d704 commit f66b5c6
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 0 deletions.
92 changes: 92 additions & 0 deletions src/alpaqa/include/alpaqa/functions/l1-norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <alpaqa/config/config.hpp>
#include <alpaqa/functions/prox.hpp>
#include <alpaqa/util/lifetime.hpp>
#include <cassert>
#include <cmath>
#include <stdexcept>
Expand Down Expand Up @@ -75,4 +76,95 @@ struct L1Norm {
}
};

/// ℓ₁-norm of complex numbers.
/// @ingroup grp_Functions
/// @tparam Weight
/// Type of weighting factors. Either scalar or vector.
template <Config Conf, class Weight = typename Conf::real_t>
requires(std::is_same_v<Weight, typename Conf::real_t> ||
std::is_same_v<Weight, typename Conf::vec> ||
std::is_same_v<Weight, typename Conf::rvec> ||
std::is_same_v<Weight, typename Conf::crvec>)
struct L1NormComplex {
USING_ALPAQA_CONFIG(Conf);
using weight_t = Weight;
static constexpr bool scalar_weight = std::is_same_v<weight_t, real_t>;

L1NormComplex(weight_t λ) : λ{std::move(λ)} {
const char *msg = "L1NormComplex::λ must be nonnegative";
if constexpr (scalar_weight) {
if (λ < 0 || !std::isfinite(λ))
throw std::invalid_argument(msg);
} else {
if ((λ.array() < 0).any() || !λ.allFinite())
throw std::invalid_argument(msg);
}
}

L1NormComplex()
requires(scalar_weight)
: λ{1} {}
L1NormComplex()
requires(!scalar_weight)
= default;

weight_t λ;

real_t prox(crcmat in, rcmat out, real_t γ = 1) {
assert(in.cols() == 1);
assert(out.cols() == 1);
assert(in.size() == out.size());
const length_t n = in.size();
if constexpr (scalar_weight) {
assert(λ >= 0);
if (λ == 0) {
out = in;
return 0;
}
auto soft_thres = [γλ{γ * λ}](cplx_t x) {
auto mag = std::abs(x), arg = std::arg(x);
return mag <= γλ ? 0 : std::polar(mag - γλ, arg);
};
out = in.unaryExpr(soft_thres);
return λ * out.template lpNorm<1>();
} else {
if constexpr (std::is_same_v<weight_t, vec>)
if (λ.size() == 0)
λ = weight_t::Ones(n);
assert(λ.cols() == 1);
assert(in.size() == λ.size());
assert((λ.array() >= 0).all());
auto soft_thres = [γ](cplx_t x, real_t λ) {
real_t γλ = γ * λ;
auto mag = std::abs(x), arg = std::arg(x);
return mag <= γλ ? 0 : std::polar(mag - γλ, arg);
};
out = in.binaryExpr(λ, soft_thres);
return out.cwiseProduct(λ).template lpNorm<1>();
}
}

/// Note: a complex vector in ℂⁿ is represented by a real vector in ℝ²ⁿ.
real_t prox(crmat in, rmat out, real_t γ = 1) {
assert(in.rows() % 2 == 0);
assert(out.rows() % 2 == 0);
cmcmat cplx_in{
util::start_lifetime_as_array<cplx_t>(in.data(), in.size() / 2),
in.rows() / 2,
in.cols(),
};
mcmat cplx_out{
util::start_lifetime_as_array<cplx_t>(out.data(), out.size() / 2),
out.rows() / 2,
out.cols(),
};
return prox(cplx_in, cplx_out, γ);
}

friend real_t alpaqa_tag_invoke(tag_t<alpaqa::prox>, L1NormComplex &self,
crmat in, rmat out, real_t γ) {
return self.prox(std::move(in), std::move(out), γ);
}
};

} // namespace alpaqa::functions
49 changes: 49 additions & 0 deletions src/alpaqa/include/alpaqa/util/lifetime.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#include <memory>

#if __cpp_lib_start_lifetime_as >= 202207L

namespace alpaqa::util {
using std::start_lifetime_as;
using std::start_lifetime_as_array;
} // namespace alpaqa::util

#else

#include <cstring>
#include <new>
#include <type_traits>

namespace alpaqa::util {
template <class T>
requires std::is_trivially_copyable_v<T>
T *start_lifetime_as_array(void *p, size_t n) noexcept {
#if __cpp_lib_is_implicit_lifetime >= 202302L
static_assert(std::is_implicit_lifetime_v<T>);
#endif
return std::launder(static_cast<T *>(std::memmove(p, p, n * sizeof(T))));
}
template <class T>
requires std::is_trivially_copyable_v<T>
const T *start_lifetime_as_array(const void *p, size_t n) noexcept {
#if __cpp_lib_is_implicit_lifetime >= 202302L
static_assert(std::is_implicit_lifetime_v<T>);
#endif
static_cast<void>(n); // TODO
// best we can do without compiler support
return std::launder(static_cast<const T *>(p));
}
template <class T>
requires std::is_trivially_copyable_v<T>
T *start_lifetime_as(void *p) noexcept {
return start_lifetime_as_array<T>(p, 1);
}
template <class T>
requires std::is_trivially_copyable_v<T>
const T *start_lifetime_as(const void *p) noexcept {
return start_lifetime_as_array<T>(p, 1);
}
} // namespace alpaqa::util

#endif

0 comments on commit f66b5c6

Please sign in to comment.