Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add discrete wrapper #11

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/autoppl/algorithm/mh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ inline void mh_posterior__(ModelType& model,
std::discrete_distribution disc_sampler({alpha, 1-2*alpha, alpha});
auto cand = disc_sampler(gen) - 1 + curr; // new candidate in curr + [-1, 0, 1]
// TODO: refactor common logic
if (dist.min() <= cand && cand <= dist.max()) { // if within dist bound
if (static_cast<int>(dist.min()) <= cand && cand <= static_cast<int>(dist.max())) { // if within dist bound
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait this is not the correct casting - this will cause round issues when dist is continuous distribution. What you could do is something like static_cast<value_t>(...) <= static_cast<value_t>(cand) and you can get value_t from util::dist_expr_traits<whatever the distribution type is>::value_t.

var.set_value(cand);
++n_swaps;
}
Expand Down
23 changes: 23 additions & 0 deletions include/autoppl/expr_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <autoppl/expression/distribution/uniform.hpp>
#include <autoppl/expression/distribution/normal.hpp>
#include <autoppl/expression/distribution/bernoulli.hpp>
#include <autoppl/expression/distribution/discrete.hpp>

namespace ppl {

Expand Down Expand Up @@ -156,6 +157,28 @@ inline constexpr auto bernoulli(ProbType&& p_expr)
return expr::Bernoulli(wrap_p_expr);
}

/*
* Builds a Discrete expression only when the parameter
* is a valid discrete distribution parameter type.
* See var_expr.hpp for more information.
*/
// template <template<class> class Container, class WeightType, class = std::enable_if_t<
// details::is_valid_dist_param_v<WeightType>
// > >
// inline constexpr auto discrete(Container<WeightType> && w_expr)
// {
// using weight_t = details::convert_to_param_t<WeightType>;
// std::vector<weight_t> wrap_w_expr(w_expr.begin(), w_expr.end());
// return expr::Discrete(wrap_w_expr);
// }

template <class WeightType, class = std::enable_if_t<details::is_valid_dist_param_v<WeightType>>>
inline constexpr auto discrete(std::initializer_list<WeightType>&& w_expr) {
using weight_t = details::convert_to_param_t<WeightType>;
std::vector<weight_t> wrap_w_expr(w_expr.begin(), w_expr.end());
return expr::Discrete(wrap_w_expr);
}

////////////////////////////////////////////////////////
// Model Expression Builder
////////////////////////////////////////////////////////
Expand Down
40 changes: 27 additions & 13 deletions include/autoppl/expression/distribution/discrete.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,31 @@
#include <cmath>
#include <numeric>
#include <algorithm>
#include <autoppl/util/var_expr_traits.hpp>
#include <autoppl/util/dist_expr_traits.hpp>

namespace ppl {
namespace expr {

// TODO: will most likely not be used as an expression
template <typename weight_type>
struct Discrete

template <class weight_type>
struct Discrete : util::DistExpr<Discrete<weight_type>>
{
using value_t = uint64_t;
static_assert(util::assert_is_var_expr_v<weight_type>);

using value_t = util::disc_param_t;
using weight_value_t = typename util::var_expr_traits<weight_type>::value_t;
using dist_value_t = double;

Discrete() { weights_ = {1}; }

Discrete(std::initializer_list<weight_type> weights)
: weights_{ weights }
Discrete(std::vector<weight_type> weights)
: weights_{ weights }
{ normalize_weights(weights_.begin(), weights_.end()); }

template <class Iter>
Discrete(Iter begin, Iter end)
: weights_{ begin, end }
{ normalize_weights(begin, end); }
{ normalize_weights(weights_.begin(), weights_.end()); }

template <class GeneratorType>
value_t sample(GeneratorType& gen) const
Expand All @@ -46,17 +50,27 @@ struct Discrete
return std::log(weights(i));
}

inline dist_value_t weights(value_t i) const { return static_cast<dist_value_t>(weights_[i]); }
inline weight_value_t weights(value_t i) const { return weights_[i].get_value(); }
weight_value_t min() const {
return 0;
}
weight_value_t max() const {
return weights_.size() - 1;
}


private:
std::vector<weight_type> weights_;
template <class Iter>
void normalize_weights(Iter begin, Iter end){
// check that weights are positive, not empty, and normalize the weights
// check that weights are positive, not empty, sort and normalize the weights
assert(std::distance(begin, end) > 0);
assert(std::all_of(begin, end, [](const weight_type& n){ return n > 0; }));
dist_value_t total = std::accumulate(begin, end, 0.0);
std::for_each(begin, end, [total](weight_type& n){n /= total; });
assert(std::all_of(begin, end, [](const weight_type& w_var){ return w_var.get_value() > 0; }));
weight_value_t total = std::accumulate(begin, end, 0.0,
[] (int tmp_total, weight_type weight) { return tmp_total + weight.get_value(); });
if (total != 1) {
std::for_each(begin, end, [total](weight_type& w_var){w_var.set_value(w_var.get_value() / total); });
}
}
};

Expand Down
3 changes: 2 additions & 1 deletion include/autoppl/expression/variable/constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ struct Constant : util::VarExpr<Constant<ValueType>>
{}
explicit operator value_t() const { return get_value(); }
value_t get_value() const { return c_; }
void set_value(value_t value) { c_ = value; }

private:
private:
value_t c_;
};

Expand Down
2 changes: 1 addition & 1 deletion include/autoppl/util/dist_expr_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ using cont_param_t = double;
/*
* Discrete distribution expressions can be constructed with this type.
*/
using disc_param_t = int64_t;
using disc_param_t = uint64_t;

/*
* Traits for Distribution Expression classes.
Expand Down
69 changes: 56 additions & 13 deletions test/expression/distribution/discrete_unittest.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <autoppl/expression/distribution/discrete.hpp>

#include <autoppl/expr_builder.hpp>
#include <testutil/mock_types.hpp>
#include <testutil/sample_tools.hpp>
#include <cmath>
#include <array>

Expand All @@ -10,22 +12,43 @@ namespace expr {

struct discrete_dist_fixture : ::testing::Test {
protected:
std::vector<double> weights {0.1, 0.2, 0.3, 0.4};
Discrete<double> dist1 = {1.0, 2.0, 3.0, 4.0};
Discrete<double> dist2;
Discrete<double> dist3 = Discrete(weights.begin(), weights.end());
using value_t = typename MockVarExpr::value_t;

std::vector<MockVarExpr> weights_norm {MockVarExpr{0.1}, MockVarExpr{0.2}, MockVarExpr{0.3}, MockVarExpr{0.4}};
std::vector<MockVarExpr> weights {MockVarExpr{1}, MockVarExpr{2}, MockVarExpr{3}, MockVarExpr{4}};

Discrete<MockVarExpr> dist1 = Discrete(weights);
Discrete<MockVarExpr> dist2;
Discrete<MockVarExpr> dist3 = Discrete(weights.begin(), weights.end());
Discrete<MockVarExpr> dist4 = ppl::discrete({MockVarExpr{1}, MockVarExpr{2}, MockVarExpr{3}, MockVarExpr{4}});
};

TEST_F(discrete_dist_fixture, ctor){
static_assert(util::assert_is_dist_expr_v<Discrete<MockVarExpr>>);
}

TEST_F(discrete_dist_fixture, discrete_check_params) {
EXPECT_DOUBLE_EQ(dist1.weights(0), static_cast<value_t>(weights_norm[0]));
EXPECT_DOUBLE_EQ(dist1.weights(1), static_cast<value_t>(weights_norm[1]));
EXPECT_DOUBLE_EQ(dist1.weights(2), static_cast<value_t>(weights_norm[2]));
EXPECT_DOUBLE_EQ(dist1.weights(3), static_cast<value_t>(weights_norm[3]));

EXPECT_DOUBLE_EQ(dist3.weights(0), static_cast<value_t>(weights_norm[0]));
EXPECT_DOUBLE_EQ(dist3.weights(1), static_cast<value_t>(weights_norm[1]));
EXPECT_DOUBLE_EQ(dist3.weights(2), static_cast<value_t>(weights_norm[2]));
EXPECT_DOUBLE_EQ(dist3.weights(3), static_cast<value_t>(weights_norm[3]));
}

TEST_F(discrete_dist_fixture, default_cstor_test){
EXPECT_DOUBLE_EQ(dist2.weights(0), 1.0);
EXPECT_DOUBLE_EQ(dist2.pdf(0), 1.0);
}

TEST_F(discrete_dist_fixture, sanity_Discrete_test) {
EXPECT_DOUBLE_EQ(dist1.weights(0), weights[0]);
EXPECT_DOUBLE_EQ(dist1.weights(1), weights[1]);
EXPECT_DOUBLE_EQ(dist1.weights(2), weights[2]);
EXPECT_DOUBLE_EQ(dist1.weights(3), weights[3]);
EXPECT_DOUBLE_EQ(dist1.weights(0), weights_norm[0].get_value());
EXPECT_DOUBLE_EQ(dist1.weights(1), weights_norm[1].get_value());
EXPECT_DOUBLE_EQ(dist1.weights(2), weights_norm[2].get_value());
EXPECT_DOUBLE_EQ(dist1.weights(3), weights_norm[3].get_value());
}

TEST_F(discrete_dist_fixture, simple_Discrete) {
Expand All @@ -47,10 +70,10 @@ TEST_F(discrete_dist_fixture, Discrete_sampling) {
}

TEST_F(discrete_dist_fixture, sanity_Discrete_iter_test) {
EXPECT_DOUBLE_EQ(dist3.weights(0), weights[0]);
EXPECT_DOUBLE_EQ(dist3.weights(1), weights[1]);
EXPECT_DOUBLE_EQ(dist3.weights(2), weights[2]);
EXPECT_DOUBLE_EQ(dist3.weights(3), weights[3]);
EXPECT_DOUBLE_EQ(dist3.weights(0), weights_norm[0].get_value());
EXPECT_DOUBLE_EQ(dist3.weights(1), weights_norm[1].get_value());
EXPECT_DOUBLE_EQ(dist3.weights(2), weights_norm[2].get_value());
EXPECT_DOUBLE_EQ(dist3.weights(3), weights_norm[3].get_value());
}

TEST_F(discrete_dist_fixture, simple_Discrete_iter) {
Expand All @@ -70,5 +93,25 @@ TEST_F(discrete_dist_fixture, Discrete_sampling_iter) {
EXPECT_TRUE(sample == 0 || sample == 1 || sample == 2 || sample == 3);
}
}

TEST_F(discrete_dist_fixture, Discrete_max_min) {
EXPECT_EQ(dist1.min(), 0);
EXPECT_EQ(dist1.max(), weights_norm.size() - 1);
}

TEST_F(discrete_dist_fixture, discrete_wrapper) {
EXPECT_DOUBLE_EQ(dist4.weights(0), weights_norm[0].get_value());
EXPECT_DOUBLE_EQ(dist4.weights(1), weights_norm[1].get_value());
EXPECT_DOUBLE_EQ(dist4.weights(2), weights_norm[2].get_value());
EXPECT_DOUBLE_EQ(dist4.weights(3), weights_norm[3].get_value());
}

int main() {
std::vector<MockVarExpr> weights {MockVarExpr{1}, MockVarExpr{2}, MockVarExpr{3}, MockVarExpr{4}};

auto dist = Discrete(weights);
std::cout << dist.pdf(0) << '\n';
return 0;
}
} // namespace expr
} // namespace ppl
18 changes: 18 additions & 0 deletions test/test_disc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "gtest/gtest.h"
#include <autoppl/expression/distribution/discrete.hpp>
#include <autoppl/expr_builder.hpp>
#include <testutil/mock_types.hpp>


namespace ppl {


int main() {
std::vector<MockVarExpr> weights {MockVarExpr{1}, MockVarExpr{2}, MockVarExpr{3}, MockVarExpr{4}};

auto dist1 = Discrete(weights);
std::cout << dist.pdf(0) << std::end;
}


}