Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix resource leak for HttpClient-based backends on cancellation #2413

Merged
merged 14 commits into from
Jan 24, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import sttp.client4.internal.NoStreams
import sttp.client4.wrappers.FollowRedirectsBackend
import sttp.client4.{wrappers, Backend, BackendOptions}
import sttp.monad.MonadAsyncError
import cats.effect.ExitCase

private final class ArmeriaCatsBackend[F[_]: Concurrent](client: WebClient, closeFactory: Boolean)
extends AbstractArmeriaBackend[F, Nothing](client, closeFactory, new CatsMonadAsyncError) {
Expand All @@ -31,6 +32,12 @@ private final class ArmeriaCatsBackend[F[_]: Concurrent](client: WebClient, clos

override protected def streamToPublisher(stream: Nothing): Publisher[HttpData] =
throw new UnsupportedOperationException("This backend does not support streaming")

override protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] =
Concurrent[F].guaranteeCase(effect) { exit =>
if (exit == ExitCase.Completed) Concurrent[F].unit
else Concurrent[F].recoverWith(finalizer) { case t => Concurrent[F].delay(t.printStackTrace()) }
}
}

object ArmeriaCatsBackend {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ private final class ArmeriaCatsBackend[F[_]: Async](client: WebClient, closeFact

override protected def streamToPublisher(stream: Nothing): Publisher[HttpData] =
throw new UnsupportedOperationException("This backend does not support streaming")

override protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] =
Async[F].guaranteeCase(effect) { outcome =>
if (outcome.isSuccess) Async[F].unit
else Async[F].onError(finalizer) { case t => Async[F].delay(t.printStackTrace()) }
}
}

object ArmeriaCatsBackend {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import sttp.capabilities.fs2.Fs2Streams
import sttp.client4.armeria.ArmeriaWebClient.newClient
import sttp.client4.armeria.{AbstractArmeriaBackend, BodyFromStreamMessage}
import sttp.client4.impl.cats.CatsMonadAsyncError
import sttp.client4.wrappers.FollowRedirectsBackend
import sttp.client4.{wrappers, BackendOptions, StreamBackend}
import sttp.monad.MonadAsyncError
import cats.effect.ExitCase

private final class ArmeriaFs2Backend[F[_]: ConcurrentEffect](client: WebClient, closeFactory: Boolean)
extends AbstractArmeriaBackend[F, Fs2Streams[F]](client, closeFactory, new CatsMonadAsyncError) {
Expand All @@ -36,6 +36,12 @@ private final class ArmeriaFs2Backend[F[_]: ConcurrentEffect](client: WebClient,
val bytes = chunk.toBytes
HttpData.wrap(bytes.values, bytes.offset, bytes.length)
}.toUnicastPublisher

override protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] =
ConcurrentEffect[F].guaranteeCase(effect) { exitCase =>
if (exitCase == ExitCase.Completed) ConcurrentEffect[F].unit
else ConcurrentEffect[F].onError(finalizer) { case t => ConcurrentEffect[F].delay(t.printStackTrace()) }
}
}

object ArmeriaFs2Backend {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ private final class ArmeriaFs2Backend[F[_]: Async](client: WebClient, closeFacto

override protected def compressors: List[Compressor[R]] =
List(new GZipFs2Compressor[F, R](), new DeflateFs2Compressor[F, R]())

override protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] =
Async[F].guaranteeCase(effect) { outcome =>
if (outcome.isSuccess) Async[F].unit
else Async[F].onError(finalizer) { case t => Async[F].delay(t.printStackTrace()) }
}
}

