Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use rounding for treatments #203

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions inst/include/host_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ class HostPool : public HostPoolInterface<RasterIndex>
* @brief Completely remove any hosts
*
* Removes hosts completely (as opposed to moving them to another pool).
* If mortality is not active, the *mortality* parameter is ignored.
*
* @param row Row index of the cell
* @param col Column index of the cell
Expand All @@ -523,9 +524,6 @@ class HostPool : public HostPoolInterface<RasterIndex>
* @param infected Number of infected hosts to remove.
* @param mortality Number of infected hosts in each mortality cohort.
*
* @note Counts are doubles, so that handling of floating point values is managed
* here in the same way as in the original treatment code.
*
* @note This does not remove resistant just like the original implementation in
* treatments.
*/
Expand Down Expand Up @@ -556,6 +554,11 @@ class HostPool : public HostPoolInterface<RasterIndex>
// Possibly reuse in the I->S removal.
if (infected <= 0)
return;
if (!mortality_tracker_vector_.size()) {
infected_(row, col) -= infected;
reset_total_host(row, col);
return;
}
if (mortality_tracker_vector_.size() != mortality.size()) {
throw std::invalid_argument(
"mortality is not the same size as the internal mortality tracker ("
Expand All @@ -564,8 +567,8 @@ class HostPool : public HostPoolInterface<RasterIndex>
+ std::to_string(row) + ", " + std::to_string(col) + ")");
}

