Skip to content

Commit

Permalink
progress on adding boolean support
Browse files Browse the repository at this point in the history
  • Loading branch information
ablack3 committed Aug 21, 2024
1 parent b327541 commit f1f7a49
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 219 deletions.
21 changes: 4 additions & 17 deletions R/InsertTable.R
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,13 @@ insertTable.default <- function(connection,
data <- as.data.frame(data)
}
}
data <- convertLogicalFields(data)
isSqlReservedWord(c(tableName, colnames(data)), warn = TRUE)
useBulkLoad <- (bulkLoad && dbms %in% c("hive", "redshift") && createTable) ||
(bulkLoad && dbms %in% c("pdw", "postgresql") && !tempTable)
useCtasHack <- dbms %in% c("pdw", "redshift", "bigquery", "hive") && createTable && nrow(data) > 0 && !useBulkLoad
if (dbms == "bigquery" && useCtasHack && is.null(tempEmulationSchema)) {
abort("tempEmulationSchema is required to use insertTable with bigquery when inserting into a new table")
}

sqlDataTypes <- sapply(data, getSqlDataTypes)
sqlTableDefinition <- paste(.sql.qescape(names(data), TRUE), sqlDataTypes, collapse = ", ")
sqlTableName <- .sql.qescape(tableName, TRUE, quote = "")
Expand All @@ -277,7 +275,7 @@ insertTable.default <- function(connection,
)
}

