diff --git a/NAMESPACE b/NAMESPACE index c2f7f620..86c5b495 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -33,6 +33,7 @@ export(createPreprocessSettings) export(createRandomForestFeatureSelection) export(createRestrictPlpDataSettings) export(createSampleSettings) +export(createSklearnModel) export(createSplineSettings) export(createStratifiedImputationSettings) export(createStudyPopulation) diff --git a/R/ExistingSklearn.R b/R/ExistingSklearn.R new file mode 100644 index 00000000..456f8579 --- /dev/null +++ b/R/ExistingSklearn.R @@ -0,0 +1,119 @@ +# @file ExistingSklearn.R +# +# Copyright 2025 Observational Health Data Sciences and Informatics +# +# This file is part of PatientLevelPrediction +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +#' Plug an existing scikit learn python model into the +#' PLP framework +#' +#' @details +#' This function lets users add an existing scikit learn model that is saved as +#' model.pkl into PLP format. covariateMap is a mapping between standard +#' covariateIds and the model columns. The user also needs to specify the +#' covariate settings and population settings as these are used to determine +#' the standard PLP model design. +#' +#' @param modelLocation The location of the folder that contains the model as +#' model.pkl +#' @param covariateMap A data.frame with the columns: columnId and covariateId. +#' `covariateId` from FeatureExtraction is the standard OHDSI covariateId. +#' `columnId` is the column location the model expects that covariate to be in. +#' For example, if you had a column called 'age' in your model and this was the +#' 3rd column when fitting the model, then the values for columnId would be 3, +#' covariateId would be 1002 (the covariateId for age in years) and +#' @param covariateSettings The settings for the standardized covariates +#' @param populationSettings The settings for the population, this includes the +#' time-at-risk settings and inclusion criteria. +#' @param isPickle If the model should be saved as a pickle set this to TRUE if +#' it should be saved as json set this to FALSE. +#' +#' @return +#' An object of class plpModel, this is a list that contains: +#' model (the location of the model.pkl), +#' preprocessing (settings for mapping the covariateIds to the model +#' column mames), +#' modelDesign (specification of the model design), +#' trainDetails (information about the model fitting) and +#' covariateImportance. +#' +#' You can use the output as an input in PatientLevelPrediction::predictPlp to +#' apply the model and calculate the risk for patients. +#' +#' @export +createSklearnModel <- function( + modelLocation = "/model", # model needs to be saved here as "model.pkl" + covariateMap = data.frame( + columnId = 1:2, + covariateId = c(1, 2), + ), + covariateSettings, # specify the covariates + populationSettings, # specify time at risk used to develop model + isPickle = TRUE) { + checkSklearn() + checkFileExists(modelLocation) + checkIsClass(covariateMap, "data.frame") + checkIsClass(covariateSettings, "covariateSettings") + checkIsClass(populationSettings, "populationSettings") + checkBoolean(isPickle) + checkDataframe(covariateMap, c("columnId", "covariateId"), + columnTypes = list(c("numeric", "integer"), c("numeric", "integer")) + ) + existingModel <- list(model = "existingSklearn") + class(existingModel) <- "modelSettings" + + plpModel <- list( + preprocessing = list( + tidyCovariates = NULL, + requireDenseMatrix = FALSE + ), + covariateImportance = data.frame( + columnId = covariateMap$columnId, + covariateId = covariateMap$covariateId, + included = TRUE + ), + modelDesign = PatientLevelPrediction::createModelDesign( + targetId = 1, + outcomeId = 2, + restrictPlpDataSettings = PatientLevelPrediction::createRestrictPlpDataSettings(), + covariateSettings = covariateSettings, + populationSettings = populationSettings, + sampleSettings = PatientLevelPrediction::createSampleSettings(), + preprocessSettings = PatientLevelPrediction::createPreprocessSettings( + minFraction = 0, + normalize = FALSE, + removeRedundancy = FALSE + ), + modelSettings = existingModel, + splitSettings = PatientLevelPrediction::createDefaultSplitSetting() + ), + model = modelLocation, + trainDetails = list( + analysisId = "existingSklearn", + developmentDatabase = "unknown", + developmentDatabaseId = "unknown", + trainingTime = -1, + modelName = "existingSklearn" + ) + ) + + attr(plpModel, "modelType") <- "binary" + attr(plpModel, "saveType") <- "file" + attr(plpModel, "predictionFunction") <- "predictPythonSklearn" + attr(plpModel, "saveToJson") <- !isPickle + class(plpModel) <- "plpModel" + return(plpModel) +} diff --git a/R/HelperFunctions.R b/R/HelperFunctions.R index 25917e2c..c9a19b6d 100644 --- a/R/HelperFunctions.R +++ b/R/HelperFunctions.R @@ -32,6 +32,18 @@ checkSurvivalPackages <- function() { ) } +checkSklearn <- function() { + rlang::check_installed( + "reticulate", + reason = "Reticulate is required to use the Python models" + ) + tryCatch({ + reticulate::import("sklearn") + }, error = function(e) { + stop("scikit-learn in a python environment reachable by reticulate is required to use the Python models") + }) +} + #' Create a temporary model location #' @return A string for the location of the temporary model location #' @export diff --git a/R/Logging.R b/R/Logging.R index 3f1c8f89..e0768bd5 100644 --- a/R/Logging.R +++ b/R/Logging.R @@ -64,7 +64,7 @@ createLog <- function( logName = "PLP Log", saveDirectory = getwd(), logFileName = paste0("plpLog", as.Date(Sys.Date(), "%Y%m%d"), ".txt")) { - checkFileExists(saveDirectory, createIfNot = TRUE) + createDir(saveDirectory) logFileName <- gsub("[[:punct:]]", "", logFileName) @@ -85,17 +85,6 @@ createLog <- function( return(logger) } -checkFileExists <- function( - saveDirectory, - createIfNot = TRUE) { - dirExists <- dir.exists(saveDirectory) - if (!dirExists && createIfNot) { - ParallelLogger::logInfo(paste0("Creating save directory at: ", saveDirectory)) - dir.create(saveDirectory, recursive = TRUE) - } - return(invisible(dirExists)) -} - closeLog <- function(logger) { # stop logger ParallelLogger::unregisterLogger(logger) diff --git a/R/ParamChecks.R b/R/ParamChecks.R index 34dc9924..2fb26dd0 100644 --- a/R/ParamChecks.R +++ b/R/ParamChecks.R @@ -108,3 +108,87 @@ checkIsEqual <- function(parameter, value) { } return(TRUE) } + +checkFileType <- function(parameter, fileType) { + name <- deparse(substitute(parameter)) + if (!grepl(fileType, parameter)) { + ParallelLogger::logError(paste0(name, " should be a ", fileType, " file")) + stop(paste0(name, " is not a ", fileType, " file")) + } + return(TRUE) +} + +createDir <- function( + saveDirectory, + createIfNot = TRUE) { + dirExists <- dir.exists(saveDirectory) + if (!dirExists && createIfNot) { + ParallelLogger::logInfo(paste0("Creating save directory at: ", saveDirectory)) + dir.create(saveDirectory, recursive = TRUE) + } + return(invisible(dirExists)) +} + +checkFileExists <- function(parameter) { + name <- deparse(substitute(parameter)) + if (!file.exists(parameter)) { + ParallelLogger::logError(paste0(name, " does not exist")) + stop(paste0(name, " does not exist")) + } + return(TRUE) +} + +checkDataframe <- function(parameter, columns, columnTypes) { + name <- deparse(substitute(parameter)) + # Check if 'parameter' is a dataframe + if (!is.data.frame(parameter)) { + ParallelLogger::logError(paste0(name, " should be a dataframe")) + stop(paste0(name, " is not a dataframe")) + } + + # Check if all specified columns exist in the dataframe + if (!all(columns %in% names(parameter))) { + ParallelLogger::logError(paste0("Column names of ", name, " are not correct")) + stop(paste0("Column names of ", name, " are not correct")) + } + + # Ensure 'columnTypes' is a list with the same length as 'columns' + if (length(columnTypes) != length(columns)) { + stop("The length of 'columnTypes' must be equal to the length of 'columns'") + } + + # Extract the classes of the specified columns + colClasses <- sapply(parameter[columns], function(x) class(x)[1]) + + # Check each column's class against its acceptable types + typeCheck <- mapply(function(colClass, acceptableTypes) { + colClass %in% acceptableTypes + }, + colClass = colClasses, + acceptableTypes = columnTypes) + + # If any column doesn't match its acceptable types, throw an error + if (!all(typeCheck)) { + errorCols <- columns[!typeCheck] + expectedTypes <- columnTypes[!typeCheck] + actualTypes <- colClasses[!typeCheck] + + # Construct detailed error messages for each problematic column + errorMessages <- mapply(function(col, expTypes, actType) { + paste0( + "Column '", col, "' should be of type(s): ", paste(expTypes, collapse = ", "), + " but is of type '", actType, "'." + ) + }, + col = errorCols, + expTypes = expectedTypes, + actType = actualTypes, + SIMPLIFY = FALSE) + + # Log and stop with the error messages + ParallelLogger::logError(paste0("Column types of ", name, " are not correct")) + stop(paste0("Column types of ", name, " are not correct.\n", + paste(errorMessages, collapse = "\n"))) + } + return(TRUE) +} diff --git a/R/SklearnClassifierSettings.R b/R/SklearnClassifierSettings.R index 6fb879f1..4822c8c3 100644 --- a/R/SklearnClassifierSettings.R +++ b/R/SklearnClassifierSettings.R @@ -29,16 +29,7 @@ setAdaBoost <- function(nEstimators = list(10, 50, 200), learningRate = list(1, 0.5, 0.1), algorithm = list("SAMME"), seed = sample(1000000, 1)) { - rlang::check_installed( - "reticulate", - reason = "Reticulate is required to use the Python models" - ) - tryCatch({ - reticulate::import("sklearn") - }, error = function(e) { - stop("scikit-learn in a python environment reachable by reticulate is required to use the Python models") - }) - + checkSklearn() checkIsClass(seed[[1]], c("numeric", "integer")) checkIsClass(nEstimators, "list") checkIsClass(learningRate, "list") @@ -138,19 +129,7 @@ setDecisionTree <- function(criterion = list("gini"), minImpurityDecrease = list(10^-7), classWeight = list(NULL), seed = sample(1000000, 1)) { - rlang::check_installed( - "reticulate", - reason = "Reticulate is required to use the Python models" - ) - tryCatch({ - reticulate::import("sklearn") - }, error = function(e) { - stop("scikit-learn in a python environment reachable by reticulate is required to use the Python models") - }) - if (!inherits(x = seed[[1]], what = c("numeric", "integer"))) { - stop("Invalid seed") - } - + checkSklearn() checkIsClass(criterion, "list") checkIsClass(splitter, "list") checkIsClass(maxDepth, "list") @@ -412,15 +391,7 @@ setMLP <- function(hiddenLayerSizes = list(c(100), c(20)), epsilon = list(0.00000001), nIterNoChange = list(10), seed = sample(100000, 1)) { - rlang::check_installed( - "reticulate", - reason = "Reticulate is required to use the Python models" - ) - tryCatch({ - reticulate::import("sklearn") - }, error = function(e) { - stop("scikit-learn in a python environment reachable by reticulate is required to use the Python models") - }) + checkSklearn() checkIsClass(seed, c("numeric", "integer")) checkIsClass(hiddenLayerSizes, c("list")) checkIsClass(activation, c("list")) @@ -566,15 +537,7 @@ setNaiveBayes <- function() { pythonModule = "sklearn.naive_bayes", pythonClass = "GaussianNB" ) - rlang::check_installed( - "reticulate", - reason = "Reticulate is required to use the Python models" - ) - tryCatch({ - reticulate::import("sklearn") - }, error = function(e) { - stop("scikit-learn in a python environment reachable by reticulate is required to use the Python models") - }) + checkSklearn() attr(param, "saveToJson") <- TRUE attr(param, "saveType") <- "file" @@ -635,15 +598,7 @@ setRandomForest <- function(ntrees = list(100, 500), nJobs = list(NULL), classWeight = list(NULL), seed = sample(100000, 1)) { - rlang::check_installed( - "reticulate", - reason = "Reticulate is required to use the Python models" - ) - tryCatch({ - reticulate::import("sklearn") - }, error = function(e) { - stop("scikit-learn in a python environment reachable by reticulate is required to use the Python models") - }) + checkSklearn() checkIsClass(seed, c("numeric", "integer")) checkIsClass(ntrees, c("list")) checkIsClass(criterion, c("list")) @@ -797,15 +752,7 @@ setSVM <- function(C = list(1, 0.9, 2, 0.1), classWeight = list(NULL), cacheSize = 500, seed = sample(100000, 1)) { - rlang::check_installed( - "reticulate", - reason = "Reticulate is required to use the Python models" - ) - tryCatch({ - reticulate::import("sklearn") - }, error = function(e) { - stop("Cannot import scikit-learn in python. scikit-learn in a python environment reachable by reticulate is required to use the Python models. Please check your python setup with reticulate::py_config() followed by reticulate::import('sklearn')") - }) + checkSklearn() checkIsClass(seed, c("numeric", "integer")) checkIsClass(cacheSize, c("numeric", "integer")) checkIsClass(C, c("list")) diff --git a/R/uploadToDatabase.R b/R/uploadToDatabase.R index 78ca525b..9aa798c2 100644 --- a/R/uploadToDatabase.R +++ b/R/uploadToDatabase.R @@ -173,7 +173,7 @@ insertResultsToSqlite <- function( #' @param testFile (used for testing) The location of an sql file with the table creation code #' #' @return -#' Returns NULL but creates the required tables into the specified database schema(s). +#' Returns NULL but creates or deletes the required tables in the specified database schema(s). #' #' @export createPlpResultTables <- function( @@ -252,14 +252,27 @@ createPlpResultTables <- function( } else { ParallelLogger::logInfo("PLP result tables already exist") } - - # then migrate - ParallelLogger::logInfo("PLP result migrration being applied") - migrateDataModel( - connectionDetails = connectionDetails, # input is connection - databaseSchema = resultSchema, - tablePrefix = tablePrefix - ) + + if (!(createTables == FALSE && deleteTables == TRUE)) { + # then migrate, unless only deleting + ParallelLogger::logInfo("PLP result migration being applied") + migrateDataModel( + connectionDetails = connectionDetails, # input is connection + databaseSchema = resultSchema, + tablePrefix = tablePrefix + ) + } else { + ParallelLogger::logInfo("Deleting PLP migration tables") + migrationTableNames <- c("MIGRATION", "PACKAGE_VERSION") + deleteTables( + conn = conn, + databaseSchema = resultSchema, + targetDialect = targetDialect, + tempEmulationSchema = tempEmulationSchema, + tableNames = migrationTableNames, + tablePrefix = tablePrefix + ) + } } #' Populate the PatientLevelPrediction results tables diff --git a/man/createSklearnModel.Rd b/man/createSklearnModel.Rd new file mode 100644 index 00000000..f66e8cca --- /dev/null +++ b/man/createSklearnModel.Rd @@ -0,0 +1,57 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/ExistingSklearn.R +\name{createSklearnModel} +\alias{createSklearnModel} +\title{Plug an existing scikit learn python model into the +PLP framework} +\usage{ +createSklearnModel( + modelLocation = "/model", + covariateMap = data.frame(columnId = 1:2, covariateId = c(1, 2), ), + covariateSettings, + populationSettings, + isPickle = TRUE +) +} +\arguments{ +\item{modelLocation}{The location of the folder that contains the model as +model.pkl} + +\item{covariateMap}{A data.frame with the columns: columnId and covariateId. +`covariateId` from FeatureExtraction is the standard OHDSI covariateId. +`columnId` is the column location the model expects that covariate to be in. +For example, if you had a column called 'age' in your model and this was the +3rd column when fitting the model, then the values for columnId would be 3, +covariateId would be 1002 (the covariateId for age in years) and} + +\item{covariateSettings}{The settings for the standardized covariates} + +\item{populationSettings}{The settings for the population, this includes the +time-at-risk settings and inclusion criteria.} + +\item{isPickle}{If the model should be saved as a pickle set this to TRUE if +it should be saved as json set this to FALSE.} +} +\value{ +An object of class plpModel, this is a list that contains: + model (the location of the model.pkl), + preprocessing (settings for mapping the covariateIds to the model + column mames), + modelDesign (specification of the model design), + trainDetails (information about the model fitting) and + covariateImportance. + +You can use the output as an input in PatientLevelPrediction::predictPlp to +apply the model and calculate the risk for patients. +} +\description{ +Plug an existing scikit learn python model into the +PLP framework +} +\details{ +This function lets users add an existing scikit learn model that is saved as +model.pkl into PLP format. covariateMap is a mapping between standard +covariateIds and the model columns. The user also needs to specify the +covariate settings and population settings as these are used to determine +the standard PLP model design. +} diff --git a/tests/testthat/test-UploadToDatabase.R b/tests/testthat/test-UploadToDatabase.R index 3f4f0399..1d9bfeac 100644 --- a/tests/testthat/test-UploadToDatabase.R +++ b/tests/testthat/test-UploadToDatabase.R @@ -183,6 +183,18 @@ test_that("database deletion", { databaseSchema = ohdsiDatabaseSchema, tableName = paste0(appendRandom("test"), "_PERFORMANCES") )) + + expect_false(DatabaseConnector::existsTable( + connection = conn, + databaseSchema = ohdsiDatabaseSchema, + tableName = paste0(appendRandom("test"), "_migration") + )) + + expect_false(DatabaseConnector::existsTable( + connection = conn, + databaseSchema = ohdsiDatabaseSchema, + tableName = paste0(appendRandom("test"), "_package_version") + )) }) # disconnect diff --git a/tests/testthat/test-existingModel.R b/tests/testthat/test-existingModel.R new file mode 100644 index 00000000..fea47023 --- /dev/null +++ b/tests/testthat/test-existingModel.R @@ -0,0 +1,148 @@ +test_that("Create existing sklearn works", { + expect_error(createSklearnModel("existing")) + # create a file model.pkl for testing + file.create("model.pkl") + covariateSettings <- + FeatureExtraction::createCovariateSettings(useDemographicsAge = TRUE) + populationSettings <- createStudyPopulationSettings() + # dataframe wrong type + expect_error(createSklearnModel( + modelLocation = "model.pkl", + covariateMap = list( + columnId = "columnId", + covariateId = c(1) + ), + covariateSettings = covariateSettings, + populationSettings = populationSettings + )) + # dataframe wrong column names + expect_error(createSklearnModel( + modelLocation = "model.pkl", + covariateMap = data.frame( + columnId = c(1), + notCovariateId = c(1002), + ), + covariateSettings = covariateSettings, + populationSettings = populationSettings + )) + # dataframe wrong column types + expect_error(createSklearnModel( + modelLocation = "model.pkl", + covariateMap = data.frame( + columnId = 1, + covariateId = "2" + ), + covariateSettings = covariateSettings, + populationSettings = populationSettings + )) + + model <- createSklearnModel( + modelLocation = "model.pkl", + covariateMap = data.frame( + columnId = c(1, 2), + covariateId = c(1002, 1003) + ), + covariateSettings = covariateSettings, + populationSettings = populationSettings + ) + expect_equal(attr(model, "modelType"), "binary") + expect_equal(attr(model, "saveType"), "file") + expect_equal(attr(model, "predictionFunction"), "predictPythonSklearn") + expect_equal(attr(model, "saveToJson"), FALSE) + expect_equal(class(model), "plpModel") + unlink("model.pkl") +}) + +test_that("existing sklearn model works", { + skip_if_not_installed("reticulate") + skip_on_cran() + # fit a simple sklearn model with plp + modelSettings <- setDecisionTree( + criterion = list("gini"), + splitter = list("best"), + maxDepth = list(as.integer(4)), + minSamplesSplit = list(2), + minSamplesLeaf = list(10), + minWeightFractionLeaf = list(0), + maxFeatures = list("sqrt"), + maxLeafNodes = list(NULL), + minImpurityDecrease = list(10^-7), + classWeight = list(NULL), + seed = sample(1000000, 1) + ) + + plpModel <- fitPlp( + trainData = tinyTrainData, + modelSettings = modelSettings, + analysisId = "DecisionTree", + analysisPath = tempdir() + ) + + # load model json and save as pickle with joblib + model <- sklearnFromJson(file.path(plpModel$model, "model.json")) + joblib <- reticulate::import("joblib") + joblib$dump(model, file.path(plpModel$model, "model.pkl")) + + # extract covariateMap from plpModel + covariateMap <- plpModel$covariateImportance %>% dplyr::select(columnId, covariateId) + + existingModel <- createSklearnModel( + modelLocation = file.path(plpModel$model), + covariateMap = covariateMap, + covariateSettings = plpModel$modelDesign$covariateSettings, + populationSettings = plpModel$modelDesign$populationSettings + ) + + prediction <- predictPlp(plpModel, testData, testData$labels) + predictionNew <- predictPlp(existingModel, testData, testData$labels) + + expect_correct_predictions(prediction, testData) + expect_equal(prediction$value, predictionNew$value) +}) + +test_that("Externally trained sklearn model works", { + skip_if_not_installed("reticulate") + skip_on_cran() + # change map to be some random order + covariateIds <- tinyTrainData$covariateData$covariates %>% + dplyr::pull(.data$covariateId) %>% + unique() + map <- data.frame( + columnId = sample(1:20, length(covariateIds)), + covariateId = sample(covariateIds, length(covariateIds)) + ) + matrixData <- toSparseM(tinyTrainData, map = map) + matrix <- matrixData$dataMatrix %>% + Matrix::as.matrix() + + # fit with sklearn + xMatrix <- reticulate::r_to_py(matrix) + y <- reticulate::r_to_py(tinyTrainData$labels$outcomeCount) + + sklearn <- reticulate::import("sklearn") + classifier <- sklearn$tree$DecisionTreeClassifier() + classifier <- classifier$fit(xMatrix, y) + + testMatrix <- toSparseM(testData, map = matrixData$covariateMap) + xTest <- reticulate::r_to_py(testMatrix$dataMatrix %>% Matrix::as.matrix()) + yTest <- reticulate::r_to_py(testData$labels$outcomeCount) + externalPredictions <- classifier$predict_proba(xTest)[, 2] + auc <- sklearn$metrics$roc_auc_score(yTest, externalPredictions) + + joblib <- reticulate::import("joblib") + path <- tempfile() + createDir(path) + joblib$dump(classifier, file.path(path, "model.pkl")) + plpModel <- createSklearnModel( + model = path, + covariateMap = matrixData$covariateMap, + covariateSettings = FeatureExtraction::createCovariateSettings( + useDemographicsAge = TRUE + ), + populationSettings = populationSettings + ) + prediction <- predictPlp(plpModel, testData, testData$labels) + + expect_equal(mean(prediction$value), mean(externalPredictions)) + expect_correct_predictions(prediction, testData) +}) diff --git a/tests/testthat/test-paramchecks.R b/tests/testthat/test-paramchecks.R index eae484f2..d0174fdd 100644 --- a/tests/testthat/test-paramchecks.R +++ b/tests/testthat/test-paramchecks.R @@ -17,9 +17,6 @@ library("testthat") context("ParamChecks") -# Test unit for the creation of the study externalValidatePlp - - test_that("checkBoolean", { testthat::expect_error(checkBoolean(1)) testthat::expect_error(checkBoolean("tertet")) @@ -84,3 +81,46 @@ test_that("checkInStringVector", { testthat::expect_equal(checkInStringVector("dsdsds", c("dsdsds", "double")), TRUE) }) + +test_that("createDir", { + dir <- tempfile() + createDir(dir) + testthat::expect_equal(file.exists(dir), TRUE) + unlink(dir) +}) + +test_that("checkFileExists", { + file <- tempfile() + testthat::expect_error(checkFileExists(file)) + file.create(file) + testthat::expect_equal(checkFileExists(file), TRUE) + unlink(file) +}) + +test_that("checkDataframe", { + expect_error(checkDataframe( + data.frame(a = 1:2, b = 1:2), c("a", "c"), + c("numeric", "numeric") + )) + expect_error(checkDataframe( + data.frame(a = 1:2, b = 1:2), + c("a", "b"), c("numeric", "character") + )) + expect_error(checkDataframe( + data.frame(a = 1:2, b = 1:2), + c("a", "b"), c("numeric", "numeric", "numeric") + )) + expect_true(checkDataframe( + data.frame(a = 1:2, b = 1:2), + c("a", "b"), c("integer", "integer") + )) + # allow both numeric and integer in a + expect_true(checkDataframe( + data.frame(a = as.numeric(1:2), b = 1:2), + c("a", "b"), list(c("numeric", "integer"), "integer") + )) + expect_true(checkDataframe( + data.frame(a = 1:2, b = 1:2), + c("a", "b"), list(c("numeric", "integer"), "integer") + )) +})