Skip to content

Commit

Permalink
fix: EXPOSED-719 H2 upsert operation converts arrays to string
Browse files Browse the repository at this point in the history
  • Loading branch information
obabichevjb committed Feb 5, 2025
1 parent 200bc77 commit eb32ccc
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,13 @@ open class ColumnWithTransform<Unwrapped, Wrapped>(
override fun setParameter(stmt: PreparedStatementApi, index: Int, value: Any?) {
return delegate.setParameter(stmt, index, value)
}

override fun parameterMarker(value: Wrapped?): String {
return delegate.parameterMarker(value?.let { transformer.unwrap(it) })
}
}

internal fun <T : Expression<*>>unwrapColumnValues(values: Map<T, Any?>): Map<T, Any?> = values.mapValues { (col, value) ->
internal fun <T : Expression<*>> unwrapColumnValues(values: Map<T, Any?>): Map<T, Any?> = values.mapValues { (col, value) ->
if (col !is ExpressionWithColumnType<*>) return@mapValues value

value?.let { (col.columnType as? ColumnWithTransform<Any, Any>)?.unwrapRecursive(it) } ?: value
Expand Down Expand Up @@ -1352,6 +1356,28 @@ class ArrayColumnType<T, R : List<Any?>>(
else -> "ARRAY"
}
}

private fun castH2ParameterMarker(columnType: IColumnType<*>): String? {
return when (columnType) {
// Here is the list of types that could be resolved by `resolveColumnType()`.
// In the common case it must not work for all the possible types. It also does not work with BigDecimal.
// This cast is needed for array types inside upsert(merge statement), otherwise statement causes "Data conversion error converting" error.
is ByteColumnType, is UByteColumnType, is BooleanColumnType, is ShortColumnType, is UShortColumnType,
is IntegerColumnType, is UIntegerColumnType, is LongColumnType, is ULongColumnType, is FloatColumnType,
is DoubleColumnType, is StringColumnType, is CharacterColumnType, is BasicBinaryColumnType, is UUIDColumnType ->
"cast(? as ${columnType.sqlType()} array)"
else -> null
}
}

override fun parameterMarker(value: R?): String {
if (currentDialect is H2Dialect) {
val columnType = if (delegate is ColumnWithTransform<*, *>) delegate.originalColumnType else delegate
return castH2ParameterMarker(columnType) ?: super.parameterMarker(value)
}

return super.parameterMarker(value)
}
}

private fun isArrayOfByteArrays(value: Array<*>) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ import org.jetbrains.exposed.sql.statements.UpdateStatement
import org.jetbrains.exposed.sql.statements.UpsertBuilder
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.tests.TestDB
import org.jetbrains.exposed.sql.tests.shared.assertEqualLists
import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.jetbrains.exposed.sql.tests.shared.expectException
import org.jetbrains.exposed.sql.transactions.transaction
import org.junit.Test
import java.lang.Integer.parseInt
import java.util.*
import kotlin.properties.Delegates
import kotlin.test.assertNotEquals
Expand Down Expand Up @@ -811,4 +813,69 @@ class UpsertTests : DatabaseTestsBase() {
}
}
}

@Test
fun testUpsertWithArrayValue() {
val tester = object : IntIdTable("tester") {
val key = text("key").uniqueIndex()
val stringArray = array<String>("stringArray")
val integerArray = array<Int>("integerArray")
val booleanArray = array<Boolean>("booleanArray")
val transformedArray = array<Int>("transformedArray")
.transform(
wrap = { array -> array.map { it.toString() } },
unwrap = { array -> array.map { parseInt(it) } }
)
val byteArray = array<Byte>("byteArray")
val uByteArray = array<UByte>("uByteArray")
val shortArray = array<Short>("shortArray")
val uShortArray = array<UShort>("uShortArray")
val uIntArray = array<UInt>("uIntArray")
val longArray = array<Long>("longArray")
val uLongArray = array<ULong>("uLongArray")
val floatArray = array<Float>("floatArray")
val doubleArray = array<Double>("doubleArray")
val charArray = array<Char>("charArray")
val uuidArray = array<UUID>("uuidArray")
}

withTables(excludeSettings = TestDB.ALL - TestDB.H2_V2, tester) {
val uuidList = listOf(UUID.randomUUID(), UUID.randomUUID())
tester.upsert(tester.key) {
it[tester.key] = "key1"
it[tester.stringArray] = listOf("a", "b", "c")
it[tester.integerArray] = listOf(1, 2, 3)
it[tester.booleanArray] = listOf(true, true, false)
it[tester.transformedArray] = listOf("4", "5", "6")
it[tester.byteArray] = listOf(1.toByte(), 2.toByte())
it[tester.uByteArray] = listOf(1.toUByte(), 2.toUByte())
it[tester.shortArray] = listOf(1.toShort(), 2.toShort())
it[tester.uShortArray] = listOf(1.toUShort(), 2.toUShort())
it[tester.uIntArray] = listOf(1u, 2u)
it[tester.longArray] = listOf(1L, 2L)
it[tester.uLongArray] = listOf(1UL, 2UL)
it[tester.floatArray] = listOf(1.1f, 2.2f)
it[tester.doubleArray] = listOf(1.1, 2.2)
it[tester.charArray] = listOf('a', 'b')
it[tester.uuidArray] = uuidList
}

val value = tester.selectAll().first()
assertEqualLists(listOf("a", "b", "c"), value[tester.stringArray])
assertEqualLists(listOf(1, 2, 3), value[tester.integerArray])
assertEqualLists(listOf(true, true, false), value[tester.booleanArray])
assertEqualLists(listOf("4", "5", "6"), value[tester.transformedArray])
assertEqualLists(listOf(1.toByte(), 2.toByte()), value[tester.byteArray])
assertEqualLists(listOf(1.toUByte(), 2.toUByte()), value[tester.uByteArray])
assertEqualLists(listOf(1.toShort(), 2.toShort()), value[tester.shortArray])
assertEqualLists(listOf(1.toUShort(), 2.toUShort()), value[tester.uShortArray])
assertEqualLists(listOf(1u, 2u), value[tester.uIntArray])
assertEqualLists(listOf(1L, 2L), value[tester.longArray])
assertEqualLists(listOf(1UL, 2UL), value[tester.uLongArray])
assertEqualLists(listOf(1.1f, 2.2f), value[tester.floatArray])
assertEqualLists(listOf(1.1, 2.2), value[tester.doubleArray])
assertEqualLists(listOf('a', 'b'), value[tester.charArray])
assertEqualLists(uuidList, value[tester.uuidArray])
}
}
}

0 comments on commit eb32ccc

Please sign in to comment.