From c8bf83afd8f6ebfa0e6e375b232e58178a4a191c Mon Sep 17 00:00:00 2001 From: Yunli Qi Date: Mon, 26 Feb 2024 14:58:15 +0000 Subject: [PATCH 1/5] pre_process set --- epios/post_process.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/epios/post_process.py b/epios/post_process.py index 9127fb1..65a9dd4 100644 --- a/epios/post_process.py +++ b/epios/post_process.py @@ -488,10 +488,11 @@ def _wrapper_Region_AgeRegion(self, sampling_method, sample_size, time_sample, n else: if sampling_method == 'AgeRegion': sampler_class = SamplerAgeRegion(data=self.demo_data, data_store_path=data_store_path, - num_age_group=num_age_group, + num_age_group=num_age_group, pre_process=False, age_group_width=age_group_width) else: - sampler_class = SamplerRegion(data=self.demo_data, data_store_path=data_store_path) + sampler_class = SamplerRegion(data=self.demo_data, data_store_path=data_store_path, + pre_process=False) try: people = sampler_class.sample(sample_size=sample_size, additional_sample=additional_sample) except NameError: @@ -726,9 +727,11 @@ def _wrapper_Age_Base(self, sampling_method, sample_size, time_sample, else: # After the data process, we can directly read files processed at the first time if sampling_method == 'Age': sampler_class = SamplerAge(data=self.demo_data, data_store_path=data_store_path, - num_age_group=num_age_group, age_group_width=age_group_width) + num_age_group=num_age_group, age_group_width=age_group_width, + pre_process=False) else: - sampler_class = Sampler(data=self.demo_data, data_store_path=data_store_path) + sampler_class = Sampler(data=self.demo_data, data_store_path=data_store_path, + pre_process=False) people = sampler_class.sample(sample_size=sample_size) # Get the results of each people sampled From c1fd21803ec656bea041b9acfca9d7577ea805a4 Mon Sep 17 00:00:00 2001 From: Yunli Qi Date: Mon, 26 Feb 2024 15:22:25 +0000 Subject: [PATCH 2/5] Region Sampler bug fixed --- epios/sampler_region.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/epios/sampler_region.py b/epios/sampler_region.py index 7c49ca1..bd12590 100644 --- a/epios/sampler_region.py +++ b/epios/sampler_region.py @@ -99,7 +99,7 @@ def multinomial_draw(self, n: int, prob: list): cap_region = [] record_cap_region = [] for i in range(len(prob)): - cap_region.append(min(max(n * prob[i] + 0.005 * n, 1), + cap_region.append(min(n * prob[i] + 0.005 * n + 1, self.geoinfo[self.geoinfo['cell'] == i]['Susceptible'].sum())) record_cap_region.append(self.geoinfo[self.geoinfo['cell'] == i]['Susceptible'].sum()) cap_region = [cap_region, list(np.arange(len(cap_region)))] From 11026fe9c6f748850858f432b0fdff7fea1082f9 Mon Sep 17 00:00:00 2001 From: Yunli Qi Date: Mon, 26 Feb 2024 21:18:54 +0000 Subject: [PATCH 3/5] Plot Error Fixed --- epios/__init__.py | 1 + epios/post_process.py | 6 ++++-- epios/tests/test_re_scaler.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/epios/__init__.py b/epios/__init__.py index 4d8b0b0..212818b 100644 --- a/epios/__init__.py +++ b/epios/__init__.py @@ -5,3 +5,4 @@ from .sampler_region import SamplerRegion # noqa from .sampler_age_region import SamplerAgeRegion # noqa from .post_process import PostProcess # noqa +from .re_scaler import ReScaler diff --git a/epios/post_process.py b/epios/post_process.py index 65a9dd4..9c3765b 100644 --- a/epios/post_process.py +++ b/epios/post_process.py @@ -659,7 +659,8 @@ def _wrapper_Region_AgeRegion(self, sampling_method, sample_size, time_sample, n # Plot the figure if gen_plot: - plt.plot(time_sample, infected_rate) + infected_population = np.array(infected_rate) * len(self.demo_data) + plt.plot(time_sample, infected_population) plt.xlabel('Time') plt.ylabel('Population') plt.xlim(0, max(time_sample)) @@ -744,7 +745,8 @@ def _wrapper_Age_Base(self, sampling_method, sample_size, time_sample, # Plot the figure if gen_plot: - plt.plot(time_sample, infected_rate) + infected_population = np.array(infected_rate) * len(self.demo_data) + plt.plot(time_sample, infected_population) plt.xlabel('Time') plt.ylabel('Population') plt.xlim(0, max(time_sample)) diff --git a/epios/tests/test_re_scaler.py b/epios/tests/test_re_scaler.py index fa6d8d6..6038f5e 100644 --- a/epios/tests/test_re_scaler.py +++ b/epios/tests/test_re_scaler.py @@ -1,7 +1,7 @@ from unittest import TestCase from numpy.random import rand from numpy import array -from epios.re_scaler import ReScaler +from re_scaler import ReScaler class TestRS(TestCase): From 3d4712675249fd0ce8a86da9cd9215fdc932a622 Mon Sep 17 00:00:00 2001 From: Yunli Qi Date: Mon, 26 Feb 2024 21:30:20 +0000 Subject: [PATCH 4/5] Flake 8 Fix --- epios/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/epios/__init__.py b/epios/__init__.py index 212818b..503c888 100644 --- a/epios/__init__.py +++ b/epios/__init__.py @@ -5,4 +5,4 @@ from .sampler_region import SamplerRegion # noqa from .sampler_age_region import SamplerAgeRegion # noqa from .post_process import PostProcess # noqa -from .re_scaler import ReScaler +from .re_scaler import ReScaler # noqa From 3c860a01ac6b2c75e3eb5ec3b7b278bbb9a3d069 Mon Sep 17 00:00:00 2001 From: Yunli Qi Date: Tue, 27 Feb 2024 11:40:20 +0000 Subject: [PATCH 5/5] Overlapping figures fixed --- epios/post_process.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/epios/post_process.py b/epios/post_process.py index 9c3765b..b256973 100644 --- a/epios/post_process.py +++ b/epios/post_process.py @@ -444,6 +444,7 @@ def _compare(self, time_sample, gen_plot=False, scale_method: str = 'proportiona # Find the difference between estimated infection level and the real one diff = np.array(true_result) - result_scaled if gen_plot: + plt.figure() plt.plot(time_sample, result_scaled, label='Predicted result', linestyle='--') plt.plot(time_sample, true_result, label='True result') plt.plot(time_sample, np.abs(diff), label='Absolute difference') @@ -659,6 +660,7 @@ def _wrapper_Region_AgeRegion(self, sampling_method, sample_size, time_sample, n # Plot the figure if gen_plot: + plt.figure() infected_population = np.array(infected_rate) * len(self.demo_data) plt.plot(time_sample, infected_population) plt.xlabel('Time') @@ -745,6 +747,7 @@ def _wrapper_Age_Base(self, sampling_method, sample_size, time_sample, # Plot the figure if gen_plot: + plt.figure() infected_population = np.array(infected_rate) * len(self.demo_data) plt.plot(time_sample, infected_population) plt.xlabel('Time')