Skip to content

Commit

Permalink
Merge branch 'dev' into infe_prof_and_same_diff
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioVitoMastromarino authored Mar 4, 2024
2 parents db2a1a3 + 8a9434a commit da50cc1
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
20 changes: 14 additions & 6 deletions epios/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def _compare(self, time_sample, gen_plot=False, scale_method: str = 'proportiona
# Find the difference between estimated infection level and the real one
diff = np.array(true_result) - result_scaled
if gen_plot:
plt.figure()
plt.plot(time_sample, result_scaled, label='Predicted result', linestyle='--')
plt.plot(time_sample, true_result, label='True result')
plt.plot(time_sample, np.abs(diff), label='Absolute difference')
Expand Down Expand Up @@ -488,10 +489,11 @@ def _wrapper_Region_AgeRegion(self, sampling_method, sample_size, time_sample, n
else:
if sampling_method == 'AgeRegion':
sampler_class = SamplerAgeRegion(data=self.demo_data, data_store_path=data_store_path,
num_age_group=num_age_group,
num_age_group=num_age_group, pre_process=False,
age_group_width=age_group_width)
else:
sampler_class = SamplerRegion(data=self.demo_data, data_store_path=data_store_path)
sampler_class = SamplerRegion(data=self.demo_data, data_store_path=data_store_path,
pre_process=False)
try:
people = sampler_class.sample(sample_size=sample_size, additional_sample=additional_sample)
except NameError:
Expand Down Expand Up @@ -658,7 +660,9 @@ def _wrapper_Region_AgeRegion(self, sampling_method, sample_size, time_sample, n

# Plot the figure
if gen_plot:
plt.plot(time_sample, infected_rate)
plt.figure()
infected_population = np.array(infected_rate) * len(self.demo_data)
plt.plot(time_sample, infected_population)
plt.xlabel('Time')
plt.ylabel('Population')
plt.xlim(0, max(time_sample))
Expand Down Expand Up @@ -728,9 +732,11 @@ def _wrapper_Age_Base(self, sampling_method, sample_size, time_sample,
else: # After the data process, we can directly read files processed at the first time
if sampling_method == 'Age':
sampler_class = SamplerAge(data=self.demo_data, data_store_path=data_store_path,
num_age_group=num_age_group, age_group_width=age_group_width)
num_age_group=num_age_group, age_group_width=age_group_width,
pre_process=False)
else:
sampler_class = Sampler(data=self.demo_data, data_store_path=data_store_path)
sampler_class = Sampler(data=self.demo_data, data_store_path=data_store_path,
pre_process=False)
people = sampler_class.sample(sample_size=sample_size)

X = SamplingMaker(non_resp_rate=0, data=self.time_data,
Expand All @@ -742,7 +748,9 @@ def _wrapper_Age_Base(self, sampling_method, sample_size, time_sample,

# Plot the figure
if gen_plot:
plt.plot(time_sample, infected_rate)
plt.figure()
infected_population = np.array(infected_rate) * len(self.demo_data)
plt.plot(time_sample, infected_population)
plt.xlabel('Time')
plt.ylabel('Population')
plt.xlim(0, max(time_sample))
Expand Down
2 changes: 1 addition & 1 deletion epios/sampler_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def multinomial_draw(self, n: int, prob: list):
cap_region = []
record_cap_region = []
for i in range(len(prob)):
cap_region.append(min(max(n * prob[i] + 0.005 * n, 1),
cap_region.append(min(n * prob[i] + 0.005 * n + 1,
self.geoinfo[self.geoinfo['cell'] == i]['Susceptible'].sum()))
record_cap_region.append(self.geoinfo[self.geoinfo['cell'] == i]['Susceptible'].sum())
cap_region = [cap_region, list(np.arange(len(cap_region)))]
Expand Down
2 changes: 1 addition & 1 deletion epios/tests/test_re_scaler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest import TestCase
from numpy.random import rand
from numpy import array
from epios.re_scaler import ReScaler
from re_scaler import ReScaler


class TestRS(TestCase):
Expand Down

0 comments on commit da50cc1

Please sign in to comment.