Skip to content

Commit

Permalink
airframe-http-netty: Fixes #2938 Handle RPC status and exception at R…
Browse files Browse the repository at this point in the history
…PCResponseFilter (#2944)
  • Loading branch information
xerial authored May 7, 2023
1 parent 09ca7fc commit 5d86f11
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import io.netty.buffer.Unpooled
import io.netty.channel.{ChannelFutureListener, ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.handler.codec.http._
import wvlet.airframe.http.HttpMessage.{Request, Response}
import wvlet.airframe.http.internal.RPCResponseFilter
import wvlet.airframe.http.{
Http,
HttpHeader,
Expand Down Expand Up @@ -88,35 +89,13 @@ class NetthRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi
val nettyResponse = toNettyResponse(v.asInstanceOf[Response])
writeResponse(msg, ctx, nettyResponse)
case OnError(ex) =>
val resp = ex match {
case ex: HttpServerException =>
toNettyResponse(ex.toResponse)
case e: RPCException =>
toNettyResponse(rpcExceptionResponse(e))
case other =>
val ex = RPCStatus.INTERNAL_ERROR_I0.newException(other.getMessage, other)
toNettyResponse(rpcExceptionResponse(ex))
}
writeResponse(msg, ctx, resp)
val resp = RPCStatus.INTERNAL_ERROR_I0.newException(ex.getMessage, ex).toResponse
val nettyResponse = toNettyResponse(resp)
writeResponse(msg, ctx, nettyResponse)
case OnCompletion =>
}
}

private def rpcExceptionResponse(e: RPCException): Response = {
var resp = Http
.response(e.status.httpStatus)
.addHeader(HttpHeader.xAirframeRPCStatus, e.status.code.toString)
try {
// Embed RPCError into the response body
resp = resp.withJson(e.toJson)
} catch {
case ex: Throwable =>
// Show warning
logger.warn(s"Failed to serialize RPCException: ${e}", ex)
}
resp
}

private def writeResponse(req: HttpRequest, ctx: ChannelHandlerContext, resp: DefaultHttpResponse): Unit = {
val keepAlive = HttpStatus.ofCode(resp.status().code()).isSuccessful && HttpUtil.isKeepAlive(req)
if (keepAlive) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import wvlet.airframe.control.ThreadUtil
import wvlet.airframe.http.HttpMessage.Response
import wvlet.airframe.http.{HttpMessage, _}
import wvlet.airframe.http.client.{AsyncClient, SyncClient}
import wvlet.airframe.http.internal.{HttpServerLoggingFilter, LogRotationHttpLogger}
import wvlet.airframe.http.internal.{RPCLoggingFilter, LogRotationHttpLogger, RPCResponseFilter}
import wvlet.airframe.http.router.{ControllerProvider, HttpRequestDispatcher}
import wvlet.airframe.rx.Rx
import wvlet.airframe.{Design, Session}
Expand All @@ -50,7 +50,7 @@ case class NettyServerConfig(
httpLoggerProvider: HttpLoggerConfig => HttpLogger = { (config: HttpLoggerConfig) =>
new LogRotationHttpLogger(config)
},
loggingFilter: HttpLogger => RxHttpFilter = { new HttpServerLoggingFilter(_) }
loggingFilter: HttpLogger => RxHttpFilter = { new RPCLoggingFilter(_) }
) {
lazy val port = serverPort.getOrElse(IOUtil.unusedPort)

Expand Down Expand Up @@ -204,7 +204,11 @@ class NettyServer(config: NettyServerConfig, session: Session) extends AutoClose

private val dispatcher = {
NettyBackend
.rxFilterAdapter(attachContextFilter.andThen(loggingFilter))
.rxFilterAdapter(
attachContextFilter
.andThen(loggingFilter)
.andThen(RPCResponseFilter)
)
.andThen(
HttpRequestDispatcher.newDispatcher(
session = session,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@
package wvlet.airframe.http.netty

import wvlet.airframe.http.client.SyncClient
import wvlet.airframe.http.{Http, HttpMessage, RPC, RPCException, RPCStatus, RxHttpEndpoint, RxHttpFilter, RxRouter}
import wvlet.airframe.http.{
Http,
HttpClientException,
HttpMessage,
RPC,
RPCException,
RPCStatus,
RxHttpEndpoint,
RxHttpFilter,
RxRouter
}
import wvlet.airframe.rx.Rx
import wvlet.airspec.AirSpec

Expand Down Expand Up @@ -50,13 +60,15 @@ object NettyRxFilterTest extends AirSpec {
test("Run server with auth filter", design = _.add(Netty.server.withRouter(router1).designWithSyncClient)) {
(client: SyncClient) =>
test("when no auth header") {
val ex = intercept[RPCException] {
val e = intercept[HttpClientException] {
client.send(
Http.POST("/wvlet.airframe.http.netty.NettyRxFilterTest.MyRPC/hello").withJson("""{"msg":"Netty"}""")
)
}
ex.status shouldBe RPCStatus.UNAUTHENTICATED_U13
ex.message shouldBe "authentication failed"
e.getCause shouldMatch { case ex: RPCException =>
ex.status shouldBe RPCStatus.UNAUTHENTICATED_U13
ex.message shouldBe "authentication failed"
}
}

test("with auth header") {
Expand All @@ -72,12 +84,14 @@ object NettyRxFilterTest extends AirSpec {

test("throw RPCException in a filter", design = _.add(Netty.server.withRouter(router2).designWithSyncClient)) {
(client: SyncClient) =>
val ex = intercept[RPCException] {
val e = intercept[HttpClientException] {
client.send(
Http.POST("/wvlet.airframe.http.netty.NettyRxFilterTest.MyRPC/hello").withJson("""{"msg":"Netty"}""")
)
}
ex.status shouldBe RPCStatus.UNAUTHENTICATED_U13
ex.message shouldBe "authentication failed"
e.getCause shouldMatch { case ex: RPCException =>
ex.status shouldBe RPCStatus.UNAUTHENTICATED_U13
ex.message shouldBe "authentication failed"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
package wvlet.airframe.http.netty

import wvlet.airframe.Design
import wvlet.airframe.http.{Http, RPC, RxRouter}
import wvlet.airframe.http.{Http, HttpHeader, RPC, RPCStatus, RxRouter}
import wvlet.airframe.http.client.SyncClient
import wvlet.airspec.AirSpec

Expand All @@ -39,10 +39,12 @@ class NettyRxRPCServerTest extends AirSpec {
Http.POST("/wvlet.airframe.http.netty.NettyRxRPCServerTest.MyRPC/helloNetty").withJson("""{"msg":"Netty"}""")
)
resp.message.toContentString shouldBe "Hello Netty!"
resp.getHeader(HttpHeader.xAirframeRPCStatus) shouldBe Some(RPCStatus.SUCCESS_S0.code.toString)

val resp2 = client.send(
Http.POST("/wvlet.airframe.http.netty.NettyRxRPCServerTest.MyRPC/helloNetty2").withJson("""{"msg":"Netty"}""")
)
resp2.message.toContentString shouldBe "Hello Netty2!"
resp2.getHeader(HttpHeader.xAirframeRPCStatus) shouldBe Some(RPCStatus.SUCCESS_S0.code.toString)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object HttpClientException extends LogSupport {
val status = adapter.statusOf(response)
val isRPCException: Boolean = adapter.headerOf(response).get(HttpHeader.xAirframeRPCStatus).isDefined
if (isRPCException) {
val cause = HttpClients.parseRPCException(adapter.httpResponseOf(response))
val cause = RPCException.fromResponse(adapter.httpResponseOf(response))
new HttpClientException(adapter.wrap(response), status, cause)
} else {
val content = adapter.contentStringOf(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,13 @@ object HttpMessage {
}

case class StringMessage(content: String) extends Message {
override def isEmpty: Boolean = content.isEmpty
override def toString: String = content
override def toContentString: String = content
override def toContentBytes: Array[Byte] = content.getBytes(StandardCharsets.UTF_8)
}
case class ByteArrayMessage(content: Array[Byte]) extends Message {
override def isEmpty: Boolean = content.isEmpty
override def toString: String = toContentString
override def toContentString: String = {
new String(content, StandardCharsets.UTF_8)
Expand All @@ -179,6 +181,7 @@ object HttpMessage {
class LazyByteArrayMessage(contentReader: => Array[Byte]) extends Message {
// Use lazy evaluation of content body to avoid unnecessary data copy
private lazy val content: Array[Byte] = contentReader
override def isEmpty: Boolean = content.isEmpty
override def toString: String = toContentString
override def toContentString: String = {
new String(content, StandardCharsets.UTF_8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
package wvlet.airframe.http

import wvlet.airframe.codec.{GenericException, GenericStackTraceElement, MessageCodec}
import wvlet.airframe.http.HttpMessage.Response
import wvlet.airframe.http.RPCException.rpcErrorMessageCodec
import wvlet.airframe.http.internal.HttpResponseBodyCodec
import wvlet.airframe.json.Json
import wvlet.airframe.msgpack.spi.MsgPack
import wvlet.log.LogSupport

import scala.util.Try

/**
* RPCException provides a backend-independent (e.g., Finagle or gRPC) RPC error reporting mechanism. Create this
Expand All @@ -36,7 +41,8 @@ case class RPCException(
appErrorCode: Option[Int] = None,
// [optional] Application-specific metadata
metadata: Map[String, Any] = Map.empty
) extends Exception(s"[${status}] ${message}", cause.getOrElse(null)) {
) extends Exception(s"[${status}] ${message}", cause.getOrElse(null))
with LogSupport {

private var _includeStackTrace: Option[Boolean] = None

Expand Down Expand Up @@ -75,6 +81,25 @@ case class RPCException(
def toMsgPack: MsgPack = {
rpcErrorMessageCodec.toMsgPack(toMessage)
}

/**
* Convert this exception to an HTTP response
*/
def toResponse: HttpMessage.Response = {
var resp = Http
.response(status.httpStatus)
.addHeader(HttpHeader.xAirframeRPCStatus, status.code.toString)

try {
// Embed RPCError into the response body
resp = resp.withJson(toJson)
} catch {
case ex: Throwable =>
// Show warning
warn(s"Failed to serialize RPCException: ${this}", ex)
}
resp
}
}

/**
Expand Down Expand Up @@ -121,4 +146,28 @@ object RPCException {
val m = rpcErrorMessageCodec.fromMsgPack(msgpack)
fromRPCErrorMessage(m)
}

def fromResponse(response: HttpMessage.Response): RPCException = {
val responseBodyCodec = new HttpResponseBodyCodec[Response]

response
.getHeader(HttpHeader.xAirframeRPCStatus)
.flatMap(x => Try(x.toInt).toOption) match {
case Some(rpcStatus) =>
try {
if (response.message.isEmpty) {
val status = RPCStatus.ofCode(rpcStatus)
status.newException(status.name)
} else {
val msgpack = responseBodyCodec.toMsgPack(response)
RPCException.fromMsgPack(msgpack)
}
} catch {
case e: Throwable =>
RPCStatus.ofCode(rpcStatus).newException(s"Failed to parse the RPC error details: ${e.getMessage}", e)
}
case None =>
RPCStatus.DATA_LOSS_I8.newException(s"Invalid RPC response: ${response}")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ trait SyncClient extends SyncClientCompat with HttpClientFactory[SyncClient] wit
* @param request
* @tparam Req
* @return
*
* @throws RPCException
* when RPC request fails
*/
def rpc[Req, Resp](method: RPCMethod, requestContent: Req): Resp = {
val request: Request =
Expand All @@ -154,7 +157,7 @@ trait SyncClient extends SyncClientCompat with HttpClientFactory[SyncClient] wit
ret.asInstanceOf[Resp]
} else {
// Parse the RPC error message
throw HttpClients.parseRPCException(response)
throw RPCException.fromResponse(response)
}
}
}
Expand Down Expand Up @@ -240,6 +243,16 @@ trait AsyncClient extends AsyncClientCompat with HttpClientFactory[AsyncClient]
}
}

/**
* @param method
* @param requestContent
* @tparam Req
* @tparam Resp
* @return
*
* @throws RPCException
* when RPC request fails
*/
def rpc[Req, Resp](
method: RPCMethod,
requestContent: Req
Expand All @@ -258,7 +271,7 @@ trait AsyncClient extends AsyncClientCompat with HttpClientFactory[AsyncClient]
val ret = HttpClients.parseRPCResponse(config, response, method.responseSurface)
ret.asInstanceOf[Resp]
} else {
throw HttpClients.parseRPCException(response)
throw RPCException.fromResponse(response)
}
}
}
Expand All @@ -276,7 +289,8 @@ object HttpClients extends LogSupport {
resp.getHeader(HttpHeader.xAirframeRPCStatus) match {
case Some(status) =>
// Throw RPCException if RPCStatus code is given
throw parseRPCException(e.response.toHttpResponse)
val ex = RPCException.fromResponse(e.response.toHttpResponse)
throw new HttpClientException(resp, ex.status.httpStatus, ex.message, ex)
case None =>
// Throw as is for known client exception
throw e
Expand Down Expand Up @@ -417,21 +431,4 @@ object HttpClients extends LogSupport {
}
}

private[http] def parseRPCException(response: Response): RPCException = {
response
.getHeader(HttpHeader.xAirframeRPCStatus)
.flatMap(x => Try(x.toInt).toOption) match {
case Some(rpcStatus) =>
try {
val msgpack = responseBodyCodec.toMsgPack(response)
RPCException.fromMsgPack(msgpack)
} catch {
case e: Throwable =>
RPCStatus.ofCode(rpcStatus).newException(s"Failed to parse the RPC error details: ${e.getMessage}", e)
}
case None =>
RPCStatus.DATA_LOSS_I8.newException(s"Invalid RPC response: ${response}")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ import wvlet.airframe.http.{HttpLogger, HttpMessage, HttpMultiMap, RPCContext, R
import wvlet.airframe.rx.Rx
import wvlet.log.LogSupport

class HttpServerLoggingFilter(httpLogger: HttpLogger) extends RxHttpFilter with LogSupport {
/**
* Report HTTP/RPC request/response logs to the given logger
*/
class RPCLoggingFilter(httpLogger: HttpLogger) extends RxHttpFilter with LogSupport {
private val excludeHeaders = HttpMultiMap.fromHeaderNames(httpLogger.config.excludeHeaders)

override def apply(request: HttpMessage.Request, next: RxHttpEndpoint): Rx[HttpMessage.Response] = {
Expand Down
Loading

0 comments on commit 5d86f11

Please sign in to comment.