object ArmeriaFs2Backend {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import sttp.client4.impl.monix.TaskMonadAsyncError
import sttp.client4.wrappers.FollowRedirectsBackend
import sttp.client4.{wrappers, BackendOptions, StreamBackend}
import sttp.monad.MonadAsyncError
import cats.effect.ExitCase

private final class ArmeriaMonixBackend(client: WebClient, closeFactory: Boolean)(implicit scheduler: Scheduler)
extends AbstractArmeriaBackend[Task, MonixStreams](client, closeFactory, TaskMonadAsyncError) {
Expand All @@ -33,6 +34,11 @@ private final class ArmeriaMonixBackend(client: WebClient, closeFactory: Boolean

override protected def streamToPublisher(stream: Observable[Array[Byte]]): Publisher[HttpData] =
stream.map(HttpData.wrap).toReactivePublisher

override protected def ensureOnAbnormal[T](effect: Task[T])(finalizer: => Task[Unit]): Task[T] =
effect.guaranteeCase { exit =>
if (exit == ExitCase.Completed) Task.unit else finalizer.onErrorHandleWith(t => Task.eval(t.printStackTrace()))
}
}

object ArmeriaMonixBackend {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ private final class ArmeriaScalazBackend(client: WebClient, closeFactory: Boolea

override protected def streamToPublisher(stream: Nothing): Publisher[HttpData] =
throw new UnsupportedOperationException("This backend does not support streaming")

override protected def ensureOnAbnormal[T](effect: Task[T])(finalizer: => Task[Unit]): Task[T] =
effect.handleWith { case e =>
finalizer.handleWith { case e2 => Task(e.addSuppressed(e2)) }.flatMap(_ => Task.fail(e))
}
}

object ArmeriaScalazBackend {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,26 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]](

protected def compressors: List[Compressor[R]] = Compressor.default[R]

override def send[T](request: GenericRequest[T, R]): F[Response[T]] =
monad.suspend(adjustExceptions(request)(execute(request)))
// #1987: see the comments in HttpClientAsyncBackend
protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T]

private def execute[T](request: GenericRequest[T, R]): F[Response[T]] = {
override def send[T](request: GenericRequest[T, R]): F[Response[T]] = {
// #1987: see the comments in HttpClientAsyncBackend
val armeriaCtx = new AtomicReference[ClientRequestContext]()
ensureOnAbnormal {
monad.suspend(adjustExceptions(request)(execute(request, armeriaCtx)))
} {
monad.eval {
val ctx = armeriaCtx.get()
if (ctx != null) ctx.cancel()
}
}
}

private def execute[T](
request: GenericRequest[T, R],
armeriaCtx: AtomicReference[ClientRequestContext]
): F[Response[T]] = {
val captor = Clients.newContextCaptor()
try {
val armeriaRes = requestToArmeria(request).execute()
Expand All @@ -84,6 +100,7 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]](
noopCanceler
}
case Success(ctx) =>
armeriaCtx.set(ctx)
fromArmeriaResponse(request, armeriaRes, ctx)
}
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ private final class ArmeriaFutureBackend(client: WebClient, closeFactory: Boolea

override protected def streamToPublisher(stream: streams.BinaryStream): Publisher[HttpData] =
throw new UnsupportedOperationException("This backend does not support streaming")

override protected def ensureOnAbnormal[T](effect: Future[T])(finalizer: => Future[Unit]): Future[T] =
effect.recoverWith { case e =>
finalizer.recoverWith { case e2 => e.addSuppressed(e2); Future.failed(e) }.flatMap(_ => Future.failed(e))
}
}

object ArmeriaFutureBackend {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ private final class ArmeriaZioBackend(runtime: Runtime[Any], client: WebClient,
}

override protected def compressors: List[Compressor[R]] = List(GZipZioCompressor, DeflateZioCompressor)

override protected def ensureOnAbnormal[T](effect: Task[T])(finalizer: => Task[Unit]): Task[T] = effect.onExit {
exit =>
if (exit.isSuccess) ZIO.unit else finalizer.catchAll(t => ZIO.logErrorCause("Error in finalizer", Cause.fail(t)))
}.resurrect
}

object ArmeriaZioBackend {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ private final class ArmeriaZioBackend(runtime: Runtime[Any], client: WebClient,

override protected def streamToPublisher(stream: Stream[Throwable, Byte]): Publisher[HttpData] =
runtime.unsafeRun(stream.mapChunks(c => Chunk.single(HttpData.wrap(c.toArray))).toPublisher)

override protected def ensureOnAbnormal[T](effect: Task[T])(finalizer: => Task[Unit]): Task[T] = effect.onExit {
exit =>
if (exit.succeeded) ZIO.unit else finalizer.catchAll(t => ZIO.effect(t.printStackTrace()).orDie)
}.resurrect
}

object ArmeriaZioBackend {
Expand Down
Loading
Loading