diff --git a/armeria-backend/cats-ce2/src/main/scala/sttp/client4/armeria/cats/ArmeriaCatsBackend.scala b/armeria-backend/cats-ce2/src/main/scala/sttp/client4/armeria/cats/ArmeriaCatsBackend.scala index 49b45f581e..fa3b6c7e92 100644 --- a/armeria-backend/cats-ce2/src/main/scala/sttp/client4/armeria/cats/ArmeriaCatsBackend.scala +++ b/armeria-backend/cats-ce2/src/main/scala/sttp/client4/armeria/cats/ArmeriaCatsBackend.scala @@ -12,6 +12,7 @@ import sttp.client4.internal.NoStreams import sttp.client4.wrappers.FollowRedirectsBackend import sttp.client4.{wrappers, Backend, BackendOptions} import sttp.monad.MonadAsyncError +import cats.effect.ExitCase private final class ArmeriaCatsBackend[F[_]: Concurrent](client: WebClient, closeFactory: Boolean) extends AbstractArmeriaBackend[F, Nothing](client, closeFactory, new CatsMonadAsyncError) { @@ -31,6 +32,12 @@ private final class ArmeriaCatsBackend[F[_]: Concurrent](client: WebClient, clos override protected def streamToPublisher(stream: Nothing): Publisher[HttpData] = throw new UnsupportedOperationException("This backend does not support streaming") + + override protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] = + Concurrent[F].guaranteeCase(effect) { exit => + if (exit == ExitCase.Completed) Concurrent[F].unit + else Concurrent[F].recoverWith(finalizer) { case t => Concurrent[F].delay(t.printStackTrace()) } + } } object ArmeriaCatsBackend { diff --git a/armeria-backend/cats/src/main/scala/sttp/client4/armeria/cats/ArmeriaCatsBackend.scala b/armeria-backend/cats/src/main/scala/sttp/client4/armeria/cats/ArmeriaCatsBackend.scala index d926a678b1..e6926edfd1 100644 --- a/armeria-backend/cats/src/main/scala/sttp/client4/armeria/cats/ArmeriaCatsBackend.scala +++ b/armeria-backend/cats/src/main/scala/sttp/client4/armeria/cats/ArmeriaCatsBackend.scala @@ -31,6 +31,12 @@ private final class ArmeriaCatsBackend[F[_]: Async](client: WebClient, closeFact override protected def streamToPublisher(stream: Nothing): Publisher[HttpData] = throw new UnsupportedOperationException("This backend does not support streaming") + + override protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] = + Async[F].guaranteeCase(effect) { outcome => + if (outcome.isSuccess) Async[F].unit + else Async[F].onError(finalizer) { case t => Async[F].delay(t.printStackTrace()) } + } } object ArmeriaCatsBackend { diff --git a/armeria-backend/fs2-ce2/src/main/scala/sttp/client4/armeria/fs2/ArmeriaFs2Backend.scala b/armeria-backend/fs2-ce2/src/main/scala/sttp/client4/armeria/fs2/ArmeriaFs2Backend.scala index 7d8ccef6c5..2d3b665bec 100644 --- a/armeria-backend/fs2-ce2/src/main/scala/sttp/client4/armeria/fs2/ArmeriaFs2Backend.scala +++ b/armeria-backend/fs2-ce2/src/main/scala/sttp/client4/armeria/fs2/ArmeriaFs2Backend.scala @@ -11,9 +11,9 @@ import sttp.capabilities.fs2.Fs2Streams import sttp.client4.armeria.ArmeriaWebClient.newClient import sttp.client4.armeria.{AbstractArmeriaBackend, BodyFromStreamMessage} import sttp.client4.impl.cats.CatsMonadAsyncError -import sttp.client4.wrappers.FollowRedirectsBackend import sttp.client4.{wrappers, BackendOptions, StreamBackend} import sttp.monad.MonadAsyncError +import cats.effect.ExitCase private final class ArmeriaFs2Backend[F[_]: ConcurrentEffect](client: WebClient, closeFactory: Boolean) extends AbstractArmeriaBackend[F, Fs2Streams[F]](client, closeFactory, new CatsMonadAsyncError) { @@ -36,6 +36,12 @@ private final class ArmeriaFs2Backend[F[_]: ConcurrentEffect](client: WebClient, val bytes = chunk.toBytes HttpData.wrap(bytes.values, bytes.offset, bytes.length) }.toUnicastPublisher + + override protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] = + ConcurrentEffect[F].guaranteeCase(effect) { exitCase => + if (exitCase == ExitCase.Completed) ConcurrentEffect[F].unit + else ConcurrentEffect[F].onError(finalizer) { case t => ConcurrentEffect[F].delay(t.printStackTrace()) } + } } object ArmeriaFs2Backend { diff --git a/armeria-backend/fs2/src/main/scala/sttp/client4/armeria/fs2/ArmeriaFs2Backend.scala b/armeria-backend/fs2/src/main/scala/sttp/client4/armeria/fs2/ArmeriaFs2Backend.scala index 3796966974..df4fcc7921 100644 --- a/armeria-backend/fs2/src/main/scala/sttp/client4/armeria/fs2/ArmeriaFs2Backend.scala +++ b/armeria-backend/fs2/src/main/scala/sttp/client4/armeria/fs2/ArmeriaFs2Backend.scala @@ -46,6 +46,12 @@ private final class ArmeriaFs2Backend[F[_]: Async](client: WebClient, closeFacto override protected def compressors: List[Compressor[R]] = List(new GZipFs2Compressor[F, R](), new DeflateFs2Compressor[F, R]()) + + override protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] = + Async[F].guaranteeCase(effect) { outcome => + if (outcome.isSuccess) Async[F].unit + else Async[F].onError(finalizer) { case t => Async[F].delay(t.printStackTrace()) } + } } object ArmeriaFs2Backend { diff --git a/armeria-backend/monix/src/main/scala/sttp/client4/armeria/monix/ArmeriaMonixBackend.scala b/armeria-backend/monix/src/main/scala/sttp/client4/armeria/monix/ArmeriaMonixBackend.scala index 1005c9311a..3a4c90cb7b 100644 --- a/armeria-backend/monix/src/main/scala/sttp/client4/armeria/monix/ArmeriaMonixBackend.scala +++ b/armeria-backend/monix/src/main/scala/sttp/client4/armeria/monix/ArmeriaMonixBackend.scala @@ -14,6 +14,7 @@ import sttp.client4.impl.monix.TaskMonadAsyncError import sttp.client4.wrappers.FollowRedirectsBackend import sttp.client4.{wrappers, BackendOptions, StreamBackend} import sttp.monad.MonadAsyncError +import cats.effect.ExitCase private final class ArmeriaMonixBackend(client: WebClient, closeFactory: Boolean)(implicit scheduler: Scheduler) extends AbstractArmeriaBackend[Task, MonixStreams](client, closeFactory, TaskMonadAsyncError) { @@ -33,6 +34,11 @@ private final class ArmeriaMonixBackend(client: WebClient, closeFactory: Boolean override protected def streamToPublisher(stream: Observable[Array[Byte]]): Publisher[HttpData] = stream.map(HttpData.wrap).toReactivePublisher + + override protected def ensureOnAbnormal[T](effect: Task[T])(finalizer: => Task[Unit]): Task[T] = + effect.guaranteeCase { exit => + if (exit == ExitCase.Completed) Task.unit else finalizer.onErrorHandleWith(t => Task.eval(t.printStackTrace())) + } } object ArmeriaMonixBackend { diff --git a/armeria-backend/scalaz/src/main/scala/sttp/client4/armeria/scalaz/ArmeriaScalazBackend.scala b/armeria-backend/scalaz/src/main/scala/sttp/client4/armeria/scalaz/ArmeriaScalazBackend.scala index e46f19ab02..a312cbb5fa 100644 --- a/armeria-backend/scalaz/src/main/scala/sttp/client4/armeria/scalaz/ArmeriaScalazBackend.scala +++ b/armeria-backend/scalaz/src/main/scala/sttp/client4/armeria/scalaz/ArmeriaScalazBackend.scala @@ -31,6 +31,11 @@ private final class ArmeriaScalazBackend(client: WebClient, closeFactory: Boolea override protected def streamToPublisher(stream: Nothing): Publisher[HttpData] = throw new UnsupportedOperationException("This backend does not support streaming") + + override protected def ensureOnAbnormal[T](effect: Task[T])(finalizer: => Task[Unit]): Task[T] = + effect.handleWith { case e => + finalizer.handleWith { case e2 => Task(e.addSuppressed(e2)) }.flatMap(_ => Task.fail(e)) + } } object ArmeriaScalazBackend { diff --git a/armeria-backend/src/main/scala/sttp/client4/armeria/AbstractArmeriaBackend.scala b/armeria-backend/src/main/scala/sttp/client4/armeria/AbstractArmeriaBackend.scala index da847a8430..40bcde96bb 100644 --- a/armeria-backend/src/main/scala/sttp/client4/armeria/AbstractArmeriaBackend.scala +++ b/armeria-backend/src/main/scala/sttp/client4/armeria/AbstractArmeriaBackend.scala @@ -58,10 +58,26 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]]( protected def compressors: List[Compressor[R]] = Compressor.default[R] - override def send[T](request: GenericRequest[T, R]): F[Response[T]] = - monad.suspend(adjustExceptions(request)(execute(request))) + // #1987: see the comments in HttpClientAsyncBackend + protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] - private def execute[T](request: GenericRequest[T, R]): F[Response[T]] = { + override def send[T](request: GenericRequest[T, R]): F[Response[T]] = { + // #1987: see the comments in HttpClientAsyncBackend + val armeriaCtx = new AtomicReference[ClientRequestContext]() + ensureOnAbnormal { + monad.suspend(adjustExceptions(request)(execute(request, armeriaCtx))) + } { + monad.eval { + val ctx = armeriaCtx.get() + if (ctx != null) ctx.cancel() + } + } + } + + private def execute[T]( + request: GenericRequest[T, R], + armeriaCtx: AtomicReference[ClientRequestContext] + ): F[Response[T]] = { val captor = Clients.newContextCaptor() try { val armeriaRes = requestToArmeria(request).execute() @@ -84,6 +100,7 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]]( noopCanceler } case Success(ctx) => + armeriaCtx.set(ctx) fromArmeriaResponse(request, armeriaRes, ctx) } } catch { diff --git a/armeria-backend/src/main/scala/sttp/client4/armeria/future/ArmeriaFutureBackend.scala b/armeria-backend/src/main/scala/sttp/client4/armeria/future/ArmeriaFutureBackend.scala index 3abfd526a6..b13d9f73da 100644 --- a/armeria-backend/src/main/scala/sttp/client4/armeria/future/ArmeriaFutureBackend.scala +++ b/armeria-backend/src/main/scala/sttp/client4/armeria/future/ArmeriaFutureBackend.scala @@ -32,6 +32,11 @@ private final class ArmeriaFutureBackend(client: WebClient, closeFactory: Boolea override protected def streamToPublisher(stream: streams.BinaryStream): Publisher[HttpData] = throw new UnsupportedOperationException("This backend does not support streaming") + + override protected def ensureOnAbnormal[T](effect: Future[T])(finalizer: => Future[Unit]): Future[T] = + effect.recoverWith { case e => + finalizer.recoverWith { case e2 => e.addSuppressed(e2); Future.failed(e) }.flatMap(_ => Future.failed(e)) + } } object ArmeriaFutureBackend { diff --git a/armeria-backend/zio/src/main/scala/sttp/client4/armeria/zio/ArmeriaZioBackend.scala b/armeria-backend/zio/src/main/scala/sttp/client4/armeria/zio/ArmeriaZioBackend.scala index ea4806f2c5..5d7e124cab 100644 --- a/armeria-backend/zio/src/main/scala/sttp/client4/armeria/zio/ArmeriaZioBackend.scala +++ b/armeria-backend/zio/src/main/scala/sttp/client4/armeria/zio/ArmeriaZioBackend.scala @@ -44,6 +44,11 @@ private final class ArmeriaZioBackend(runtime: Runtime[Any], client: WebClient, } override protected def compressors: List[Compressor[R]] = List(GZipZioCompressor, DeflateZioCompressor) + + override protected def ensureOnAbnormal[T](effect: Task[T])(finalizer: => Task[Unit]): Task[T] = effect.onExit { + exit => + if (exit.isSuccess) ZIO.unit else finalizer.catchAll(t => ZIO.logErrorCause("Error in finalizer", Cause.fail(t))) + }.resurrect } object ArmeriaZioBackend { diff --git a/armeria-backend/zio1/src/main/scala/sttp/client4/armeria/zio/ArmeriaZioBackend.scala b/armeria-backend/zio1/src/main/scala/sttp/client4/armeria/zio/ArmeriaZioBackend.scala index 512f2e9853..93d92a2694 100644 --- a/armeria-backend/zio1/src/main/scala/sttp/client4/armeria/zio/ArmeriaZioBackend.scala +++ b/armeria-backend/zio1/src/main/scala/sttp/client4/armeria/zio/ArmeriaZioBackend.scala @@ -37,6 +37,11 @@ private final class ArmeriaZioBackend(runtime: Runtime[Any], client: WebClient, override protected def streamToPublisher(stream: Stream[Throwable, Byte]): Publisher[HttpData] = runtime.unsafeRun(stream.mapChunks(c => Chunk.single(HttpData.wrap(c.toArray))).toPublisher) + + override protected def ensureOnAbnormal[T](effect: Task[T])(finalizer: => Task[Unit]): Task[T] = effect.onExit { + exit => + if (exit.succeeded) ZIO.unit else finalizer.catchAll(t => ZIO.effect(t.printStackTrace()).orDie) + }.resurrect } object ArmeriaZioBackend { diff --git a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientAsyncBackend.scala b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientAsyncBackend.scala index fd0fc6d527..aad7bd7efb 100644 --- a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientAsyncBackend.scala +++ b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientAsyncBackend.scala @@ -1,18 +1,28 @@ package sttp.client4.httpclient -import sttp.capabilities.{Streams, WebSockets} -import sttp.client4.internal.SttpToJavaConverters.{toJavaBiConsumer, toJavaFunction} -import sttp.client4.internal.httpclient.{AddToQueueListener, DelegatingWebSocketListener, Sequencer, WebSocketImpl} -import sttp.client4.internal.ws.{SimpleQueue, WebSocketEvent} -import sttp.client4.{GenericRequest, Response, WebSocketBackend} +import sttp.capabilities.Streams +import sttp.capabilities.WebSockets +import sttp.client4.GenericRequest +import sttp.client4.Response +import sttp.client4.WebSocketBackend import sttp.client4.compression.CompressionHandlers +import sttp.client4.internal.SttpToJavaConverters.toJavaBiConsumer +import sttp.client4.internal.SttpToJavaConverters.toJavaFunction +import sttp.client4.internal.httpclient.AddToQueueListener +import sttp.client4.internal.httpclient.DelegatingWebSocketListener +import sttp.client4.internal.httpclient.Sequencer +import sttp.client4.internal.httpclient.WebSocketImpl +import sttp.client4.internal.ws.SimpleQueue +import sttp.client4.internal.ws.WebSocketEvent import sttp.model.StatusCode +import sttp.monad.Canceler +import sttp.monad.MonadAsyncError import sttp.monad.syntax._ -import sttp.monad.{Canceler, MonadAsyncError} import java.net.http._ import java.util.concurrent.CompletionException import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicReference import java.util.function.BiConsumer /** @tparam F @@ -36,46 +46,82 @@ abstract class HttpClientAsyncBackend[F[_], S <: Streams[S], BH, B]( protected def createSimpleQueue[T]: F[SimpleQueue[F, T]] protected def createSequencer: F[Sequencer[F]] protected def createBodyHandler: HttpResponse.BodyHandler[BH] - protected def bodyHandlerBodyToBody(p: BH): B + protected def lowLevelBodyToBody(p: BH): B + protected def cancelLowLevelBody(p: BH): Unit protected def bodyToLimitedBody(b: B, limit: Long): B protected def emptyBody(): B - override def sendRegular[T](request: GenericRequest[T, R]): F[Response[T]] = - monad.flatMap(convertRequest(request)) { convertedRequest => - val jRequest = customizeRequest(convertedRequest) + /** A variant of [[MonadAsyncError.ensure]] which runs the finalizer only when the effect finished abnormally + * (exception thrown, failed effect, cancellation/interruption). + * + * This is used to release any resources allocated by HttpClient after the request is sent, but before the response + * is consumed. This is done only in case of failure, as in case of success the body is either fully consumed as + * specified in the response description, or when an `...Unsafe` response description is used, it's up to the user to + * consume it. + * + * Any exceptions that occur while running `finalizer` should be added as suppressed or logged, not impacting the + * outcome of `effect`. If possible, `finalizer` should not be run in a cancellable way. + */ + protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] - monad.flatten(monad.async[F[Response[T]]] { cb => - def success(r: F[Response[T]]): Unit = cb(Right(r)) - def error(t: Throwable): Unit = cb(Left(t)) - var cf = client.sendAsync(jRequest, createBodyHandler) + override def sendRegular[T](request: GenericRequest[T, R]): F[Response[T]] = { + monad + .flatMap(convertRequest(request)) { convertedRequest => + val jRequest = customizeRequest(convertedRequest) - val consumer = toJavaBiConsumer { (t: HttpResponse[BH], u: Throwable) => - if (t != null) { - // sometimes body returned by HttpClient can be null, we handle this by returning empty body to prevent NPE - val body = Option(t.body()) - .map(bodyHandlerBodyToBody) - .getOrElse(emptyBody()) + // #1987: whenever the low-level body is acquired, we need to ensure any resources associated with it are + // released. This includes proper handling of cancellation of the effect right after sending, but before + // consuming the body (which happens only when the effect returned by `readResponse` is evaluated). - val limitedBody = request.options.maxResponseBodyLength.fold(body)(bodyToLimitedBody(body, _)) + // storing the low-level body (usually an `InputStream` or a `Publisher`) so that it can be cancelled; + // cancellation might happen during request sending, so it's not always set + val lowLevelBody = new AtomicReference[BH]() + ensureOnAbnormal { + monad + .async[HttpResponse[BH]] { cb => + def success(r: HttpResponse[BH]): Unit = cb(Right(r)) + def error(t: Throwable): Unit = cb(Left(t)) + val cf = client.sendAsync(jRequest, createBodyHandler) - try success(readResponse(t, Left(limitedBody), request)) - catch { - case e: Exception => error(e) + cf.whenComplete(toJavaBiConsumer { (t: HttpResponse[BH], u: Throwable) => + if (t != null) { + lowLevelBody.set(t.body()) + success(t) + } + if (u != null) { + error(u) + } + }) + + // contrary to what the JavaDoc says, this actually cancels the request, even if it's in progress + // however, the request will be cancelled asynchronously, and there's no way of waiting for cancellation to + // complete; that's not ideal (both ZIO and cats-effect contracts require that the effect completes only + // when cancellation is complete), but it's the best we can do + // see: https://bugs.openjdk.org/browse/JDK-8245462 + Canceler { () => + val _ = cf.cancel(true) + } } - } - if (u != null) { - error(u) - } - } + .flatMap { jResponse => + // sometimes body returned by HttpClient can be null, we handle this by returning empty body to prevent NPE + val body = Option(jResponse.body()) + .map(lowLevelBodyToBody) + .getOrElse(emptyBody()) - cf = client.executor().orElse(null) match { - case null => cf.whenComplete(consumer) - case e => cf.whenCompleteAsync(consumer, e) // using the provided executor to further process the body - } + val limitedBody = request.options.maxResponseBodyLength.fold(body)(bodyToLimitedBody(body, _)) - Canceler(() => cf.cancel(true)) - }) - } + readResponse(jResponse, Left(limitedBody), request) + } + } { + monad.eval { + // the request might have been interrupted during sending (no publisher is available then), or any time + // after that, including right after the sending effect completed, but before the response was read + val llb = lowLevelBody.get() + if (llb != null) cancelLowLevelBody(llb) + } + } + } + } override def sendWebSocket[T](request: GenericRequest[T, R]): F[Response[T]] = (for { @@ -97,47 +143,57 @@ abstract class HttpClientAsyncBackend[F[_], S <: Streams[S], BH, B]( sequencer: Sequencer[F] ): F[Response[T]] = { val isOpen: AtomicBoolean = new AtomicBoolean(false) - monad.flatten(monad.async[F[Response[T]]] { cb => - def success(r: F[Response[T]]): Unit = cb(Right(r)) - def error(t: Throwable): Unit = cb(Left(t)) - - val listener = new DelegatingWebSocketListener( - new AddToQueueListener(queue, isOpen), - ws => { - val webSocket = new WebSocketImpl[F]( - ws, - queue, - isOpen, - sequencer, - monad, - cf => - monad.async { cb => - cf.whenComplete(new BiConsumer[WebSocket, Throwable] { - override def accept(t: WebSocket, error: Throwable): Unit = - if (error != null) { - cb(Left(error)) - } else { - cb(Right(())) - } - }) - Canceler { () => - cf.cancel(true) - () + + // see sendRegular for explanation + val lowLevelWS = new AtomicReference[WebSocket]() + ensureOnAbnormal { + monad.flatten(monad.async[F[Response[T]]] { cb => + def success(r: F[Response[T]]): Unit = cb(Right(r)) + def error(t: Throwable): Unit = cb(Left(t)) + + val listener = new DelegatingWebSocketListener( + new AddToQueueListener(queue, isOpen), + ws => { + lowLevelWS.set(ws) + val webSocket = new WebSocketImpl[F]( + ws, + queue, + isOpen, + sequencer, + monad, + cf => + monad.async { cb => + cf.whenComplete(new BiConsumer[WebSocket, Throwable] { + override def accept(t: WebSocket, error: Throwable): Unit = + if (error != null) { + cb(Left(error)) + } else { + cb(Right(())) + } + }) + Canceler { () => + val _ = cf.cancel(true) + } } - } - ) - val baseResponse = Response((), StatusCode.SwitchingProtocols, "", Nil, Nil, request.onlyMetadata) - val body = bodyFromHttpClient(Right(webSocket), request.response, baseResponse) - success(body.map(b => baseResponse.copy(body = b))) - }, - error - ) - - val cf = prepareWebSocketBuilder(request, client) - .buildAsync(request.uri.toJavaUri, listener) - .thenApply[Unit](toJavaFunction((_: WebSocket) => ())) - .exceptionally(toJavaFunction((t: Throwable) => cb(Left(t)))) - Canceler(() => cf.cancel(true)) - }) + ) + val baseResponse = Response((), StatusCode.SwitchingProtocols, "", Nil, Nil, request.onlyMetadata) + val body = bodyFromHttpClient(Right(webSocket), request.response, baseResponse) + success(body.map(b => baseResponse.copy(body = b))) + }, + error + ) + + val cf = prepareWebSocketBuilder(request, client) + .buildAsync(request.uri.toJavaUri, listener) + .thenApply[Unit](toJavaFunction((_: WebSocket) => ())) + .exceptionally(toJavaFunction((t: Throwable) => cb(Left(t)))) + Canceler { () => + val _ = cf.cancel(true) + } + }) + } { + val llws = lowLevelWS.get() + if (llws != null) monad.eval(llws.abort()) else monad.unit(()) + } } } diff --git a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientFutureBackend.scala b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientFutureBackend.scala index bac950441e..7b576dcac7 100644 --- a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientFutureBackend.scala +++ b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientFutureBackend.scala @@ -1,23 +1,32 @@ package sttp.client4.httpclient +import sttp.client4.BackendOptions +import sttp.client4.WebSocketBackend +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor +import sttp.client4.compression.Decompressor +import sttp.client4.internal.NoStreams +import sttp.client4.internal.emptyInputStream import sttp.client4.internal.httpclient._ -import sttp.client4.internal.ws.{FutureSimpleQueue, SimpleQueue} -import sttp.client4.internal.{emptyInputStream, NoStreams} +import sttp.client4.internal.ws.FutureSimpleQueue +import sttp.client4.internal.ws.SimpleQueue import sttp.client4.testing.WebSocketBackendStub -import sttp.client4.{wrappers, BackendOptions, WebSocketBackend} -import sttp.monad.{FutureMonad, MonadError} -import sttp.ws.{WebSocket, WebSocketFrame} +import sttp.client4.wrappers +import sttp.monad.FutureMonad +import sttp.monad.MonadError +import sttp.tapir.server.jdkhttp.internal.FailingLimitedInputStream +import sttp.ws.WebSocket +import sttp.ws.WebSocketFrame import java.io.InputStream +import java.net.http.HttpClient +import java.net.http.HttpRequest import java.net.http.HttpRequest.BodyPublisher +import java.net.http.HttpResponse import java.net.http.HttpResponse.BodyHandlers -import java.net.http.{HttpClient, HttpRequest, HttpResponse} import java.util.concurrent.Executor -import scala.concurrent.{ExecutionContext, Future} -import sttp.client4.compression.Compressor -import sttp.client4.compression.CompressionHandlers -import sttp.client4.compression.Decompressor -import sttp.tapir.server.jdkhttp.internal.FailingLimitedInputStream +import scala.concurrent.ExecutionContext +import scala.concurrent.Future class HttpClientFutureBackend private ( client: HttpClient, @@ -61,12 +70,19 @@ class HttpClientFutureBackend private ( override protected def createBodyHandler: HttpResponse.BodyHandler[InputStream] = BodyHandlers.ofInputStream() - override protected def bodyHandlerBodyToBody(p: InputStream): InputStream = p + override protected def lowLevelBodyToBody(p: InputStream): InputStream = p + + override protected def cancelLowLevelBody(p: InputStream): Unit = p.close() override protected def emptyBody(): InputStream = emptyInputStream() override protected def bodyToLimitedBody(b: InputStream, limit: Long): InputStream = new FailingLimitedInputStream(b, limit) + + override protected def ensureOnAbnormal[T](effect: Future[T])(finalizer: => Future[Unit]): Future[T] = + effect.recoverWith { case e => + finalizer.recoverWith { case e2 => e.addSuppressed(e2); Future.failed(e) }.flatMap(_ => Future.failed(e)) + } } object HttpClientFutureBackend { diff --git a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientSyncBackend.scala b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientSyncBackend.scala index f76045e8a8..52eb78415d 100644 --- a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientSyncBackend.scala +++ b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientSyncBackend.scala @@ -21,6 +21,7 @@ import sttp.client4.compression.Compressor import sttp.client4.compression.CompressionHandlers import sttp.client4.compression.Decompressor import sttp.tapir.server.jdkhttp.internal.FailingLimitedInputStream +import java.util.concurrent.atomic.AtomicReference class HttpClientSyncBackend private ( client: HttpClient, @@ -39,15 +40,30 @@ class HttpClientSyncBackend private ( override protected def sendRegular[T](request: GenericRequest[T, R]): Response[T] = { val jRequest = customizeRequest(convertRequest(request)) val response = client.send(jRequest, BodyHandlers.ofInputStream()) - val body = response.body() - val limitedBody = request.options.maxResponseBodyLength.fold(body)(new FailingLimitedInputStream(body, _)) - readResponse(response, Left(limitedBody), request) + try { + val body = response.body() + val limitedBody = request.options.maxResponseBodyLength.fold(body)(new FailingLimitedInputStream(body, _)) + readResponse(response, Left(limitedBody), request) + } catch { + case e: Throwable => + // ensuring that the response body always gets closed + // in case of success the body is either already consumed, or an `...Unsafe` response description is used and + // it's up to the user to consume it + try { + response.body().close() + } catch { + case e2: Throwable => e.addSuppressed(e2) + } + throw e + } } override protected def sendWebSocket[T](request: GenericRequest[T, R]): Response[T] = { val queue = new SyncQueue[WebSocketEvent](None) val sequencer = new IdSequencer - try sendWebSocket(request, queue, sequencer) + // see HttpClientAsyncBackend.sendRegular for explanation + val lowLevelWS = new AtomicReference[java.net.http.WebSocket]() + try sendWebSocket(request, queue, sequencer, lowLevelWS) catch { case e: CompletionException if e.getCause.isInstanceOf[WebSocketHandshakeException] => readResponse( @@ -55,13 +71,22 @@ class HttpClientSyncBackend private ( Left(emptyInputStream()), request ) + case e: Throwable => + try { + val llws = lowLevelWS.get() + if (llws != null) llws.abort() + } catch { + case e2: Throwable => e.addSuppressed(e2) + } + throw e } } private def sendWebSocket[T]( request: GenericRequest[T, R], queue: SimpleQueue[Identity, WebSocketEvent], - sequencer: Sequencer[Identity] + sequencer: Sequencer[Identity], + lowLevelWS: AtomicReference[java.net.http.WebSocket] ): Response[T] = { val isOpen: AtomicBoolean = new AtomicBoolean(false) val responseCell = new ArrayBlockingQueue[Either[Throwable, () => Response[T]]](1) @@ -72,6 +97,7 @@ class HttpClientSyncBackend private ( val listener = new DelegatingWebSocketListener( new AddToQueueListener(queue, isOpen), ws => { + lowLevelWS.set(ws) val webSocket = new WebSocketImpl[Identity](ws, queue, isOpen, sequencer, monad, cf => { val _ = cf.get() }) val baseResponse = Response((), StatusCode.SwitchingProtocols, "", Nil, Nil, request.onlyMetadata) val body = () => bodyFromHttpClient(Right(webSocket), request.response, baseResponse) diff --git a/core/src/main/scalajvm/sttp/client4/internal/httpclient/package.scala b/core/src/main/scalajvm/sttp/client4/internal/httpclient/package.scala new file mode 100644 index 0000000000..fb8dc2ac5a --- /dev/null +++ b/core/src/main/scalajvm/sttp/client4/internal/httpclient/package.scala @@ -0,0 +1,13 @@ +package sttp.client4.internal + +import java.util.concurrent.Flow.Publisher + +package object httpclient { + private[client4] def cancelPublisher[T](p: Publisher[T]): Unit = + p.subscribe(new java.util.concurrent.Flow.Subscriber[T] { + override def onSubscribe(s: java.util.concurrent.Flow.Subscription): Unit = s.cancel() + override def onNext(t: T): Unit = () + override def onError(t: Throwable): Unit = () + override def onComplete(): Unit = () + }) +} diff --git a/core/src/test/scala/sttp/client4/testing/HttpTest.scala b/core/src/test/scala/sttp/client4/testing/HttpTest.scala index 2e356ec7fb..8b690bee37 100644 --- a/core/src/test/scala/sttp/client4/testing/HttpTest.scala +++ b/core/src/test/scala/sttp/client4/testing/HttpTest.scala @@ -716,7 +716,7 @@ trait HttpTest[F[_]] if (supportsCancellation) { "cancel" - { - "a request in progress" in { + "a request before any response is received" in { implicit val monad: MonadError[F] = backend.monad import sttp.monad.syntax._ @@ -737,6 +737,27 @@ trait HttpTest[F[_]] } ) } + + "a request when the response is produced slowly" in { + implicit val monad: MonadError[F] = backend.monad + import sttp.monad.syntax._ + + val req = basicRequest + .get(uri"$endpoint/streaming/slow") + .response(asString) + + val now = monad.eval(System.currentTimeMillis()) + + convertToFuture.toFuture( + now.flatMap { start => + timeoutToNone(req.send(backend), 100) + .map { r => + (System.currentTimeMillis() - start) should be < 2000L + r shouldBe None + } + } + ) + } } } diff --git a/effects/cats/src/main/scalajvm/sttp/client4/httpclient/cats/HttpClientCatsBackend.scala b/effects/cats/src/main/scalajvm/sttp/client4/httpclient/cats/HttpClientCatsBackend.scala index b62e4426ed..274f85a528 100644 --- a/effects/cats/src/main/scalajvm/sttp/client4/httpclient/cats/HttpClientCatsBackend.scala +++ b/effects/cats/src/main/scalajvm/sttp/client4/httpclient/cats/HttpClientCatsBackend.scala @@ -1,26 +1,37 @@ package sttp.client4.httpclient.cats -import cats.effect.kernel.{Async, Resource, Sync} -import cats.effect.std.{Dispatcher, Queue} -import cats.implicits.{toFlatMapOps, toFunctorOps} -import sttp.client4.httpclient.{HttpClientAsyncBackend, HttpClientBackend} +import cats.effect.kernel.Async +import cats.effect.kernel.Resource +import cats.effect.kernel.Sync +import cats.effect.std.Dispatcher +import cats.effect.std.Queue +import cats.implicits.toFlatMapOps +import cats.implicits.toFunctorOps +import sttp.client4.BackendOptions +import sttp.client4.WebSocketBackend +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor +import sttp.client4.compression.Decompressor +import sttp.client4.httpclient.HttpClientAsyncBackend +import sttp.client4.httpclient.HttpClientBackend import sttp.client4.impl.cats.CatsMonadAsyncError +import sttp.client4.internal.NoStreams +import sttp.client4.internal.emptyInputStream import sttp.client4.internal.httpclient._ import sttp.client4.internal.ws.SimpleQueue -import sttp.client4.internal.{emptyInputStream, NoStreams} import sttp.client4.testing.WebSocketBackendStub -import sttp.client4.{wrappers, BackendOptions, WebSocketBackend} +import sttp.client4.wrappers import sttp.monad.MonadError -import sttp.ws.{WebSocket, WebSocketFrame} +import sttp.tapir.server.jdkhttp.internal.FailingLimitedInputStream +import sttp.ws.WebSocket +import sttp.ws.WebSocketFrame import java.io.InputStream +import java.net.http.HttpClient +import java.net.http.HttpRequest import java.net.http.HttpRequest.BodyPublisher +import java.net.http.HttpResponse import java.net.http.HttpResponse.BodyHandlers -import java.net.http.{HttpClient, HttpRequest, HttpResponse} -import sttp.client4.compression.CompressionHandlers -import sttp.client4.compression.Compressor -import sttp.client4.compression.Decompressor -import sttp.tapir.server.jdkhttp.internal.FailingLimitedInputStream class HttpClientCatsBackend[F[_]: Async] private ( client: HttpClient, @@ -66,7 +77,15 @@ class HttpClientCatsBackend[F[_]: Async] private ( override protected def createBodyHandler: HttpResponse.BodyHandler[InputStream] = BodyHandlers.ofInputStream() - override protected def bodyHandlerBodyToBody(p: InputStream): InputStream = p + override protected def lowLevelBodyToBody(p: InputStream): InputStream = p + + override protected def cancelLowLevelBody(p: InputStream): Unit = p.close() + + override protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] = + Async[F].guaranteeCase(effect) { outcome => + if (outcome.isSuccess) Async[F].unit + else Async[F].onError(finalizer) { case t => Async[F].delay(t.printStackTrace()) } + } override protected def emptyBody(): InputStream = emptyInputStream() diff --git a/effects/fs2-ce2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala b/effects/fs2-ce2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala index 0c9d73709a..923158fcee 100644 --- a/effects/fs2-ce2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala +++ b/effects/fs2-ce2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala @@ -1,35 +1,43 @@ package sttp.client4.httpclient.fs2 -import java.net.http.HttpRequest.BodyPublishers -import java.net.http.{HttpClient, HttpRequest, HttpResponse} -import java.nio.ByteBuffer -import java.{util => ju} import cats.effect._ import cats.effect.implicits._ import cats.implicits._ -import fs2.{Chunk, Stream, Pull} +import fs2.Chunk +import fs2.Pull +import fs2.Stream import fs2.concurrent.InspectableQueue import fs2.interop.reactivestreams._ import org.reactivestreams.FlowAdapters +import sttp.capabilities.StreamMaxLengthExceededException import sttp.capabilities.fs2.Fs2Streams -import sttp.client4.internal.httpclient.{BodyFromHttpClient, BodyToHttpClient, Sequencer} +import sttp.client4._ +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor +import sttp.client4.httpclient.HttpClientAsyncBackend +import sttp.client4.httpclient.HttpClientBackend import sttp.client4.impl.cats.implicits._ +import sttp.client4.impl.fs2.DeflateFs2Decompressor import sttp.client4.impl.fs2.Fs2SimpleQueue +import sttp.client4.impl.fs2.GZipFs2Decompressor +import sttp.client4.internal.httpclient.BodyFromHttpClient +import sttp.client4.internal.httpclient.BodyToHttpClient +import sttp.client4.internal.httpclient.Sequencer +import sttp.client4.internal.httpclient.cancelPublisher import sttp.client4.internal.ws.SimpleQueue import sttp.client4.testing.WebSocketStreamBackendStub -import sttp.client4._ -import sttp.client4.httpclient.{HttpClientAsyncBackend, HttpClientBackend} import sttp.client4.wrappers.FollowRedirectsBackend import sttp.monad.MonadError +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpRequest.BodyPublishers +import java.net.http.HttpResponse import java.net.http.HttpResponse.BodyHandlers +import java.nio.ByteBuffer +import java.{util => ju} import java.util.concurrent.Flow.Publisher import scala.collection.JavaConverters._ -import sttp.client4.compression.CompressionHandlers -import sttp.client4.compression.Compressor -import sttp.client4.impl.fs2.GZipFs2Decompressor -import sttp.client4.impl.fs2.DeflateFs2Decompressor -import sttp.capabilities.StreamMaxLengthExceededException class HttpClientFs2Backend[F[_]: ConcurrentEffect: ContextShift] private ( client: HttpClient, @@ -75,12 +83,20 @@ class HttpClientFs2Backend[F[_]: ConcurrentEffect: ContextShift] private ( override protected def createSequencer: F[Sequencer[F]] = Fs2Sequencer.create - override protected def bodyHandlerBodyToBody(p: Publisher[ju.List[ByteBuffer]]): Stream[F, Byte] = + override protected def lowLevelBodyToBody(p: Publisher[ju.List[ByteBuffer]]): Stream[F, Byte] = FlowAdapters .toPublisher(p) .toStream[F] .flatMap(data => Stream.emits(data.asScala.map(Chunk.byteBuffer)).flatMap(Stream.chunk)) + override protected def cancelLowLevelBody(p: Publisher[ju.List[ByteBuffer]]): Unit = cancelPublisher(p) + + override protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] = + ConcurrentEffect[F].guaranteeCase(effect) { exitCase => + if (exitCase == ExitCase.Completed) ConcurrentEffect[F].unit + else ConcurrentEffect[F].onError(finalizer) { case t => ConcurrentEffect[F].delay(t.printStackTrace()) } + } + override protected def emptyBody(): Stream[F, Byte] = Stream.empty override protected def bodyToLimitedBody(b: Stream[F, Byte], limit: Long): Stream[F, Byte] = limitBytes(b, limit) diff --git a/effects/fs2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala b/effects/fs2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala index 8eecddd2a3..d5d0b5ccc2 100644 --- a/effects/fs2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala +++ b/effects/fs2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala @@ -1,34 +1,46 @@ package sttp.client4.httpclient.fs2 -import java.net.http.HttpRequest.BodyPublishers -import java.net.http.{HttpClient, HttpRequest, HttpResponse} -import java.nio.ByteBuffer -import java.util import cats.effect.kernel._ -import cats.effect.std.{Dispatcher, Queue} +import cats.effect.std.Dispatcher +import cats.effect.std.Queue import cats.implicits._ -import fs2.interop.reactivestreams.{PublisherOps, StreamUnicastPublisher} -import fs2.{Chunk, Stream} +import fs2.Chunk +import fs2.Stream +import fs2.compression.Compression +import fs2.interop.reactivestreams.PublisherOps +import fs2.interop.reactivestreams.StreamUnicastPublisher import org.reactivestreams.FlowAdapters import sttp.capabilities.fs2.Fs2Streams -import sttp.client4.internal.httpclient.{BodyFromHttpClient, BodyToHttpClient, Sequencer} +import sttp.client4._ +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor +import sttp.client4.httpclient.HttpClientAsyncBackend +import sttp.client4.httpclient.HttpClientBackend import sttp.client4.impl.cats.implicits._ +import sttp.client4.impl.fs2.DeflateFs2Compressor +import sttp.client4.impl.fs2.DeflateFs2Decompressor import sttp.client4.impl.fs2.Fs2SimpleQueue +import sttp.client4.impl.fs2.GZipFs2Compressor +import sttp.client4.impl.fs2.GZipFs2Decompressor +import sttp.client4.internal.httpclient.BodyFromHttpClient +import sttp.client4.internal.httpclient.BodyToHttpClient +import sttp.client4.internal.httpclient.Sequencer +import sttp.client4.internal.httpclient.cancelPublisher import sttp.client4.internal.ws.SimpleQueue import sttp.client4.testing.WebSocketStreamBackendStub -import sttp.client4._ -import sttp.client4.httpclient.{HttpClientAsyncBackend, HttpClientBackend} import sttp.client4.wrappers.FollowRedirectsBackend import sttp.monad.MonadError +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpRequest.BodyPublishers +import java.net.http.HttpResponse import java.net.http.HttpResponse.BodyHandlers -import java.util.concurrent.Flow.Publisher +import java.nio.ByteBuffer +import java.util import java.{util => ju} +import java.util.concurrent.Flow.Publisher import scala.collection.JavaConverters._ -import sttp.client4.compression.Compressor -import sttp.client4.impl.fs2.{DeflateFs2Compressor, DeflateFs2Decompressor, GZipFs2Compressor, GZipFs2Decompressor} -import sttp.client4.compression.CompressionHandlers -import fs2.compression.Compression class HttpClientFs2Backend[F[_]: Async] private ( client: HttpClient, @@ -75,12 +87,20 @@ class HttpClientFs2Backend[F[_]: Async] private ( override protected def createSequencer: F[Sequencer[F]] = Fs2Sequencer.create - override protected def bodyHandlerBodyToBody(p: Publisher[util.List[ByteBuffer]]): Stream[F, Byte] = + override protected def lowLevelBodyToBody(p: Publisher[util.List[ByteBuffer]]): Stream[F, Byte] = FlowAdapters .toPublisher(p) .toStream[F] .flatMap(data => Stream.emits(data.asScala.map(Chunk.byteBuffer)).flatMap(Stream.chunk)) + override protected def cancelLowLevelBody(p: Publisher[ju.List[ByteBuffer]]): Unit = cancelPublisher(p) + + override protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] = + Async[F].guaranteeCase(effect) { outcome => + if (outcome.isSuccess) Async[F].unit + else Async[F].onError(finalizer) { case t => Async[F].delay(t.printStackTrace()) } + } + override protected def emptyBody(): Stream[F, Byte] = Stream.empty override protected def bodyToLimitedBody(b: Stream[F, Byte], limit: Long): Stream[F, Byte] = diff --git a/effects/monix/src/main/scalajvm/sttp/client4/httpclient/monix/HttpClientMonixBackend.scala b/effects/monix/src/main/scalajvm/sttp/client4/httpclient/monix/HttpClientMonixBackend.scala index 39ab0ab5ee..666fc2ea9c 100644 --- a/effects/monix/src/main/scalajvm/sttp/client4/httpclient/monix/HttpClientMonixBackend.scala +++ b/effects/monix/src/main/scalajvm/sttp/client4/httpclient/monix/HttpClientMonixBackend.scala @@ -1,30 +1,40 @@ package sttp.client4.httpclient.monix +import cats.effect.ExitCase import cats.effect.Resource import monix.eval.Task import monix.execution.Scheduler import monix.reactive.Observable import org.reactivestreams.FlowAdapters +import sttp.capabilities.StreamMaxLengthExceededException import sttp.capabilities.monix.MonixStreams -import sttp.client4.httpclient.{HttpClientAsyncBackend, HttpClientBackend} -import sttp.client4.impl.monix.{MonixSimpleQueue, TaskMonadAsyncError} +import sttp.client4.BackendOptions +import sttp.client4.WebSocketStreamBackend +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor +import sttp.client4.httpclient.HttpClientAsyncBackend +import sttp.client4.httpclient.HttpClientBackend +import sttp.client4.impl.monix.MonixSimpleQueue +import sttp.client4.impl.monix.TaskMonadAsyncError import sttp.client4.internal._ -import sttp.client4.internal.httpclient.{BodyFromHttpClient, BodyToHttpClient, Sequencer} +import sttp.client4.internal.httpclient.BodyFromHttpClient +import sttp.client4.internal.httpclient.BodyToHttpClient +import sttp.client4.internal.httpclient.Sequencer +import sttp.client4.internal.httpclient.cancelPublisher import sttp.client4.internal.ws.SimpleQueue import sttp.client4.testing.WebSocketStreamBackendStub -import sttp.client4.{wrappers, BackendOptions, WebSocketStreamBackend} +import sttp.client4.wrappers import sttp.monad.MonadError +import java.net.http.HttpClient +import java.net.http.HttpRequest import java.net.http.HttpRequest.BodyPublishers +import java.net.http.HttpResponse import java.net.http.HttpResponse.BodyHandlers -import java.net.http.{HttpClient, HttpRequest, HttpResponse} import java.nio.ByteBuffer -import java.util.concurrent.Flow.Publisher import java.{util => ju} +import java.util.concurrent.Flow.Publisher import scala.collection.JavaConverters._ -import sttp.client4.compression.CompressionHandlers -import sttp.client4.compression.Compressor -import sttp.capabilities.StreamMaxLengthExceededException class HttpClientMonixBackend private ( client: HttpClient, @@ -68,12 +78,19 @@ class HttpClientMonixBackend private ( override protected def createBodyHandler: HttpResponse.BodyHandler[Publisher[ju.List[ByteBuffer]]] = BodyHandlers.ofPublisher() - override protected def bodyHandlerBodyToBody(p: Publisher[ju.List[ByteBuffer]]): Observable[Array[Byte]] = + override protected def lowLevelBodyToBody(p: Publisher[ju.List[ByteBuffer]]): Observable[Array[Byte]] = Observable .fromReactivePublisher(FlowAdapters.toPublisher(p)) .flatMapIterable(_.asScala.toList) .map(_.safeRead()) + override protected def cancelLowLevelBody(p: Publisher[ju.List[ByteBuffer]]): Unit = cancelPublisher(p) + + override protected def ensureOnAbnormal[T](effect: Task[T])(finalizer: => Task[Unit]): Task[T] = + effect.guaranteeCase { exit => + if (exit == ExitCase.Completed) Task.unit else finalizer.onErrorHandleWith(t => Task.eval(t.printStackTrace())) + } + override protected def emptyBody(): Observable[Array[Byte]] = Observable.empty override protected def bodyToLimitedBody(b: Observable[Array[Byte]], limit: Long): Observable[Array[Byte]] = { diff --git a/effects/zio/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala b/effects/zio/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala index e32de7007f..0d5e0a3616 100644 --- a/effects/zio/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala +++ b/effects/zio/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala @@ -3,28 +3,43 @@ package sttp.client4.httpclient.zio import _root_.zio.interop.reactivestreams._ import org.reactivestreams.FlowAdapters import sttp.capabilities.zio.ZioStreams -import sttp.client4.httpclient.{HttpClientAsyncBackend, HttpClientBackend} -import sttp.client4.impl.zio.{RIOMonadAsyncError, ZioSimpleQueue} +import sttp.client4.BackendOptions +import sttp.client4.GenericRequest +import sttp.client4.Response +import sttp.client4.WebSocketStreamBackend +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor +import sttp.client4.httpclient.HttpClientAsyncBackend +import sttp.client4.httpclient.HttpClientBackend +import sttp.client4.impl.zio.DeflateZioCompressor +import sttp.client4.impl.zio.DeflateZioDecompressor +import sttp.client4.impl.zio.GZipZioCompressor +import sttp.client4.impl.zio.GZipZioDecompressor +import sttp.client4.impl.zio.RIOMonadAsyncError +import sttp.client4.impl.zio.ZioSimpleQueue import sttp.client4.internal._ -import sttp.client4.internal.httpclient.{BodyFromHttpClient, BodyToHttpClient, Sequencer} +import sttp.client4.internal.httpclient.BodyFromHttpClient +import sttp.client4.internal.httpclient.BodyToHttpClient +import sttp.client4.internal.httpclient.Sequencer +import sttp.client4.internal.httpclient.cancelPublisher import sttp.client4.internal.ws.SimpleQueue import sttp.client4.testing.WebSocketStreamBackendStub -import sttp.client4.{wrappers, BackendOptions, GenericRequest, Response, WebSocketStreamBackend} +import sttp.client4.wrappers import sttp.monad.MonadError -import zio.Chunk.ByteArray import zio._ +import zio.Chunk.ByteArray import zio.stream.ZStream -import java.net.http.HttpRequest.{BodyPublisher, BodyPublishers} +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpRequest.BodyPublisher +import java.net.http.HttpRequest.BodyPublishers +import java.net.http.HttpResponse import java.net.http.HttpResponse.BodyHandlers -import java.net.http.{HttpClient, HttpRequest, HttpResponse} import java.nio.ByteBuffer import java.util -import java.util.concurrent.Flow.Publisher import java.{util => ju} -import sttp.client4.compression.Compressor -import sttp.client4.impl.zio.{DeflateZioCompressor, DeflateZioDecompressor, GZipZioCompressor, GZipZioDecompressor} -import sttp.client4.compression.CompressionHandlers +import java.util.concurrent.Flow.Publisher class HttpClientZioBackend private ( client: HttpClient, @@ -52,12 +67,19 @@ class HttpClientZioBackend private ( override protected def emptyBody(): ZStream[Any, Throwable, Byte] = ZStream.empty - override protected def bodyHandlerBodyToBody(p: Publisher[util.List[ByteBuffer]]): ZStream[Any, Throwable, Byte] = + override protected def lowLevelBodyToBody(p: Publisher[util.List[ByteBuffer]]): ZStream[Any, Throwable, Byte] = FlowAdapters.toPublisher(p).toZIOStream().mapConcatChunk { list => val a = Chunk.fromJavaIterable(list).flatMap(_.safeRead()).toArray ByteArray(a, 0, a.length) } + override protected def cancelLowLevelBody(p: Publisher[ju.List[ByteBuffer]]): Unit = cancelPublisher(p) + + override protected def ensureOnAbnormal[T](effect: Task[T])(finalizer: => Task[Unit]): Task[T] = effect.onExit { + exit => + if (exit.isSuccess) ZIO.unit else finalizer.catchAll(t => ZIO.logErrorCause("Error in finalizer", Cause.fail(t))) + }.resurrect + override protected val bodyToHttpClient: BodyToHttpClient[Task, ZioStreams, R] = new BodyToHttpClient[Task, ZioStreams, R] { override val streams: ZioStreams = ZioStreams diff --git a/effects/zio1/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala b/effects/zio1/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala index a51257aa3a..ddd0443e09 100644 --- a/effects/zio1/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala +++ b/effects/zio1/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala @@ -2,30 +2,42 @@ package sttp.client4.httpclient.zio import _root_.zio.interop.reactivestreams._ import org.reactivestreams.FlowAdapters +import sttp.capabilities.StreamMaxLengthExceededException import sttp.capabilities.zio.ZioStreams -import sttp.client4.httpclient.{HttpClientAsyncBackend, HttpClientBackend} -import sttp.client4.impl.zio.{RIOMonadAsyncError, ZioSimpleQueue} +import sttp.client4.BackendOptions +import sttp.client4.GenericRequest +import sttp.client4.Response +import sttp.client4.WebSocketStreamBackend +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor +import sttp.client4.httpclient.HttpClientAsyncBackend +import sttp.client4.httpclient.HttpClientBackend +import sttp.client4.impl.zio.RIOMonadAsyncError +import sttp.client4.impl.zio.ZioSimpleQueue import sttp.client4.internal._ -import sttp.client4.internal.httpclient.{BodyFromHttpClient, BodyToHttpClient, Sequencer} +import sttp.client4.internal.httpclient.BodyFromHttpClient +import sttp.client4.internal.httpclient.BodyToHttpClient +import sttp.client4.internal.httpclient.Sequencer +import sttp.client4.internal.httpclient.cancelPublisher import sttp.client4.internal.ws.SimpleQueue import sttp.client4.testing.WebSocketStreamBackendStub -import sttp.client4.{wrappers, BackendOptions, GenericRequest, Response, WebSocketStreamBackend} +import sttp.client4.wrappers import sttp.monad.MonadError -import zio.Chunk.ByteArray import zio._ +import zio.Chunk.ByteArray import zio.stream.ZStream -import java.net.http.HttpRequest.{BodyPublisher, BodyPublishers} +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpRequest.BodyPublisher +import java.net.http.HttpRequest.BodyPublishers +import java.net.http.HttpResponse import java.net.http.HttpResponse.BodyHandlers -import java.net.http.{HttpClient, HttpRequest, HttpResponse} import java.nio.ByteBuffer import java.util -import java.util.concurrent.Flow.Publisher import java.{util => ju} +import java.util.concurrent.Flow.Publisher import scala.collection.JavaConverters._ -import sttp.client4.compression.CompressionHandlers -import sttp.client4.compression.Compressor -import sttp.capabilities.StreamMaxLengthExceededException class HttpClientZioBackend private ( client: HttpClient, @@ -53,12 +65,19 @@ class HttpClientZioBackend private ( override protected def emptyBody(): ZStream[Any, Throwable, Byte] = ZStream.empty - override protected def bodyHandlerBodyToBody(p: Publisher[util.List[ByteBuffer]]): ZStream[Any, Throwable, Byte] = + override protected def lowLevelBodyToBody(p: Publisher[util.List[ByteBuffer]]): ZStream[Any, Throwable, Byte] = FlowAdapters .toPublisher(p) .toStream() .mapConcatChunk(list => ByteArray(list.asScala.toList.flatMap(_.safeRead()).toArray)) + override protected def cancelLowLevelBody(p: Publisher[ju.List[ByteBuffer]]): Unit = cancelPublisher(p) + + override protected def ensureOnAbnormal[T](effect: Task[T])(finalizer: => Task[Unit]): Task[T] = effect.onExit { + exit => + if (exit.succeeded) ZIO.unit else finalizer.catchAll(t => ZIO.effect(t.printStackTrace()).orDie) + }.resurrect + override protected val bodyToHttpClient: BodyToHttpClient[Task, ZioStreams, R] = new BodyToHttpClient[Task, ZioStreams, R] { override val streams: ZioStreams = ZioStreams diff --git a/okhttp-backend/monix/src/main/scala/sttp/client4/okhttp/monix/OkHttpMonixBackend.scala b/okhttp-backend/monix/src/main/scala/sttp/client4/okhttp/monix/OkHttpMonixBackend.scala index 7af1837595..7607dff7a4 100644 --- a/okhttp-backend/monix/src/main/scala/sttp/client4/okhttp/monix/OkHttpMonixBackend.scala +++ b/okhttp-backend/monix/src/main/scala/sttp/client4/okhttp/monix/OkHttpMonixBackend.scala @@ -25,6 +25,7 @@ import scala.concurrent.Future import sttp.client4.compression.CompressionHandlers import sttp.client4.compression.Compressor import sttp.client4.compression.Decompressor +import cats.effect.ExitCase class OkHttpMonixBackend private ( client: OkHttpClient, @@ -112,6 +113,11 @@ class OkHttpMonixBackend private ( override protected def createSimpleQueue[T]: Task[SimpleQueue[Task, T]] = Task.eval(new MonixSimpleQueue[T](webSocketBufferCapacity)) + + override protected def ensureOnAbnormal[T](effect: Task[T])(finalizer: => Task[Unit]): Task[T] = + effect.guaranteeCase { exit => + if (exit == ExitCase.Completed) Task.unit else finalizer.onErrorHandleWith(t => Task.eval(t.printStackTrace())) + } } object OkHttpMonixBackend { diff --git a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpAsyncBackend.scala b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpAsyncBackend.scala index 39993d9291..80f86d11ff 100644 --- a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpAsyncBackend.scala +++ b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpAsyncBackend.scala @@ -11,6 +11,7 @@ import sttp.client4.{ignore, GenericRequest, Response} import sttp.monad.{Canceler, MonadAsyncError} import sttp.client4.compression.CompressionHandlers import java.io.InputStream +import java.util.concurrent.atomic.AtomicReference abstract class OkHttpAsyncBackend[F[_], S <: Streams[S], P]( client: OkHttpClient, @@ -19,51 +20,72 @@ abstract class OkHttpAsyncBackend[F[_], S <: Streams[S], P]( compressionHandlers: CompressionHandlers[P, InputStream] ) extends OkHttpBackend[F, S, P](client, closeClient, compressionHandlers) { + // #1987: see the comments in HttpClientAsyncBackend + protected def ensureOnAbnormal[T](effect: F[T])(finalizer: => F[Unit]): F[T] + override protected def sendRegular[T](request: GenericRequest[T, R]): F[Response[T]] = { val nativeRequest = convertRequest(request) - monad.flatten(monad.async[F[Response[T]]] { cb => - def success(r: F[Response[T]]): Unit = cb(Right(r)) + val okHttpResponse = new AtomicReference[OkHttpResponse]() + ensureOnAbnormal { + monad.flatten(monad.async[F[Response[T]]] { cb => + def success(r: F[Response[T]]): Unit = cb(Right(r)) - def error(t: Throwable): Unit = cb(Left(t)) + def error(t: Throwable): Unit = cb(Left(t)) - val call = OkHttpBackend - .updateClientIfCustomReadTimeout(request, client) - .newCall(nativeRequest) + val call = OkHttpBackend + .updateClientIfCustomReadTimeout(request, client) + .newCall(nativeRequest) - call.enqueue(new Callback { - override def onFailure(call: Call, e: IOException): Unit = - error(e) + call.enqueue(new Callback { + override def onFailure(call: Call, e: IOException): Unit = + error(e) - override def onResponse(call: Call, response: OkHttpResponse): Unit = - try success(readResponse(response, request, request.response)) - catch { - case e: Exception => - response.close() - error(e) + override def onResponse(call: Call, response: OkHttpResponse): Unit = { + okHttpResponse.set(response) + try success(readResponse(response, request, request.response)) + catch { + case e: Exception => + try response.close() + finally error(e) + } } - }) + }) - Canceler(() => call.cancel()) - }) + Canceler(() => call.cancel()) + }) + } { + monad.eval { + val response = okHttpResponse.get() + if (response != null) response.close() + } + } } override protected def sendWebSocket[T]( request: GenericRequest[T, R] ): F[Response[T]] = { val nativeRequest = convertRequest(request) - monad.flatten( - createSimpleQueue[WebSocketEvent] - .flatMap { queue => - monad.async[F[Response[T]]] { cb => - val listener = createListener(queue, cb, request) - val ws = OkHttpBackend - .updateClientIfCustomReadTimeout(request, client) - .newWebSocket(nativeRequest, listener) - - Canceler(() => ws.cancel()) + val okHttpWS = new AtomicReference[okhttp3.WebSocket]() + ensureOnAbnormal { + monad.flatten( + createSimpleQueue[WebSocketEvent] + .flatMap { queue => + monad.async[F[Response[T]]] { cb => + val listener = createListener(queue, cb, request) + val ws = OkHttpBackend + .updateClientIfCustomReadTimeout(request, client) + .newWebSocket(nativeRequest, listener) + + Canceler(() => ws.cancel()) + } } - } - ) + ) + } { + monad.eval { + val ws = okHttpWS.get() + if (ws != null) ws.cancel() + } + } } private def createListener[T]( diff --git a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpFutureBackend.scala b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpFutureBackend.scala index 04d56559c5..cc4d9011fb 100644 --- a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpFutureBackend.scala +++ b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpFutureBackend.scala @@ -42,6 +42,11 @@ class OkHttpFutureBackend private ( override val streams: NoStreams = NoStreams override def streamToRequestBody(stream: Nothing, mt: MediaType, cl: Option[Long]): OkHttpRequestBody = stream } + + override protected def ensureOnAbnormal[T](effect: Future[T])(finalizer: => Future[Unit]): Future[T] = + effect.recoverWith { case e => + finalizer.recoverWith { case e2 => e.addSuppressed(e2); Future.failed(e) }.flatMap(_ => Future.failed(e)) + } } object OkHttpFutureBackend { diff --git a/testing/server/src/main/scala/sttp/client4/testing/server/HttpServer.scala b/testing/server/src/main/scala/sttp/client4/testing/server/HttpServer.scala index 16fa31d22a..a1716f3c18 100644 --- a/testing/server/src/main/scala/sttp/client4/testing/server/HttpServer.scala +++ b/testing/server/src/main/scala/sttp/client4/testing/server/HttpServer.scala @@ -131,6 +131,23 @@ private class HttpServer(port: Int, info: String => Unit) extends AutoCloseable discardEntity(complete(isChunked)) } } + } ~ + path("slow") { + get { + complete { + val source = akka.stream.scaladsl.Source + .repeat("a") + .throttle(1, 100.millis) + .take(20) // producing the entire response will take 2s + + HttpResponse( + entity = HttpEntity.Chunked.fromData( + ContentTypes.`text/plain(UTF-8)`, + source.map(str => ByteString(str)) + ) + ) + } + } } } ~ pathPrefix("sse") { path("echo3") {