Skip to content

Commit

Permalink
update docs and fix multi write with psis turned off so that multiple…
Browse files Browse the repository at this point in the history
… ends do not get written once
  • Loading branch information
SteveBronder committed Jan 9, 2025
1 parent e00d78e commit a2f023d
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 33 deletions.
2 changes: 1 addition & 1 deletion lib/stan_math
20 changes: 20 additions & 0 deletions src/stan/callbacks/multi_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class multi_writer {
: output_(std::forward<Args>(args)...) {}

multi_writer();

/**
* @tparam T Any type accepted by a `writer` overload
* @param[in] x A value to write to the output streams
Expand All @@ -42,6 +43,14 @@ class multi_writer {
stan::math::for_each([](auto&& output) { output(); }, output_);
}

/**
* Checks if all underlying writers are nonnull.
*/
inline bool is_nonnull() const noexcept {
return stan::math::apply([](auto&&... output) { return (output.is_nonnull() && ...); },
output_);
}

/**
* Get the underlying stream
*/
Expand All @@ -54,6 +63,17 @@ class multi_writer {
std::tuple<std::reference_wrapper<Writers>...> output_;
};

namespace internal {
template <typename T>
struct is_multi_writer : std::false_type {};

template <typename... Types>
struct is_multi_writer<multi_writer<Types...>> : std::true_type {};
}

template <typename T>
inline constexpr bool is_multi_writer_v = internal::is_multi_writer<std::decay_t<T>>::value;

} // namespace callbacks
} // namespace stan

Expand Down
3 changes: 3 additions & 0 deletions src/stan/callbacks/stream_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class stream_writer : public writer {
output_ << comment_prefix_ << message << std::endl;
}

/**
* Checks if stream is valid.
*/
virtual bool is_nonnull() const noexcept { return output_.good(); }

private:
Expand Down
3 changes: 3 additions & 0 deletions src/stan/callbacks/tee_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class tee_writer final : public writer {
writer2_(message);
}

/**
* Checks if both streams are valid.
*/
virtual bool is_nonnull() const noexcept {
return writer1_.is_nonnull() && writer2_.is_nonnull();
}
Expand Down
3 changes: 3 additions & 0 deletions src/stan/callbacks/unique_stream_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ class unique_stream_writer final : public writer {
*output_ << comment_prefix_ << message << std::endl;
}

/**
* Checks if stream is valid.
*/
bool is_nonnull() const noexcept { return output_ != nullptr; }

