Skip to content

Commit

Permalink
Introduce SegmentQueueSynchronizer abstraction for synchronization …
Browse files Browse the repository at this point in the history
…primitives and `ReadWriteMutex`

Signed-off-by: Nikita Koval <[email protected]>
  • Loading branch information
ndkoval committed Feb 13, 2023
1 parent 2f8744c commit 43b6be5
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package kotlinx.coroutines.internal

import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.internal.SegmentQueueSynchronizer.*
import kotlinx.coroutines.internal.SegmentQueueSynchronizer.CancellationMode.*
import kotlinx.coroutines.internal.SegmentQueueSynchronizer.ResumeMode.*
Expand Down Expand Up @@ -208,6 +209,42 @@ internal abstract class SegmentQueueSynchronizer<T : Any> {
returnValue(value)
}

internal fun suspendCancelled(): T? {
// Increment `suspendIdx` and find the segment
// with the corresponding id. It is guaranteed
// that this segment is not removed since at
// least the cell for this `suspend` invocation
// is not in the `CANCELLED` state.
val curSuspendSegm = this.suspendSegment.value
val suspendIdx = suspendIdx.getAndIncrement()
val segment = this.suspendSegment.findSegmentAndMoveForward(id = suspendIdx / SEGMENT_SIZE, startFrom = curSuspendSegm,
createNewSegment = ::createSegment).segment
assert { segment.id == suspendIdx / SEGMENT_SIZE }
// Try to install the waiter into the cell - this is the regular path.
val i = (suspendIdx % SEGMENT_SIZE).toInt()
if (segment.cas(i, null, CANCELLED)) {
// The continuation is successfully installed, and
// `resume` cannot break the cell now, so this
// suspension is successful.
// Add a cancellation handler if required and finish.
return null
}
// The continuation installation has failed. This happened because a concurrent
// `resume` came earlier to this cell and put its value into it. Remember that
// in the `SYNC` resumption mode this concurrent `resume` can mark the cell as broken.
//
// Try to grab the value if the cell is not in the `BROKEN` state.
val value = segment.get(i)
if (value !== BROKEN && segment.cas(i, value, TAKEN)) {
// The elimination is performed successfully,
// complete with the value stored in the cell.
return value as T
}
// The cell is broken, this can happen only in the `SYNC` resumption mode.
assert { resumeMode == SYNC && segment.get(i) === BROKEN }
return null
}

