Skip to content

Commit

Permalink
fix(model-server): deadlock caused by non-existing lock ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
slisson committed Dec 12, 2024
1 parent 9d36c70 commit 2670d49
Show file tree
Hide file tree
Showing 27 changed files with 1,036 additions and 392 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ class HealthApiImpl(

private fun isHealthy(): Boolean {
val store = stores.getGlobalStoreClient()
val value = toLong(store[HEALTH_KEY]) + 1
store.put(HEALTH_KEY, java.lang.Long.toString(value))
return toLong(store[HEALTH_KEY]) >= value
return store.getTransactionManager().runWrite {
val value = toLong(store[HEALTH_KEY]) + 1
store.put(HEALTH_KEY, java.lang.Long.toString(value))
toLong(store[HEALTH_KEY]) >= value
}
}

private fun toLong(value: String?): Long {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.modelix.model.lazy.CLVersion
import org.modelix.model.lazy.IDeserializingKeyValueStore
import org.modelix.model.lazy.RepositoryId
import org.modelix.model.server.store.IStoreClient
import org.modelix.model.server.store.ITransactionManager
import org.modelix.model.server.store.StoreManager

interface IRepositoriesManager {
Expand All @@ -18,25 +19,23 @@ interface IRepositoriesManager {
* If the server ID was created previously but is only stored under a legacy database key,
* it also gets stored under the current and all legacy database keys.
*/
suspend fun maybeInitAndGetSeverId(): String
fun maybeInitAndGetSeverId(): String
fun getRepositories(): Set<RepositoryId>
suspend fun createRepository(repositoryId: RepositoryId, userName: String?, useRoleIds: Boolean = true, legacyGlobalStorage: Boolean = false): CLVersion
suspend fun removeRepository(repository: RepositoryId): Boolean
fun createRepository(repositoryId: RepositoryId, userName: String?, useRoleIds: Boolean = true, legacyGlobalStorage: Boolean = false): CLVersion
fun removeRepository(repository: RepositoryId): Boolean

fun getBranches(repositoryId: RepositoryId): Set<BranchReference>

suspend fun removeBranches(repository: RepositoryId, branchNames: Set<String>)

/**
* Same as [removeBranches] but blocking.
* Caller is expected to execute it outside the request thread.
*/
fun removeBranchesBlocking(repository: RepositoryId, branchNames: Set<String>)
suspend fun getVersion(branch: BranchReference): CLVersion?
suspend fun getVersion(repository: RepositoryId, versionHash: String): CLVersion?
suspend fun getVersionHash(branch: BranchReference): String?
fun removeBranches(repository: RepositoryId, branchNames: Set<String>)
fun getVersion(branch: BranchReference): CLVersion?
fun getVersion(repository: RepositoryId, versionHash: String): CLVersion?
fun getVersionHash(branch: BranchReference): String?
suspend fun pollVersionHash(branch: BranchReference, lastKnown: String?): String
suspend fun mergeChanges(branch: BranchReference, newVersionHash: String): String
fun mergeChanges(branch: BranchReference, newVersionHash: String): String

/**
* Same as [mergeChanges] but blocking.
Expand All @@ -54,14 +53,15 @@ interface IRepositoriesManager {
fun isIsolated(repository: RepositoryId): Boolean?

fun getStoreManager(): StoreManager
fun getTransactionManager(): ITransactionManager
}

fun IRepositoriesManager.getBranchNames(repositoryId: RepositoryId): Set<String> {
return getBranches(repositoryId).map { it.branchName }.toSet()
}

fun IRepositoriesManager.getStoreClient(repository: RepositoryId?): IStoreClient {
return getStoreManager().getStoreClient(repository?.takeIf { isIsolated(it) ?: false })
fun IRepositoriesManager.getStoreClient(repository: RepositoryId?, immutable: Boolean): IStoreClient {
return getStoreManager().getStoreClient(repository?.takeIf { isIsolated(it) ?: false }, immutable)
}

fun IRepositoriesManager.getAsyncStore(repository: RepositoryId?): IAsyncObjectStore {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import io.ktor.server.resources.put
import io.ktor.server.response.respondText
import io.ktor.server.routing.routing
import io.ktor.util.pipeline.PipelineContext
import kotlinx.coroutines.runBlocking
import kotlinx.html.br
import kotlinx.html.div
import kotlinx.html.h1
Expand All @@ -31,7 +30,8 @@ import org.modelix.model.server.ModelServerPermissionSchema
import org.modelix.model.server.store.ObjectInRepository
import org.modelix.model.server.store.StoreManager
import org.modelix.model.server.store.pollEntry
import org.modelix.model.server.store.runTransactionSuspendable
import org.modelix.model.server.store.runReadIO
import org.modelix.model.server.store.runWriteIO
import org.modelix.model.server.templates.PageWithMenuBar
import java.io.IOException
import java.util.*
Expand Down Expand Up @@ -61,7 +61,7 @@ class KeyValueLikeModelServer(
// request to initialize it lazily, would make the code less robust.
// Each change in the logic of RepositoriesManager#maybeInitAndGetSeverId would need
// the special conditions in the affected requests to be updated.
runBlocking { repositoriesManager.maybeInitAndGetSeverId() }
repositoriesManager.getTransactionManager().runWrite { repositoriesManager.maybeInitAndGetSeverId() }
application.apply {
modelServerModule()
}
Expand Down Expand Up @@ -89,7 +89,7 @@ class KeyValueLikeModelServer(
get<Paths.getKeyGet> {
val key = call.parameters["key"]!!
checkKeyPermission(key, EPermissionType.READ)
val value = stores.getGlobalKeyValueStore()[key]
val value = runRead { stores.getGlobalStoreClient()[key] }
respondValue(key, value)
}
get<Paths.pollKeyGet> {
Expand All @@ -106,21 +106,23 @@ class KeyValueLikeModelServer(
post<Paths.counterKeyPost> {
val key = call.parameters["key"]!!
checkKeyPermission(key, EPermissionType.WRITE)
val value = stores.getGlobalStoreClient().generateId(key)
val value = stores.getGlobalStoreClient(false).generateId(key)
call.respondText(text = value.toString())
}

get<Paths.getRecursivelyKeyGet> {
val key = call.parameters["key"]!!
checkKeyPermission(key, EPermissionType.READ)
call.respondText(collect(key, this).toString(2), contentType = ContentType.Application.Json)
call.respondText(runRead { collect(key, this) }.toString(2), contentType = ContentType.Application.Json)
}

put<Paths.putKeyPut> {
val key = call.parameters["key"]!!
val value = call.receiveText()
try {
putEntries(mapOf(key to value))
runWrite {
putEntries(mapOf(key to value))
}
call.respondText("OK")
} catch (e: NotFoundException) {
throw HttpException(HttpStatusCode.NotFound, title = "Not found", details = e.message, cause = e)
Expand All @@ -139,7 +141,9 @@ class KeyValueLikeModelServer(
}
entries = sortByDependency(entries)
try {
putEntries(entries)
runWrite {
putEntries(entries)
}
call.respondText(entries.size.toString() + " entries written")
} catch (e: NotFoundException) {
throw HttpException(HttpStatusCode.NotFound, title = "Not found", details = e.message, cause = e)
Expand All @@ -158,7 +162,7 @@ class KeyValueLikeModelServer(
checkKeyPermission(key, EPermissionType.READ)
keys.add(key)
}
val values = stores.getGlobalStoreClient().getAll(keys)
val values = runRead { stores.getGlobalStoreClient(false).getAll(keys) }
for (i in keys.indices) {
val respEntry = JSONObject()
respEntry.put("key", keys[i])
Expand Down Expand Up @@ -210,7 +214,7 @@ class KeyValueLikeModelServer(
if (callContext != null) {
keys.forEach { callContext.checkKeyPermission(it, EPermissionType.READ) }
}
val values = stores.getGlobalStoreClient().getAll(keys)
val values = stores.getGlobalStoreClient(false).getAll(keys)
for (i in keys.indices) {
val key = keys[i]
val value = values[i]
Expand Down Expand Up @@ -240,7 +244,7 @@ class KeyValueLikeModelServer(
return result
}

private suspend fun CallContext.putEntries(newEntries: Map<String, String?>) {
private fun CallContext.putEntries(newEntries: Map<String, String?>) {
val referencedKeys: MutableSet<String> = HashSet()
for ((key, value) in newEntries) {
checkKeyPermission(key, EPermissionType.WRITE)
Expand Down Expand Up @@ -300,17 +304,15 @@ class KeyValueLikeModelServer(
// We could try to move the objects later, but since this API is deprecated, it's not worth the effort.
}

stores.getGlobalStoreClient().runTransactionSuspendable {
stores.genericStore.putAll(hashedObjects.mapKeys { ObjectInRepository.global(it.key) })
stores.genericStore.putAll(userDefinedEntries.mapKeys { ObjectInRepository.global(it.key) })
for ((branch, value) in branchChanges) {
if (value == null) {
checkPermission(ModelServerPermissionSchema.branch(branch).delete)
repositoriesManager.removeBranchesBlocking(branch.repositoryId, setOf(branch.branchName))
} else {
checkPermission(ModelServerPermissionSchema.branch(branch).push)
repositoriesManager.mergeChangesBlocking(branch, value)
}
stores.genericStore.putAll(hashedObjects.mapKeys { ObjectInRepository.global(it.key) })
stores.genericStore.putAll(userDefinedEntries.mapKeys { ObjectInRepository.global(it.key) })
for ((branch, value) in branchChanges) {
if (value == null) {
checkPermission(ModelServerPermissionSchema.branch(branch).delete)
repositoriesManager.removeBranches(branch.repositoryId, setOf(branch.branchName))
} else {
checkPermission(ModelServerPermissionSchema.branch(branch).push)
repositoriesManager.mergeChangesBlocking(branch, value)
}
}
}
Expand Down Expand Up @@ -363,4 +365,12 @@ class KeyValueLikeModelServer(
else -> unknown()
}
}

private suspend fun <R> runRead(body: () -> R): R {
return repositoriesManager.getTransactionManager().runReadIO(body)
}

private suspend fun <R> runWrite(body: () -> R): R {
return repositoriesManager.getTransactionManager().runWriteIO(body)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ import org.modelix.model.server.api.v2.VersionDelta
import org.modelix.model.server.api.v2.VersionDeltaStream
import org.modelix.model.server.api.v2.VersionDeltaStreamV2
import org.modelix.model.server.store.StoreManager
import org.modelix.model.server.store.runReadIO
import org.modelix.model.server.store.runWriteIO
import org.modelix.modelql.core.IMemoizationPersistence
import org.modelix.modelql.core.IStepOutput
import org.modelix.modelql.core.MonoUnboundQuery
Expand Down Expand Up @@ -89,16 +91,15 @@ class ModelReplicationServer(

override suspend fun PipelineContext<Unit, ApplicationCall>.getRepositories() {
call.respondText(
repositoriesManager.getRepositories()
runRead { repositoriesManager.getRepositories() }
.filter { call.hasPermission(ModelServerPermissionSchema.repository(it).list) }
.joinToString("\n") { it.id },
)
}

override suspend fun PipelineContext<Unit, ApplicationCall>.getRepositoryBranches(repository: String) {
call.respondText(
repositoriesManager
.getBranchNames(repositoryId(repository))
runRead { repositoriesManager.getBranchNames(repositoryId(repository)) }
.filter { call.hasPermission(ModelServerPermissionSchema.repository(repository).branch(it).list) }
.joinToString("\n"),
)
Expand All @@ -111,7 +112,9 @@ class ModelReplicationServer(
) {
checkPermission(ModelServerPermissionSchema.repository(repository).branch(branch).pull)
val branchRef = repositoryId(repository).getBranchReference(branch)
val versionHash = repositoriesManager.getVersionHash(branchRef) ?: throw BranchNotFoundException(branchRef)
val versionHash = runRead {
repositoriesManager.getVersionHash(branchRef) ?: throw BranchNotFoundException(branchRef)
}
call.respondDelta(RepositoryId(repository), versionHash, lastKnown)
}

Expand All @@ -122,7 +125,7 @@ class ModelReplicationServer(
) {
checkPermission(ModelServerPermissionSchema.repository(repository).branch(branch).pull)
val branchRef = repositoryId(repository).getBranchReference(branch)
val versionHash = repositoriesManager.getVersionHash(branchRef) ?: throw BranchNotFoundException(branchRef)
val versionHash = runRead { repositoriesManager.getVersionHash(branchRef) ?: throw BranchNotFoundException(branchRef) }
call.respond(BranchV1(branch, versionHash))
}

Expand All @@ -138,11 +141,13 @@ class ModelReplicationServer(

checkPermission(ModelServerPermissionSchema.repository(repositoryId).branch(branch).delete)

if (!repositoriesManager.getBranchNames(repositoryId).contains(branch)) {
throw BranchNotFoundException(branch, repositoryId.id)
}
runWrite {
if (!repositoriesManager.getBranchNames(repositoryId).contains(branch)) {
throw BranchNotFoundException(branch, repositoryId.id)
}

repositoriesManager.removeBranches(repositoryId, setOf(branch))
repositoriesManager.removeBranches(repositoryId, setOf(branch))
}

call.respond(HttpStatusCode.NoContent)
}
Expand All @@ -153,7 +158,7 @@ class ModelReplicationServer(
) {
checkPermission(ModelServerPermissionSchema.repository(repository).branch(branch).pull)
val branchRef = repositoryId(repository).getBranchReference(branch)
val versionHash = repositoriesManager.getVersionHash(branchRef) ?: throw BranchNotFoundException(branchRef)
val versionHash = runRead { repositoriesManager.getVersionHash(branchRef) ?: throw BranchNotFoundException(branchRef) }
call.respondText(versionHash)
}

Expand All @@ -163,19 +168,23 @@ class ModelReplicationServer(
legacyGlobalStorage: Boolean?,
) {
checkPermission(ModelServerPermissionSchema.repository(repository).create)
val initialVersion = repositoriesManager.createRepository(
repositoryId(repository),
call.getUserName(),
useRoleIds ?: true,
legacyGlobalStorage ?: false,
)
val initialVersion = runWrite {
repositoriesManager.createRepository(
repositoryId(repository),
call.getUserName(),
useRoleIds ?: true,
legacyGlobalStorage ?: false,
)
}
call.respondDelta(RepositoryId(repository), initialVersion.getContentHash(), null)
}

override suspend fun PipelineContext<Unit, ApplicationCall>.deleteRepository(repository: String) {
checkPermission(ModelServerPermissionSchema.repository(repository).delete)

val foundAndDeleted = repositoriesManager.removeRepository(repositoryId(repository))
val foundAndDeleted = runWrite {
repositoriesManager.removeRepository(repositoryId(repository))
}
if (foundAndDeleted) {
call.respond(HttpStatusCode.NoContent)
} else {
Expand All @@ -191,8 +200,10 @@ class ModelReplicationServer(
val branchRef = repositoryId(repository).getBranchReference(branch)
val deltaFromClient = call.receive<VersionDelta>()
deltaFromClient.checkObjectHashes()
repositoriesManager.getStoreClient(RepositoryId(repository)).putAll(deltaFromClient.getAllObjects())
val mergedHash = repositoriesManager.mergeChanges(branchRef, deltaFromClient.versionHash)
repositoriesManager.getStoreClient(RepositoryId(repository), true).putAll(deltaFromClient.getAllObjects())
val mergedHash = runWrite {
repositoriesManager.mergeChanges(branchRef, deltaFromClient.versionHash)
}
call.respondDelta(RepositoryId(repository), mergedHash, deltaFromClient.versionHash)
}

Expand All @@ -217,7 +228,7 @@ class ModelReplicationServer(
}

val objects = withContext(Dispatchers.IO) {
repositoriesManager.getStoreClient(RepositoryId(repository)).getAll(keys)
repositoriesManager.getStoreClient(RepositoryId(repository), true).getAll(keys)
}

for (entry in objects) {
Expand Down Expand Up @@ -260,7 +271,7 @@ class ModelReplicationServer(
) {
val branchRef = repositoryId(repository).getBranchReference(branchName)
checkPermission(ModelServerPermissionSchema.branch(branchRef).query)
val version = repositoriesManager.getVersion(branchRef) ?: throw BranchNotFoundException(branchRef)
val version = runRead { repositoriesManager.getVersion(branchRef) ?: throw BranchNotFoundException(branchRef) }
LOG.trace("Running query on {} @ {}", branchRef, version)
val initialTree = version.getTree()

Expand Down Expand Up @@ -290,7 +301,9 @@ class ModelReplicationServer(
baseVersion = version,
operations = ops.map { it.getOriginalOp() }.toTypedArray(),
)
repositoriesManager.mergeChanges(branchRef, newVersion.getContentHash())
runWrite {
repositoriesManager.mergeChanges(branchRef, newVersion.getContentHash())
}
}
}
})
Expand Down Expand Up @@ -330,7 +343,7 @@ class ModelReplicationServer(
}

withContext(Dispatchers.IO) {
repositoriesManager.getStoreClient(RepositoryId(repository)).putAll(entries, true)
repositoriesManager.getStoreClient(RepositoryId(repository), true).putAll(entries, true)
}
call.respondText("${entries.size} objects received")
}
Expand All @@ -341,7 +354,7 @@ class ModelReplicationServer(
lastKnown: String?,
) {
checkPermission(ModelServerPermissionSchema.legacyGlobalObjects.read)
if (stores.getGlobalStoreClient()[versionHash] == null) {
if (runRead { stores.getGlobalStoreClient()[versionHash] } == null) {
throw VersionNotFoundException(versionHash)
}
call.respondDelta(null, versionHash, lastKnown)
Expand Down Expand Up @@ -402,6 +415,14 @@ class ModelReplicationServer(
}
}
}

private suspend fun <R> runRead(body: () -> R): R {
return repositoriesManager.getTransactionManager().runReadIO(body)
}

private suspend fun <R> runWrite(body: () -> R): R {
return repositoriesManager.getTransactionManager().runWriteIO(body)
}
}

/**
Expand Down
Loading

0 comments on commit 2670d49

Please sign in to comment.