Skip to content

Commit

Permalink
Merge branch 'develop' into function_tags_and_examples
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Jan 28, 2025
2 parents 09e0204 + bf9b9b5 commit 6ec9e52
Show file tree
Hide file tree
Showing 11 changed files with 505 additions and 83 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export(createPreprocessSettings)
export(createRandomForestFeatureSelection)
export(createRestrictPlpDataSettings)
export(createSampleSettings)
export(createSklearnModel)
export(createSplineSettings)
export(createStratifiedImputationSettings)
export(createStudyPopulation)
Expand Down
119 changes: 119 additions & 0 deletions R/ExistingSklearn.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# @file ExistingSklearn.R
#
# Copyright 2025 Observational Health Data Sciences and Informatics
#
# This file is part of PatientLevelPrediction
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


#' Plug an existing scikit learn python model into the
#' PLP framework
#'
#' @details
#' This function lets users add an existing scikit learn model that is saved as
#' model.pkl into PLP format. covariateMap is a mapping between standard
#' covariateIds and the model columns. The user also needs to specify the
#' covariate settings and population settings as these are used to determine
#' the standard PLP model design.
#'
#' @param modelLocation The location of the folder that contains the model as
#' model.pkl
#' @param covariateMap A data.frame with the columns: columnId and covariateId.
#' `covariateId` from FeatureExtraction is the standard OHDSI covariateId.
#' `columnId` is the column location the model expects that covariate to be in.
#' For example, if you had a column called 'age' in your model and this was the
#' 3rd column when fitting the model, then the values for columnId would be 3,
#' covariateId would be 1002 (the covariateId for age in years) and
#' @param covariateSettings The settings for the standardized covariates
#' @param populationSettings The settings for the population, this includes the
#' time-at-risk settings and inclusion criteria.
#' @param isPickle If the model should be saved as a pickle set this to TRUE if
#' it should be saved as json set this to FALSE.
#'
#' @return
#' An object of class plpModel, this is a list that contains:
#' model (the location of the model.pkl),
#' preprocessing (settings for mapping the covariateIds to the model
#' column mames),
#' modelDesign (specification of the model design),
#' trainDetails (information about the model fitting) and
#' covariateImportance.
#'
#' You can use the output as an input in PatientLevelPrediction::predictPlp to
#' apply the model and calculate the risk for patients.
#'
#' @export
createSklearnModel <- function(
modelLocation = "/model", # model needs to be saved here as "model.pkl"
covariateMap = data.frame(
columnId = 1:2,
covariateId = c(1, 2),
),
covariateSettings, # specify the covariates
populationSettings, # specify time at risk used to develop model
isPickle = TRUE) {
checkSklearn()
checkFileExists(modelLocation)
checkIsClass(covariateMap, "data.frame")
checkIsClass(covariateSettings, "covariateSettings")
checkIsClass(populationSettings, "populationSettings")
checkBoolean(isPickle)
checkDataframe(covariateMap, c("columnId", "covariateId"),
columnTypes = list(c("numeric", "integer"), c("numeric", "integer"))
)
existingModel <- list(model = "existingSklearn")
class(existingModel) <- "modelSettings"

plpModel <- list(
preprocessing = list(
tidyCovariates = NULL,
requireDenseMatrix = FALSE
),
covariateImportance = data.frame(
columnId = covariateMap$columnId,
covariateId = covariateMap$covariateId,
included = TRUE
),
modelDesign = PatientLevelPrediction::createModelDesign(
targetId = 1,
outcomeId = 2,
restrictPlpDataSettings = PatientLevelPrediction::createRestrictPlpDataSettings(),
covariateSettings = covariateSettings,
populationSettings = populationSettings,
sampleSettings = PatientLevelPrediction::createSampleSettings(),
preprocessSettings = PatientLevelPrediction::createPreprocessSettings(
minFraction = 0,
normalize = FALSE,
removeRedundancy = FALSE
),
modelSettings = existingModel,
splitSettings = PatientLevelPrediction::createDefaultSplitSetting()
),
model = modelLocation,
trainDetails = list(
analysisId = "existingSklearn",
developmentDatabase = "unknown",
developmentDatabaseId = "unknown",
trainingTime = -1,
modelName = "existingSklearn"
)
)

attr(plpModel, "modelType") <- "binary"
attr(plpModel, "saveType") <- "file"
attr(plpModel, "predictionFunction") <- "predictPythonSklearn"
attr(plpModel, "saveToJson") <- !isPickle
class(plpModel) <- "plpModel"
return(plpModel)
}
12 changes: 12 additions & 0 deletions R/HelperFunctions.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ checkSurvivalPackages <- function() {
)
}

