From a2f023df88c78628043f417e369ee1da2f82f2cd Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Thu, 9 Jan 2025 16:06:17 -0500 Subject: [PATCH] update docs and fix multi write with psis turned off so that multiple ends do not get written once --- lib/stan_math | 2 +- src/stan/callbacks/multi_writer.hpp | 20 ++++++ src/stan/callbacks/stream_writer.hpp | 3 + src/stan/callbacks/tee_writer.hpp | 3 + src/stan/callbacks/unique_stream_writer.hpp | 3 + src/stan/callbacks/writer.hpp | 3 + src/stan/services/pathfinder/multi.hpp | 6 +- src/stan/services/pathfinder/single.hpp | 65 +++++++++++-------- .../pathfinder/eight_schools_test.cpp | 2 +- .../services/pathfinder/normal_glm_test.cpp | 4 +- 10 files changed, 78 insertions(+), 33 deletions(-) diff --git a/lib/stan_math b/lib/stan_math index 42d94c4840..9bae49db12 160000 --- a/lib/stan_math +++ b/lib/stan_math @@ -1 +1 @@ -Subproject commit 42d94c4840f681806ae0e0134120a4077a29e46c +Subproject commit 9bae49db12815e6906765788663ff6b6f04769fc diff --git a/src/stan/callbacks/multi_writer.hpp b/src/stan/callbacks/multi_writer.hpp index e29c2c95e7..882e32c999 100644 --- a/src/stan/callbacks/multi_writer.hpp +++ b/src/stan/callbacks/multi_writer.hpp @@ -30,6 +30,7 @@ class multi_writer { : output_(std::forward(args)...) {} multi_writer(); + /** * @tparam T Any type accepted by a `writer` overload * @param[in] x A value to write to the output streams @@ -42,6 +43,14 @@ class multi_writer { stan::math::for_each([](auto&& output) { output(); }, output_); } + /** + * Checks if all underlying writers are nonnull. + */ + inline bool is_nonnull() const noexcept { + return stan::math::apply([](auto&&... output) { return (output.is_nonnull() && ...); }, + output_); + } + /** * Get the underlying stream */ @@ -54,6 +63,17 @@ class multi_writer { std::tuple...> output_; }; +namespace internal { +template +struct is_multi_writer : std::false_type {}; + +template +struct is_multi_writer> : std::true_type {}; +} + +template +inline constexpr bool is_multi_writer_v = internal::is_multi_writer>::value; + } // namespace callbacks } // namespace stan diff --git a/src/stan/callbacks/stream_writer.hpp b/src/stan/callbacks/stream_writer.hpp index 72f901625a..62c1719adf 100644 --- a/src/stan/callbacks/stream_writer.hpp +++ b/src/stan/callbacks/stream_writer.hpp @@ -68,6 +68,9 @@ class stream_writer : public writer { output_ << comment_prefix_ << message << std::endl; } + /** + * Checks if stream is valid. + */ virtual bool is_nonnull() const noexcept { return output_.good(); } private: diff --git a/src/stan/callbacks/tee_writer.hpp b/src/stan/callbacks/tee_writer.hpp index 9bf1434de5..9a0b264f53 100644 --- a/src/stan/callbacks/tee_writer.hpp +++ b/src/stan/callbacks/tee_writer.hpp @@ -49,6 +49,9 @@ class tee_writer final : public writer { writer2_(message); } + /** + * Checks if both streams are valid. + */ virtual bool is_nonnull() const noexcept { return writer1_.is_nonnull() && writer2_.is_nonnull(); } diff --git a/src/stan/callbacks/unique_stream_writer.hpp b/src/stan/callbacks/unique_stream_writer.hpp index 01c205bb87..7e0be90d24 100644 --- a/src/stan/callbacks/unique_stream_writer.hpp +++ b/src/stan/callbacks/unique_stream_writer.hpp @@ -114,6 +114,9 @@ class unique_stream_writer final : public writer { *output_ << comment_prefix_ << message << std::endl; } + /** + * Checks if stream is valid. + */ bool is_nonnull() const noexcept { return output_ != nullptr; } private: diff --git a/src/stan/callbacks/writer.hpp b/src/stan/callbacks/writer.hpp index 1458a4bdbb..e752f17535 100644 --- a/src/stan/callbacks/writer.hpp +++ b/src/stan/callbacks/writer.hpp @@ -46,6 +46,9 @@ class writer { */ virtual void operator()(const std::string& message) {} + /** + * Checks if stream is valid. + */ virtual bool is_nonnull() const noexcept { return false; } /** * Writes multiple rows and columns of values in csv format. diff --git a/src/stan/services/pathfinder/multi.hpp b/src/stan/services/pathfinder/multi.hpp index a88fc7bb46..97780a897c 100644 --- a/src/stan/services/pathfinder/multi.hpp +++ b/src/stan/services/pathfinder/multi.hpp @@ -172,7 +172,7 @@ inline int pathfinder_lbfgs_multi( multi_writer_t multi_param_writer( single_path_parameter_writer[iter], safe_write); auto pathfinder_ret - = stan::services::pathfinder::pathfinder_lbfgs_single( + = stan::services::pathfinder::pathfinder_lbfgs_single( model, *(init[iter]), random_seed, stride_id + iter, init_radius, history_size, init_alpha, tol_obj, tol_rel_obj, tol_grad, tol_rel_grad, tol_param, @@ -181,12 +181,12 @@ inline int pathfinder_lbfgs_multi( init_writers[iter], multi_param_writer, single_path_diagnostic_writer[iter], calculate_lp, psis_resample); - if (unlikely(std::get<0>(pathfinder_ret) != error_codes::OK)) { + if (pathfinder_ret.first != error_codes::OK) { logger.error(std::string("Pathfinder iteration: ") + std::to_string(iter) + " failed."); return; } - lp_calls += std::get<2>(pathfinder_ret); + lp_calls += pathfinder_ret.second; } } }); diff --git a/src/stan/services/pathfinder/single.hpp b/src/stan/services/pathfinder/single.hpp index 4932990386..4c08fdfcb4 100644 --- a/src/stan/services/pathfinder/single.hpp +++ b/src/stan/services/pathfinder/single.hpp @@ -459,20 +459,19 @@ inline taylor_approx_t taylor_approximation( * matrix of samples, and an unsigned integer for number of times the log prob * functions was called */ -template * = nullptr> +template inline auto ret_pathfinder(int return_code, EigVec&& elbo_est, - const std::atomic& lp_calls) { - return std::make_tuple(return_code, std::forward(elbo_est), - lp_calls.load()); + const std::atomic& lp_calls, ParamWriter&& /* params */) noexcept { + if constexpr (ReturnLpSamples) { + return std::make_tuple(return_code, std::forward(elbo_est), + lp_calls.load()); + } else if constexpr (stan::callbacks::is_multi_writer_v) { + return std::pair(return_code, lp_calls.load()); + } else { + return return_code; + } } -template * = nullptr> -inline auto ret_pathfinder(int return_code, EigVec&& elbo_est, - const std::atomic& lp_calls) noexcept { - return return_code; -} /** * Estimate the approximate draws given the taylor approximation. @@ -618,7 +617,7 @@ inline auto pathfinder_lbfgs_single( } catch (const std::exception& e) { logger.error(path_num + e.what()); return internal::ret_pathfinder(error_codes::SOFTWARE, - internal::elbo_est_t{}, 0); + internal::elbo_est_t{}, 0, parameter_writer); } const auto num_parameters = cont_vector.size(); @@ -820,7 +819,7 @@ inline auto pathfinder_lbfgs_single( } else { logger.error(e.what()); return internal::ret_pathfinder( - error_codes::SOFTWARE, internal::elbo_est_t{}, 0); + error_codes::SOFTWARE, internal::elbo_est_t{}, 0, parameter_writer); } } } @@ -836,7 +835,7 @@ inline auto pathfinder_lbfgs_single( + " Optimization failed to start, pathfinder cannot be run."); return internal::ret_pathfinder( error_codes::SOFTWARE, internal::elbo_est_t{}, - std::atomic{num_evals + lbfgs.grad_evals()}); + std::atomic{num_evals + lbfgs.grad_evals()}, parameter_writer); } else { logger.warn(prefix_err_msg + " Stan will still attempt pathfinder but may fail or produce " @@ -848,7 +847,7 @@ inline auto pathfinder_lbfgs_single( "Failure: None of the LBFGS iterations completed " "successfully"); return internal::ret_pathfinder( - error_codes::SOFTWARE, internal::elbo_est_t{}, num_evals); + error_codes::SOFTWARE, internal::elbo_est_t{}, num_evals, parameter_writer); } else { if (refresh != 0) { logger.info(path_num + "Best Iter: [" + std::to_string(best_iteration) @@ -856,12 +855,12 @@ inline auto pathfinder_lbfgs_single( + " evaluations: (" + std::to_string(num_evals) + ")"); } } - if (ReturnLpSamples && psis_resample && calculate_lp) { + if constexpr (ReturnLpSamples) { internal::elbo_est_t est_draws = internal::est_approx_draws( lp_fun, constrain_fun, rng, taylor_approx_best, num_draws, taylor_approx_best.alpha, path_num, logger, calculate_lp); return internal::ret_pathfinder( - error_codes::OK, std::move(est_draws), num_evals + est_draws.fn_calls); + error_codes::OK, std::move(est_draws), num_evals + est_draws.fn_calls, parameter_writer); } else { std::vector names; names.push_back("lp_approx__"); @@ -871,7 +870,7 @@ inline auto pathfinder_lbfgs_single( parameter_writer(names); Eigen::Matrix constrained_draws_vec( names.size()); - constrained_draws_vec(2) = stride_id - (ReturnLpSamples ? 1 : 0); + constrained_draws_vec(2) = stride_id - ((stride_id == 0) ? 0 : 1); Eigen::Array lp_ratio; auto&& elbo_draws = elbo_best.repeat_draws; auto&& elbo_lp_ratio = elbo_best.lp_ratio; @@ -950,17 +949,31 @@ inline auto pathfinder_lbfgs_single( } lp_ratio = std::move(elbo_best.lp_ratio.head(num_draws)); } - parameter_writer(); const auto end_pathfinder_time = std::chrono::steady_clock::now(); const double pathfinder_delta_time = stan::services::util::duration_diff( start_pathfinder_time, end_pathfinder_time); - std::string pathfinder_time_str = "Elapsed Time: "; - pathfinder_time_str += std::to_string(pathfinder_delta_time) - + std::string(" seconds (Pathfinder)"); - parameter_writer(pathfinder_time_str); - parameter_writer(); - return internal::ret_pathfinder( - error_codes::OK, internal::elbo_est_t{}, num_evals); + // For multi pathfinder, multi would write multiple end times + if constexpr (stan::callbacks::is_multi_writer_v) { + auto&& single_stream = std::get<0>(parameter_writer.get_stream()); + single_stream(); + std::string pathfinder_time_str = "Elapsed Time: "; + pathfinder_time_str += std::to_string(pathfinder_delta_time) + + std::string(" seconds (Pathfinder)"); + single_stream(pathfinder_time_str); + single_stream(); + return internal::ret_pathfinder( + error_codes::OK, internal::elbo_est_t{}, num_evals, parameter_writer); + } else { + parameter_writer(); + std::string pathfinder_time_str = "Elapsed Time: "; + pathfinder_time_str += std::to_string(pathfinder_delta_time) + + std::string(" seconds (Pathfinder)"); + parameter_writer(pathfinder_time_str); + parameter_writer(); + return internal::ret_pathfinder( + error_codes::OK, internal::elbo_est_t{}, num_evals, parameter_writer); + } + } } diff --git a/src/test/unit/services/pathfinder/eight_schools_test.cpp b/src/test/unit/services/pathfinder/eight_schools_test.cpp index f60cc7c859..7f556a85f2 100644 --- a/src/test/unit/services/pathfinder/eight_schools_test.cpp +++ b/src/test/unit/services/pathfinder/eight_schools_test.cpp @@ -174,7 +174,7 @@ TEST_F(ServicesPathfinderEightSchools, single) { Eigen::MatrixXd param_vals = parameter.get_eigen_state_values(); for (auto&& x_i : param_vals.col(2)) { - EXPECT_EQ(x_i, stride_id); + EXPECT_EQ(x_i, stride_id - 1); } auto param_tmp = param_vals(Eigen::all, param_indices); auto mean_sd_pair = stan::test::get_mean_sd(param_tmp); diff --git a/src/test/unit/services/pathfinder/normal_glm_test.cpp b/src/test/unit/services/pathfinder/normal_glm_test.cpp index 194200f4e3..5cc252494f 100644 --- a/src/test/unit/services/pathfinder/normal_glm_test.cpp +++ b/src/test/unit/services/pathfinder/normal_glm_test.cpp @@ -98,7 +98,7 @@ TEST_F(ServicesPathfinderGLM, single) { "", ""); Eigen::MatrixXd param_vals = parameter.get_eigen_state_values(); for (auto&& x_i : param_vals.col(2)) { - EXPECT_EQ(x_i, stride_id); + EXPECT_EQ(x_i, stride_id - 1); } auto param_tmp = param_vals(Eigen::all, param_indices); @@ -164,7 +164,7 @@ TEST_F(ServicesPathfinderGLM, single_noreturnlp) { EXPECT_EQ(11, param_vals.cols()); EXPECT_EQ(500, param_vals.rows()); for (auto&& x_i : param_vals.col(2)) { - EXPECT_EQ(x_i, stride_id); + EXPECT_EQ(x_i, stride_id - 1); } for (Eigen::Index i = 0; i < num_elbo_draws; ++i) { EXPECT_FALSE(std::isnan(param_vals.coeff(num_draws + i, 1)))