Skip to content

Commit

Permalink
rf updated
Browse files Browse the repository at this point in the history
  • Loading branch information
aminuldu07 committed Dec 29, 2024
1 parent 159ee56 commit 43c135c
Show file tree
Hide file tree
Showing 8 changed files with 743 additions and 80 deletions.
162 changes: 162 additions & 0 deletions R/get_Data_formatted_for_ml.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@





get_Data_formatted_for_ml <- function(path_db,
rat_studies=FALSE,
studyid_metadata=NULL,
fake_study = FALSE,
use_xpt_file = FALSE,
Round = FALSE,
Impute = FALSE,
reps,
holdback,
Undersample = FALSE,
hyperparameter_tuning = FALSE,
error_correction_method # = must be 'Flip' or "Prune' or 'None'
){

# Process the database to retrieve the vector of "STUDYIDs"-------------
if(use_xpt_file){

studyid_or_studyids <- list.dirs(path_db , full.names = TRUE, recursive = FALSE)

} else {

if (fake_study) {
# Helper function to fetch data from SQLite database
fetch_domain_data <- function(db_connection, domain_name) {
# Convert domain name to uppercase
domain_name <- toupper(domain_name)
# Create SQL query statement
query_statement <- paste0('SELECT * FROM ', domain_name)
# Execute query and fetch the data
query_result <- DBI::dbGetQuery(db_connection, statement = query_statement)
# Return the result
query_result
}
# Establish a connection to the SQLite database
db_connection <- DBI::dbConnect(RSQLite::SQLite(), dbname = path_db)

# Fetch data for required domains
dm <- fetch_domain_data(db_connection, 'dm')

# Close the database connection
DBI::dbDisconnect(db_connection)

# get the studyids from the dm table
studyid_or_studyids <- as.vector(unique(dm$STUDYID)) # unique STUDYIDS from DM table

# Filter the fake data for the "rat_studies"
if(rat_studies){

studyid_or_studyids <- studyid_or_studyids
}

#--------------------------------------------------------------------
#-----------we can set logic here for rat studies in "fake data"----
#--------------------------------------------------------------------

} else {
# For the real data in sqlite database
# filter for the repeat-dose and parallel studyids

studyid_or_studyids <- get_repeat_dose_parallel_studyids(path_db=path_db,
rat_studies = rat_studies)

}
}


# process the database to get the "studyid_metadata"------------
if(is.null(studyid_metadata)) {
if(fake_study) {
# Extract study ID metadata
studyid_metadata <- dm[, "STUDYID", drop=FALSE]

# Remove duplicates based on STUDYID
studyid_metadata <- studyid_metadata[!duplicated(studyid_metadata$STUDYID), , drop =FALSE]

# Add a new column for Target_Organ
studyid_metadata$Target_Organ <- NA

# assign "Target_Organ" column values randomly
# randomly 50% of the value is Liver and rest are not_Liver
set.seed(123) # Set seed for reproducibility
rows_number <- nrow(studyid_metadata) # Number of rows

# Randomly sample 50% for "Liver" and rest for "not_Liver"
studyid_metadata$Target_Organ <- sample(c("Liver", "not_Liver"), size = rows_number, replace = TRUE, prob = c(0.5, 0.5))

# View the result

} else {

# create "studyid_metadata" data frame from "studyid_or_studyids" vector
studyid_metadata <- data.frame(STUDYID = studyid_or_studyids)

# Remove duplicates based on STUDYID
studyid_metadata <- studyid_metadata[!duplicated(studyid_metadata$STUDYID), , drop = FALSE]

# Add a new column for Target_Organ
studyid_metadata$Target_Organ <- NA

# assign "Target_Organ" column values randomly
# randomly 50% of the value is Liver and rest are not_Liver
set.seed(123) # Set seed for reproducibility
rows_number <- nrow(studyid_metadata) # Number of rows

# Randomly sample 50% for "Liver" and rest for "not_Liver"
studyid_metadata$Target_Organ <- sample(c("Liver", "not_Liver"), size = rows_number, replace = TRUE, prob = c(0.5, 0.5))

}
}


#-----------------------------------------------------------------------
# if studyid_metadata is not provided then use the data frame to
# creae a data frame with two columns "STUDYID" and "Target_Organ"

#-------------------------------------------------------------------------

# get_liver_om_lb_mi_tox_score_list(
calculated_liver_scores <- get_liver_om_lb_mi_tox_score_list(studyid_or_studyids = studyid_or_studyids,
path_db = path_db,
fake_study = fake_study,
use_xpt_file = use_xpt_file,
output_individual_scores = TRUE,
output_zscore_by_USUBJID = FALSE)

# Harmonize the column
column_harmonized_liverscr_df <- get_col_harmonized_scores_df(liver_score_data_frame = calculated_liver_scores,
Round = Round)



rfData_and_best_m <- get_ml_data_and_tuned_hyperparameters( scores_df = column_harmonized_liverscr_df,
studyid_metadata = studyid_metadata,
Impute = Impute,
Round = Round,
reps=reps,
holdback=holdback,
Undersample = Undersample,
hyperparameter_tuning = hyperparameter_tuning,
error_correction_method = error_correction_method)



rfData <- rfData_and_best_m[["rfData"]]

# best.m input handling------------------------------------------------
# if(is.null(best.m)){
# best.m <- rfData_and_best_m[["best.m"]]
# } else {
# best.m <- best.m
# }



return(Data = rfData)

}
214 changes: 214 additions & 0 deletions R/get_imp_features_from_rf_model_cv_imp.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@

