Skip to content

Commit

Permalink
Request body callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Jan 20, 2025
1 parent 518afd2 commit d741313
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package sttp.client4.httpclient

import sttp.attributes.AttributeKey

import java.nio.ByteBuffer

/** Defines a callback to be invoked when subsequent parts of the request body to be sent are created, just before they
* are sent over the network.
*
* When a request is sent, `onInit` is invoked exactly once with the content length (if it is known). This is followed
* by arbitrary number of `onNext` calls. Finally, either `onComplete` or `onError` are called exactly once.
*
* All of the methods should be non-blocking and complete as fast as possible, so as not to obstruct sending data over
* the network.
*
* To register a callback, set the [[RequestBodyCallback.Attribute]] on a request, using the
* [[sttp.client4.Request.attribute]] method.
*/
trait RequestBodyCallback {
def onInit(contentLength: Option[Long]): Unit

def onNext(b: ByteBuffer): Unit

def onComplete(): Unit
def onError(t: Throwable): Unit
}

object RequestBodyCallback {

/** The key of the attribute that should be set on a request, to receive callbacks when the request body is sent. */
val Attribute = AttributeKey[RequestBodyCallback]
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import java.nio.{Buffer, ByteBuffer}
import java.util.concurrent.Flow
import java.util.function.Supplier
import scala.collection.JavaConverters._
import java.util.concurrent.Flow.Subscription
import sttp.client4.httpclient.RequestBodyCallback

private[client4] trait BodyToHttpClient[F[_], S, R] {
val streams: Streams[S]
Expand Down Expand Up @@ -44,10 +46,15 @@ private[client4] trait BodyToHttpClient[F[_], S, R] {
multipartBodyPublisher.build().unit
}

contentLength match {
val bodyWithContentLength = contentLength match {
case None => body
case Some(cl) => body.map(b => withKnownContentLength(b, cl))
}

request.attribute(RequestBodyCallback.Attribute) match {
case None => bodyWithContentLength
case Some(callback) => bodyWithContentLength.map(withCallback(_, callback))
}
}

def streamToPublisher(stream: streams.BinaryStream): F[BodyPublisher]
Expand Down Expand Up @@ -90,6 +97,46 @@ private[client4] trait BodyToHttpClient[F[_], S, R] {
override def subscribe(subscriber: Flow.Subscriber[_ >: ByteBuffer]): Unit = delegate.subscribe(subscriber)
}

private def withCallback(
delegate: HttpRequest.BodyPublisher,
callback: RequestBodyCallback
): HttpRequest.BodyPublisher =
new HttpRequest.BodyPublisher {
override def contentLength(): Long = delegate.contentLength()
override def subscribe(subscriber: Flow.Subscriber[_ >: ByteBuffer]): Unit = {
delegate.subscribe(new Flow.Subscriber[ByteBuffer] {
override def onSubscribe(subscription: Subscription): Unit = {
runCallbackSafe {
val cl = contentLength()
callback.onInit(if (cl < 0) None else Some(cl))
}
subscriber.onSubscribe(subscription)
}

override def onNext(item: ByteBuffer): Unit = {
runCallbackSafe(callback.onNext(item))
subscriber.onNext(item)
}

override def onComplete(): Unit = {
runCallbackSafe(callback.onComplete())
subscriber.onComplete()
}
override def onError(throwable: Throwable): Unit = {
runCallbackSafe(callback.onError(throwable))
subscriber.onError(throwable)
}

private def runCallbackSafe(f: => Unit): Unit =
try f
catch {
case e: Exception =>
System.getLogger(this.getClass().getName()).log(System.Logger.Level.ERROR, "Error in callback", e)
}
})
}
}

// https://stackoverflow.com/a/6603018/362531
private class ByteBufferBackedInputStream(buf: ByteBuffer) extends InputStream {
override def read: Int = {
Expand Down
46 changes: 45 additions & 1 deletion core/src/test/scalajvm/sttp/client4/HttpClientSyncHttpTest.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
package sttp.client4

import sttp.client4.httpclient.HttpClientSyncBackend
import sttp.client4.testing.{ConvertToFuture, HttpTest}
import sttp.client4.httpclient.RequestBodyCallback
import sttp.client4.testing.ConvertToFuture
import sttp.client4.testing.HttpTest
import sttp.model.StatusCode
import sttp.shared.Identity

import java.nio.ByteBuffer
import java.util.concurrent.ConcurrentLinkedQueue
import scala.collection.JavaConverters._

class HttpClientSyncHttpTest extends HttpTest[Identity] {
override val backend: WebSocketSyncBackend = HttpClientSyncBackend()
override implicit val convertToFuture: ConvertToFuture[Identity] = ConvertToFuture.id
Expand All @@ -13,4 +20,41 @@ class HttpClientSyncHttpTest extends HttpTest[Identity] {
override def supportsDeflateWrapperChecking = false

override def timeoutToNone[T](t: Identity[T], timeoutMillis: Int): Identity[Option[T]] = Some(t)

"callback" - {
"should be invoked as described in the callback protocol" in {
val trail = new ConcurrentLinkedQueue[String]()
val callback = new RequestBodyCallback {

override def onInit(contentLength: Option[Long]): Unit = {
val _ = trail.add(s"init ${contentLength.getOrElse(-1)}")
}

override def onNext(b: ByteBuffer): Unit = {
val _ = trail.add(s"next ${b.remaining()}")
}

override def onComplete(): Unit = {
val _ = trail.add(s"complete")
}

override def onError(t: Throwable): Unit = {
val _ = trail.add(s"error")
}
}

val contentLength = 2048 * 100
val req = postEcho.body("x" * contentLength).attribute(RequestBodyCallback.Attribute, callback)

(req.send(backend): Identity[Response[Either[String, String]]]).toFuture().map { response =>
val t = trail.asScala

t.size should be >= 3
t.head shouldBe s"init $contentLength"
t.tail.init.foreach(_ should startWith("next "))
t.last shouldBe "complete"
response.code shouldBe StatusCode.Ok
}
}
}
}
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ Third party projects:
other/resilience
other/openapi
other/sse
other/body_callbacks
.. toctree::
:maxdepth: 2
Expand Down
43 changes: 43 additions & 0 deletions docs/other/body_callbacks.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Body callbacks

When using the `HttpClient`-based backends (which includes the `DefaultSyncBackend` and `DefaultFutureBackend` on the
JVM), it is possible to register body-related callbacks.

This feature is not available in other backends, and setting the attribute described below will have no effect.

## Request body callbacks

Defines a callback to be invoked when subsequent parts of the request body to be sent are created, just before they
are sent over the network. The callback is defined through an instance of the `RequestBodyCallback` trait.

When a request is sent, the `RequestBodyCallback.onInit` method is invoked exactly once with the content length (if it
is known). This is followed by arbitrary number of `onNext` calls. Finally, either `onComplete` or `onError` are called
exactly once.

All of the methods in the `RequestBodyCallback` implementation should be non-blocking and complete as fast as possible,
so as not to obstruct sending data over the network.

To register a callback, set the `RequestBodyCallback.Attribute` on a request. For example:

```scala mdoc:compile-only
import sttp.client4.*
import sttp.client4.httpclient.{HttpClientSyncBackend, RequestBodyCallback}
import java.nio.ByteBuffer
import java.io.File

val backend = HttpClientSyncBackend()

val fileToSend: File = ???
val callback = new RequestBodyCallback {
override def onInit(contentLength: Option[Long]): Unit = println(s"expected content length: $contentLength")
override def onNext(b: ByteBuffer): Unit = println(s"next, bytes: ${b.remaining()}")
override def onComplete(): Unit = println(s"complete")
override def onError(t: Throwable): Unit = println(s"error: ${t.getMessage}")
}

val response = basicRequest
.get(uri"http://example.com")
.body(fileToSend)
.attribute(RequestBodyCallback.Attribute, callback)
.send(backend)
```

0 comments on commit d741313

Please sign in to comment.