Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to specify the number of permits to acquire and release #1553

Closed
wants to merge 10 commits into from
Closed
9 changes: 9 additions & 0 deletions kotlinx-coroutines-core/api/kotlinx-coroutines-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -1147,10 +1147,19 @@ public final class kotlinx/coroutines/sync/MutexKt {
}

public abstract interface class kotlinx/coroutines/sync/Semaphore {
public abstract fun acquire (ILkotlin/coroutines/Continuation;)Ljava/lang/Object;
public abstract fun acquire (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public abstract fun getAvailablePermits ()I
public abstract fun release ()V
public abstract fun release (I)V
public abstract fun tryAcquire ()Z
public abstract fun tryAcquire (I)Z
}

public final class kotlinx/coroutines/sync/Semaphore$DefaultImpls {
public static synthetic fun acquire$default (Lkotlinx/coroutines/sync/Semaphore;ILkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public static synthetic fun release$default (Lkotlinx/coroutines/sync/Semaphore;IILjava/lang/Object;)V
public static synthetic fun tryAcquire$default (Lkotlinx/coroutines/sync/Semaphore;IILjava/lang/Object;)Z
}

public final class kotlinx/coroutines/sync/SemaphoreKt {
Expand Down
277 changes: 210 additions & 67 deletions kotlinx-coroutines-core/common/src/sync/Semaphore.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.internal.*
import kotlin.coroutines.*
import kotlin.jvm.*
import kotlin.math.*
import kotlin.native.concurrent.*

Expand Down Expand Up @@ -45,19 +44,63 @@ public interface Semaphore {
*/
public suspend fun acquire()

/**
* Acquires the given number of permits from this semaphore, suspending until ones are available.
* All suspending acquirers are processed in first-in-first-out (FIFO) order.
*
* This suspending function is cancellable. If the [Job] of the current coroutine is cancelled or completed while this
* function is suspended, this function immediately resumes with [CancellationException].
*
* *Cancellation of suspended semaphore acquisition is atomic* -- when this function
* throws [CancellationException] it means that the semaphore was not acquired.
*
* Note, that this function does not check for cancellation when it does not suspend.
* Use [CoroutineScope.isActive] or [CoroutineScope.ensureActive] to periodically
* check for cancellation in tight loops if needed.
*
* Use [tryAcquire] to try acquire the given number of permits of this semaphore without suspension.
*
* @param permits the number of permits to acquire
*
* @throws [IllegalArgumentException] if [permits] is less than or equal to zero.
*/
public suspend fun acquire(permits: Int)

/**
* Tries to acquire a permit from this semaphore without suspension.
*
* @return `true` if a permit was acquired, `false` otherwise.
*/
public fun tryAcquire(): Boolean

/**
* Tries to acquire the given number of permits from this semaphore without suspension.
*
* @param permits the number of permits to acquire
* @return `true` if all permits were acquired, `false` otherwise.
*
* @throws [IllegalArgumentException] if [permits] is less than or equal to zero.
*/
public fun tryAcquire(permits: Int): Boolean

/**
* Releases a permit, returning it into this semaphore. Resumes the first
* suspending acquirer if there is one at the point of invocation.
* Throws [IllegalStateException] if the number of [release] invocations is greater than the number of preceding [acquire].
* suspending acquirer if there is one at the point of invocation and the requested number of permits is available.
*
* @throws [IllegalStateException] if the number of [release] invocations is greater than the number of preceding [acquire].
*/
public fun release()

/**
* Releases the given number of permits, returning them into this semaphore. Resumes the first
* suspending acquirer if there is one at the point of invocation and the requested number of permits is available.
*
* @param permits the number of permits to release
*
* @throws [IllegalArgumentException] if [permits] is less than or equal to zero.
* @throws [IllegalStateException] if the number of [release] invocations is greater than the number of preceding [acquire].
*/
public fun release(permits: Int)
}

/**
Expand Down Expand Up @@ -101,8 +144,8 @@ private class SemaphoreImpl(
* and the maximum number of waiting acquirers cannot be greater than 2^31 in any
* real application.
*/
private val _availablePermits = atomic(permits - acquiredPermits)
override val availablePermits: Int get() = max(_availablePermits.value, 0)
private val permitsBalance = atomic(permits - acquiredPermits)
override val availablePermits: Int get() = max(permitsBalance.value, 0)

// The queue of waiting acquirers is essentially an infinite array based on `SegmentQueue`;
// each segment contains a fixed number of slots. To determine a slot for each enqueue
Expand All @@ -112,105 +155,205 @@ private class SemaphoreImpl(
private val enqIdx = atomic(0L)
private val deqIdx = atomic(0L)

/**
* The remaining permits from release operations, which could not be spent, because the next slot was not defined
*/
internal val accumulator = atomic(0)

override fun tryAcquire(): Boolean {
_availablePermits.loop { p ->
if (p <= 0) return false
if (_availablePermits.compareAndSet(p, p - 1)) return true
return tryAcquire(1)
}

override fun tryAcquire(permits: Int): Boolean {
require(permits > 0) { "The number of acquired permits must be greater than 0" }
permitsBalance.loop { p ->
if (p < permits) return false
if (permitsBalance.compareAndSet(p, p - permits)) return true
}
}

override suspend fun acquire() {
val p = _availablePermits.getAndDecrement()
if (p > 0) return // permit acquired
addToQueueAndSuspend()
return acquire(1)
}

override suspend fun acquire(permits: Int) {
require(permits > 0) { "The number of acquired permits must be greater than 0" }
val p = permitsBalance.getAndAdd(-permits)
if (p >= permits) return // permits are acquired
tryToAddToQueue(permits)
}

override fun release() {
val p = incPermits()
release(1)
}

override fun release(permits: Int) {
require(permits > 0) { "The number of released permits must be greater than 0" }
val p = incPermits(permits)
if (p >= 0) return // no waiters
resumeNextFromQueue()
tryToResumeFromQueue(permits)
}

fun incPermits() = _availablePermits.getAndUpdate { cur ->
check(cur < permits) { "The number of released permits cannot be greater than $permits" }
cur + 1
internal fun incPermits(delta: Int = 1) = permitsBalance.getAndUpdate { cur ->
assert { delta >= 1 }
check(cur + delta <= permits) { "The number of released permits cannot be greater than $permits" }
cur + delta
}

private suspend fun addToQueueAndSuspend() = suspendAtomicCancellableCoroutineReusable<Unit> sc@ { cont ->
private suspend fun tryToAddToQueue(permits: Int) = suspendAtomicCancellableCoroutine<Unit> sc@{ cont ->
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how permits is used by this method. I'm very surprised that tests pass. Seems like some tests are missing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous commit didn't contain significant changes, sorry. Please, recheck.

val last = this.tail
val enqIdx = enqIdx.getAndIncrement()
val segment = getSegment(last, enqIdx / SEGMENT_SIZE)
val i = (enqIdx % SEGMENT_SIZE).toInt()
if (segment === null || segment.get(i) === RESUMED || !segment.cas(i, null, cont)) {
// already resumed
val enqueueId = enqIdx.getAndIncrement()
val segmentId = enqueueId / SEGMENT_SIZE
val segment = getSegment(last, segmentId)
if (segment == null) {
// The segment is already removed
// Probably, this is the unreachable case
cont.resume(Unit)
return@sc
} else {
val slotId = (enqueueId % SEGMENT_SIZE).toInt()
val prevCont = segment.continuations[slotId].getAndSet(cont)
// It is safe to set continuation, because this slot is not defined yet, so another threads can not use it
assert { prevCont == null }
val prevSlot = segment.slots[slotId].getAndSet(permits)
// The assertion is true, cause [RESUMED] can be set up only after [SUSPENDED]
// and [CANCELLED] can be set up only in the handler, which will be added next
assert { prevSlot == null }
cont.invokeOnCancellation(CancelSemaphoreAcquisitionHandler(this, segment, slotId, permits).asHandler)
}
cont.invokeOnCancellation(CancelSemaphoreAcquisitionHandler(this, segment, i).asHandler)
// Help to resume slots, if accumulator has permits
tryToResumeFromQueue(0)
}

@Suppress("UNCHECKED_CAST")
internal fun resumeNextFromQueue() {
try_again@while (true) {
val first = this.head
val deqIdx = deqIdx.getAndIncrement()
val segment = getSegmentAndMoveHead(first, deqIdx / SEGMENT_SIZE) ?: continue@try_again
val i = (deqIdx % SEGMENT_SIZE).toInt()
val cont = segment.getAndSet(i, RESUMED)
if (cont === null) return // just resumed
if (cont === CANCELLED) continue@try_again
(cont as CancellableContinuation<Unit>).resume(Unit)
internal fun tryToResumeFromQueue(permits: Int) {
val accumulated = accumulator.getAndSet(0) // try to take possession of all the accumulated permits at the moment
var remain = permits + accumulated
if (remain == 0) {
// The accumulator had not any permits or the another thread stole permits. Also this method called with zero permits.
return
}
try_again@ while (true) {
val first = this.head
val dequeueId = deqIdx.value
val segmentId = dequeueId / SEGMENT_SIZE
val segment = getSegmentAndMoveHead(first, segmentId)
if (segment == null) {
// The segment is already removed
// Try to help to increment [deqIdx] once, because multiple threads can increment the [deqIdx] in parallel otherwise
deqIdx.compareAndSet(dequeueId, dequeueId + 1)
continue@try_again
}
val slotId = (dequeueId % SEGMENT_SIZE).toInt()
val slot = segment.slots[slotId].value
if (slot == null) {
// If the slot is not defined yet we can't spent permits for it, so return [remain] to [accumulator]
accumulator.addAndGet(remain)
return
}
if (slot == CANCELLED) {
// The slot was cancelled in the another thread
// Try to help to increment [deqIdx] once, because multiple threads can increment the [deqIdx] in parallel otherwise
if (deqIdx.compareAndSet(dequeueId, dequeueId + 1)) {
removeSegmentIfNeeded(segment, dequeueId + 1)
}
continue@try_again
}
if (slot == RESUMED) {
// The slot was updated in the another thread
// The another thread was supposed to increment [deqIdx]
continue@try_again
}
val diff = min(slot, remain) // How many permits we can spent for the slot at most
val newSlot = slot - diff
if (!segment.slots[slotId].compareAndSet(slot, newSlot)) {
// The slot was updated in another thread, let's try again
continue@try_again
}
// Here we successfully updated the slot
remain -= diff // remove spent permits
if (newSlot == RESUMED) {
segment.continuations[slotId].value!!.resume(Unit)
removeSegmentIfNeeded(segment, deqIdx.incrementAndGet())
}
if (remain == 0) {
// We spent all available permits, so let's finish
return
}
// We still have permits, so we continue to spent them
}
}

/**
* Remove the segment if needed. The method checks, that all segment's slots were processed
*
* @param segment the segment to validation
* @param dequeueId the current dequeue operation ID
*/
internal fun removeSegmentIfNeeded(segment: SemaphoreSegment, dequeueId: Long) {
val slotId = (dequeueId % SEGMENT_SIZE).toInt()
if (slotId == SEGMENT_SIZE) {
segment.remove()
}
}

override fun toString(): String {
return "Semaphore=(balance=${permitsBalance.value}, accumulator=${accumulator.value})"
}
}

/**
* Cleans the acquirer slot located by the specified index and removes this segment physically if all slots are cleaned.
*/
private class CancelSemaphoreAcquisitionHandler(
private val semaphore: SemaphoreImpl,
private val segment: SemaphoreSegment,
private val index: Int
private val semaphore: SemaphoreImpl,
private val segment: SemaphoreSegment,
private val slotId: Int,
private val permits: Int
) : CancelHandler() {
override fun invoke(cause: Throwable?) {
val p = semaphore.incPermits()
// Don't wait and use [prevSlot] to handle permits, because it starts races with release (see StressTest)
val p = semaphore.incPermits(permits)
if (p >= 0) return
if (segment.cancel(index)) return
semaphore.resumeNextFromQueue()
// Copy [slotId] to local variable to prevent exception:
// "Complex data flow is not allowed for calculation of an array element index at the point of loading the reference to this element."
val temp = slotId
val prevSlot = segment.slots[temp].getAndSet(CANCELLED)
// The assertion is true, cause the slot has [SUSPENDED] state at least
assert { prevSlot != null }

// Remove this segment if needed
if (segment.cancelledSlots.incrementAndGet() == SEGMENT_SIZE) {
segment.remove()
}
if (prevSlot == RESUMED) {
// The slot has already resumed, so return free permits to the semaphore
semaphore.tryToResumeFromQueue(prevSlot)
}
}

override fun toString() = "CancelSemaphoreAcquisitionHandler[$semaphore, $segment, $index]"
override fun toString() = "CancelSemaphoreAcquisitionHandler[$semaphore, $segment, $slotId]"
}

private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?): Segment<SemaphoreSegment>(id, prev) {
val acquirers = atomicArrayOfNulls<Any?>(SEGMENT_SIZE)
private val cancelledSlots = atomic(0)
private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?) : Segment<SemaphoreSegment>(id, prev) {
val continuations = atomicArrayOfNulls<CancellableContinuation<Unit>>(SEGMENT_SIZE)
/**
* Each slot can contain one of following values:
* 1. A number greater than zero. It is [SUSPENDED] state;
* 2. Zero. It is [RESUMED] state;
* 3. "-1". It is [CANCELLED] state;
* 4. "null". The slot is not defined yet.
*/
val slots = atomicArrayOfNulls<Int>(SEGMENT_SIZE)
val cancelledSlots = atomic(0)
override val removed get() = cancelledSlots.value == SEGMENT_SIZE

@Suppress("NOTHING_TO_INLINE")
inline fun get(index: Int): Any? = acquirers[index].value

@Suppress("NOTHING_TO_INLINE")
inline fun cas(index: Int, expected: Any?, value: Any?): Boolean = acquirers[index].compareAndSet(expected, value)

@Suppress("NOTHING_TO_INLINE")
inline fun getAndSet(index: Int, value: Any?) = acquirers[index].getAndSet(value)

// Cleans the acquirer slot located by the specified index
// and removes this segment physically if all slots are cleaned.
fun cancel(index: Int): Boolean {
// Try to cancel the slot
val cancelled = getAndSet(index, CANCELLED) !== RESUMED
// Remove this segment if needed
if (cancelledSlots.incrementAndGet() == SEGMENT_SIZE)
remove()
return cancelled
override fun toString(): String {
return "SemaphoreSegment(id=$id)"
}

override fun toString() = "SemaphoreSegment[id=$id, hashCode=${hashCode()}]"
}

@SharedImmutable
private val RESUMED = Symbol("RESUMED")
private val RESUMED = 0
@SharedImmutable
private val CANCELLED = Symbol("CANCELLED")
private val CANCELLED = -1
@SharedImmutable
private val SEGMENT_SIZE = systemProp("kotlinx.coroutines.semaphore.segmentSize", 16)
Loading