get_imp_features_from_rf_model_with_cv <- function(scores_df=NULL,
Undersample = FALSE,
best.m = NULL, # any numeric value or call function to get it
testReps, # testRps must be at least 2;
indeterminateUpper,
indeterminateLower,
Type,
nTopImportance) {

rfData <- scores_df
#---------------------------------------------------------------------
# Initialize model performance metric trackers------------------------
#---------------------------------------------------------------------

# custom function definition
`%ni%` <- Negate('%in%')

Sensitivity <- NULL
Specificity <- NULL
PPV <- NULL
NPV <- NULL
Prevalence <- NULL
Accuracy <- NULL
nRemoved <- NULL


#-----------------doing cross-validation--------------------------
#-----------------------------------------------------------------
#------------------------------------------------------------------

#-----create and prepare "`rfTestData data` frame" for storing predictions----
rfTestData <- rfData

#replaces the existing column names with simple numeric identifiers
colnames(rfTestData) <- seq(ncol(rfTestData))

#emptying the data frame.
for (j in seq(ncol(rfTestData))) {
rfTestData[,j] <- NA
}

#prepares rfTestData to maintain a consistent structure with the necessary
#columns for storing predictions in subsequent iterations of the loop
rfTestData <- rfTestData[,1:2] # Keep structure for predictions

#remove 'gini' from the previous iteration
if (exists('gini')) {rm(gini)}


#-------------------------------------------------------------------
# model building and testing----------------------------------------
#-------------------------------------------------------------------


# Iterate through test repetitions----------------------------------
for (i in seq(testReps)) {
if (i == 1) {
sampleIndicies <- seq(nrow(rfData))
}
if (i < testReps) {
ind <- sample(seq(nrow(rfData)), floor((nrow(rfData)/testReps)-1), replace = F)
sampleIndicies <- sampleIndicies[which(sampleIndicies %ni% ind)]
} else {
ind <- sampleIndicies
}

trainIndex <- which(seq(nrow(rfData)) %ni% ind)
testIndex <- ind

# ind <- sample(2, nrow(rfData), replace = T, prob = c((1- testHoldBack), testHoldBack))
train <- rfData[trainIndex,]

#train_data_two <- train

test <- rfData[testIndex,]

# rfAll <- randomForest::randomForest(Target_Organ ~ ., data=rfData, mytry = best.m,
# importance = F, ntree = 500, proximity = T)


# Perform under sampling if enabled
if (Undersample == T) {
posIndex <- which(train[,1] == 1)
nPos <- length(posIndex)
# trainIndex <- c(posIndex, sample(which(train[,1] == 0), nPos, replace = F))
trainIndex <- c(posIndex, sample(which(train[,1] == 0), nPos, replace = T))
train <- train[trainIndex,]
test <- rbind(train[-trainIndex,], test)
}

#train_data_two <- train

browser()
#model building with current iteration train data
# Train Random Forest model--------------------------------------------
rf <- randomForest::randomForest(Target_Organ ~ ., data=train, mytry = best.m,
importance = T, ntree = 500, proximity = T)

print(rf)

#----------------------------------------------------------------------
#predictions with current model with current test data
# @___________________this_line_has_problems_______
# Predict probabilities on test data
#----------------------------------------------------------------------

p2r <- stats::predict(rf, test, type = 'prob')[,1]

#Store these predictions in a structured data frame
rfTestData[names(p2r), i] <- as.numeric(p2r)


#--------------------------------------------------------------------------
#--------------------------------------------------------------------------
#--------------------------------------------------------------------------
#Identifying Indeterminate Predictions (Tracking Indeterminate Predictions)
#Keeps track of the proportion of indeterminate predictions in each iteration
#Proportion Tracking
#------------------------------------------------------------------------
#------------------------------------------------------------------------

indeterminateIndex <- which((p2r < indeterminateUpper)&(p2r > indeterminateLower))

#Calculating the Proportion of Indeterminate Predictions
#Sets the indeterminate predictions to NA, effectively marking them
#as missing or invalid.
nRemoved <- c(nRemoved, length(indeterminateIndex)/length(p2r))

#Handling Indeterminate Predictions
p2r[indeterminateIndex] <- NA

#Rounding the Predictions:
p2r <- round(p2r)


# Compute confusion matrix and extract metrics using "caret" package----

Results <- caret::confusionMatrix(factor(p2r, levels = c(1, 0)), factor(test$Target_Organ, levels = c(1, 0)))
Sensitivity <- c(Sensitivity, Results$byClass[['Sensitivity']])
Specificity <- c(Specificity, Results$byClass[['Specificity']])
PPV <- c(PPV, Results$byClass[['Pos Pred Value']])
NPV <- c(NPV, Results$byClass[['Neg Pred Value']])
Prevalence <- c(Prevalence, Results$byClass[['Prevalence']])
Accuracy <- c(Accuracy, Results$byClass[['Balanced Accuracy']])


# Aggregate Gini importance scores
giniTmp <- randomForest::importance(rf, type = Type)
if (exists('gini')) {
gini <- cbind(gini, giniTmp)
} else {
gini <- giniTmp
}
}


#------------------------------------------------------------------------
# Performance Summary
#-------------------------------------------------------------------------

PerformanceMatrix <- cbind(Sensitivity,
Specificity,
PPV,
NPV,
Prevalence,
Accuracy,
nRemoved)
PerformanceSummary <- colMeans(PerformanceMatrix, na.rm = T)
print(PerformanceSummary)

#-------------------------------------------------------------------------
# Feature Importance------------------------------------------------------
#-------------------------------------------------------------------------

print("Feature Importance (Mean Decrease):")
print(sort(rowMeans(gini), decreasing = T))


#-------------------------------------------------------------------------
# Top Important Features--------------------------------------------------
#--------------------------------------------------------------------------
imp <- as.matrix(rowMeans(gini)[1:nTopImportance])
if (Type == 1) {
colnames(imp) <- 'MeanDecreaseAccuracy'
} else {
colnames(imp) <- 'MeanDecreaseGini'
}
ord <- order(imp[,1])

# #------------------------------------------------------------------------
# # Dotchart for top Variable Importance
# #------------------------------------------------------------------------
# dotchart(imp[ord, 1], xlab = colnames(imp)[1], ylab = "",
# main = paste0('Top ', nrow(imp), ' - Variable Importance'))#, xlim = c(xmin, max(imp[, i])))
# # varImpPlot(rf,
# # sort = T,
# # n.var = 20,
# # main = "Top 20 - Variable Importance")
print(".............................................................................")
print(PerformanceSummary)

return(list(
performance_metrics = PerformanceSummary, # Aggregated performance metrics
feature_importance = imp, # Top n features by importance
raw_results = list( # Raw data for debugging or extended analysis
sensitivity = Sensitivity,
specificity = Specificity,
accuracy = Accuracy,
gini_scores = gini
)
))

}
Loading

0 comments on commit 43c135c

Please sign in to comment.