Skip to content

Commit

Permalink
Add an option to limit the response length for JVM backends (#2410)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw authored Jan 21, 2025
1 parent 7a6aea7 commit 371b40f
Show file tree
Hide file tree
Showing 28 changed files with 284 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,19 @@ class AkkaHttpBackend private (
val body = bodyFromAkka(
r.response,
responseMetadata,
wsFlow.map(Right(_)).getOrElse(Left(decodeAkkaResponse(hr, r.autoDecompressionEnabled)))
wsFlow
.map(Right(_))
.getOrElse(
Left(decodeAkkaResponse(limitPekkoResponseIfNeeded(hr, r.maxResponseBodyLength), r.autoDecompressionEnabled))
)
)

body.map(sttp.client4.Response(_, code, statusText, headers, Nil, r.onlyMetadata))
}

private def limitPekkoResponseIfNeeded(response: HttpResponse, limit: Option[Long]): HttpResponse =
limit.fold(response)(l => response.withEntity(response.entity.withSizeLimit(l)))

// http://doc.akka.io/docs/akka-http/10.0.7/scala/http/common/de-coding.html
private def decodeAkkaResponse(response: HttpResponse, autoDecompressionEnabled: Boolean): HttpResponse =
if (!response.status.allowsEntity() || !autoDecompressionEnabled) response
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sttp.client4.akkahttp

import akka.http.scaladsl.model.HttpResponse
import akka.http.scaladsl.model.EntityStreamSizeException
import sttp.client4.{GenericRequest, SttpClientException}
import sttp.model.{Header, HeaderNames}

Expand All @@ -25,6 +26,7 @@ private[akkahttp] object FromAkka {
case _ => Some(new SttpClientException.ReadException(request, e))
}
case e: akka.stream.scaladsl.TcpIdleTimeoutException => Some(new SttpClientException.TimeoutException(request, e))
case e: EntityStreamSizeException => Some(new SttpClientException.ReadException(request, e))
case e: Exception => SttpClientException.defaultExceptionToSttpClientException(request, e)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import sttp.model._
import sttp.monad.syntax._
import sttp.monad.{Canceler, MonadAsyncError}
import sttp.client4.compression.Compressor
import com.linecorp.armeria.common.ContentTooLargeException

abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]](
client: WebClient = WebClient.of(),
Expand Down Expand Up @@ -122,7 +123,7 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]](

val contentType = customContentType.getOrElse(ArmeriaMediaType.parse(request.body.defaultContentType.toString()))

body match {
val withBody = body match {
case NoBody => requestPreparation
case StringBody(s, encoding, _) =>
val charset =
Expand All @@ -149,6 +150,8 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]](
case StreamBody(s) =>
requestPreparation.content(contentType, streamToPublisher(s.asInstanceOf[streams.BinaryStream]))
}

request.maxResponseBodyLength.fold(withBody)(l => withBody.maxResponseLength(l))
}

private def methodToArmeria(method: Method): HttpMethod =
Expand Down Expand Up @@ -203,10 +206,9 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]](
case ex: UnprocessedRequestException =>
// The cause of an UnprocessedRequestException is always not null
Some(new ConnectException(request, ex.getCause.asInstanceOf[Exception]))
case ex: ResponseTimeoutException =>
Some(new TimeoutException(request, ex))
case ex: ClosedStreamException =>
Some(new ReadException(request, ex))
case ex: ResponseTimeoutException => Some(new TimeoutException(request, ex))
case ex: ClosedStreamException => Some(new ReadException(request, ex))
case ex: ContentTooLargeException => Some(new ReadException(request, ex))
case ex =>
SttpClientException.defaultExceptionToSttpClientException(request, ex)
}
Expand Down
7 changes: 6 additions & 1 deletion core/src/main/scala/sttp/client4/RequestOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ import sttp.client4.logging.LoggingOptions
* [[sttp.model.Encodings.Gzip]] and [[sttp.model.Encodings.Deflate]] encodings, but others might available as well;
* refer to the backend documentation for details. If an encoding is not supported, an exception is thrown / a failed
* effect returned, when sending the request.
* @param maxResponseBodyLength
* The maximum length of the response body (in bytes). When sending the request, if the response body is longer, an
* exception is thrown / a failed effect is returned. By default, when `None`, the is no limit on the response body's
* length.
*/
case class RequestOptions(
followRedirects: Boolean,
Expand All @@ -25,5 +29,6 @@ case class RequestOptions(
decompressResponseBody: Boolean,
compressRequestBody: Option[String],
httpVersion: Option[HttpVersion],
loggingOptions: LoggingOptions
loggingOptions: LoggingOptions,
maxResponseBodyLength: Option[Long]
)
3 changes: 2 additions & 1 deletion core/src/main/scala/sttp/client4/SttpApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ trait SttpApi extends SttpExtensions with UriInterpolator {
decompressResponseBody = true,
compressRequestBody = None,
httpVersion = None,
loggingOptions = LoggingOptions()
loggingOptions = LoggingOptions(),
maxResponseBodyLength = None
),
AttributeMap.Empty
)
Expand Down
77 changes: 77 additions & 0 deletions core/src/main/scala/sttp/client4/internal/LimitedInputStream.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package sttp.tapir.server.jdkhttp.internal

