diff --git a/R/get_prediction_plot.R b/R/get_prediction_plot.R new file mode 100644 index 0000000..3ed7aa8 --- /dev/null +++ b/R/get_prediction_plot.R @@ -0,0 +1,160 @@ +get_prediction_plot <- function(Data=NULL, + path_db, + rat_studies=FALSE, + studyid_metadata=NULL, + fake_study = FALSE, + use_xpt_file = FALSE, + Round = FALSE, + Impute = FALSE, + reps, + holdback, + Undersample = FALSE, + hyperparameter_tuning = FALSE, + error_correction_method, + testReps){ + + + if(is.null(Data)){ + data_and_best.m <- get_Data_formatted_for_ml_and_best.m(path_db=path_db, + rat_studies=rat_studies, + studyid_metadata=studyid_metadata, + fake_study = fake_study, + use_xpt_file = use_xpt_file, + Round = Round, + Impute = Impute, + reps=reps, + holdback=holdback, + Undersample = Undersample, + hyperparameter_tuning = hyperparameter_tuning, + error_correction_method=error_correction_method) # = must be 'Flip' or "Prune' or 'None' + + } + + Data <- data_and_best.m[["Data"]] + best.m <- data_and_best.m[["best.m"]] + + + + + rfData <- Data + #--------------------------------------------------------------------- + # Initialize model performance metric trackers------------------------ + #--------------------------------------------------------------------- + + # custom function definition + `%ni%` <- Negate('%in%') + + Sensitivity <- NULL + Specificity <- NULL + PPV <- NULL + NPV <- NULL + Prevalence <- NULL + Accuracy <- NULL + #nRemoved <- NULL + + + #-----------------doing cross-validation-------------------------- + #----------------------------------------------------------------- + #------------------------------------------------------------------ + + #-----create and prepare "`rfTestData data` frame" for storing predictions---- + rfTestData <- rfData + + #replaces the existing column names with simple numeric identifiers + colnames(rfTestData) <- seq(ncol(rfTestData)) + + #emptying the data frame. + for (j in seq(ncol(rfTestData))) { + rfTestData[,j] <- NA + } + + #prepares rfTestData to maintain a consistent structure with the necessary + #columns for storing predictions in subsequent iterations of the loop + rfTestData <- rfTestData[,1:2] # Keep structure for predictions + + #remove 'gini' from the previous iteration + #if (exists('gini')) {rm(gini)} + + + #------------------------------------------------------------------- + # model building and testing---------------------------------------- + #------------------------------------------------------------------- + + + # Iterate through test repetitions---------------------------------- + for (i in seq(testReps)) { + if (i == 1) { + sampleIndicies <- seq(nrow(rfData)) + } + if (i < testReps) { + ind <- sample(seq(nrow(rfData)), floor((nrow(rfData)/testReps)-1), replace = F) + sampleIndicies <- sampleIndicies[which(sampleIndicies %ni% ind)] + } else { + ind <- sampleIndicies + } + + trainIndex <- which(seq(nrow(rfData)) %ni% ind) + testIndex <- ind + + # ind <- sample(2, nrow(rfData), replace = T, prob = c((1- testHoldBack), testHoldBack)) + train <- rfData[trainIndex,] + + #train_data_two <- train + + test <- rfData[testIndex,] + + # rfAll <- randomForest::randomForest(Target_Organ ~ ., data=rfData, mytry = best.m, + # importance = F, ntree = 500, proximity = T) + + + # Perform under sampling if enabled + if (Undersample == T) { + posIndex <- which(train[,1] == 1) + nPos <- length(posIndex) + # trainIndex <- c(posIndex, sample(which(train[,1] == 0), nPos, replace = F)) + trainIndex <- c(posIndex, sample(which(train[,1] == 0), nPos, replace = T)) + train <- train[trainIndex,] + test <- rbind(train[-trainIndex,], test) + } + + #train_data_two <- train + + + #model building with current iteration train data + # Train Random Forest model-------------------------------------------- + rf <- randomForest::randomForest(Target_Organ ~ ., data=train, mytry = best.m, + importance = T, ntree = 500, proximity = T) + + print(rf) + + #---------------------------------------------------------------------- + #predictions with current model with current test data + # @___________________this_line_has_problems_______ + # Predict probabilities on test data + #---------------------------------------------------------------------- + + p2r <- stats::predict(rf, test, type = 'prob')[,1] + + #Store these predictions in a structured data frame + rfTestData[names(p2r), i] <- as.numeric(p2r) + + #Rounding the Predictions: + p2r <- round(p2r) + } + + + #------------------------------------------------------- + histoData <- as.data.frame(cbind(rowMeans(rfTestData, na.rm = T), rfData[,1])) + histoData[which(histoData[,2] == 1), 2] <- 'Y' + histoData[which(histoData[,2] == 2), 2] <- 'N' + colnames(histoData) <- c('Probability', 'LIVER') + + H <- p <- histoData %>% + ggplot2::ggplot( ggplot2::aes(x=Probability, fill=LIVER)) + + ggplot2::geom_histogram( color="#e9ecef", alpha=0.6, position = 'identity') + + ggplot2::scale_fill_manual(values=c("#69b3a2", "#404080")) + + ggplot2::labs(fill = "LIVER", x = "Model Prediction P(LIVER)", y = "Count") + +print(H) + + } diff --git a/R/get_reprtree_from_rf_model .R b/R/get_reprtree_from_rf_model .R index 5defa1a..74e8498 100644 --- a/R/get_reprtree_from_rf_model .R +++ b/R/get_reprtree_from_rf_model .R @@ -13,7 +13,7 @@ get_reprtree_from_rf_model <- function ( Data=NULL, hyperparameter_tuning = FALSE, error_correction_method) { # = must be 'Flip' or "Prune' or 'None' - browser() + if(is.null(Data)){ data_and_best.m <- get_Data_formatted_for_ml_and_best.m(path_db=path_db, rat_studies=rat_studies, @@ -29,7 +29,7 @@ get_reprtree_from_rf_model <- function ( Data=NULL, error_correction_method=error_correction_method) # = must be 'Flip' or "Prune' or 'None' } - browser() + Data <- data_and_best.m[["Data"]] best.m <- data_and_best.m[["best.m"]] @@ -75,7 +75,7 @@ get_reprtree_from_rf_model <- function ( Data=NULL, # train <- train[balancedIndex, ] # } - + browser() # Train a Random Forest model using the specified mtry value rfAll <- randomForest::randomForest(Target_Organ ~ ., data = Data, @@ -88,6 +88,15 @@ get_reprtree_from_rf_model <- function ( Data=NULL, train, metric='d2') - plot.reprtree(ReprTree(rfAll, train, metric='d2')) + #plot(ReprTree) + #library(reprtree) + + # Plot the first tree (k = 1) from the random forest + reprtree::plot.getTree(rforest = rfAll, k = 5, depth = 10)#, main = "Tree 1") + + #reprtree::plot.getTree(rfAll,train, ) + #plot(ReprTree(rfAll, train, metric = "d2")) + + #reprtree::plot.reprtree(ReprTree(rfAll, train, metric='d2')) } diff --git a/inst/test_get_prediction_plot .R b/inst/test_get_prediction_plot .R new file mode 100644 index 0000000..4dc1081 --- /dev/null +++ b/inst/test_get_prediction_plot .R @@ -0,0 +1,89 @@ +rm(list = ls()) +devtools::load_all(".") + +# Initialize a connection to the SQLite database +path_db='C:/Users/MdAminulIsla.Prodhan/OneDrive - FDA/Documents/DATABASES/fake_merged_liver_not_liver.db' + +#path_db='C:/Users/MdAminulIsla.Prodhan/OneDrive - FDA/Documents/DATABASES/fake_xpt' +#studyid_or_studyids <- list.dirs(path_db , full.names = TRUE, recursive = FALSE) +#path_db = "C:/Users/MdAminulIsla.Prodhan/OneDrive - FDA/Documents/DATABASES/TestDB.db" +studyid_metadata <- read.csv("C:/Users/MdAminulIsla.Prodhan/OneDrive - FDA/Documents/DATABASES/fake_80_MD.csv", + header = TRUE, sep = ",", stringsAsFactors = FALSE) + +# ---------------------------------------------------- +# For this function we need ml format data +#---------------------------------------------------- +prediction_plot <- get_prediction_plot(Data=NULL, + path_db=path_db, + rat_studies=FALSE, + studyid_metadata=studyid_metadata, + fake_study = TRUE, + use_xpt_file = FALSE, + Round = TRUE, + Impute = TRUE, + reps=1, + holdback=0.25, + Undersample =TRUE, + hyperparameter_tuning = FALSE, + error_correction_method = 'None', # = must be 'Flip' or "Prune' or 'None' + testReps=3) + + +# simple_rf_model <- get_rf_model_with_cv(Data = Data, +# Undersample = FALSE, +# best.m = NULL, # any numeric value or call function to get it +# testReps=2, # testRps must be at least 2; +# Type=1) + + + +# rf_with_intermediate <- get_imp_features_from_rf_model_with_cv(scores_df=Data, +# Undersample = TRUE, +# best.m = 4, # any numeric value or call function to get it +# testReps=2, # testRps must be at least 2; +# indeterminateUpper=0.75, +# indeterminateLower=0.25, +# Type=1, +# nTopImportance=20) + + + + + + + + +#rf_model <- get_random_forest_model_amin2(Data=rf_Data) + + + +# # Create a connection to the database +# dbtoken <- DBI::dbConnect(RSQLite::SQLite(), dbname = path_db) +# +# # Retrieve the STUDYID column from the dm table +# query <- "SELECT STUDYID FROM dm" +# studyid_data <- DBI::dbGetQuery(dbtoken, query) +# +# # Extract unique STUDYID values +# unique_studyids <- unique(studyid_data$STUDYID) +# +# # Disconnect from the database +# DBI::dbDisconnect(dbtoken) +# +# studyid_or_studyids <- unique_studyids + +#studyid_or_studyids <- list.dirs(path_db , full.names = TRUE, recursive = FALSE) + +# #@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +# rm(list = ls()) +# devtools::load_all(".") +# path_db <- "C:/Users/MdAminulIsla.Prodhan/OneDrive - FDA/Documents/DATABASES/fake_merged_liver_not_liver.db" +# studyid_metadata_path <- "C:/Users/MdAminulIsla.Prodhan/OneDrive - FDA/Documents/DATABASES/fake_80_MD.csv" +# +# rfData_and_best_m <- get_rfData_and_best_m( +# path_db = path_db, +# studyid_metadata_path = studyid_metadata_path, +# fake_study = TRUE, +# Round = TRUE, +# Undersample = TRUE +# ) diff --git a/inst/vvc_rf_functions/visualization.R b/inst/vvc_rf_functions/visualization.R index ee8810c..98f5e7a 100644 --- a/inst/vvc_rf_functions/visualization.R +++ b/inst/vvc_rf_functions/visualization.R @@ -45,3 +45,6 @@ H <- p <- histoData %>% ggplot2::scale_fill_manual(values=c("#69b3a2", "#404080")) + # theme_ipsum() + ggplot2::labs(fill = "LIVER", x = "Model Prediction P(LIVER)", y = "Count") + + +