From d7fcb2d7d157efc8b5bc6b6ca599592b7f9525be Mon Sep 17 00:00:00 2001 From: Luis Soares <80909424+luishsr@users.noreply.github.com> Date: Wed, 11 Oct 2023 09:01:10 -0300 Subject: [PATCH] Added Dynamic Service Registry --- server/src/main.rs | 153 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 143 insertions(+), 10 deletions(-) diff --git a/server/src/main.rs b/server/src/main.rs index 117456d..894f72a 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,16 +1,77 @@ -use hyper::{Body, Request, Response, Server, StatusCode}; -use hyper::service::{make_service_fn, service_fn}; -use hyper::client::HttpConnector; use hyper_tls::HttpsConnector; use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use std::net::SocketAddr; +use std::time::Duration; +use hyper::{Body, Request, Response, Server, StatusCode}; +use hyper::client::HttpConnector; use hyper::server::conn::AddrStream; +use hyper::service::{make_service_fn, service_fn}; use serde_json::json; use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm, errors::ErrorKind}; +use serde::{Deserialize, Serialize}; const SECRET_KEY: &'static str = "secret_key"; // Use a stronger secret in a real-world scenario +#[derive(Debug, Serialize, Deserialize)] +struct ServiceConfig { + name: String, + address: String, +} + +struct ServiceRegistry { + services: Arc>>, // Service Name -> Service Address (URL/URI) +} + +impl ServiceRegistry { + fn new() -> Self { + ServiceRegistry { + services: Arc::new(RwLock::new(HashMap::new())), + } + } + + fn register(&self, name: String, address: String) { + let mut services = self.services.write().unwrap(); + services.insert(name, address); + } + + fn deregister(&self, name: &str) { + let mut services = self.services.write().unwrap(); + services.remove(name); + } + + fn get_address(&self, name: &str) -> Option { + let services = self.services.read().unwrap(); + services.get(name).cloned() + } +} + +async fn register_service(req: Request, registry: Arc) -> Result, hyper::Error> { + let body_bytes = hyper::body::to_bytes(req.into_body()).await?; + let body_str = String::from_utf8_lossy(&body_bytes); + let parts: Vec<&str> = body_str.split(',').collect(); + + if parts.len() != 2 { + return Ok(Response::new(Body::from("Invalid format. Expecting 'name,address'"))); + } + + let name = parts[0].to_string(); + let address = parts[1].to_string(); + + registry.register(name, address); + + Ok(Response::new(Body::from("Service registered successfully"))) +} + +async fn deregister_service(req: Request, registry: Arc) -> Result, hyper::Error> { + let body_bytes = hyper::body::to_bytes(req.into_body()).await?; + let name = String::from_utf8_lossy(&body_bytes).to_string(); + + registry.deregister(&name); + + Ok(Response::new(Body::from("Service deregistered successfully"))) +} + struct RateLimiter { visitors: Arc>>, } @@ -44,6 +105,7 @@ fn authenticate(token: &str) -> bool { match decode::(&token, &DecodingKey::from_secret(SECRET_KEY.as_ref()), &validation) { Ok(_data) => true, Err(err) => { + eprintln!("JWT Decoding error: {:?}", err); match *err.kind() { ErrorKind::InvalidToken => false, // token is invalid _ => false @@ -52,7 +114,7 @@ fn authenticate(token: &str) -> bool { } } -async fn service_handler(req: Request, client: &hyper::Client>) -> Result, hyper::Error> { +async fn service_handler(req: Request, client: &hyper::Client>) -> Result, hyper::Error>{ // Example of request transformation: Adding a custom header let req = Request::builder() .method(req.method()) @@ -62,6 +124,7 @@ async fn service_handler(req: Request, client: &hyper::Client, client: &hyper::Client, remote_addr: SocketAddr, rate_limiter: Arc, client: Arc>>) -> Result, hyper::Error> { +/*async fn handle_request(req: Request, rate_limiter: Arc, client: Arc>>, service_registry: &ServiceRegistry) -> Result, hyper::Error> {*/ +async fn handle_request( + mut req: Request, + remote_addr: SocketAddr, + rate_limiter: Arc, + client: Arc>>, + registry: Arc, +) -> Result, hyper::Error> { if !rate_limiter.allow(remote_addr) { return Ok(Response::builder() @@ -112,10 +182,61 @@ async fn handle_request(req: Request, remote_addr: SocketAddr, rate_limite } } - // Send the request to the service handler - service_handler(req, &client).await + let path = req.uri().path(); + + // Let's assume the first path segment is the service name. + let parts: Vec<&str> = path.split('/').collect(); + if parts.len() < 2 { + return Ok(Response::new(Body::from("Invalid request URI"))); + } + + let service_name = parts[1]; + + match registry.get_address(service_name) { + Some(address) => { + // Here, use the address to forward the request. + + // Create a new URI based on the resolved address + let mut address = address; + if !address.starts_with("http://") && !address.starts_with("https://") { + address = format!("http://{}", address); + } + let forward_uri = format!("{}{}", address, req.uri().path_and_query().map_or("", |x| x.as_str())); + + if let Ok(uri) = forward_uri.parse() { + *req.uri_mut() = uri; + } else { + return Ok(Response::new(Body::from("Invalid service URI"))); + } + + // Send the request to the service handler + service_handler(req, &client).await + }, + None => return Ok(Response::new(Body::from("Service not found"))), + } + } +async fn router( + req: Request, + remote_addr: SocketAddr, + rate_limiter: Arc, + client: Arc>>, + registry: Arc, +) -> Result, hyper::Error> { + let path = req.uri().path(); + + if path == "/register_service" { + return register_service(req, Arc::clone(®istry)).await; + } + + if path == "/deregister_service" { + return deregister_service(req, Arc::clone(®istry)).await; + } + + // Handle other requests using the previously defined handler + handle_request(req, remote_addr, rate_limiter, client, registry).await +} #[tokio::main] async fn main() { let rate_limiter = Arc::new(RateLimiter::new()); @@ -123,17 +244,29 @@ async fn main() { let client = hyper::Client::builder().build::<_, hyper::Body>(https); let client = Arc::new(client); + let registry = Arc::new(ServiceRegistry::new()); + + // Handle Requests let make_svc = make_service_fn(move |conn: &AddrStream| { let remote_addr = conn.remote_addr(); let rate_limiter = Arc::clone(&rate_limiter); let client = Arc::clone(&client); + let registry_clone = Arc::clone(®istry); + + let service = service_fn(move |req| { + router(req, remote_addr, Arc::clone(&rate_limiter), Arc::clone(&client), Arc::clone(®istry_clone)) + }); - let service = service_fn(move |req| handle_request(req, remote_addr, Arc::clone(&rate_limiter), Arc::clone(&client))); async { Ok::<_, hyper::Error>(service) } }); + let addr = ([127, 0, 0, 1], 8080).into(); - let server = Server::bind(&addr).serve(make_svc); + + let server = Server::bind(&addr) + .http1_keepalive(true) + .http2_keep_alive_timeout(Duration::from_secs(120)) + .serve(make_svc); println!("API Gateway running on http://{}", addr);