import sttp.capabilities.StreamMaxLengthExceededException
import java.io.FilterInputStream
import java.io.InputStream
import java.io.IOException

class FailingLimitedInputStream(in: InputStream, limit: Long) extends LimitedInputStream(in, limit) {
override def onLimit: Int = {
throw new StreamMaxLengthExceededException(limit)
}
}

/** Based on Guava's https://github.com/google/guava/blob/master/guava/src/com/google/common/io/ByteStreams.java */
class LimitedInputStream(in: InputStream, limit: Long) extends FilterInputStream(in) {
protected var left: Long = limit
private var mark: Long = -1L

override def available(): Int = Math.min(in.available(), left.toInt)

override def mark(readLimit: Int): Unit = {
in.mark(readLimit)
mark = left
}

override def read(): Int = {
if (left == 0) {
onLimit
} else {
val result = in.read()
if (result != -1) {
left -= 1
}
result
}
}

override def read(b: Array[Byte], off: Int, len: Int): Int = {
if (left == 0) {
// Temporarily perform a read to check if more bytes are available
val checkRead = in.read()
if (checkRead == -1) {
-1 // No more bytes available in the stream
} else {
onLimit
}
} else {
val adjustedLen = Math.min(len, left.toInt)
val result = in.read(b, off, adjustedLen)
if (result != -1) {
left -= result
}
result
}
}

override def reset(): Unit = {
if (!in.markSupported) {
throw new IOException("Mark not supported")
}
if (mark == -1) {
throw new IOException("Mark not set")
}

in.reset()
left = mark
}

override def skip(n: Long): Long = {
val toSkip = Math.min(n, left)
val skipped = in.skip(toSkip)
left -= skipped
skipped
}

protected def onLimit: Int = -1
}
8 changes: 8 additions & 0 deletions core/src/main/scala/sttp/client4/requestBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,14 @@ trait PartialRequestBuilder[+PR <: PartialRequestBuilder[PR, R], +R]
*/
def loggingOptions: LoggingOptions = options.loggingOptions

/** Set the maximum response body length. When sending the request, if the response body is longer, an exception is
* thrown / a failed effect is returned. By default, there's no limit on the response body's length.
*/
def maxResponseBodyLength(limit: Long): PR = withOptions(options.copy(maxResponseBodyLength = Some(limit)))

/** The maximum response body length, if any. */
def maxResponseBodyLength: Option[Long] = options.maxResponseBodyLength

/** Reads a per-request attribute for the given key, if present. */
def attribute[T](k: AttributeKey[T]): Option[T] = attributes.get(k)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import sttp.client4.ws.{GotAWebSocketException, NotAWebSocketException}

import scala.annotation.tailrec
import sttp.client4.SttpClientException.ResponseHandlingException
import sttp.capabilities.StreamMaxLengthExceededException