if (createTable && !useCtasHack && !(bulkLoad && dbms == "hive")) {
if (createTable && !useCtasHack) {
sql <- paste("CREATE TABLE ", sqlTableName, " (", sqlTableDefinition, ");", sep = "")
renderTranslateExecuteSql(
connection = connection,
Expand Down Expand Up @@ -358,6 +356,9 @@ insertTable.default <- function(connection,
} else if (is(column, "Date")) {
rJava::.jcall(batchedInsert, "V", "setDate", i, as.character(column))
} else if (is.logical(column)) {
# encode column as -1 (NA), 1 (TRUE), 0 (FALSE) to pass logical NAs into Java
column <- vapply(as.integer(column), FUN = function(x) ifelse(is.na(x), -1L, x), FUN.VALUE = integer(1L))
print(class(column))
rJava::.jcall(batchedInsert, "V", "setBoolean", i, column)
} else {
rJava::.jcall(batchedInsert, "V", "setString", i, as.character(column))
Expand Down Expand Up @@ -429,7 +430,6 @@ insertTable.DatabaseConnectorDbiConnection <- function(connection,
}

}
data <- convertLogicalFields(data)

logTrace(sprintf("Inserting %d rows into table '%s' ", nrow(data), tableName))
if (!is.null(databaseSchema)) {
Expand Down Expand Up @@ -457,16 +457,3 @@ insertTable.DatabaseConnectorDbiConnection <- function(connection,
inform(paste("Inserting data took", signif(delta, 3), attr(delta, "units")))
invisible(NULL)
}

convertLogicalFields <- function(data) {
print("don't convert logical fields")
# for (i in 1:ncol(data)) {
# column <- data[[i]]
# if (is.logical(column)) {
# warn(sprintf("Column '%s' is of type 'logical', but this is not supported by many DBMSs. Converting to numeric (1 = TRUE, 0 = FALSE)",
# colnames(data)[i]))
# data[, i] <- as.integer(column)
# }
# }
return(data)
}
4 changes: 4 additions & 0 deletions R/Sql.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ parseJdbcColumnData <- function(batchedQuery,
} else if (columnTypes[i] == 4) {
column <- rJava::.jcall(batchedQuery, "[D", "getNumeric", as.integer(i))
column <- as.POSIXct(column, origin = "1970-01-01")
} else if (columnTypes[i] == 7) {
column <- rJava::.jcall(batchedQuery, "[I", "getBoolean", as.integer(i))
column <- vapply(column, FUN = function(x) ifelse(x == -1L, NA, as.logical(x)), FUN.VALUE = logical(1))
} else {
column <- rJava::.jcall(batchedQuery, "[Ljava/lang/String;", "getString", i)
if (!datesAsString) {
Expand All @@ -131,6 +134,7 @@ parseJdbcColumnData <- function(batchedQuery,
columns[[i]] <- column
}
names(columns) <- rJava::.jcall(batchedQuery, "[Ljava/lang/String;", "getColumnNames")

# More efficient than as.data.frame, as it avoids converting row.names to character:
columns <- structure(columns, class = "data.frame", row.names = seq_len(length(columns[[1]])))
return(columns)
Expand Down
2 changes: 1 addition & 1 deletion inst/csv/jarChecksum.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
41e7d71b717d9b51cbeca5dd10f85cd9a882163f595ce52cb56048f08ff0aef1
28a381e03c756f10e3d48bf2010b67ed27870b4835e714b8215198abbac0b84e
Binary file modified inst/java/DatabaseConnector.jar
Binary file not shown.
25 changes: 17 additions & 8 deletions java/org/ohdsi/databaseConnector/BatchedInsert.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private void checkColumns() {
if (((int[]) columns[i]).length != rowCount)
throw new RuntimeException("Column " + (i + 1) + " data not of correct length");
} else if (columnTypes[i] == BOOLEAN) {
if (((Boolean[]) columns[i]).length != rowCount)
if (((int[]) columns[i]).length != rowCount)
throw new RuntimeException("Column " + (i + 1) + " data not of correct length");
} else if (columnTypes[i] == NUMERIC) {
if (((double[]) columns[i]).length != rowCount)
Expand All @@ -86,11 +86,16 @@ private void setValue(PreparedStatement statement, int statementIndex, int rowIn
else
statement.setInt(statementIndex, value);
} else if (columnTypes[columnIndex] == BOOLEAN) {
Boolean value = ((Boolean[]) columns[columnIndex])[rowIndex];
if (value == null)
int value = ((int[]) columns[columnIndex])[rowIndex];
if (value == -1) {
statement.setObject(statementIndex, null);
else
statement.setBoolean(statementIndex, value);
} else if (value == 1) {
statement.setBoolean(statementIndex, true);
} else if (value == 0) {
statement.setBoolean(statementIndex, false);
} else {
throw new RuntimeException("Boolean values must be encoded as 1 (true) 0 (false) or -1 (NA) and not " + value);
}
} else if (columnTypes[columnIndex] == NUMERIC) {
double value = ((double[]) columns[columnIndex])[rowIndex];
if (Double.isNaN(value))
Expand Down Expand Up @@ -222,7 +227,11 @@ public void setInteger(int columnIndex, int[] column) {
rowCount = column.length;
}

public void setBoolean(int columnIndex, Boolean[] column) {
public void setBoolean(int columnIndex, int[] column) {
// represent boolean as int 1 for true, 0 for false, -1 for NA
// should we use byte type instead of integer? I also tried the Boolean wrapper class but
// could not get rJava to pass the boolean type to java as Boolean[]
// seems better to pass int type to and from R
columns[columnIndex - 1] = column;
columnTypes[columnIndex - 1] = BOOLEAN;
rowCount = column.length;
Expand Down Expand Up @@ -262,8 +271,8 @@ public void setInteger(int columnIndex, int column) {
setInteger(columnIndex, new int[] { column });
}

public void setBoolean(int columnIndex, Boolean column) {
setBoolean(columnIndex, new Boolean[] { column });
public void setBoolean(int columnIndex, int column) {
setBoolean(columnIndex, new int[] { column });
}

public void setNumeric(int columnIndex, double column) {
Expand Down
57 changes: 54 additions & 3 deletions java/org/ohdsi/databaseConnector/BatchedQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ public class BatchedQuery {
public static int DATETIME = 4;
public static int INTEGER64 = 5;
public static int INTEGER = 6;
public static int BOOLEAN = 7;
public static int FETCH_SIZE = 2048;
public static double MAX_BATCH_SIZE = 1000000;
public static long CHECK_MEM_ROWS = 10000;
private static String SPARK = "spark";
public static double NA_DOUBLE = Double.longBitsToDouble(0x7ff00000000007a2L);
public static int NA_INTEGER = Integer.MIN_VALUE;
public static long NA_LONG = Long.MIN_VALUE;
public static final Boolean NA_BOOLEAN = null;

private Object[] columns;
private int[] columnTypes;
Expand Down Expand Up @@ -82,6 +84,8 @@ else if (columnTypes[columnIndex] == DATE)
bytesPerRow += 4;
else if (columnTypes[columnIndex] == DATETIME)
bytesPerRow += 8;
else if (columnTypes[columnIndex] == BOOLEAN)
bytesPerRow += 8; // not sure if this is correct
else // String
bytesPerRow += 512;
batchSize = (int) Math.min(MAX_BATCH_SIZE, Math.round((availableMemoryAtStart / 10d) / (double) bytesPerRow));
Expand All @@ -100,6 +104,8 @@ else if (columnTypes[columnIndex] == DATE)
columns[columnIndex] = new int[batchSize];
else if (columnTypes[columnIndex] == DATETIME)
columns[columnIndex] = new double[batchSize];
else if (columnTypes[columnIndex] == BOOLEAN)
columns[columnIndex] = new Boolean[batchSize];
else
columns[columnIndex] = new String[batchSize];
byteBuffer = ByteBuffer.allocate(8 * batchSize);
Expand Down Expand Up @@ -134,12 +140,22 @@ public BatchedQuery(Connection connection, String query, String dbms) throws SQL
resultSet = statement.executeQuery(query);
resultSet.setFetchSize(FETCH_SIZE);
ResultSetMetaData metaData = resultSet.getMetaData();

columnTypes = new int[metaData.getColumnCount()];
columnSqlTypes = new String[metaData.getColumnCount()];
for (int columnIndex = 0; columnIndex < metaData.getColumnCount(); columnIndex++) {
columnSqlTypes[columnIndex] = metaData.getColumnTypeName(columnIndex + 1);
int type = metaData.getColumnType(columnIndex + 1);
String className = metaData.getColumnClassName(columnIndex + 1);

//System.out.println("======================== debug ====================");
//System.out.println("type= " + type);
//System.out.println("className= " + className);
//System.out.println("columnSqlTypes[columnIndex]= " + columnSqlTypes[columnIndex]);
//System.out.println("Types.BOOLEAN=" + Types.BOOLEAN);


//Types.BOOLEAN is 16 but for a boolean datatype in the database type is -7.
int precision = metaData.getPrecision(columnIndex + 1);
int scale = metaData.getScale(columnIndex + 1);
if (type == Types.INTEGER || type == Types.SMALLINT || type == Types.TINYINT
Expand All @@ -154,6 +170,10 @@ else if (type == Types.DATE)
columnTypes[columnIndex] = DATE;
else if (type == Types.TIMESTAMP)
columnTypes[columnIndex] = DATETIME;
else if (type == Types.BOOLEAN || className.equals("java.lang.Boolean") || columnSqlTypes[columnIndex] == "bool") {
System.out.println("Setting boolean type.");
columnTypes[columnIndex] = BOOLEAN;
}
else
columnTypes[columnIndex] = STRING;
}
Expand Down Expand Up @@ -183,14 +203,18 @@ public void fetchBatch() throws SQLException {
((int[]) columns[columnIndex])[rowCount] = resultSet.getInt(columnIndex + 1);
if (resultSet.wasNull())
((int[]) columns[columnIndex])[rowCount] = NA_INTEGER;
} else if (columnTypes[columnIndex] == STRING)
} else if (columnTypes[columnIndex] == STRING) {
((String[]) columns[columnIndex])[rowCount] = resultSet.getString(columnIndex + 1);
else if (columnTypes[columnIndex] == DATE) {
} else if (columnTypes[columnIndex] == DATE) {
Date date = resultSet.getDate(columnIndex + 1);
if (date == null)
((int[]) columns[columnIndex])[rowCount] = NA_INTEGER;
else
((int[]) columns[columnIndex])[rowCount] = (int)date.toLocalDate().toEpochDay();
} else if (columnTypes[columnIndex] == BOOLEAN) {
((Boolean[]) columns[columnIndex])[rowCount] = resultSet.getBoolean(columnIndex + 1);
if (resultSet.wasNull())
((Boolean[]) columns[columnIndex])[rowCount] = NA_BOOLEAN;
} else {
Timestamp timestamp = resultSet.getTimestamp(columnIndex + 1);
if (timestamp == null)
Expand Down Expand Up @@ -246,7 +270,7 @@ public String[] getString(int columnIndex) {
} else
return column;
}
public int[] getInteger(int columnIndex) {
int[] column = ((int[]) columns[columnIndex - 1]);
if (column.length > rowCount) {
Expand All @@ -257,6 +281,33 @@ public int[] getInteger(int columnIndex) {
return column;
}

private int[] mapBooleanToInt(Boolean[] booleanArray) {
int[] intArray = new int[booleanArray.length];

// Map Boolean values to int values
for (int i = 0; i < booleanArray.length; i++) {
if (booleanArray[i] == null) {
intArray[i] = -1; // Map null to -1
} else if (booleanArray[i]) {
intArray[i] = 1; // Map true to 1
} else {
intArray[i] = 0; // Map false to 0
}
}
return intArray;
}
// Pass integer to R which is easier than boolean types
public int[] getBoolean(int columnIndex) {
Boolean[] column = ((Boolean[]) columns[columnIndex - 1]);

if (column.length > rowCount) {
Boolean[] newColumn = new Boolean[rowCount];
System.arraycopy(column, 0, newColumn, 0, rowCount);
return mapBooleanToInt(newColumn);
} else
return mapBooleanToInt(column);
}

public double[] getInteger64(int columnIndex) {
long[] column = ((long[]) columns[columnIndex - 1]);
if (column.length > rowCount) {
Expand Down
23 changes: 0 additions & 23 deletions java/org/ohdsi/databaseConnector/DebugRtoSqlTranslation.java

This file was deleted.

37 changes: 0 additions & 37 deletions java/org/ohdsi/databaseConnector/RFunctionToTranslate.java

This file was deleted.

Loading

0 comments on commit f1f7a49

Please sign in to comment.