Skip to content

Commit

Permalink
Check for zero residual Steihaug
Browse files Browse the repository at this point in the history
  • Loading branch information
tttapa committed Dec 17, 2023
1 parent 04ba67f commit 720d2e9
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions src/alpaqa/include/alpaqa/accelerators/steihaugcg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <alpaqa/config/config.hpp>
#include <alpaqa/util/alloc-check.hpp>
#include <cmath>

namespace alpaqa {

Expand All @@ -18,14 +19,14 @@ struct SteihaugCGParams {
/// \min\left(\mathrm{tol\_scale\_root},\; \sqrt{\|g\|}\right)
/// \right)
/// @f]
real_t tol_scale = 1;
real_t tol_scale = 1;
/// Determines the tolerance for termination of the algorithm.
/// See @ref tol_scale.
real_t tol_scale_root = real_t(0.5);
real_t tol_scale_root = real_t(0.5);
/// Determines the tolerance for termination of the algorithm.
/// Prevents the use of huge tolerances if the gradient norm is still large.
/// See @ref tol_scale.
real_t tol_max = inf<config_t>;
real_t tol_max = inf<config_t>;
/// Limit the number of CG iterations to @f$ \lfloor n \cdot
/// \mathrm{max\_iter\_factor} \rceil @f$, where @f$ n @f$ is the number
/// of free variables of the problem.
Expand Down Expand Up @@ -113,7 +114,11 @@ struct SteihaugCG {
}

real_t alpha = r_sq / dBd;
s = z + alpha * d;
if (!std::isfinite(alpha)) {
s.setConstant(NaN<config_t>);
return NaN<config_t>;
}
s = z + alpha * d;
if (s.norm() >= trust_radius) {
// Find t >= 0 to get the boundary point such that
// ||z + t d|| == trust_radius
Expand All @@ -124,7 +129,8 @@ struct SteihaugCG {
}
r += alpha * Bd;
real_t r_next_sq = r.squaredNorm();
if (std::sqrt(r_next_sq) < tolerance || i > max_iter)
real_t r_next = std::sqrt(r_next_sq);
if (r_next < tolerance || r_next == 0 || i > max_iter)
return eval(s);
real_t beta_next = r_next_sq / r_sq;
r_sq = r_next_sq;
Expand Down

0 comments on commit 720d2e9

Please sign in to comment.