trait SttpClientExceptionExtensions {
@tailrec
Expand All @@ -25,6 +26,7 @@ trait SttpClientExceptionExtensions {
case e: java.io.IOException => Some(new ReadException(request, e))
case e: NotAWebSocketException => Some(new ReadException(request, e))
case e: GotAWebSocketException => Some(new ReadException(request, e))
case e: StreamMaxLengthExceededException => Some(new ReadException(request, e))
case e: ResponseException[_] => Some(new ResponseHandlingException(request, e))
case e if e.getCause != null && e.getCause.isInstanceOf[Exception] =>
defaultExceptionToSttpClientException(request, e.getCause.asInstanceOf[Exception])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ abstract class HttpClientAsyncBackend[F[_], S <: Streams[S], BH, B](
protected def createSequencer: F[Sequencer[F]]
protected def createBodyHandler: HttpResponse.BodyHandler[BH]
protected def bodyHandlerBodyToBody(p: BH): B
protected def bodyToLimitedBody(b: B, limit: Long): B
protected def emptyBody(): B

override def sendRegular[T](request: GenericRequest[T, R]): F[Response[T]] =
Expand All @@ -55,7 +56,9 @@ abstract class HttpClientAsyncBackend[F[_], S <: Streams[S], BH, B](
.map(bodyHandlerBodyToBody)
.getOrElse(emptyBody())

try success(readResponse(t, Left(body), request))
val limitedBody = request.options.maxResponseBodyLength.fold(body)(bodyToLimitedBody(body, _))

try success(readResponse(t, Left(limitedBody), request))
catch {
case e: Exception => error(e)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import scala.concurrent.{ExecutionContext, Future}
import sttp.client4.compression.Compressor
import sttp.client4.compression.CompressionHandlers
import sttp.client4.compression.Decompressor
import sttp.tapir.server.jdkhttp.internal.FailingLimitedInputStream

class HttpClientFutureBackend private (
client: HttpClient,
Expand Down Expand Up @@ -63,6 +64,9 @@ class HttpClientFutureBackend private (
override protected def bodyHandlerBodyToBody(p: InputStream): InputStream = p

override protected def emptyBody(): InputStream = emptyInputStream()

override protected def bodyToLimitedBody(b: InputStream, limit: Long): InputStream =
new FailingLimitedInputStream(b, limit)
}

object HttpClientFutureBackend {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.util.concurrent.{ArrayBlockingQueue, CompletionException}
import sttp.client4.compression.Compressor
import sttp.client4.compression.CompressionHandlers
import sttp.client4.compression.Decompressor
import sttp.tapir.server.jdkhttp.internal.FailingLimitedInputStream

class HttpClientSyncBackend private (
client: HttpClient,
Expand All @@ -38,7 +39,9 @@ class HttpClientSyncBackend private (
override protected def sendRegular[T](request: GenericRequest[T, R]): Response[T] = {
val jRequest = customizeRequest(convertRequest(request))
val response = client.send(jRequest, BodyHandlers.ofInputStream())
readResponse(response, Left(response.body()), request)
val body = response.body()
val limitedBody = request.options.maxResponseBodyLength.fold(body)(new FailingLimitedInputStream(body, _))
readResponse(response, Left(limitedBody), request)
}

override protected def sendWebSocket[T](request: GenericRequest[T, R]): Response[T] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import scala.concurrent.duration.Duration
import sttp.client4.GenericRequestBody
import sttp.client4.compression.CompressionHandlers
import sttp.client4.compression.Decompressor
import sttp.tapir.server.jdkhttp.internal.FailingLimitedInputStream

class HttpURLConnectionBackend private (
opts: BackendOptions,
Expand Down Expand Up @@ -81,7 +82,8 @@ class HttpURLConnectionBackend private (

try {
val is = c.getInputStream
readResponse(c, is, r)
val limitedIs = r.options.maxResponseBodyLength.fold(is)(new FailingLimitedInputStream(is, _))
readResponse(c, limitedIs, r)
} catch {
case e: CharacterCodingException => throw e
case e: UnsupportedEncodingException => throw e
Expand Down
25 changes: 25 additions & 0 deletions core/src/test/scala/sttp/client4/testing/HttpTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ trait HttpTest[F[_]]
protected def supportsDeflateWrapperChecking = true
protected def supportsEmptyContentEncoding = true
protected def supportsNonAsciiHeaderValues = true
protected def supportsMaxResponseBodyLength = true

"request parsing" - {
"Inf timeout should not throw exception" in {
Expand Down Expand Up @@ -739,6 +740,30 @@ trait HttpTest[F[_]]
}
}

if (supportsMaxResponseBodyLength) {
"maxResponseBodyLength" - {
"should be enforced when set" in {
val req = postEchoExact
.body("01234567890123456789") // 20 bytes
.maxResponseBodyLength(10)

Future(req.send(backend)).flatMap(_.toFuture()).failed.map { e =>
e shouldBe a[SttpClientException.ReadException]
}
}

"should have no effect when the limit is not reached" in {
val req = postEchoExact
.body("0123456789")
.maxResponseBodyLength(10) // the limit is reached exactly

Future(req.send(backend)).flatMap(_.toFuture()).map { r =>
r.body shouldBe Right("0123456789")
}
}
}
}

override protected def afterAll(): Unit = {
backend.close().toFuture()
super.afterAll()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ abstract class AbstractFetchHttpTest[F[_], +P] extends HttpTest[F] {

override protected def supportsHostHeaderOverride = false

// not yet implemented
override protected def supportsMaxResponseBodyLength = false

override def supportsCancellation: Boolean = false

override def timeoutToNone[T](t: F[T], timeoutMillis: Int): F[Option[T]] = ???
Expand Down
2 changes: 1 addition & 1 deletion docs/requests/body.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Body
# Request body

## Text data

Expand Down
14 changes: 11 additions & 3 deletions docs/responses/body.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Response body descriptions
# Response body

By default, the received response body will be read as a `Either[String, String]`, using the encoding specified in the `Content-Type` response header (and if none is specified, using `UTF-8`). This is of course configurable: response bodies can be ignored, deserialized into custom types, received as a stream or saved to a file.

Expand Down Expand Up @@ -251,7 +251,7 @@ val response: Future[Response[Either[String, Source[ByteString, Any]]]] =

It's also possible to parse the received stream as server-sent events (SSE), using an implementation-specific mapping function. Refer to the documentation for particular backends for more details.

## Decompressing bodies (handling the Conent-Encoding header)
## Decompressing bodies (handling the Content-Encoding header)

If the response body is compressed using `gzip` or `deflate` algorithms, it will be decompressed if the `decompressResponseBody` request option is set. By default this is set to `true`, and can be disabled using the `request.disableAutoDecompression` method.

Expand All @@ -260,4 +260,12 @@ The encoding of the response body is determined by the encodings that are accept
If you'd like to use additional decompression algorithms, you'll need to:

* amend the `Accept-Encoding` header that's set on the request
* add a decompression algorithm to the backend; that can be done on backend creation time, by customising the `compressionHandlers` parameter, and adding a `Decompressor` implementation. Such an implementation has to specify the encoding, which it handles, as well as appropriate body transformation (which is backend-specific).
* add a decompression algorithm to the backend; that can be done on backend creation time, by customizing the `compressionHandlers` parameter, and adding a `Decompressor` implementation. Such an implementation has to specify the encoding, which it handles, as well as appropriate body transformation (which is backend-specific).

## Limiting the response body size

To limit the size of the response body, use the `maxResponseBodyLength` method on the request description. This modified the `RequestOption`s associated with the request. By default, there's no limit set.

When a limit is set and it is exceed, sending the request will fail with a `SttpClientException.ReadException`.

This feature is currently available only for JVM backends.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.net.http.{HttpClient, HttpRequest, HttpResponse}
import sttp.client4.compression.CompressionHandlers
import sttp.client4.compression.Compressor
import sttp.client4.compression.Decompressor
import sttp.tapir.server.jdkhttp.internal.FailingLimitedInputStream

class HttpClientCatsBackend[F[_]: Async] private (
client: HttpClient,
Expand Down Expand Up @@ -68,6 +69,9 @@ class HttpClientCatsBackend[F[_]: Async] private (
override protected def bodyHandlerBodyToBody(p: InputStream): InputStream = p

override protected def emptyBody(): InputStream = emptyInputStream()

override protected def bodyToLimitedBody(b: InputStream, limit: Long): InputStream =
new FailingLimitedInputStream(b, limit)
}

object HttpClientCatsBackend {
Expand Down
Loading

0 comments on commit 371b40f

Please sign in to comment.