diff --git a/Cargo.lock b/Cargo.lock index 4bd56bacf..553309cc8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1679,6 +1679,7 @@ dependencies = [ "hostname", "http 0.2.12", "http 1.1.0", + "http-body-util", "humantime", "hyper 1.4.1", "indoc", diff --git a/Cargo.toml b/Cargo.toml index 5c0d7a02f..900b1975a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -113,10 +113,12 @@ procfs = "0.15.1" criterion = "0.5.1" kuchikiki = "0.8" http02 = { version = "0.2.11", package = "http"} +http-body-util = "0.1.0" rand = "0.8" mockito = "1.0.2" test-case = "3.0.0" reqwest = { version = "0.12", features = ["blocking", "json"] } +tower = { version = "0.5.1", features = ["util"] } aws-smithy-types = "1.0.1" aws-smithy-runtime = {version = "1.0.1", features = ["client", "test-util"]} aws-smithy-http = "0.60.0" diff --git a/src/bin/cratesfyi.rs b/src/bin/cratesfyi.rs index 1e22c29f8..ffceddafb 100644 --- a/src/bin/cratesfyi.rs +++ b/src/bin/cratesfyi.rs @@ -914,6 +914,10 @@ impl Context for BinContext { }; } + async fn async_pool(&self) -> Result { + self.pool() + } + fn pool(&self) -> Result { Ok(self .pool diff --git a/src/context.rs b/src/context.rs index aa21c6f1a..948415af6 100644 --- a/src/context.rs +++ b/src/context.rs @@ -19,6 +19,7 @@ pub trait Context { async fn async_storage(&self) -> Result>; async fn cdn(&self) -> Result>; fn pool(&self) -> Result; + async fn async_pool(&self) -> Result; fn service_metrics(&self) -> Result>; fn instance_metrics(&self) -> Result>; fn index(&self) -> Result>; diff --git a/src/test/mod.rs b/src/test/mod.rs index 505f453da..89b7d4854 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -12,9 +12,10 @@ use crate::{ ServiceMetrics, }; use anyhow::Context as _; -use axum::async_trait; +use axum::{async_trait, body::Body, http::Request, response::Response as AxumResponse, Router}; use fn_error_context::context; use futures_util::{stream::TryStreamExt, FutureExt}; +use http_body_util::BodyExt; // for `collect` use once_cell::sync::OnceCell; use reqwest::{ blocking::{Client, ClientBuilder, RequestBuilder, Response}, @@ -27,6 +28,7 @@ use std::{ }; use tokio::runtime::{Builder, Runtime}; use tokio::sync::oneshot::Sender; +use tower::ServiceExt; use tracing::{debug, error, instrument, trace}; #[track_caller] @@ -126,7 +128,6 @@ pub(crate) fn assert_success(path: &str, web: &TestFrontend) -> Result<()> { assert!(status.is_success(), "failed to GET {path}: {status}"); Ok(()) } - /// Make sure that a URL returns a status code between 200-299, /// also check the cache-control headers. pub(crate) fn assert_success_cached( @@ -259,6 +260,96 @@ pub(crate) fn assert_redirect_cached( Ok(redirect_response) } +pub(crate) trait AxumResponseTestExt { + async fn text(self) -> String; +} + +impl AxumResponseTestExt for axum::response::Response { + async fn text(self) -> String { + String::from_utf8_lossy(&self.into_body().collect().await.unwrap().to_bytes()).to_string() + } +} + +pub(crate) trait AxumRouterTestExt { + async fn assert_success(&self, path: &str) -> Result<()>; + async fn get(&self, path: &str) -> Result; + async fn assert_redirect_common( + &self, + path: &str, + expected_target: &str, + ) -> Result; + async fn assert_redirect(&self, path: &str, expected_target: &str) -> Result; +} + +impl AxumRouterTestExt for axum::Router { + /// Make sure that a URL returns a status code between 200-299 + async fn assert_success(&self, path: &str) -> Result<()> { + let response = self + .clone() + .oneshot(Request::builder().uri(path).body(Body::empty()).unwrap()) + .await?; + + let status = response.status(); + assert!(status.is_success(), "failed to GET {path}: {status}"); + Ok(()) + } + /// simple `get` method + async fn get(&self, path: &str) -> Result { + Ok(self + .clone() + .oneshot(Request::builder().uri(path).body(Body::empty()).unwrap()) + .await?) + } + + async fn assert_redirect_common( + &self, + path: &str, + expected_target: &str, + ) -> Result { + let response = self.get(path).await?; + let status = response.status(); + if !status.is_redirection() { + anyhow::bail!("non-redirect from GET {path}: {status}"); + } + + let redirect_target = response + .headers() + .get("Location") + .context("missing 'Location' header")? + .to_str() + .context("non-ASCII redirect")?; + + // FIXME: not sure we need this + // if !expected_target.starts_with("http") { + // // TODO: Should be able to use Url::make_relative, + // // but https://github.com/servo/rust-url/issues/766 + // let base = format!("http://{}", web.server_addr()); + // redirect_target = redirect_target + // .strip_prefix(&base) + // .unwrap_or(redirect_target); + // } + + if redirect_target != expected_target { + anyhow::bail!("got redirect to {redirect_target}"); + } + + Ok(response) + } + + #[context("expected redirect from {path} to {expected_target}")] + async fn assert_redirect(&self, path: &str, expected_target: &str) -> Result { + let redirect_response = self.assert_redirect_common(path, expected_target).await?; + + let response = self.get(expected_target).await?; + let status = response.status(); + if !status.is_success() { + anyhow::bail!("failed to GET {expected_target}: {status}"); + } + + Ok(redirect_response) + } +} + pub(crate) struct TestEnvironment { build_queue: OnceCell>, async_build_queue: tokio::sync::OnceCell>, @@ -534,6 +625,13 @@ impl TestEnvironment { self.runtime().block_on(self.async_fake_release()) } + pub(crate) async fn web_app(&self) -> Router { + let template_data = Arc::new(TemplateData::new(1).unwrap()); + build_axum_app(self, template_data) + .await + .expect("could not build axum app") + } + pub(crate) async fn async_fake_release(&self) -> fakes::FakeRelease { fakes::FakeRelease::new( self.async_db().await, @@ -569,6 +667,10 @@ impl Context for TestEnvironment { Ok(TestEnvironment::cdn(self).await) } + async fn async_pool(&self) -> Result { + Ok(self.async_db().await.pool()) + } + fn pool(&self) -> Result { Ok(self.db().pool()) } @@ -734,10 +836,12 @@ impl TestFrontend { let (tx, rx) = tokio::sync::oneshot::channel::<()>(); debug!("building axum app"); - let axum_app = build_axum_app(context, template_data).expect("could not build axum app"); + let runtime = context.runtime().unwrap(); + let axum_app = runtime + .block_on(build_axum_app(context, template_data)) + .expect("could not build axum app"); let handle = thread::spawn({ - let runtime = context.runtime().unwrap(); move || { runtime.block_on(async { axum::serve(axum_listener, axum_app.into_make_service()) diff --git a/src/web/mod.rs b/src/web/mod.rs index d8eba0ad3..3c59f9853 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -393,16 +393,16 @@ async fn set_sentry_transaction_name_from_axum_route( next.run(request).await } -fn apply_middleware( +async fn apply_middleware( router: AxumRouter, context: &dyn Context, template_data: Option>, ) -> Result { let config = context.config()?; let has_templates = template_data.is_some(); - let runtime = context.runtime()?; - let async_storage = runtime.block_on(context.async_storage())?; - let build_queue = runtime.block_on(context.async_build_queue())?; + + let async_storage = context.async_storage().await?; + let build_queue = context.async_build_queue().await?; Ok(router.layer( ServiceBuilder::new() @@ -419,12 +419,11 @@ fn apply_middleware( .then_some(middleware::from_fn(log_timeouts_to_sentry)), )) .layer(option_layer(config.request_timeout.map(TimeoutLayer::new))) - .layer(Extension(context.pool()?)) + .layer(Extension(context.async_pool().await?)) .layer(Extension(build_queue)) .layer(Extension(context.service_metrics()?)) .layer(Extension(context.instance_metrics()?)) .layer(Extension(context.config()?)) - .layer(Extension(context.storage()?)) .layer(Extension(async_storage)) .layer(option_layer(template_data.map(Extension))) .layer(middleware::from_fn(csp::csp_middleware)) @@ -435,15 +434,15 @@ fn apply_middleware( )) } -pub(crate) fn build_axum_app( +pub(crate) async fn build_axum_app( context: &dyn Context, template_data: Arc, ) -> Result { - apply_middleware(routes::build_axum_routes(), context, Some(template_data)) + apply_middleware(routes::build_axum_routes(), context, Some(template_data)).await } -pub(crate) fn build_metrics_axum_app(context: &dyn Context) -> Result { - apply_middleware(routes::build_metric_routes(), context, None) +pub(crate) async fn build_metrics_axum_app(context: &dyn Context) -> Result { + apply_middleware(routes::build_metric_routes(), context, None).await } pub fn start_background_metrics_webserver( @@ -458,8 +457,10 @@ pub fn start_background_metrics_webserver( axum_addr.port() ); - let metrics_axum_app = build_metrics_axum_app(context)?.into_make_service(); let runtime = context.runtime()?; + let metrics_axum_app = runtime + .block_on(build_metrics_axum_app(context))? + .into_make_service(); runtime.spawn(async move { match tokio::net::TcpListener::bind(axum_addr) @@ -501,8 +502,10 @@ pub fn start_web_server(addr: Option, context: &dyn Context) -> Resu context.storage()?; context.repository_stats_updater()?; - let app = build_axum_app(context, template_data)?.into_make_service(); context.runtime()?.block_on(async { + let app = build_axum_app(context, template_data) + .await? + .into_make_service(); let listener = tokio::net::TcpListener::bind(axum_addr) .await .context("error binding socket for metrics web server")?; diff --git a/src/web/sitemap.rs b/src/web/sitemap.rs index 5a06d5676..ec82e2cd1 100644 --- a/src/web/sitemap.rs +++ b/src/web/sitemap.rs @@ -203,28 +203,28 @@ pub(crate) async fn about_handler(subpage: Option>) -> AxumResult