diff --git a/examples/examples/http_proxy_middleware.rs b/examples/examples/http_proxy_middleware.rs index 245dcea6ea..bcbc19e698 100644 --- a/examples/examples/http_proxy_middleware.rs +++ b/examples/examples/http_proxy_middleware.rs @@ -87,7 +87,7 @@ async fn run_server() -> anyhow::Result { // Custom tower service to handle the RPC requests let service_builder = tower::ServiceBuilder::new() // Proxy `GET /health` requests to internal `system_health` method. - .layer(ProxyGetRequestLayer::new("/health", "system_health")?) + .layer(ProxyGetRequestLayer::new([("/health", "system_health")])?) .timeout(Duration::from_secs(2)); let server = diff --git a/server/Cargo.toml b/server/Cargo.toml index 0d620b7ecd..a924b09f1d 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -15,27 +15,28 @@ publish = true [dependencies] futures-util = { version = "0.3.14", default-features = false, features = ["io", "async-await-macro"] } -jsonrpsee-types = { workspace = true } +http = "1" +http-body = "1" +http-body-util = "0.1.0" +hyper = { version = "1.3", features = ["server", "http1", "http2"] } +hyper-util = { version = "0.1", features = ["tokio", "service", "tokio", "server-auto"] } jsonrpsee-core = { workspace = true, features = ["server", "http-helpers"] } -tracing = "0.1.34" +jsonrpsee-types = { workspace = true } +pin-project = "1.1.3" +route-recognizer = "0.3.1" +rustc-hash = "2.0.0" serde = "1" serde_json = { version = "1", features = ["raw_value"] } soketto = { version = "0.8", features = ["http"] } +thiserror = "2" tokio = { version = "1.23.1", features = ["net", "rt-multi-thread", "macros", "time"] } tokio-util = { version = "0.7", features = ["compat"] } tokio-stream = { version = "0.1.7", features = ["sync"] } -hyper = { version = "1.3", features = ["server", "http1", "http2"] } -hyper-util = { version = "0.1", features = ["tokio", "service", "tokio", "server-auto"] } -http = "1" -http-body = "1" -http-body-util = "0.1.0" tower = { workspace = true, features = ["util"] } -thiserror = "2" -route-recognizer = "0.3.1" -pin-project = "1.1.3" +tracing = "0.1.34" [dev-dependencies] jsonrpsee-test-utils = { path = "../test-utils" } -tracing-subscriber = { version = "0.3.3", features = ["env-filter"] } -tower = { workspace = true, features = ["timeout"] } socket2 = "0.5.1" +tower = { workspace = true, features = ["timeout"] } +tracing-subscriber = { version = "0.3.3", features = ["env-filter"] } diff --git a/server/src/middleware/http/proxy_get_request.rs b/server/src/middleware/http/proxy_get_request.rs index 9efdcbecab..0b74012c50 100644 --- a/server/src/middleware/http/proxy_get_request.rs +++ b/server/src/middleware/http/proxy_get_request.rs @@ -26,9 +26,10 @@ //! Middleware that proxies requests at a specified URI to internal //! RPC method calls. + use crate::transport::http; use crate::{HttpBody, HttpRequest, HttpResponse}; - +use futures_util::{FutureExt, TryFutureExt}; use http_body_util::BodyExt; use hyper::body::Bytes; use hyper::header::{ACCEPT, CONTENT_TYPE}; @@ -36,6 +37,7 @@ use hyper::http::HeaderValue; use hyper::{Method, Uri}; use jsonrpsee_core::BoxError; use jsonrpsee_types::{Id, RequestSer}; +use rustc_hash::FxHashMap; use std::future::Future; use std::pin::Pin; use std::sync::Arc; @@ -53,29 +55,40 @@ pub struct InvalidPath(String); /// See [`ProxyGetRequest`] for more details. #[derive(Debug, Clone)] pub struct ProxyGetRequestLayer { - path: String, - method: String, + // path => method mapping + methods: Arc>, } impl ProxyGetRequestLayer { /// Creates a new [`ProxyGetRequestLayer`]. /// - /// See [`ProxyGetRequest`] for more details. - pub fn new(path: impl Into, method: impl Into) -> Result { - let path = path.into(); - if !path.starts_with('/') { - return Err(InvalidPath(path)); + /// The request `GET /path` is redirected to the provided method. + /// Fails if the path does not start with `/`. + pub fn new(methods: impl IntoIterator) -> Result + where + P: Into, + M: Into, + { + let methods = methods + .into_iter() + .map(|(path, method)| (path.into(), method.into())) + .collect::>(); + + for path in methods.keys() { + if !path.starts_with('/') { + return Err(InvalidPath(path.clone())); + } } - Ok(Self { path, method: method.into() }) + Ok(Self { methods: Arc::new(methods) }) } } + impl Layer for ProxyGetRequestLayer { type Service = ProxyGetRequest; fn layer(&self, inner: S) -> Self::Service { - ProxyGetRequest::new(inner, &self.path, &self.method) - .expect("Path already validated in ProxyGetRequestLayer; qed") + ProxyGetRequest { inner, methods: self.methods.clone() } } } @@ -94,22 +107,8 @@ impl Layer for ProxyGetRequestLayer { #[derive(Debug, Clone)] pub struct ProxyGetRequest { inner: S, - path: Arc, - method: Arc, -} - -impl ProxyGetRequest { - /// Creates a new [`ProxyGetRequest`]. - /// - /// The request `GET /path` is redirected to the provided method. - /// Fails if the path does not start with `/`. - pub fn new(inner: S, path: &str, method: &str) -> Result { - if !path.starts_with('/') { - return Err(InvalidPath(path.to_string())); - } - - Ok(Self { inner, path: Arc::from(path), method: Arc::from(method) }) - } + // path => method mapping + methods: Arc>, } impl Service> for ProxyGetRequest @@ -132,65 +131,60 @@ where } fn call(&mut self, mut req: HttpRequest) -> Self::Future { - let modify = self.path.as_ref() == req.uri() && req.method() == Method::GET; - - // Proxy the request to the appropriate method call. - let req = if modify { - // RPC methods are accessed with `POST`. - *req.method_mut() = Method::POST; - // Precautionary remove the URI. - *req.uri_mut() = Uri::from_static("/"); - - // Requests must have the following headers: - req.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); - req.headers_mut().insert(ACCEPT, HeaderValue::from_static("application/json")); - - // Adjust the body to reflect the method call. - let bytes = serde_json::to_vec(&RequestSer::borrowed(&Id::Number(0), &self.method, None)) - .expect("Valid request; qed"); - - let body = HttpBody::from(bytes); - - req.map(|_| body) - } else { - req.map(HttpBody::new) - }; - - // Call the inner service and get a future that resolves to the response. - let fut = self.inner.call(req); - - // Adjust the response if needed. - let res_fut = async move { - let res = fut.await.map_err(|err| err.into())?; - - // Nothing to modify: return the response as is. - if !modify { - return Ok(res); + let path = req.uri().path(); + let method = self.methods.get(path); + + match (method, req.method()) { + // Proxy the `GET /path` request to the appropriate method call. + (Some(method), &Method::GET) => { + // RPC methods are accessed with `POST`. + *req.method_mut() = Method::POST; + // Precautionary remove the URI. + *req.uri_mut() = Uri::from_static("/"); + // Requests must have the following headers: + req.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + req.headers_mut().insert(ACCEPT, HeaderValue::from_static("application/json")); + + // Adjust the body to reflect the method call. + let bytes = serde_json::to_vec(&RequestSer::borrowed(&Id::Number(0), method, None)) + .expect("Valid request; qed"); + let req = req.map(|_| HttpBody::from(bytes)); + + // Call the inner service and get a future that resolves to the response. + let fut = self.inner.call(req); + + async move { + let res = fut.await.map_err(Into::into)?; + + let mut body = http_body_util::BodyStream::new(res.into_body()); + let mut bytes = Vec::new(); + + while let Some(frame) = body.frame().await { + let data = frame?.into_data().map_err(|e| format!("{e:?}"))?; + bytes.extend(data); + } + + #[derive(serde::Deserialize)] + struct RpcPayload<'a> { + #[serde(borrow)] + result: &'a serde_json::value::RawValue, + } + + let response = if let Ok(payload) = serde_json::from_slice::(&bytes) { + http::response::ok_response(payload.result.to_string()) + } else { + http::response::internal_error() + }; + + Ok(response) + } + .boxed() } - - let mut body = http_body_util::BodyStream::new(res.into_body()); - let mut bytes = Vec::new(); - - while let Some(frame) = body.frame().await { - let data = frame?.into_data().map_err(|e| format!("{e:?}"))?; - bytes.extend(data); - } - - #[derive(serde::Deserialize, Debug)] - struct RpcPayload<'a> { - #[serde(borrow)] - result: &'a serde_json::value::RawValue, + // Call the inner service and get a future that resolves to the response. + _ => { + let req = req.map(HttpBody::new); + self.inner.call(req).map_err(Into::into).boxed() } - - let response = if let Ok(payload) = serde_json::from_slice::(&bytes) { - http::response::ok_response(payload.result.to_string()) - } else { - http::response::internal_error() - }; - - Ok(response) - }; - - Box::pin(res_fut) + } } } diff --git a/server/src/server.rs b/server/src/server.rs index 77ffcedf50..e06e837fd3 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -746,8 +746,8 @@ impl Builder { /// fn run_server() -> ServerHandle { /// let (stop_handle, server_handle) = stop_channel(); /// let svc_builder = jsonrpsee_server::Server::builder() - /// .set_config(ServerConfig::builder().max_connections(33).build()) - /// .to_service_builder(); + /// .set_config(ServerConfig::builder().max_connections(33).build()) + /// .to_service_builder(); /// let methods = Methods::new(); /// let stop_handle = stop_handle.clone(); /// diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 40cd17baf7..86481a7ab7 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -279,7 +279,7 @@ pub async fn server_with_health_api() -> (SocketAddr, ServerHandle) { pub async fn server_with_cors(cors: CorsLayer) -> (SocketAddr, ServerHandle) { let middleware = tower::ServiceBuilder::new() // Proxy `GET /health` requests to internal `system_health` method. - .layer(ProxyGetRequestLayer::new("/health", "system_health").unwrap()) + .layer(ProxyGetRequestLayer::new([("/health", "system_health")]).unwrap()) // Add `CORS` layer. .layer(cors);