diff --git a/R/simulation.r b/R/simulation.r index 5e2addd..9329ed3 100644 --- a/R/simulation.r +++ b/R/simulation.r @@ -3,14 +3,16 @@ source("R/wrapper.R") # Run complete simulation run_complete_simulation <- function(output_dir="data/simulation_outputs", output_file = "output.csv", - plot_file = "SIR_plot.png", + sir_plot_file = "SIR_plot.png", + rt_plot_file = "Rt_plot.png", + si_plot_file = "SerialInterval_plot.png", use_toy_example = TRUE, simulation_duration = 60, initial_infected = 10) { # Initialize environment pe <- initialize_simulation_env() - # User defined variables - see README for instructions + # User-defined variables input_dir <- "" config_parameters <- "data/simple_parameters.json" seed <- 42 @@ -54,7 +56,8 @@ run_complete_simulation <- function(output_dir="data/simulation_outputs", infectiousness_output = TRUE, compress = FALSE, secondary_infections_output = TRUE, - generation_time_output = TRUE + generation_time_output = TRUE, + serial_interval_output = TRUE ) # Select population creation function @@ -64,14 +67,25 @@ run_complete_simulation <- function(output_dir="data/simulation_outputs", # Run simulation sim <- run_simulation(pe, sim_params, file_params, dem_file_params, population, inf_history_params, seed) - # Process data and create plot + # Process data df_long <- process_simulation_data(file.path(output_dir, output_file)) - plot <- create_sir_plot(df_long) - # Save plot - save_sir_plot(plot, file.path(output_dir, plot_file)) + print(colnames(df_long)) + print(df_long) - return(list(simulation = sim, data = df_long, plot = plot)) + # Generate SIR plot + sir_plot <- create_sir_plot(df_long, display = TRUE) + save_sir_plot(sir_plot, file.path(output_dir, sir_plot_file)) + + + plot_rt_curves("simulation_outputs/secondary_infections.csv") + + # Generate Serial Interval plot + df_si <- calculate_serial_interval(df_long) + si_plot <- create_serial_interval_plot(df_si, display = TRUE) + save_sir_plot(si_plot, file.path(output_dir, si_plot_file)) + + return(list(simulation = sim, data = df_long, sir_plot = sir_plot, rt_plot = "", si_plot = "")) } results <- run_complete_simulation() \ No newline at end of file diff --git a/R/wrapper.R b/R/wrapper.R index 40e4b0b..f536945 100644 --- a/R/wrapper.R +++ b/R/wrapper.R @@ -128,34 +128,56 @@ save_sir_plot <- function(plot, filename, width = 10, height = 6, dpi = 300) { ) } -# Calculate effective reproduction number (Rt) -calculate_rt <- function(df, window = 7) { - df_infected <- df[df$Status == "Infected", ] - df_infected$Rt <- c(NA, diff(df_infected$Count) / lag(df_infected$Count)) - df_infected$Rt <- zoo::rollmean(df_infected$Rt, window, fill = NA) - return(df_infected) +save_sir_plot <- function(plot, filename, width = 10, height = 6, dpi = 300) { + ggsave( + filename = here(filename), + plot = plot, + width = width, + height = height, + dpi = dpi + ) } -# Create Rt plot -create_rt_plot <- function(df_rt, title = "Effective Reproduction Number (Rt)", display = TRUE) { - p <- ggplot(df_rt, aes(x = time, y = Rt)) + - geom_line(color = "orange") + - theme_minimal() + +plot_rt_curves <- function(file_path) { + # Check if file exists + if (!file.exists(file_path)) { + stop("The file does not exist. Please provide a valid file path.") + } + + # Read the CSV file + data <- tryCatch({ + read.csv(file_path) + }, error = function(e) { + stop("Error reading the file. Ensure it's a valid CSV.") + }) + + # Filter for "time" and "R_t" columns + if (!all(c("time", "R_t") %in% colnames(data))) { + stop("The CSV file must contain 'time' and 'R_t' columns.") + } + data <- data[, c("time", "R_t")] + + # Remove rows with NaN values + data <- na.omit(data) + + # Create the ggplot + gg <- ggplot(data, aes(x = time, y = R_t)) + + geom_line(color = "blue", size = 1) + labs( - title = title, + title = "Reproduction Number (R_t) Over Time", x = "Time", - y = "Rt" + y = "R_t" ) + - theme( - plot.title = element_text(hjust = 0.5), - panel.grid.minor = element_blank() - ) + theme_minimal() - if (display) { - print(p) - } + # Print the plot + print(gg) + print("R_t plot generated successfully.") - return(p) + # Save the plot + save_sir_plot(gg, "simulation_outputs/rt_curve.png") + + return(gg) } # Calculate serial intervals diff --git a/data/simulation_outputs/SIR_plot.png b/data/simulation_outputs/SIR_plot.png index 92acab1..5713871 100644 Binary files a/data/simulation_outputs/SIR_plot.png and b/data/simulation_outputs/SIR_plot.png differ