checkSklearn <- function() {
rlang::check_installed(
"reticulate",
reason = "Reticulate is required to use the Python models"
)
tryCatch({
reticulate::import("sklearn")
}, error = function(e) {
stop("scikit-learn in a python environment reachable by reticulate is required to use the Python models")
})
}

#' Create a temporary model location
#' @return A string for the location of the temporary model location
#' @export
Expand Down
13 changes: 1 addition & 12 deletions R/Logging.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ createLog <- function(
logName = "PLP Log",
saveDirectory = getwd(),
logFileName = paste0("plpLog", as.Date(Sys.Date(), "%Y%m%d"), ".txt")) {
checkFileExists(saveDirectory, createIfNot = TRUE)
createDir(saveDirectory)

logFileName <- gsub("[[:punct:]]", "", logFileName)

Expand All @@ -85,17 +85,6 @@ createLog <- function(
return(logger)
}

checkFileExists <- function(
saveDirectory,
createIfNot = TRUE) {
dirExists <- dir.exists(saveDirectory)
if (!dirExists && createIfNot) {
ParallelLogger::logInfo(paste0("Creating save directory at: ", saveDirectory))
dir.create(saveDirectory, recursive = TRUE)
}
return(invisible(dirExists))
}

closeLog <- function(logger) {
# stop logger
ParallelLogger::unregisterLogger(logger)
Expand Down
84 changes: 84 additions & 0 deletions R/ParamChecks.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,87 @@ checkIsEqual <- function(parameter, value) {
}
return(TRUE)
}

checkFileType <- function(parameter, fileType) {
name <- deparse(substitute(parameter))
if (!grepl(fileType, parameter)) {
ParallelLogger::logError(paste0(name, " should be a ", fileType, " file"))
stop(paste0(name, " is not a ", fileType, " file"))
}
return(TRUE)
}

createDir <- function(
saveDirectory,
createIfNot = TRUE) {
dirExists <- dir.exists(saveDirectory)
if (!dirExists && createIfNot) {
ParallelLogger::logInfo(paste0("Creating save directory at: ", saveDirectory))
dir.create(saveDirectory, recursive = TRUE)
}
return(invisible(dirExists))
}

checkFileExists <- function(parameter) {
name <- deparse(substitute(parameter))
if (!file.exists(parameter)) {
ParallelLogger::logError(paste0(name, " does not exist"))
stop(paste0(name, " does not exist"))
}
return(TRUE)
}

checkDataframe <- function(parameter, columns, columnTypes) {
name <- deparse(substitute(parameter))
# Check if 'parameter' is a dataframe
if (!is.data.frame(parameter)) {
ParallelLogger::logError(paste0(name, " should be a dataframe"))
stop(paste0(name, " is not a dataframe"))
}

# Check if all specified columns exist in the dataframe
if (!all(columns %in% names(parameter))) {
ParallelLogger::logError(paste0("Column names of ", name, " are not correct"))
stop(paste0("Column names of ", name, " are not correct"))
}

# Ensure 'columnTypes' is a list with the same length as 'columns'
if (length(columnTypes) != length(columns)) {
stop("The length of 'columnTypes' must be equal to the length of 'columns'")
}

# Extract the classes of the specified columns
colClasses <- sapply(parameter[columns], function(x) class(x)[1])

# Check each column's class against its acceptable types
typeCheck <- mapply(function(colClass, acceptableTypes) {
colClass %in% acceptableTypes
},
colClass = colClasses,
acceptableTypes = columnTypes)

# If any column doesn't match its acceptable types, throw an error
if (!all(typeCheck)) {
errorCols <- columns[!typeCheck]
expectedTypes <- columnTypes[!typeCheck]
actualTypes <- colClasses[!typeCheck]

# Construct detailed error messages for each problematic column
errorMessages <- mapply(function(col, expTypes, actType) {
paste0(
"Column '", col, "' should be of type(s): ", paste(expTypes, collapse = ", "),
" but is of type '", actType, "'."
)
},
col = errorCols,
expTypes = expectedTypes,
actType = actualTypes,
SIMPLIFY = FALSE)

# Log and stop with the error messages
ParallelLogger::logError(paste0("Column types of ", name, " are not correct"))
stop(paste0("Column types of ", name, " are not correct.\n",
paste(errorMessages, collapse = "\n")))
}
return(TRUE)
}
65 changes: 6 additions & 59 deletions R/SklearnClassifierSettings.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,7 @@ setAdaBoost <- function(nEstimators = list(10, 50, 200),
learningRate = list(1, 0.5, 0.1),
algorithm = list("SAMME"),
seed = sample(1000000, 1)) {
rlang::check_installed(
"reticulate",
reason = "Reticulate is required to use the Python models"
)
tryCatch({
reticulate::import("sklearn")
}, error = function(e) {
stop("scikit-learn in a python environment reachable by reticulate is required to use the Python models")
})

checkSklearn()
checkIsClass(seed[[1]], c("numeric", "integer"))
checkIsClass(nEstimators, "list")
checkIsClass(learningRate, "list")
Expand Down Expand Up @@ -138,19 +129,7 @@ setDecisionTree <- function(criterion = list("gini"),
minImpurityDecrease = list(10^-7),
classWeight = list(NULL),
seed = sample(1000000, 1)) {
rlang::check_installed(
"reticulate",
reason = "Reticulate is required to use the Python models"
)
tryCatch({
reticulate::import("sklearn")
}, error = function(e) {
stop("scikit-learn in a python environment reachable by reticulate is required to use the Python models")
})
if (!inherits(x = seed[[1]], what = c("numeric", "integer"))) {
stop("Invalid seed")
}

checkSklearn()
checkIsClass(criterion, "list")
checkIsClass(splitter, "list")
checkIsClass(maxDepth, "list")
Expand Down Expand Up @@ -412,15 +391,7 @@ setMLP <- function(hiddenLayerSizes = list(c(100), c(20)),
epsilon = list(0.00000001),
nIterNoChange = list(10),
seed = sample(100000, 1)) {
rlang::check_installed(
"reticulate",
reason = "Reticulate is required to use the Python models"
)
tryCatch({
reticulate::import("sklearn")
}, error = function(e) {
stop("scikit-learn in a python environment reachable by reticulate is required to use the Python models")
})
checkSklearn()
checkIsClass(seed, c("numeric", "integer"))
checkIsClass(hiddenLayerSizes, c("list"))
checkIsClass(activation, c("list"))
Expand Down Expand Up @@ -566,15 +537,7 @@ setNaiveBayes <- function() {
pythonModule = "sklearn.naive_bayes",
pythonClass = "GaussianNB"
)
rlang::check_installed(
"reticulate",
reason = "Reticulate is required to use the Python models"
)
tryCatch({
reticulate::import("sklearn")
}, error = function(e) {
stop("scikit-learn in a python environment reachable by reticulate is required to use the Python models")
})
checkSklearn()
attr(param, "saveToJson") <- TRUE
attr(param, "saveType") <- "file"

Expand Down Expand Up @@ -635,15 +598,7 @@ setRandomForest <- function(ntrees = list(100, 500),
nJobs = list(NULL),
classWeight = list(NULL),
seed = sample(100000, 1)) {
rlang::check_installed(
"reticulate",
reason = "Reticulate is required to use the Python models"
)
tryCatch({
reticulate::import("sklearn")
}, error = function(e) {
stop("scikit-learn in a python environment reachable by reticulate is required to use the Python models")
})
checkSklearn()
checkIsClass(seed, c("numeric", "integer"))
checkIsClass(ntrees, c("list"))
checkIsClass(criterion, c("list"))
Expand Down Expand Up @@ -797,15 +752,7 @@ setSVM <- function(C = list(1, 0.9, 2, 0.1),
classWeight = list(NULL),
cacheSize = 500,
seed = sample(100000, 1)) {
rlang::check_installed(
"reticulate",
reason = "Reticulate is required to use the Python models"
)
tryCatch({
reticulate::import("sklearn")
}, error = function(e) {
stop("Cannot import scikit-learn in python. scikit-learn in a python environment reachable by reticulate is required to use the Python models. Please check your python setup with reticulate::py_config() followed by reticulate::import('sklearn')")
})
checkSklearn()
checkIsClass(seed, c("numeric", "integer"))
checkIsClass(cacheSize, c("numeric", "integer"))
checkIsClass(C, c("list"))
Expand Down
Loading

0 comments on commit 6ec9e52

Please sign in to comment.