double mortality_total = 0;
for (size_t i = 0; i < mortality.size(); ++i) {
int mortality_total = 0;
for (size_t i = 0; i < mortality_tracker_vector_.size(); ++i) {
if (mortality_tracker_vector_[i](row, col) < mortality[i]) {
throw std::invalid_argument(
"Mortality value [" + std::to_string(i) + "] is too high ("
Expand All @@ -582,20 +585,20 @@ class HostPool : public HostPoolInterface<RasterIndex>
// and once we don't need to keep the exact same double to int results for
// tests. First condition always fails the tests. The second one may potentially
// fail.
if (false && infected != mortality_total) {
if (infected != mortality_total) {
throw std::invalid_argument(
"Total of removed mortality values differs from removed infected "
"count ("
+ std::to_string(mortality_total) + " != " + std::to_string(infected)
+ " for cell (" + std::to_string(row) + ", " + std::to_string(col)
+ ") for cell (" + std::to_string(row) + ", " + std::to_string(col)
+ ")");
}
if (false && infected_(row, col) < mortality_total) {
if (infected_(row, col) < mortality_total) {
throw std::invalid_argument(
"Total of removed mortality values is higher than current number "
"of infected hosts for cell ("
+ std::to_string(row) + ", " + std::to_string(col) + ") is too high ("
"of infected hosts ("
+ std::to_string(mortality_total) + " > " + std::to_string(infected)
+ ") for cell (" + std::to_string(row) + ", " + std::to_string(col)
+ ")");
}
infected_(row, col) -= infected;
Expand Down Expand Up @@ -700,6 +703,8 @@ class HostPool : public HostPoolInterface<RasterIndex>
/**
* @brief Make hosts resistant in a given cell
*
* If mortality is not active, the *mortality* parameter is ignored.
*
* @param row Row index of the cell
* @param col Column index of the cell
* @param susceptible Number of susceptible hosts to make resistant
Expand Down Expand Up @@ -746,6 +751,12 @@ class HostPool : public HostPoolInterface<RasterIndex>
total_resistant += exposed[i];
}
infected_(row, col) -= infected;
total_resistant += infected;
resistant_(row, col) += total_resistant;
if (!mortality_tracker_vector_.size()) {
reset_total_host(row, col);
return;
}
if (mortality_tracker_vector_.size() != mortality.size()) {
throw std::invalid_argument(
"mortality is not the same size as the internal mortality tracker ("
Expand All @@ -755,7 +766,7 @@ class HostPool : public HostPoolInterface<RasterIndex>
}
int mortality_total = 0;
// no simple zip in C++, falling back to indices
for (size_t i = 0; i < mortality.size(); ++i) {
for (size_t i = 0; i < mortality_tracker_vector_.size(); ++i) {
mortality_tracker_vector_[i](row, col) -= mortality[i];
mortality_total += mortality[i];
}
Expand All @@ -771,8 +782,7 @@ class HostPool : public HostPoolInterface<RasterIndex>
+ " for cell (" + std::to_string(row) + ", " + std::to_string(col)
+ "))");
}
total_resistant += infected;
resistant_(row, col) += total_resistant;
reset_total_host(row, col);
}

/**
Expand Down Expand Up @@ -975,6 +985,9 @@ class HostPool : public HostPoolInterface<RasterIndex>
/**
* @brief Get infected hosts in each mortality cohort at a given cell
*
* If mortality is not active, it returns number of all infected individuals
* in the first and only item of the vector.
*
* @param row Row index of the cell
* @param col Column index of the cell
*
Expand All @@ -983,6 +996,12 @@ class HostPool : public HostPoolInterface<RasterIndex>
std::vector<int> mortality_by_group_at(RasterIndex row, RasterIndex col) const
{
std::vector<int> all;

if (!mortality_tracker_vector_.size()) {
all.push_back(infected_at(row, col));
return all;
}

all.reserve(mortality_tracker_vector_.size());
for (const auto& raster : mortality_tracker_vector_)
all.push_back(raster(row, col));
Expand Down
44 changes: 19 additions & 25 deletions inst/include/treatments.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ class BaseTreatment : public AbstractTreatment<HostPool, FloatRaster>
}

// returning double allows identical results with the previous version
double get_treated(int i, int j, int count)
int get_treated(int i, int j, int count)
{
return get_treated(i, j, count, this->application_);
}

double get_treated(int i, int j, int count, TreatmentApplication application)
int get_treated(int i, int j, int count, TreatmentApplication application)
{
if (application == TreatmentApplication::Ratio) {
return count * this->map_(i, j);
return std::lround(count * this->map_(i, j));
}
else if (application == TreatmentApplication::AllInfectedInCell) {
return static_cast<bool>(this->map_(i, j)) ? count : 0;
Expand Down Expand Up @@ -173,20 +173,21 @@ class SimpleTreatment : public BaseTreatment<HostPool, FloatRaster>
for (auto indices : host_pool.suitable_cells()) {
int i = indices[0];
int j = indices[1];
int remove_susceptible = static_cast<int>(std::ceil(this->get_treated(
i, j, host_pool.susceptible_at(i, j), TreatmentApplication::Ratio)));
int remove_infected = static_cast<int>(
std::ceil(this->get_treated(i, j, host_pool.infected_at(i, j))));
int remove_susceptible = this->get_treated(
i, j, host_pool.susceptible_at(i, j), TreatmentApplication::Ratio);
// Treated infected are computed as a sum of treated in mortality groups.
int remove_infected = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't remove infected individuals.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fails tests and leaves all infected in the cells. The susceptible call above works as it should though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more complex than I thought yesterday, i.e., it is not a simple oversight. This is an initial value of a sum. The infected ones are computed as a sum of removed individuals from mortality groups (line 184), and then they are passed to HostPool::completely_remove_hosts_at just like before (lines 191-197). The infected is passed as a 5th argument and simply subtracted in that function from the infected raster/matrix.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ncsu-landscape-dynamics/pops-core#225 now has the update which allows for treatments with no mortality. I added a low level test for that to pops-core.

std::vector<int> remove_mortality;
for (int count : host_pool.mortality_by_group_at(i, j)) {
remove_mortality.push_back(
static_cast<int>(std::ceil(this->get_treated(i, j, count))));
int remove = this->get_treated(i, j, count);
remove_mortality.push_back(remove);
remove_infected += remove;
}
// Will need to use infected directly if not mortality.

std::vector<int> remove_exposed;
for (int count : host_pool.exposed_by_group_at(i, j)) {
remove_exposed.push_back(
static_cast<int>(std::ceil(this->get_treated(i, j, count))));
remove_exposed.push_back(this->get_treated(i, j, count));
}
host_pool.completely_remove_hosts_at(
i,
Expand Down Expand Up @@ -240,26 +241,19 @@ class PesticideTreatment : public BaseTreatment<HostPool, FloatRaster>
for (auto indices : host_pool.suitable_cells()) {
int i = indices[0];
int j = indices[1];
// Given how the original code was written (everything was first converted
// to ints and subtractions happened only afterwards), this needs ints,
// not doubles to pass the r.pops.spread test (unlike the other code which
// did substractions before converting to ints), so the conversion to ints
// happened only later. Now get_treated returns double and floor or ceil is
// applied to the result to get the same results as before.
int susceptible_resistant = static_cast<int>(std::floor(this->get_treated(
i, j, host_pool.susceptible_at(i, j), TreatmentApplication::Ratio)));
int susceptible_resistant = this->get_treated(
i, j, host_pool.susceptible_at(i, j), TreatmentApplication::Ratio);
std::vector<int> resistant_exposed_list;
for (const auto& number : host_pool.exposed_by_group_at(i, j)) {
resistant_exposed_list.push_back(
static_cast<int>(std::floor(this->get_treated(i, j, number))));
resistant_exposed_list.push_back(this->get_treated(i, j, number));
}
int infected = 0;
std::vector<int> resistant_mortality_list;
for (const auto& number : host_pool.mortality_by_group_at(i, j)) {
resistant_mortality_list.push_back(
static_cast<int>(std::floor(this->get_treated(i, j, number))));
int remove = this->get_treated(i, j, number);
resistant_mortality_list.push_back(remove);
infected += remove;
}
int infected = static_cast<int>(
std::floor(this->get_treated(i, j, host_pool.infected_at(i, j))));
host_pool.make_resistant_at(
i,
j,
Expand Down
Loading