private:
Expand Down
3 changes: 3 additions & 0 deletions src/stan/callbacks/writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class writer {
*/
virtual void operator()(const std::string& message) {}

/**
* Checks if stream is valid.
*/
virtual bool is_nonnull() const noexcept { return false; }
/**
* Writes multiple rows and columns of values in csv format.
Expand Down
6 changes: 3 additions & 3 deletions src/stan/services/pathfinder/multi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ inline int pathfinder_lbfgs_multi(
multi_writer_t multi_param_writer(
single_path_parameter_writer[iter], safe_write);
auto pathfinder_ret
= stan::services::pathfinder::pathfinder_lbfgs_single<true>(
= stan::services::pathfinder::pathfinder_lbfgs_single<false>(
model, *(init[iter]), random_seed, stride_id + iter,
init_radius, history_size, init_alpha, tol_obj,
tol_rel_obj, tol_grad, tol_rel_grad, tol_param,
Expand All @@ -181,12 +181,12 @@ inline int pathfinder_lbfgs_multi(
init_writers[iter], multi_param_writer,
single_path_diagnostic_writer[iter], calculate_lp,
psis_resample);
if (unlikely(std::get<0>(pathfinder_ret) != error_codes::OK)) {
if (pathfinder_ret.first != error_codes::OK) {
logger.error(std::string("Pathfinder iteration: ")
+ std::to_string(iter) + " failed.");
return;
}
lp_calls += std::get<2>(pathfinder_ret);
lp_calls += pathfinder_ret.second;
}
}
});
Expand Down
65 changes: 39 additions & 26 deletions src/stan/services/pathfinder/single.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,20 +459,19 @@ inline taylor_approx_t taylor_approximation(
* matrix of samples, and an unsigned integer for number of times the log prob
* functions was called
*/
template <bool ReturnLpSamples, typename EigVec,
std::enable_if_t<ReturnLpSamples>* = nullptr>
template <bool ReturnLpSamples, typename EigVec, typename ParamWriter>
inline auto ret_pathfinder(int return_code, EigVec&& elbo_est,
const std::atomic<size_t>& lp_calls) {
return std::make_tuple(return_code, std::forward<EigVec>(elbo_est),
lp_calls.load());
const std::atomic<size_t>& lp_calls, ParamWriter&& /* params */) noexcept {
if constexpr (ReturnLpSamples) {
return std::make_tuple(return_code, std::forward<EigVec>(elbo_est),
lp_calls.load());
} else if constexpr (stan::callbacks::is_multi_writer_v<ParamWriter>) {
return std::pair(return_code, lp_calls.load());
} else {
return return_code;
}
}

template <bool ReturnLpSamples, typename EigVec,
std::enable_if_t<!ReturnLpSamples>* = nullptr>
inline auto ret_pathfinder(int return_code, EigVec&& elbo_est,
const std::atomic<size_t>& lp_calls) noexcept {
return return_code;
}

/**
* Estimate the approximate draws given the taylor approximation.
Expand Down Expand Up @@ -618,7 +617,7 @@ inline auto pathfinder_lbfgs_single(
} catch (const std::exception& e) {
logger.error(path_num + e.what());
return internal::ret_pathfinder<ReturnLpSamples>(error_codes::SOFTWARE,
internal::elbo_est_t{}, 0);
internal::elbo_est_t{}, 0, parameter_writer);
}

const auto num_parameters = cont_vector.size();
Expand Down Expand Up @@ -820,7 +819,7 @@ inline auto pathfinder_lbfgs_single(
} else {
logger.error(e.what());
return internal::ret_pathfinder<ReturnLpSamples>(
error_codes::SOFTWARE, internal::elbo_est_t{}, 0);
error_codes::SOFTWARE, internal::elbo_est_t{}, 0, parameter_writer);
}
}
}
Expand All @@ -836,7 +835,7 @@ inline auto pathfinder_lbfgs_single(
+ " Optimization failed to start, pathfinder cannot be run.");
return internal::ret_pathfinder<ReturnLpSamples>(
error_codes::SOFTWARE, internal::elbo_est_t{},
std::atomic<size_t>{num_evals + lbfgs.grad_evals()});
std::atomic<size_t>{num_evals + lbfgs.grad_evals()}, parameter_writer);
} else {
logger.warn(prefix_err_msg +
" Stan will still attempt pathfinder but may fail or produce "
Expand All @@ -848,20 +847,20 @@ inline auto pathfinder_lbfgs_single(
"Failure: None of the LBFGS iterations completed "
"successfully");
return internal::ret_pathfinder<ReturnLpSamples>(
error_codes::SOFTWARE, internal::elbo_est_t{}, num_evals);
error_codes::SOFTWARE, internal::elbo_est_t{}, num_evals, parameter_writer);
} else {
if (refresh != 0) {
logger.info(path_num + "Best Iter: [" + std::to_string(best_iteration)
+ "] ELBO (" + std::to_string(elbo_best.elbo) + ")"
+ " evaluations: (" + std::to_string(num_evals) + ")");
}
}
if (ReturnLpSamples && psis_resample && calculate_lp) {
if constexpr (ReturnLpSamples) {
internal::elbo_est_t est_draws = internal::est_approx_draws<false>(
lp_fun, constrain_fun, rng, taylor_approx_best, num_draws,
taylor_approx_best.alpha, path_num, logger, calculate_lp);
return internal::ret_pathfinder<ReturnLpSamples>(
error_codes::OK, std::move(est_draws), num_evals + est_draws.fn_calls);
error_codes::OK, std::move(est_draws), num_evals + est_draws.fn_calls, parameter_writer);
} else {
std::vector<std::string> names;
names.push_back("lp_approx__");
Expand All @@ -871,7 +870,7 @@ inline auto pathfinder_lbfgs_single(
parameter_writer(names);
Eigen::Matrix<double, 1, Eigen::Dynamic> constrained_draws_vec(
names.size());
constrained_draws_vec(2) = stride_id - (ReturnLpSamples ? 1 : 0);
constrained_draws_vec(2) = stride_id - ((stride_id == 0) ? 0 : 1);
Eigen::Array<double, Eigen::Dynamic, 1> lp_ratio;
auto&& elbo_draws = elbo_best.repeat_draws;
auto&& elbo_lp_ratio = elbo_best.lp_ratio;
Expand Down Expand Up @@ -950,17 +949,31 @@ inline auto pathfinder_lbfgs_single(
}
lp_ratio = std::move(elbo_best.lp_ratio.head(num_draws));
}
parameter_writer();
const auto end_pathfinder_time = std::chrono::steady_clock::now();
const double pathfinder_delta_time = stan::services::util::duration_diff(
start_pathfinder_time, end_pathfinder_time);
std::string pathfinder_time_str = "Elapsed Time: ";
pathfinder_time_str += std::to_string(pathfinder_delta_time)
+ std::string(" seconds (Pathfinder)");
parameter_writer(pathfinder_time_str);
parameter_writer();
return internal::ret_pathfinder<ReturnLpSamples>(
error_codes::OK, internal::elbo_est_t{}, num_evals);
// For multi pathfinder, multi would write multiple end times
if constexpr (stan::callbacks::is_multi_writer_v<ParamWriter>) {
auto&& single_stream = std::get<0>(parameter_writer.get_stream());
single_stream();
std::string pathfinder_time_str = "Elapsed Time: ";
pathfinder_time_str += std::to_string(pathfinder_delta_time)
+ std::string(" seconds (Pathfinder)");
single_stream(pathfinder_time_str);
single_stream();
return internal::ret_pathfinder<ReturnLpSamples>(
error_codes::OK, internal::elbo_est_t{}, num_evals, parameter_writer);
} else {
parameter_writer();
std::string pathfinder_time_str = "Elapsed Time: ";
pathfinder_time_str += std::to_string(pathfinder_delta_time)
+ std::string(" seconds (Pathfinder)");
parameter_writer(pathfinder_time_str);
parameter_writer();
return internal::ret_pathfinder<ReturnLpSamples>(
error_codes::OK, internal::elbo_est_t{}, num_evals, parameter_writer);
}

}
}

Expand Down
2 changes: 1 addition & 1 deletion src/test/unit/services/pathfinder/eight_schools_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ TEST_F(ServicesPathfinderEightSchools, single) {

Eigen::MatrixXd param_vals = parameter.get_eigen_state_values();
for (auto&& x_i : param_vals.col(2)) {
EXPECT_EQ(x_i, stride_id);
EXPECT_EQ(x_i, stride_id - 1);
}
auto param_tmp = param_vals(Eigen::all, param_indices);
auto mean_sd_pair = stan::test::get_mean_sd(param_tmp);
Expand Down
4 changes: 2 additions & 2 deletions src/test/unit/services/pathfinder/normal_glm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ TEST_F(ServicesPathfinderGLM, single) {
"", "");
Eigen::MatrixXd param_vals = parameter.get_eigen_state_values();
for (auto&& x_i : param_vals.col(2)) {
EXPECT_EQ(x_i, stride_id);
EXPECT_EQ(x_i, stride_id - 1);
}

auto param_tmp = param_vals(Eigen::all, param_indices);
Expand Down Expand Up @@ -164,7 +164,7 @@ TEST_F(ServicesPathfinderGLM, single_noreturnlp) {
EXPECT_EQ(11, param_vals.cols());
EXPECT_EQ(500, param_vals.rows());
for (auto&& x_i : param_vals.col(2)) {
EXPECT_EQ(x_i, stride_id);
EXPECT_EQ(x_i, stride_id - 1);
}
for (Eigen::Index i = 0; i < num_elbo_draws; ++i) {
EXPECT_FALSE(std::isnan(param_vals.coeff(num_draws + i, 1)))
Expand Down

0 comments on commit a2f023d

Please sign in to comment.