diff --git a/kotlinx-coroutines-core/jvm/src/flow/internal/SafeCollector.kt b/kotlinx-coroutines-core/jvm/src/flow/internal/SafeCollector.kt index ea973287a7..cad3f1aea4 100644 --- a/kotlinx-coroutines-core/jvm/src/flow/internal/SafeCollector.kt +++ b/kotlinx-coroutines-core/jvm/src/flow/internal/SafeCollector.kt @@ -29,15 +29,22 @@ internal actual class SafeCollector actual constructor( @JvmField // Note, it is non-capturing lambda, so no extra allocation during init of SafeCollector internal actual val collectContextSize = collectContext.fold(0) { count, _ -> count + 1 } + + // Either context of the last emission or wrapper 'DownstreamExceptionContext' private var lastEmissionContext: CoroutineContext? = null + // Completion if we are currently suspended or within completion body or null otherwise private var completion: Continuation? = null - // ContinuationImpl + /* + * This property is accessed in two places: + * * ContinuationImpl invokes this in its `releaseIntercepted` as `context[ContinuationInterceptor]!!` + * * When we are within a callee, it is used to create its continuation object with this collector as completion + */ override val context: CoroutineContext - get() = completion?.context ?: EmptyCoroutineContext + get() = lastEmissionContext ?: EmptyCoroutineContext override fun invokeSuspend(result: Result): Any { - result.onFailure { lastEmissionContext = DownstreamExceptionElement(it) } + result.onFailure { lastEmissionContext = DownstreamExceptionContext(it, context) } completion?.resumeWith(result as Result) return COROUTINE_SUSPENDED } @@ -59,7 +66,9 @@ internal actual class SafeCollector actual constructor( emit(uCont, value) } catch (e: Throwable) { // Save the fact that exception from emit (or even check context) has been thrown - lastEmissionContext = DownstreamExceptionElement(e) + // Note, that this can the first emit and lastEmissionContext may not be saved yet, + // hence we use `uCont.context` here. + lastEmissionContext = DownstreamExceptionContext(e, uCont.context) throw e } } @@ -72,9 +81,18 @@ internal actual class SafeCollector actual constructor( val previousContext = lastEmissionContext if (previousContext !== currentContext) { checkContext(currentContext, previousContext, value) + lastEmissionContext = currentContext } completion = uCont - return emitFun(collector as FlowCollector, value, this as Continuation) + val result = emitFun(collector as FlowCollector, value, this as Continuation) + /* + * If the callee hasn't suspended, that means that it won't (it's forbidden) call 'resumeWith` (-> `invokeSuspend`) + * and we don't have to retain a strong reference to it to avoid memory leaks. + */ + if (result != COROUTINE_SUSPENDED) { + completion = null + } + return result } private fun checkContext( @@ -82,14 +100,13 @@ internal actual class SafeCollector actual constructor( previousContext: CoroutineContext?, value: T ) { - if (previousContext is DownstreamExceptionElement) { + if (previousContext is DownstreamExceptionContext) { exceptionTransparencyViolated(previousContext, value) } checkContext(currentContext) - lastEmissionContext = currentContext } - private fun exceptionTransparencyViolated(exception: DownstreamExceptionElement, value: Any?) { + private fun exceptionTransparencyViolated(exception: DownstreamExceptionContext, value: Any?) { /* * Exception transparency ensures that if a `collect` block or any intermediate operator * throws an exception, then no more values will be received by it. @@ -122,14 +139,12 @@ internal actual class SafeCollector actual constructor( For a more detailed explanation, please refer to Flow documentation. """.trimIndent()) } - } -internal class DownstreamExceptionElement(@JvmField val e: Throwable) : CoroutineContext.Element { - companion object Key : CoroutineContext.Key - - override val key: CoroutineContext.Key<*> = Key -} +internal class DownstreamExceptionContext( + @JvmField val e: Throwable, + originalContext: CoroutineContext +) : CoroutineContext by originalContext private object NoOpContinuation : Continuation { override val context: CoroutineContext = EmptyCoroutineContext diff --git a/kotlinx-coroutines-core/jvm/test/FieldWalker.kt b/kotlinx-coroutines-core/jvm/test/FieldWalker.kt index 52bcce3c69..7b2aaf63fc 100644 --- a/kotlinx-coroutines-core/jvm/test/FieldWalker.kt +++ b/kotlinx-coroutines-core/jvm/test/FieldWalker.kt @@ -9,6 +9,7 @@ import java.lang.reflect.* import java.text.* import java.util.* import java.util.Collections.* +import java.util.concurrent.* import java.util.concurrent.atomic.* import java.util.concurrent.locks.* import kotlin.test.* @@ -26,11 +27,11 @@ object FieldWalker { // excluded/terminal classes (don't walk them) fieldsCache += listOf( Any::class, String::class, Thread::class, Throwable::class, StackTraceElement::class, - WeakReference::class, ReferenceQueue::class, AbstractMap::class, - ReentrantReadWriteLock::class, SimpleDateFormat::class + WeakReference::class, ReferenceQueue::class, AbstractMap::class, Enum::class, + ReentrantLock::class, ReentrantReadWriteLock::class, SimpleDateFormat::class, ThreadPoolExecutor::class, ) .map { it.java } - .associateWith { emptyList() } + .associateWith { emptyList() } } /* @@ -159,6 +160,13 @@ object FieldWalker { && !(it.type.isArray && it.type.componentType.isPrimitive) && it.name != "previousOut" // System.out from TestBase that we store in a field to restore later } + check(fields.isEmpty() || !type.name.startsWith("java.")) { + """ + Trying to walk trough JDK's '$type' will get into illegal reflective access on JDK 9+. + Either modify your test to avoid usage of this class or update FieldWalker code to retrieve + the captured state of this class without going through reflection (see how collections are handled). + """.trimIndent() + } fields.forEach { it.isAccessible = true } // make them all accessible result.addAll(fields) type = type.superclass diff --git a/kotlinx-coroutines-core/jvm/test/flow/SafeCollectorMemoryLeakTest.kt b/kotlinx-coroutines-core/jvm/test/flow/SafeCollectorMemoryLeakTest.kt new file mode 100644 index 0000000000..b75ec60ed5 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/flow/SafeCollectorMemoryLeakTest.kt @@ -0,0 +1,48 @@ +/* + * Copyright 2016-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines.flow + +import kotlinx.coroutines.* +import org.junit.* + +class SafeCollectorMemoryLeakTest : TestBase() { + // custom List.forEach impl to avoid using iterator (FieldWalker cannot scan it) + private inline fun List.listForEach(action: (T) -> Unit) { + for (i in indices) action(get(i)) + } + + @Test + fun testCompletionIsProperlyCleanedUp() = runBlocking { + val job = flow { + emit(listOf(239)) + expect(2) + hang {} + }.transform { l -> l.listForEach { _ -> emit(42) } } + .onEach { expect(1) } + .launchIn(this) + yield() + expect(3) + FieldWalker.assertReachableCount(0, job) { it == 239 } + job.cancelAndJoin() + finish(4) + } + + @Test + fun testCompletionIsNotCleanedUp() = runBlocking { + val job = flow { + emit(listOf(239)) + hang {} + }.transform { l -> l.listForEach { _ -> emit(42) } } + .onEach { + expect(1) + hang { finish(3) } + } + .launchIn(this) + yield() + expect(2) + FieldWalker.assertReachableCount(1, job) { it == 239 } + job.cancelAndJoin() + } +}