@Suppress("UNCHECKED_CAST")
internal fun suspend(waiter: Waiter): Boolean {
// Increment `suspendIdx` and find the segment
Expand Down Expand Up @@ -359,17 +396,30 @@ internal abstract class SegmentQueueSynchronizer<T : Any> {
return TRY_RESUME_SUCCESS
}
// Does the cell store a cancellable continuation?
cellState is CancellableContinuation<*> -> {
cellState is Waiter -> {
// Change the cell state to `RESUMED`, so
// the cancellation handler cannot be invoked
// even if the continuation becomes cancelled.
if (!segment.cas(i, cellState, RESUMED)) continue@modify_cell
// Try to resume the continuation.
val token = (cellState as CancellableContinuation<T>).tryResume(value, null, { returnValue(value) })
if (token != null) {
// Hooray, the continuation is successfully resumed!
cellState.completeResume(token)
} else {
val resumed = when(cellState) {
is CancellableContinuation<*> -> {
(cellState as CancellableContinuation<T>)
val token = cellState.tryResume(value, null, { returnValue(value) })
if (token != null) {
// Hooray, the continuation is successfully resumed!
cellState.completeResume(token)
true
} else {
false
}
}
is SelectInstance<*> -> {
cellState.trySelect(this@SegmentQueueSynchronizer, value)
}
else -> error("unexpected")
}
if (!resumed) {
// Unfortunately, the continuation resumption has failed.
// Fail the current `resume` if the simple cancellation mode is used.
if (cancellationMode === SIMPLE)
Expand Down Expand Up @@ -552,7 +602,7 @@ internal abstract class SegmentQueueSynchronizer<T : Any> {
val cellState = get(index)
when {
cellState === RESUMED -> return false
cellState is CancellableContinuation<*> -> {
cellState is Waiter -> {
if (cas(index, cellState, CANCELLING)) return true
}
else -> {
Expand Down
151 changes: 124 additions & 27 deletions kotlinx-coroutines-core/common/src/sync/Mutex.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import kotlinx.coroutines.*
import kotlinx.coroutines.internal.*
import kotlinx.coroutines.selects.*
import kotlin.contracts.*
import kotlin.coroutines.*
import kotlin.jvm.*

/**
Expand Down Expand Up @@ -131,7 +132,7 @@ public suspend inline fun <T> Mutex.withLock(owner: Any? = null, action: () -> T
}


internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1 else 0), Mutex {
internal open class MutexImpl(locked: Boolean) : SegmentQueueSynchronizer<Unit>(), Mutex {
/**
* After the lock is acquired, the corresponding owner is stored in this field.
* The [unlock] operation checks the owner and either re-sets it to [NO_OWNER],
Expand All @@ -140,13 +141,15 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1
*/
private val owner = atomic<Any?>(if (locked) null else NO_OWNER)

private val availablePermits = atomic(if (locked) 0 else 1)

private val onSelectCancellationUnlockConstructor: OnCancellationConstructor =
{ _: SelectInstance<*>, owner: Any?, _: Any? ->
{ unlock(owner) }
}

override val isLocked: Boolean get() =
availablePermits == 0
availablePermits.value <= 0

override fun holdsLock(owner: Any): Boolean {
while (true) {
Expand All @@ -161,13 +164,84 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1
}

override suspend fun lock(owner: Any?) {
if (tryLock(owner)) return
// if (tryLock(owner)) return
lockSuspend(owner)
}

private suspend fun lockSuspend(owner: Any?) = suspendCancellableCoroutineReusable { cont ->
private suspend fun lockSuspend(owner: Any?) = suspendCancellableCoroutineReusable<Unit> { cont ->
cont as CancellableContinuationImpl<Unit>
val contWithOwner = CancellableContinuationWithOwner(cont, owner)
acquire(contWithOwner)
lockImpl(contWithOwner, owner)
}

private fun lockImpl(waiter: Waiter, owner: Any?) {
xxx@ while (true) {
// Get the current number of available permits.
val p = availablePermits.getAndDecrement()
// Try to decrement the number of available
// permits if it is greater than zero.
if (p <= 0) {
// The semaphore permit acquisition has failed.
// However, we need to check that this mutex is not
// locked by our owner.
if (owner != null) {
// Is this mutex locked by our owner?
var curOwner = this.owner.value

if (curOwner === owner) {
if (suspendCancelled() != null) release()
when (waiter) {
is CancellableContinuation<*> -> {
waiter.resumeWithException(IllegalStateException("ERROR"))
}
is SelectInstance<*> -> {
waiter.selectInRegistrationPhase(ON_LOCK_ALREADY_LOCKED_BY_OWNER)
}
}
return
}

while (curOwner === NO_OWNER) {
curOwner = this.owner.value
if (!isLocked) {
if (suspendCancelled() != null) release()
continue@xxx
}
}
if (curOwner === owner) {
if (suspendCancelled() != null) release()
when (waiter) {
is CancellableContinuation<*> -> {
waiter.resumeWithException(IllegalStateException("ERROR"))
}
is SelectInstance<*> -> {
waiter.selectInRegistrationPhase(ON_LOCK_ALREADY_LOCKED_BY_OWNER)
}
}
return
}
// This mutex is either locked by another owner or unlocked.
// In the latter case, it is possible that it WAS locked by
// our owner when the semaphore permit acquisition has failed.
// To preserve linearizability, the operation restarts in this case.
// if (!isLocked) continuex
}
if (suspend(waiter)) return
} else {
assert { p == 1 }
assert { this.owner.value === NO_OWNER }
when (waiter) {
is CancellableContinuation<*> -> {
waiter as CancellableContinuation<Unit>
waiter.resume(Unit, null)
}
is SelectInstance<*> -> {
waiter.selectInRegistrationPhase(Unit)
}
}
return
}
}
}

override fun tryLock(owner: Any?): Boolean = when (tryLockImpl(owner)) {
Expand All @@ -179,25 +253,27 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1

private fun tryLockImpl(owner: Any?): Int {
while (true) {
if (tryAcquire()) {
assert { this.owner.value === NO_OWNER }
this.owner.value = owner
return TRY_LOCK_SUCCESS
} else {
// Get the current number of available permits.
val p = availablePermits.value
// Try to decrement the number of available
// permits if it is greater than zero.
if (p <= 0) {
// The semaphore permit acquisition has failed.
// However, we need to check that this mutex is not
// locked by our owner.
if (owner != null) {
// Is this mutex locked by our owner?
if (holdsLock(owner)) return TRY_LOCK_ALREADY_LOCKED_BY_OWNER
// This mutex is either locked by another owner or unlocked.
// In the latter case, it is possible that it WAS locked by
// our owner when the semaphore permit acquisition has failed.
// To preserve linearizability, the operation restarts in this case.
if (!isLocked) continue
val curOwner = this.owner.value
if (curOwner === NO_OWNER) continue
if (curOwner === owner) return TRY_LOCK_ALREADY_LOCKED_BY_OWNER
}
return TRY_LOCK_FAILED
}
if (availablePermits.compareAndSet(p, p - 1)) {
assert { this.owner.value === NO_OWNER }
this.owner.value = owner
return TRY_LOCK_SUCCESS
}
}
}

Expand All @@ -218,6 +294,27 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1
}
}

fun release() {
while (true) {
// Increment the number of available permits.
val p = availablePermits.value
// Is this `release` call correct and does not
// exceed the maximal number of permits?
if (p >= 1) {
error("This mutex is not locked")
}
if (availablePermits.compareAndSet(p, p + 1)) {
// Is there a waiter that should be resumed?
if (p == 0) return
// Try to resume the first waiter, and
// restart the operation if either this
// first waiter is cancelled or
// due to `SYNC` resumption mode.
if (resume(Unit)) return
}
}
}

@Suppress("UNCHECKED_CAST", "OverridingDeprecatedMember", "OVERRIDE_DEPRECATION")
override val onLock: SelectClause2<Any?, Mutex> get() = SelectClause2Impl(
clauseObject = this,
Expand All @@ -227,11 +324,7 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1
)

protected open fun onLockRegFunction(select: SelectInstance<*>, owner: Any?) {
if (owner != null && holdsLock(owner)) {
select.selectInRegistrationPhase(ON_LOCK_ALREADY_LOCKED_BY_OWNER)
} else {
onAcquireRegFunction(SelectInstanceWithOwner(select, owner), owner)
}
lockImpl(SelectInstanceWithOwner(select as SelectInstanceInternal<*>, owner), owner)
}

protected open fun onLockProcessResult(owner: Any?, result: Any?): Any? {
Expand All @@ -243,10 +336,10 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1

private inner class CancellableContinuationWithOwner(
@JvmField
val cont: CancellableContinuation<Unit>,
val cont: CancellableContinuationImpl<Unit>,
@JvmField
val owner: Any?
) : CancellableContinuation<Unit> by cont {
) : CancellableContinuation<Unit> by cont, Waiter by cont {
override fun tryResume(value: Unit, idempotent: Any?, onCancellation: ((cause: Throwable) -> Unit)?): Any? {
assert { this@MutexImpl.owner.value === NO_OWNER }
val token = cont.tryResume(value, idempotent) {
Expand All @@ -270,10 +363,10 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1

private inner class SelectInstanceWithOwner<Q>(
@JvmField
val select: SelectInstance<Q>,
val select: SelectInstanceInternal<Q>,
@JvmField
val owner: Any?
) : SelectInstanceInternal<Q> by select as SelectInstanceInternal<Q> {
) : SelectInstanceInternal<Q> by select {
override fun trySelect(clauseObject: Any, result: Any?): Boolean {
assert { this@MutexImpl.owner.value === NO_OWNER }
return select.trySelect(clauseObject, result).also { success ->
Expand All @@ -282,12 +375,16 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1
}

override fun selectInRegistrationPhase(internalResult: Any?) {
assert { this@MutexImpl.owner.value === NO_OWNER }
this@MutexImpl.owner.value = owner
if (internalResult !== ON_LOCK_ALREADY_LOCKED_BY_OWNER) {
assert { this@MutexImpl.owner.value === NO_OWNER }
this@MutexImpl.owner.value = owner
}
select.selectInRegistrationPhase(internalResult)
}
}

internal val debugStateRepresentation: String get() = "p=${availablePermits.value},owner=${owner.value},SQS=${super.toString()}"

override fun toString() = "Mutex@${hexAddress}[isLocked=$isLocked,owner=${owner.value}]"
}

Expand Down
6 changes: 3 additions & 3 deletions kotlinx-coroutines-core/common/src/sync/ReadWriteMutex.kt
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ internal class ReadWriteMutexImpl : ReadWriteMutex, Mutex {
// The number of waiting readers was incremented
// correctly, wait for a reader lock in `sqsReaders`.
suspendCancellableCoroutineReusable<Unit> { cont ->
sqsReaders.suspend(cont)
sqsReaders.suspend(cont as Waiter)
}
return
} else {
Expand All @@ -286,7 +286,7 @@ internal class ReadWriteMutexImpl : ReadWriteMutex, Mutex {
// when this concurrent `write.unlock()` completes.
if (wr == 0) {
suspendCancellableCoroutineReusable<Unit> { cont ->
sqsReaders.suspend(cont)
sqsReaders.suspend(cont as Waiter)
}
return
}
Expand Down Expand Up @@ -399,7 +399,7 @@ internal class ReadWriteMutexImpl : ReadWriteMutex, Mutex {
// Try to increment the number of waiting writers and suspend in `sqsWriters`.
if (state.compareAndSet(s, state(s.ar, s.wla, s.ww + 1, s.rwr))) {
suspendCancellableCoroutineReusable<Unit> { cont ->
sqsWriters.suspend(cont)
sqsWriters.suspend(cont as Waiter)
}
return
}
Expand Down
6 changes: 3 additions & 3 deletions kotlinx-coroutines-core/common/src/sync/Semaphore.kt
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,18 @@ internal open class SemaphoreImpl(
}
}

private suspend fun acquireSlowPath() = suspendCancellableCoroutineReusable<Unit> sc@{ cont ->
private suspend fun acquireSlowPath() = suspendCancellableCoroutineReusable sc@{ cont ->
while (true) {
// Try to suspend.
if (suspend(cont)) return@sc
if (suspend(cont as Waiter)) return@sc
// The suspension has been failed
// due to the synchronous resumption mode.
// Restart the whole `acquire`, and decrement
// the number of available permits at first.
val p = decPermits()
// Is the permit acquired?
if (p > 0) {
cont.resume(Unit)
cont.resume(Unit) { release() }
return@sc
}
// Permit has not been acquired, go to
Expand Down
Loading

0 comments on commit 43b6be5

Please sign in to comment.