Skip to content

Commit

Permalink
Backport impl tower::Layer for Extension (#805)
Browse files Browse the repository at this point in the history
* Implement `tower::Layer` for `Extension` (#801)

* Implement `tower::Layer` for `Extension`

* changelog

* Deprecate `AddExtensionLayer`

* changelog

* Add missing #[allow(deprecated)]
  • Loading branch information
davidpdrsn authored Mar 1, 2022
1 parent 5cb6e6d commit ee387ae
Show file tree
Hide file tree
Showing 16 changed files with 67 additions and 43 deletions.
6 changes: 5 additions & 1 deletion axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- None.
- **added:** Implement `tower::Layer` for `Extension` ([#801])
- **changed:** Deprecate `AddExtensionLayer`. Use `Extension` instead ([#805])

[#801]: https://github.com/tokio-rs/axum/pull/801
[#805]: https://github.com/tokio-rs/axum/pull/805

# 0.4.6 (22. February, 2022)

Expand Down
15 changes: 13 additions & 2 deletions axum/src/add_extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,23 @@ use tower_service::Service;
///
/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
#[derive(Clone, Copy, Debug)]
#[deprecated(
since = "0.4.7",
note = "Use `axum::Extension` instead. It implements `tower::Layer`"
)]
pub struct AddExtensionLayer<T> {
value: T,
}

#[allow(deprecated)]
impl<T> AddExtensionLayer<T> {
/// Create a new [`AddExtensionLayer`].
pub fn new(value: T) -> Self {
Self { value }
}
}

#[allow(deprecated)]
impl<S, T> Layer<S> for AddExtensionLayer<T>
where
T: Clone,
Expand All @@ -45,12 +51,17 @@ where
/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
#[derive(Clone, Copy, Debug)]
pub struct AddExtension<S, T> {
inner: S,
value: T,
pub(crate) inner: S,
pub(crate) value: T,
}

impl<S, T> AddExtension<S, T> {
/// Create a new [`AddExtensionLayer`].
#[deprecated(
since = "0.4.7",
note = "Use `axum::Extension` instead. It implements `tower::Layer`"
)]
#[allow(deprecated)]
pub fn layer(value: T) -> AddExtensionLayer<T> {
AddExtensionLayer::new(value)
}
Expand Down
4 changes: 2 additions & 2 deletions axum/src/extract/connect_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
use super::{Extension, FromRequest, RequestParts};
use crate::{AddExtension, AddExtensionLayer};
use crate::AddExtension;
use async_trait::async_trait;
use hyper::server::conn::AddrStream;
use std::{
Expand Down Expand Up @@ -104,7 +104,7 @@ where

fn call(&mut self, target: T) -> Self::Future {
let connect_info = ConnectInfo(C::connect_info(target));
let svc = AddExtensionLayer::new(connect_info).layer(self.svc.clone());
let svc = Extension(connect_info).layer(self.svc.clone());
ResponseFuture::new(ready(Ok(svc)))
}
}
Expand Down
19 changes: 16 additions & 3 deletions axum/src/extract/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use std::ops::Deref;
///
/// ```rust,no_run
/// use axum::{
/// AddExtensionLayer,
/// extract::Extension,
/// routing::get,
/// Router,
Expand All @@ -31,7 +30,7 @@ use std::ops::Deref;
/// let app = Router::new().route("/", get(handler))
/// // Add middleware that inserts the state into all incoming request's
/// // extensions.
/// .layer(AddExtensionLayer::new(state));
/// .layer(Extension(state));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
Expand All @@ -57,7 +56,7 @@ where
.get::<T>()
.ok_or_else(|| {
MissingExtension::from_err(format!(
"Extension of type `{}` was not found. Perhaps you forgot to add it?",
"Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::extract::Extension`.",
std::any::type_name::<T>()
))
})
Expand All @@ -74,3 +73,17 @@ impl<T> Deref for Extension<T> {
&self.0
}
}

impl<S, T> tower_layer::Layer<S> for Extension<T>
where
T: Clone + Send + Sync + 'static,
{
type Service = crate::AddExtension<S, T>;

fn layer(&self, inner: S) -> Self::Service {
crate::AddExtension {
inner,
value: self.0.clone(),
}
}
}
9 changes: 3 additions & 6 deletions axum/src/extract/request_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,10 @@ where
mod tests {
use crate::{
body::Body,
extract::Extension,
routing::{get, post},
test_helpers::*,
AddExtensionLayer, Router,
Router,
};
use http::{Method, Request, StatusCode};

Expand Down Expand Up @@ -253,11 +254,7 @@ mod tests {
parts.extensions.get::<Ext>().unwrap();
}

let client = TestClient::new(
Router::new()
.route("/", get(handler))
.layer(AddExtensionLayer::new(Ext)),
);
let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext)));

let res = client.get("/").header("x-foo", "123").send().await;
assert_eq!(res.status(), StatusCode::OK);
Expand Down
14 changes: 7 additions & 7 deletions axum/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,11 @@
//!
//! ## Using request extensions
//!
//! The easiest way to extract state in handlers is using [`AddExtension`]
//! middleware (applied with [`AddExtensionLayer`]) and the
//! [`Extension`](crate::extract::Extension) extractor:
//! The easiest way to extract state in handlers is using [`Extension`](crate::extract::Extension)
//! as layer and extractor:
//!
//! ```rust,no_run
//! use axum::{
//! AddExtensionLayer,
//! extract::Extension,
//! routing::get,
//! Router,
Expand All @@ -194,7 +192,7 @@
//!
//! let app = Router::new()
//! .route("/", get(handler))
//! .layer(AddExtensionLayer::new(shared_state));
//! .layer(Extension(shared_state));
//!
//! async fn handler(
//! Extension(state): Extension<Arc<State>>,
Expand All @@ -217,7 +215,6 @@
//!
//! ```rust,no_run
//! use axum::{
//! AddExtensionLayer,
//! Json,
//! extract::{Extension, Path},
//! routing::{get, post},
Expand Down Expand Up @@ -408,7 +405,10 @@ pub mod routing;
#[cfg(test)]
mod test_helpers;

pub use add_extension::{AddExtension, AddExtensionLayer};
pub use add_extension::AddExtension;
#[allow(deprecated)]
pub use add_extension::AddExtensionLayer;

#[doc(no_inline)]
pub use async_trait::async_trait;
#[cfg(feature = "headers")]
Expand Down
4 changes: 2 additions & 2 deletions axum/src/middleware/from_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ use tower_service::Service;
/// ```rust
/// use axum::{
/// Router,
/// extract::Extension,
/// http::{Request, StatusCode},
/// routing::get,
/// response::IntoResponse,
/// middleware::{self, Next},
/// AddExtensionLayer,
/// };
/// use tower::ServiceBuilder;
///
Expand All @@ -129,7 +129,7 @@ use tower_service::Service;
/// .route("/", get(|| async { /* ... */ }))
/// .layer(
/// ServiceBuilder::new()
/// .layer(AddExtensionLayer::new(state))
/// .layer(Extension(state))
/// .layer(middleware::from_fn(my_middleware)),
/// );
/// # let app: Router = app;
Expand Down
4 changes: 2 additions & 2 deletions examples/async-graphql/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use axum::{
extract::Extension,
response::{Html, IntoResponse},
routing::get,
AddExtensionLayer, Json, Router,
Json, Router,
};
use starwars::{QueryRoot, StarWars, StarWarsSchema};

Expand All @@ -28,7 +28,7 @@ async fn main() {

let app = Router::new()
.route("/", get(graphql_playground).post(graphql_handler))
.layer(AddExtensionLayer::new(schema));
.layer(Extension(schema));

println!("Playground: http://localhost:3000");

Expand Down
4 changes: 2 additions & 2 deletions examples/chat/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use axum::{
},
response::{Html, IntoResponse},
routing::get,
AddExtensionLayer, Router,
Router,
};
use futures::{sink::SinkExt, stream::StreamExt};
use std::{
Expand All @@ -39,7 +39,7 @@ async fn main() {
let app = Router::new()
.route("/", get(index))
.route("/websocket", get(websocket_handler))
.layer(AddExtensionLayer::new(app_state));
.layer(Extension(app_state));

let addr = SocketAddr::from(([127, 0, 0, 1], 3000));

Expand Down
4 changes: 2 additions & 2 deletions examples/error-handling-and-dependency-injection/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use axum::{
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
AddExtensionLayer, Json, Router,
Json, Router,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down Expand Up @@ -41,7 +41,7 @@ async fn main() {
.route("/users", post(users_create))
// Add our `user_repo` to all request's extensions so handlers can access
// it.
.layer(AddExtensionLayer::new(user_repo));
.layer(Extension(user_repo));

// Run our application
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
Expand Down
5 changes: 2 additions & 3 deletions examples/key-value-store/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ use std::{
};
use tower::{BoxError, ServiceBuilder};
use tower_http::{
add_extension::AddExtensionLayer, auth::RequireAuthorizationLayer,
compression::CompressionLayer, trace::TraceLayer,
auth::RequireAuthorizationLayer, compression::CompressionLayer, trace::TraceLayer,
};

#[tokio::main]
Expand Down Expand Up @@ -58,7 +57,7 @@ async fn main() {
.concurrency_limit(1024)
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.layer(AddExtensionLayer::new(SharedState::default()))
.layer(Extension(SharedState::default()))
.into_inner(),
);

Expand Down
6 changes: 3 additions & 3 deletions examples/oauth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use axum::{
http::{header::SET_COOKIE, HeaderMap},
response::{IntoResponse, Redirect, Response},
routing::get,
AddExtensionLayer, Router,
Router,
};
use http::header;
use oauth2::{
Expand Down Expand Up @@ -49,8 +49,8 @@ async fn main() {
.route("/auth/authorized", get(login_authorized))
.route("/protected", get(protected))
.route("/logout", get(logout))
.layer(AddExtensionLayer::new(store))
.layer(AddExtensionLayer::new(oauth_client));
.layer(Extension(store))
.layer(Extension(oauth_client));

let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
Expand Down
4 changes: 2 additions & 2 deletions examples/reverse-proxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use axum::{
extract::Extension,
http::{uri::Uri, Request, Response},
routing::get,
AddExtensionLayer, Router,
Router,
};
use hyper::{client::HttpConnector, Body};
use std::{convert::TryFrom, net::SocketAddr};
Expand All @@ -26,7 +26,7 @@ async fn main() {

let app = Router::new()
.route("/", get(handler))
.layer(AddExtensionLayer::new(client));
.layer(Extension(client));

let addr = SocketAddr::from(([127, 0, 0, 1], 4000));
println!("reverse proxy listening on {}", addr);
Expand Down
4 changes: 2 additions & 2 deletions examples/sessions/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use axum::{
},
response::IntoResponse,
routing::get,
AddExtensionLayer, Router,
Router,
};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
Expand All @@ -38,7 +38,7 @@ async fn main() {

let app = Router::new()
.route("/", get(handler))
.layer(AddExtensionLayer::new(store));
.layer(Extension(store));

let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
Expand Down
4 changes: 2 additions & 2 deletions examples/todos/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use std::{
time::Duration,
};
use tower::{BoxError, ServiceBuilder};
use tower_http::{add_extension::AddExtensionLayer, trace::TraceLayer};
use tower_http::trace::TraceLayer;
use uuid::Uuid;

#[tokio::main]
Expand Down Expand Up @@ -61,7 +61,7 @@ async fn main() {
}))
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.layer(AddExtensionLayer::new(db))
.layer(Extension(db))
.into_inner(),
);

Expand Down
4 changes: 2 additions & 2 deletions examples/tokio-postgres/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use axum::{
extract::{Extension, FromRequest, RequestParts},
http::StatusCode,
routing::get,
AddExtensionLayer, Router,
Router,
};
use bb8::{Pool, PooledConnection};
use bb8_postgres::PostgresConnectionManager;
Expand All @@ -36,7 +36,7 @@ async fn main() {
"/",
get(using_connection_pool_extractor).post(using_connection_extractor),
)
.layer(AddExtensionLayer::new(pool));
.layer(Extension(pool));

// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
Expand Down

0 comments on commit ee387ae

Please sign in to comment.