Skip to content

Commit

Permalink
Completed RT curves
Browse files Browse the repository at this point in the history
  • Loading branch information
Kemuk committed Jan 27, 2025
1 parent 7722811 commit c587d50
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 29 deletions.
30 changes: 22 additions & 8 deletions R/simulation.r
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ source("R/wrapper.R")
# Run complete simulation
run_complete_simulation <- function(output_dir="data/simulation_outputs",
output_file = "output.csv",
plot_file = "SIR_plot.png",
sir_plot_file = "SIR_plot.png",
rt_plot_file = "Rt_plot.png",
si_plot_file = "SerialInterval_plot.png",
use_toy_example = TRUE,
simulation_duration = 60,
initial_infected = 10) {
# Initialize environment
pe <- initialize_simulation_env()

# User defined variables - see README for instructions
# User-defined variables
input_dir <- ""
config_parameters <- "data/simple_parameters.json"
seed <- 42
Expand Down Expand Up @@ -54,7 +56,8 @@ run_complete_simulation <- function(output_dir="data/simulation_outputs",
infectiousness_output = TRUE,
compress = FALSE,
secondary_infections_output = TRUE,
generation_time_output = TRUE
generation_time_output = TRUE,
serial_interval_output = TRUE
)

# Select population creation function
Expand All @@ -64,14 +67,25 @@ run_complete_simulation <- function(output_dir="data/simulation_outputs",
# Run simulation
sim <- run_simulation(pe, sim_params, file_params, dem_file_params, population, inf_history_params, seed)

# Process data and create plot
# Process data
df_long <- process_simulation_data(file.path(output_dir, output_file))
plot <- create_sir_plot(df_long)

# Save plot
save_sir_plot(plot, file.path(output_dir, plot_file))
print(colnames(df_long))
print(df_long)

return(list(simulation = sim, data = df_long, plot = plot))
# Generate SIR plot
sir_plot <- create_sir_plot(df_long, display = TRUE)
save_sir_plot(sir_plot, file.path(output_dir, sir_plot_file))


plot_rt_curves("simulation_outputs/secondary_infections.csv")

# Generate Serial Interval plot
df_si <- calculate_serial_interval(df_long)
si_plot <- create_serial_interval_plot(df_si, display = TRUE)
save_sir_plot(si_plot, file.path(output_dir, si_plot_file))

return(list(simulation = sim, data = df_long, sir_plot = sir_plot, rt_plot = "", si_plot = ""))
}

results <- run_complete_simulation()
64 changes: 43 additions & 21 deletions R/wrapper.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,34 +128,56 @@ save_sir_plot <- function(plot, filename, width = 10, height = 6, dpi = 300) {
)
}

# Calculate effective reproduction number (Rt)
calculate_rt <- function(df, window = 7) {
df_infected <- df[df$Status == "Infected", ]
df_infected$Rt <- c(NA, diff(df_infected$Count) / lag(df_infected$Count))
df_infected$Rt <- zoo::rollmean(df_infected$Rt, window, fill = NA)
return(df_infected)
save_sir_plot <- function(plot, filename, width = 10, height = 6, dpi = 300) {
ggsave(
filename = here(filename),
plot = plot,
width = width,
height = height,
dpi = dpi
)
}

# Create Rt plot
create_rt_plot <- function(df_rt, title = "Effective Reproduction Number (Rt)", display = TRUE) {
p <- ggplot(df_rt, aes(x = time, y = Rt)) +
geom_line(color = "orange") +
theme_minimal() +
plot_rt_curves <- function(file_path) {
# Check if file exists
if (!file.exists(file_path)) {
stop("The file does not exist. Please provide a valid file path.")
}

# Read the CSV file
data <- tryCatch({
read.csv(file_path)
}, error = function(e) {
stop("Error reading the file. Ensure it's a valid CSV.")
})

# Filter for "time" and "R_t" columns
if (!all(c("time", "R_t") %in% colnames(data))) {
stop("The CSV file must contain 'time' and 'R_t' columns.")
}
data <- data[, c("time", "R_t")]

# Remove rows with NaN values
data <- na.omit(data)

# Create the ggplot
gg <- ggplot(data, aes(x = time, y = R_t)) +
geom_line(color = "blue", size = 1) +
labs(
title = title,
title = "Reproduction Number (R_t) Over Time",
x = "Time",
y = "Rt"
y = "R_t"
) +
theme(
plot.title = element_text(hjust = 0.5),
panel.grid.minor = element_blank()
)
theme_minimal()

if (display) {
print(p)
}
# Print the plot
print(gg)
print("R_t plot generated successfully.")

return(p)
# Save the plot
save_sir_plot(gg, "simulation_outputs/rt_curve.png")

return(gg)
}

# Calculate serial intervals
Expand Down
Binary file modified data/simulation_outputs/SIR_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit c587d50

Please sign in to comment.