Skip to content

Commit

Permalink
get_prediction_plot works
Browse files Browse the repository at this point in the history
  • Loading branch information
aminuldu07 committed Dec 30, 2024
1 parent 8fb4e96 commit 892543e
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 4 deletions.
160 changes: 160 additions & 0 deletions R/get_prediction_plot.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
get_prediction_plot <- function(Data=NULL,
path_db,
rat_studies=FALSE,
studyid_metadata=NULL,
fake_study = FALSE,
use_xpt_file = FALSE,
Round = FALSE,
Impute = FALSE,
reps,
holdback,
Undersample = FALSE,
hyperparameter_tuning = FALSE,
error_correction_method,
testReps){


if(is.null(Data)){
data_and_best.m <- get_Data_formatted_for_ml_and_best.m(path_db=path_db,
rat_studies=rat_studies,
studyid_metadata=studyid_metadata,
fake_study = fake_study,
use_xpt_file = use_xpt_file,
Round = Round,
Impute = Impute,
reps=reps,
holdback=holdback,
Undersample = Undersample,
hyperparameter_tuning = hyperparameter_tuning,
error_correction_method=error_correction_method) # = must be 'Flip' or "Prune' or 'None'

}

Data <- data_and_best.m[["Data"]]
best.m <- data_and_best.m[["best.m"]]




rfData <- Data
#---------------------------------------------------------------------
# Initialize model performance metric trackers------------------------
#---------------------------------------------------------------------

# custom function definition
`%ni%` <- Negate('%in%')

Sensitivity <- NULL
Specificity <- NULL
PPV <- NULL
NPV <- NULL
Prevalence <- NULL
Accuracy <- NULL
#nRemoved <- NULL


#-----------------doing cross-validation--------------------------
#-----------------------------------------------------------------
#------------------------------------------------------------------

#-----create and prepare "`rfTestData data` frame" for storing predictions----
rfTestData <- rfData

#replaces the existing column names with simple numeric identifiers
colnames(rfTestData) <- seq(ncol(rfTestData))

#emptying the data frame.
for (j in seq(ncol(rfTestData))) {
rfTestData[,j] <- NA
}

#prepares rfTestData to maintain a consistent structure with the necessary
#columns for storing predictions in subsequent iterations of the loop
rfTestData <- rfTestData[,1:2] # Keep structure for predictions

#remove 'gini' from the previous iteration
#if (exists('gini')) {rm(gini)}


#-------------------------------------------------------------------
# model building and testing----------------------------------------
#-------------------------------------------------------------------


# Iterate through test repetitions----------------------------------
for (i in seq(testReps)) {
if (i == 1) {
sampleIndicies <- seq(nrow(rfData))
}
if (i < testReps) {
ind <- sample(seq(nrow(rfData)), floor((nrow(rfData)/testReps)-1), replace = F)
sampleIndicies <- sampleIndicies[which(sampleIndicies %ni% ind)]
} else {
ind <- sampleIndicies
}

trainIndex <- which(seq(nrow(rfData)) %ni% ind)
testIndex <- ind

# ind <- sample(2, nrow(rfData), replace = T, prob = c((1- testHoldBack), testHoldBack))
train <- rfData[trainIndex,]

#train_data_two <- train

test <- rfData[testIndex,]

# rfAll <- randomForest::randomForest(Target_Organ ~ ., data=rfData, mytry = best.m,
# importance = F, ntree = 500, proximity = T)


# Perform under sampling if enabled
if (Undersample == T) {
posIndex <- which(train[,1] == 1)
nPos <- length(posIndex)
# trainIndex <- c(posIndex, sample(which(train[,1] == 0), nPos, replace = F))
trainIndex <- c(posIndex, sample(which(train[,1] == 0), nPos, replace = T))
train <- train[trainIndex,]
test <- rbind(train[-trainIndex,], test)
}

#train_data_two <- train


#model building with current iteration train data
# Train Random Forest model--------------------------------------------
rf <- randomForest::randomForest(Target_Organ ~ ., data=train, mytry = best.m,
importance = T, ntree = 500, proximity = T)

print(rf)

#----------------------------------------------------------------------
#predictions with current model with current test data
# @___________________this_line_has_problems_______
# Predict probabilities on test data
#----------------------------------------------------------------------

p2r <- stats::predict(rf, test, type = 'prob')[,1]

#Store these predictions in a structured data frame
rfTestData[names(p2r), i] <- as.numeric(p2r)

#Rounding the Predictions:
p2r <- round(p2r)
}


#-------------------------------------------------------
histoData <- as.data.frame(cbind(rowMeans(rfTestData, na.rm = T), rfData[,1]))
histoData[which(histoData[,2] == 1), 2] <- 'Y'
histoData[which(histoData[,2] == 2), 2] <- 'N'
colnames(histoData) <- c('Probability', 'LIVER')

H <- p <- histoData %>%
ggplot2::ggplot( ggplot2::aes(x=Probability, fill=LIVER)) +
ggplot2::geom_histogram( color="#e9ecef", alpha=0.6, position = 'identity') +
ggplot2::scale_fill_manual(values=c("#69b3a2", "#404080")) +
ggplot2::labs(fill = "LIVER", x = "Model Prediction P(LIVER)", y = "Count")

print(H)

}
17 changes: 13 additions & 4 deletions R/get_reprtree_from_rf_model .R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ get_reprtree_from_rf_model <- function ( Data=NULL,
hyperparameter_tuning = FALSE,
error_correction_method) { # = must be 'Flip' or "Prune' or 'None'

browser()

if(is.null(Data)){
data_and_best.m <- get_Data_formatted_for_ml_and_best.m(path_db=path_db,
rat_studies=rat_studies,
Expand All @@ -29,7 +29,7 @@ get_reprtree_from_rf_model <- function ( Data=NULL,
error_correction_method=error_correction_method) # = must be 'Flip' or "Prune' or 'None'

}
browser()

Data <- data_and_best.m[["Data"]]
best.m <- data_and_best.m[["best.m"]]

Expand Down Expand Up @@ -75,7 +75,7 @@ get_reprtree_from_rf_model <- function ( Data=NULL,
# train <- train[balancedIndex, ]
# }


browser()
# Train a Random Forest model using the specified mtry value
rfAll <- randomForest::randomForest(Target_Organ ~ .,
data = Data,
Expand All @@ -88,6 +88,15 @@ get_reprtree_from_rf_model <- function ( Data=NULL,
train,
metric='d2')

plot.reprtree(ReprTree(rfAll, train, metric='d2'))
#plot(ReprTree)
#library(reprtree)

# Plot the first tree (k = 1) from the random forest
reprtree::plot.getTree(rforest = rfAll, k = 5, depth = 10)#, main = "Tree 1")

#reprtree::plot.getTree(rfAll,train, )
#plot(ReprTree(rfAll, train, metric = "d2"))

#reprtree::plot.reprtree(ReprTree(rfAll, train, metric='d2'))

}
89 changes: 89 additions & 0 deletions inst/test_get_prediction_plot .R
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
rm(list = ls())
devtools::load_all(".")

# Initialize a connection to the SQLite database
path_db='C:/Users/MdAminulIsla.Prodhan/OneDrive - FDA/Documents/DATABASES/fake_merged_liver_not_liver.db'

#path_db='C:/Users/MdAminulIsla.Prodhan/OneDrive - FDA/Documents/DATABASES/fake_xpt'
#studyid_or_studyids <- list.dirs(path_db , full.names = TRUE, recursive = FALSE)
#path_db = "C:/Users/MdAminulIsla.Prodhan/OneDrive - FDA/Documents/DATABASES/TestDB.db"
studyid_metadata <- read.csv("C:/Users/MdAminulIsla.Prodhan/OneDrive - FDA/Documents/DATABASES/fake_80_MD.csv",
header = TRUE, sep = ",", stringsAsFactors = FALSE)

# ----------------------------------------------------
# For this function we need ml format data
#----------------------------------------------------
prediction_plot <- get_prediction_plot(Data=NULL,
path_db=path_db,
rat_studies=FALSE,
studyid_metadata=studyid_metadata,
fake_study = TRUE,
use_xpt_file = FALSE,
Round = TRUE,
Impute = TRUE,
reps=1,
holdback=0.25,
Undersample =TRUE,
hyperparameter_tuning = FALSE,
error_correction_method = 'None', # = must be 'Flip' or "Prune' or 'None'
testReps=3)


# simple_rf_model <- get_rf_model_with_cv(Data = Data,
# Undersample = FALSE,
# best.m = NULL, # any numeric value or call function to get it
# testReps=2, # testRps must be at least 2;
# Type=1)



# rf_with_intermediate <- get_imp_features_from_rf_model_with_cv(scores_df=Data,
# Undersample = TRUE,
# best.m = 4, # any numeric value or call function to get it
# testReps=2, # testRps must be at least 2;
# indeterminateUpper=0.75,
# indeterminateLower=0.25,
# Type=1,
# nTopImportance=20)








#rf_model <- get_random_forest_model_amin2(Data=rf_Data)



# # Create a connection to the database
# dbtoken <- DBI::dbConnect(RSQLite::SQLite(), dbname = path_db)
#
# # Retrieve the STUDYID column from the dm table
# query <- "SELECT STUDYID FROM dm"
# studyid_data <- DBI::dbGetQuery(dbtoken, query)
#
# # Extract unique STUDYID values
# unique_studyids <- unique(studyid_data$STUDYID)
#
# # Disconnect from the database
# DBI::dbDisconnect(dbtoken)
#
# studyid_or_studyids <- unique_studyids

#studyid_or_studyids <- list.dirs(path_db , full.names = TRUE, recursive = FALSE)

# #@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
# rm(list = ls())
# devtools::load_all(".")
# path_db <- "C:/Users/MdAminulIsla.Prodhan/OneDrive - FDA/Documents/DATABASES/fake_merged_liver_not_liver.db"
# studyid_metadata_path <- "C:/Users/MdAminulIsla.Prodhan/OneDrive - FDA/Documents/DATABASES/fake_80_MD.csv"
#
# rfData_and_best_m <- get_rfData_and_best_m(
# path_db = path_db,
# studyid_metadata_path = studyid_metadata_path,
# fake_study = TRUE,
# Round = TRUE,
# Undersample = TRUE
# )
3 changes: 3 additions & 0 deletions inst/vvc_rf_functions/visualization.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ H <- p <- histoData %>%
ggplot2::scale_fill_manual(values=c("#69b3a2", "#404080")) +
# theme_ipsum() +
ggplot2::labs(fill = "LIVER", x = "Model Prediction P(LIVER)", y = "Count")



0 comments on commit 892543e

Please sign in to comment.