Skip to content

Commit

Permalink
refactor(server): make ProxyGetRequestLayer http middleware support…
Browse files Browse the repository at this point in the history
… multiple path-method pairs

**BREAKING CHNAGES**:

- change `ProxyGetRequestLayer::new` method to support multiple path-method pairs
- remove `ProxyGetRequest::new` method, which is useless
  • Loading branch information
koushiro committed Nov 7, 2024
1 parent 2a6406b commit e5ada32
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 101 deletions.
2 changes: 1 addition & 1 deletion examples/examples/http_proxy_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async fn run_server() -> anyhow::Result<SocketAddr> {
// Custom tower service to handle the RPC requests
let service_builder = tower::ServiceBuilder::new()
// Proxy `GET /health` requests to internal `system_health` method.
.layer(ProxyGetRequestLayer::new("/health", "system_health")?)
.layer(ProxyGetRequestLayer::new([("/health", "system_health")])?)
.timeout(Duration::from_secs(2));

let server =
Expand Down
25 changes: 13 additions & 12 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,28 @@ publish = true

[dependencies]
futures-util = { version = "0.3.14", default-features = false, features = ["io", "async-await-macro"] }
jsonrpsee-types = { workspace = true }
http = "1"
http-body = "1"
http-body-util = "0.1.0"
hyper = { version = "1.3", features = ["server", "http1", "http2"] }
hyper-util = { version = "0.1", features = ["tokio", "service", "tokio", "server-auto"] }
jsonrpsee-core = { workspace = true, features = ["server", "http-helpers"] }
tracing = "0.1.34"
jsonrpsee-types = { workspace = true }
pin-project = "1.1.3"
route-recognizer = "0.3.1"
rustc-hash = "2.0.0"
serde = "1"
serde_json = { version = "1", features = ["raw_value"] }
soketto = { version = "0.8", features = ["http"] }
thiserror = "2"
tokio = { version = "1.23.1", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.7", features = ["compat"] }
tokio-stream = { version = "0.1.7", features = ["sync"] }
hyper = { version = "1.3", features = ["server", "http1", "http2"] }
hyper-util = { version = "0.1", features = ["tokio", "service", "tokio", "server-auto"] }
http = "1"
http-body = "1"
http-body-util = "0.1.0"
tower = { workspace = true, features = ["util"] }
thiserror = "2"
route-recognizer = "0.3.1"
pin-project = "1.1.3"
tracing = "0.1.34"

[dev-dependencies]
jsonrpsee-test-utils = { path = "../test-utils" }
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
tower = { workspace = true, features = ["timeout"] }
socket2 = "0.5.1"
tower = { workspace = true, features = ["timeout"] }
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
164 changes: 79 additions & 85 deletions server/src/middleware/http/proxy_get_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,18 @@

//! Middleware that proxies requests at a specified URI to internal
//! RPC method calls.
use crate::transport::http;
use crate::{HttpBody, HttpRequest, HttpResponse};

use futures_util::{FutureExt, TryFutureExt};
use http_body_util::BodyExt;
use hyper::body::Bytes;
use hyper::header::{ACCEPT, CONTENT_TYPE};
use hyper::http::HeaderValue;
use hyper::{Method, Uri};
use jsonrpsee_core::BoxError;
use jsonrpsee_types::{Id, RequestSer};
use rustc_hash::FxHashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
Expand All @@ -53,29 +55,40 @@ pub struct InvalidPath(String);
/// See [`ProxyGetRequest`] for more details.
#[derive(Debug, Clone)]
pub struct ProxyGetRequestLayer {
path: String,
method: String,
// path => method mapping
methods: Arc<FxHashMap<String, String>>,
}

impl ProxyGetRequestLayer {
/// Creates a new [`ProxyGetRequestLayer`].
///
/// See [`ProxyGetRequest`] for more details.
pub fn new(path: impl Into<String>, method: impl Into<String>) -> Result<Self, InvalidPath> {
let path = path.into();
if !path.starts_with('/') {
return Err(InvalidPath(path));
/// The request `GET /path` is redirected to the provided method.
/// Fails if the path does not start with `/`.
pub fn new<P, M>(methods: impl IntoIterator<Item = (P, M)>) -> Result<Self, InvalidPath>
where
P: Into<String>,
M: Into<String>,
{
let methods = methods
.into_iter()
.map(|(path, method)| (path.into(), method.into()))
.collect::<FxHashMap<String, String>>();

for path in methods.keys() {
if !path.starts_with('/') {
return Err(InvalidPath(path.clone()));
}
}

Ok(Self { path, method: method.into() })
Ok(Self { methods: Arc::new(methods) })
}
}

impl<S> Layer<S> for ProxyGetRequestLayer {
type Service = ProxyGetRequest<S>;

fn layer(&self, inner: S) -> Self::Service {
ProxyGetRequest::new(inner, &self.path, &self.method)
.expect("Path already validated in ProxyGetRequestLayer; qed")
ProxyGetRequest { inner, methods: self.methods.clone() }
}
}

Expand All @@ -94,22 +107,8 @@ impl<S> Layer<S> for ProxyGetRequestLayer {
#[derive(Debug, Clone)]
pub struct ProxyGetRequest<S> {
inner: S,
path: Arc<str>,
method: Arc<str>,
}

impl<S> ProxyGetRequest<S> {
/// Creates a new [`ProxyGetRequest`].
///
/// The request `GET /path` is redirected to the provided method.
/// Fails if the path does not start with `/`.
pub fn new(inner: S, path: &str, method: &str) -> Result<Self, InvalidPath> {
if !path.starts_with('/') {
return Err(InvalidPath(path.to_string()));
}

Ok(Self { inner, path: Arc::from(path), method: Arc::from(method) })
}
// path => method mapping
methods: Arc<FxHashMap<String, String>>,
}

impl<S, B> Service<HttpRequest<B>> for ProxyGetRequest<S>
Expand All @@ -132,65 +131,60 @@ where
}

fn call(&mut self, mut req: HttpRequest<B>) -> Self::Future {
let modify = self.path.as_ref() == req.uri() && req.method() == Method::GET;

// Proxy the request to the appropriate method call.
let req = if modify {
// RPC methods are accessed with `POST`.
*req.method_mut() = Method::POST;
// Precautionary remove the URI.
*req.uri_mut() = Uri::from_static("/");

// Requests must have the following headers:
req.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
req.headers_mut().insert(ACCEPT, HeaderValue::from_static("application/json"));

// Adjust the body to reflect the method call.
let bytes = serde_json::to_vec(&RequestSer::borrowed(&Id::Number(0), &self.method, None))
.expect("Valid request; qed");

let body = HttpBody::from(bytes);

req.map(|_| body)
} else {
req.map(HttpBody::new)
};

// Call the inner service and get a future that resolves to the response.
let fut = self.inner.call(req);

// Adjust the response if needed.
let res_fut = async move {
let res = fut.await.map_err(|err| err.into())?;

// Nothing to modify: return the response as is.
if !modify {
return Ok(res);
let path = req.uri().path();
let method = self.methods.get(path);

match (method, req.method()) {
// Proxy the `GET /path` request to the appropriate method call.
(Some(method), &Method::GET) => {
// RPC methods are accessed with `POST`.
*req.method_mut() = Method::POST;
// Precautionary remove the URI.
*req.uri_mut() = Uri::from_static("/");
// Requests must have the following headers:
req.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
req.headers_mut().insert(ACCEPT, HeaderValue::from_static("application/json"));

// Adjust the body to reflect the method call.
let bytes = serde_json::to_vec(&RequestSer::borrowed(&Id::Number(0), method, None))
.expect("Valid request; qed");
let req = req.map(|_| HttpBody::from(bytes));

// Call the inner service and get a future that resolves to the response.
let fut = self.inner.call(req);

async move {
let res = fut.await.map_err(Into::into)?;

let mut body = http_body_util::BodyStream::new(res.into_body());
let mut bytes = Vec::new();

while let Some(frame) = body.frame().await {
let data = frame?.into_data().map_err(|e| format!("{e:?}"))?;
bytes.extend(data);
}

#[derive(serde::Deserialize)]
struct RpcPayload<'a> {
#[serde(borrow)]
result: &'a serde_json::value::RawValue,
}

let response = if let Ok(payload) = serde_json::from_slice::<RpcPayload>(&bytes) {
http::response::ok_response(payload.result.to_string())
} else {
http::response::internal_error()
};

Ok(response)
}
.boxed()
}

let mut body = http_body_util::BodyStream::new(res.into_body());
let mut bytes = Vec::new();

while let Some(frame) = body.frame().await {
let data = frame?.into_data().map_err(|e| format!("{e:?}"))?;
bytes.extend(data);
}

#[derive(serde::Deserialize, Debug)]
struct RpcPayload<'a> {
#[serde(borrow)]
result: &'a serde_json::value::RawValue,
// Call the inner service and get a future that resolves to the response.
_ => {
let req = req.map(HttpBody::new);
self.inner.call(req).map_err(Into::into).boxed()
}

let response = if let Ok(payload) = serde_json::from_slice::<RpcPayload>(&bytes) {
http::response::ok_response(payload.result.to_string())
} else {
http::response::internal_error()
};

Ok(response)
};

Box::pin(res_fut)
}
}
}
4 changes: 2 additions & 2 deletions server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,8 @@ impl<HttpMiddleware, RpcMiddleware> Builder<HttpMiddleware, RpcMiddleware> {
/// fn run_server() -> ServerHandle {
/// let (stop_handle, server_handle) = stop_channel();
/// let svc_builder = jsonrpsee_server::Server::builder()
/// .set_config(ServerConfig::builder().max_connections(33).build())
/// .to_service_builder();
/// .set_config(ServerConfig::builder().max_connections(33).build())
/// .to_service_builder();
/// let methods = Methods::new();
/// let stop_handle = stop_handle.clone();
///
Expand Down
2 changes: 1 addition & 1 deletion tests/tests/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ pub async fn server_with_health_api() -> (SocketAddr, ServerHandle) {
pub async fn server_with_cors(cors: CorsLayer) -> (SocketAddr, ServerHandle) {
let middleware = tower::ServiceBuilder::new()
// Proxy `GET /health` requests to internal `system_health` method.
.layer(ProxyGetRequestLayer::new("/health", "system_health").unwrap())
.layer(ProxyGetRequestLayer::new([("/health", "system_health")]).unwrap())
// Add `CORS` layer.
.layer(cors);

Expand Down

0 comments on commit e5ada32

Please sign in to comment.