diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index eaf6562..5ca3f77 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -125,15 +125,19 @@ jobs: docker: false - store: mysql_store - features: diesel-store,diesel/mysql + features: diesel-store,diesel/mysql,diesel-r2d2 docker: true - store: postgres_store - features: diesel-store,diesel/postgres + features: diesel-store,diesel/postgres,diesel-r2d2 docker: true - store: diesel_store - features: diesel-store + features: diesel-store,diesel-r2d2 + docker: false + + - store: diesel_store_deadpool + features: diesel-store,diesel-deadpool docker: false steps: diff --git a/Cargo.toml b/Cargo.toml index e967b19..4d6ec92 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,8 @@ postgres-store = ["sqlx/postgres", "sqlx-store"] mysql-store = ["sqlx/mysql", "sqlx-store"] moka-store = ["moka"] diesel-store = ["dep:diesel", "tokio/rt", "rmp-serde"] +diesel-r2d2 = ["diesel/r2d2"] +diesel-deadpool = ["dep:deadpool-diesel", "dep:deadpool"] [dependencies] async-trait = "0.1.73" @@ -58,9 +60,10 @@ sqlx = { optional = true, version = "0.7.2", features = [ tokio = { optional = true, version = "1.32.0", default-features = false } moka = { optional = true, version = "0.12.0", features = ["future"] } diesel = { optional = true, default-features = false, version = "2.1", features = [ - "r2d2", "time", ] } +deadpool-diesel = { version = "0.5", optional = true } +deadpool = { version ="0.10", optional = true } [dev-dependencies] axum = "0.6.20" @@ -132,6 +135,7 @@ required-features = [ "diesel-store", "diesel/sqlite", "continuously-delete-expired", + "diesel-r2d2", ] @@ -142,4 +146,15 @@ required-features = [ "diesel-store", "diesel/sqlite", "continuously-delete-expired", + "diesel-r2d2", +] + +[[example]] +name = "diesel-deadpool" +required-features = [ + "axum-core", + "diesel-store", + "diesel/sqlite", + "continuously-delete-expired", + "diesel-deadpool", ] diff --git a/examples/diesel-deadpool.rs b/examples/diesel-deadpool.rs new file mode 100644 index 0000000..63ebeae --- /dev/null +++ b/examples/diesel-deadpool.rs @@ -0,0 +1,69 @@ +use std::net::SocketAddr; + +use axum::{ + error_handling::HandleErrorLayer, response::IntoResponse, routing::get, BoxError, Router, +}; +use deadpool_diesel::{Manager, Pool}; +use diesel::prelude::*; +use http::StatusCode; +use serde::{Deserialize, Serialize}; +use time::Duration; +use tower::ServiceBuilder; +use tower_sessions::{ + diesel_store::DieselStore, session_store::ExpiredDeletion, Session, SessionManagerLayer, +}; + +const COUNTER_KEY: &str = "counter"; + +#[derive(Serialize, Deserialize, Default)] +struct Counter(usize); + +#[tokio::main] +async fn main() -> Result<(), Box> { + let manager = Manager::::new(":memory:", deadpool_diesel::Runtime::Tokio1); + let pool = Pool::builder(manager).max_size(1).build()?; + let session_store = DieselStore::new(pool); + session_store.migrate().await?; + + let deletion_task = tokio::task::spawn( + session_store + .clone() + .continuously_delete_expired(tokio::time::Duration::from_secs(60)), + ); + + let session_service = ServiceBuilder::new() + .layer(HandleErrorLayer::new(|_: BoxError| async { + StatusCode::BAD_REQUEST + })) + .layer( + SessionManagerLayer::new(session_store) + .with_secure(false) + .with_max_age(Duration::seconds(10)), + ); + + let app = Router::new() + .route("/", get(handler)) + .layer(session_service); + + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .await?; + + deletion_task.await??; + + Ok(()) +} + +async fn handler(session: Session) -> impl IntoResponse { + let counter: Counter = session + .get(COUNTER_KEY) + .expect("Could not deserialize.") + .unwrap_or_default(); + + session + .insert(COUNTER_KEY, counter.0 + 1) + .expect("Could not serialize."); + + format!("Current count: {}", counter.0) +} diff --git a/src/diesel_store.rs b/src/diesel_store.rs index b2bc344..db76956 100644 --- a/src/diesel_store.rs +++ b/src/diesel_store.rs @@ -10,12 +10,11 @@ use diesel::{ expression::{AsExpression, ValidGrouping}, expression_methods::ExpressionMethods, helper_types::{Eq, Filter, Gt, IntoBoxed, SqlTypeOf}, - prelude::{BoolExpressionMethods, Insertable, Queryable}, + prelude::{BoolExpressionMethods, Connection, Insertable, Queryable}, query_builder::{ AsQuery, DeleteStatement, InsertStatement, IntoUpdateTarget, QueryBuilder, QueryFragment, }, query_dsl::methods::{BoxedDsl, ExecuteDsl, FilterDsl, LimitDsl, LoadQuery}, - r2d2::{ConnectionManager, ManageConnection, Pool, R2D2Connection}, sql_types::{Binary, Bool, Nullable, SingleValue, SqlType, Text, Timestamp}, BoxableExpression, Column, Expression, OptionalExtension, QueryDsl, RunQueryDsl, SelectableExpression, Table, @@ -25,36 +24,37 @@ use crate::{session_store::ExpiredDeletion, SessionStore}; /// An error type for diesel stores #[derive(thiserror::Error, Debug)] +#[non_exhaustive] pub enum DieselStoreError { + #[cfg(feature = "diesel-r2d2")] /// A pool related error #[error("Pool Error: {0}")] R2D2Error(#[from] diesel::r2d2::PoolError), /// A diesel related error #[error("Diesel Error: {0}")] DieselError(#[from] diesel::result::Error), + #[cfg(feature = "diesel-r2d2")] /// Failed to join a blocking tokio task #[error("Failed to join task: {0}")] - TokioJoinERror(#[from] tokio::task::JoinError), + TokioJoinError(#[from] tokio::task::JoinError), /// A variant to map `rmp_serde` encode errors. #[error("Failed to serialize session data: {0}")] SerializationError(#[from] rmp_serde::encode::Error), + #[cfg(feature = "diesel-deadpool")] + #[error("Failed to interact with deadpool: {0}")] + /// A variant that indicates that we cannot interact with deadpool + InteractError(String), + #[cfg(feature = "diesel-deadpool")] + #[error("Failed to get a connection from deadpool: {0}")] + /// A variant that indicates that we cannot get a connection from deadpool + DeadpoolError(#[from] deadpool_diesel::PoolError), } /// A Diesel session store -#[derive(Debug)] -pub struct DieselStore { +#[derive(Debug, Clone)] +pub struct DieselStore { p: PhantomData, - pool: diesel::r2d2::Pool>, -} - -// custom impl as we don't want to have `Clone bounds on the types -impl Clone for DieselStore { - fn clone(&self) -> Self { - Self { - p: self.p, - pool: self.pool.clone(), - } - } + pool: P, } diesel::table! { @@ -76,7 +76,7 @@ diesel::table! { /// definition pub trait SessionTable: Copy + Send + Sync + AsQuery + HasTable + Table where - C: R2D2Connection, + C: Connection, { /// the `expiration_time` column of your table type ExpirationTime: Column> @@ -97,10 +97,83 @@ where } } +/// A helper trait to abstract over different pooling solutions for diesel +#[async_trait::async_trait] +pub trait DieselPool: Clone + Sync + Send + 'static { + /// The underlying diesel connection type used by this pool + type Connection: Connection + 'static; + + /// Interact with a connection from that pool + async fn interact(&self, c: F) -> Result + where + R: Send + 'static, + E: Send + 'static, + F: FnOnce(&mut Self::Connection) -> Result + Send + 'static, + DieselStoreError: From; +} + +#[cfg(feature = "diesel-r2d2")] +#[async_trait::async_trait] +impl DieselPool for diesel::r2d2::Pool> +where + C: diesel::r2d2::R2D2Connection + 'static, +{ + type Connection = C; + + async fn interact(&self, c: F) -> Result + where + F: FnOnce(&mut Self::Connection) -> Result + Send + 'static, + DieselStoreError: From, + R: Send + 'static, + E: Send + 'static, + { + let pool = self.clone(); + Ok(tokio::task::spawn_blocking(move || { + let mut conn = pool.get()?; + let r = c(&mut *conn)?; + Ok::<_, DieselStoreError>(r) + }) + .await??) + } +} + +#[cfg(feature = "diesel-deadpool")] +#[async_trait::async_trait] +impl DieselPool for deadpool_diesel::Pool> +where + C: Connection + 'static, + deadpool_diesel::Manager: deadpool::managed::Manager>, + as deadpool::managed::Manager>::Type: Send + Sync, + as deadpool::managed::Manager>::Error: std::fmt::Debug, + DieselStoreError: From< + deadpool::managed::PoolError< + as deadpool::managed::Manager>::Error, + >, + >, +{ + type Connection = C; + + async fn interact(&self, c: F) -> Result + where + F: FnOnce(&mut Self::Connection) -> Result + Send + 'static, + DieselStoreError: From, + R: Send + 'static, + E: Send + 'static, + { + let conn = self.get().await?; + let r = conn + .as_ref() + .interact(c) + .await + .map_err(|e| DieselStoreError::InteractError(e.to_string()))?; + r.map_err(Into::into) + } +} + impl SessionTable for self::sessions::table where ::QueryBuilder: Default, - C: diesel::r2d2::R2D2Connection, + C: diesel::Connection, InsertStatement< Self, <( @@ -166,9 +239,9 @@ where } } -impl DieselStore +impl

DieselStore

where - C: R2D2Connection, + P: DieselPool, { /// Create a new diesel store with a provided connection pool. /// @@ -186,7 +259,7 @@ where /// .unwrap(); /// let session_store = DieselStore::new(pool); /// ``` - pub fn new(pool: diesel::r2d2::Pool>) -> Self { + pub fn new(pool: P) -> Self { Self { pool, p: PhantomData, @@ -194,27 +267,29 @@ where } } -impl DieselStore +impl DieselStore where - C: R2D2Connection, - T: SessionTable, + P: DieselPool, + T: SessionTable, { /// Create a new diesel store with a provided connection pool. /// /// # Examples /// /// ```rust,no_run - /// use tower_sessions::diesel_store::{DieselStore}; - /// use diesel::prelude::*; - /// use diesel::r2d2::{Pool, ConnectionManager}; + /// use diesel::{ + /// prelude::*, + /// r2d2::{ConnectionManager, Pool}, + /// }; + /// use tower_sessions::diesel_store::DieselStore; /// - /// let pool = Pool::builder().build(ConnectionManager::::new(":memory:")).unwrap(); - /// let session_store = DieselStore::with_table(tower_sessions::diesel_store::sessions::table, pool); + /// let pool = Pool::builder() + /// .build(ConnectionManager::::new(":memory:")) + /// .unwrap(); + /// let session_store = + /// DieselStore::with_table(tower_sessions::diesel_store::sessions::table, pool); /// ``` - pub fn with_table( - _table: T, - pool: diesel::r2d2::Pool>, - ) -> Self { + pub fn with_table(_table: T, pool: P) -> Self { Self { pool, p: PhantomData, @@ -223,13 +298,12 @@ where /// Migrate the session schema. pub async fn migrate(&self) -> Result<(), DieselStoreError> { - let pool = self.pool.clone(); - tokio::task::spawn_blocking(move || { - let mut conn = pool.get()?; - T::migrate(&mut conn)?; - Ok::<_, DieselStoreError>(()) - }) - .await??; + self.pool + .interact(|conn| { + T::migrate(conn)?; + Ok::<_, DieselStoreError>(()) + }) + .await?; Ok(()) } } @@ -255,45 +329,42 @@ where } #[async_trait::async_trait] -impl SessionStore for DieselStore +impl SessionStore for DieselStore where - T: SessionTable + 'static, + P: DieselPool, + T: SessionTable + 'static, String: AsExpression>, T::PrimaryKey: Default, ::SqlType: SqlType + SingleValue, - DeleteStatement>: ExecuteDsl, - T: FilterDsl> + BoxedDsl<'static, C::Backend>, - IntoBoxed<'static, T, C::Backend>: LimitDsl>, - IntoBoxed<'static, T, C::Backend>: FilterDsl< + DeleteStatement>: ExecuteDsl, + T: FilterDsl> + + BoxedDsl<'static, ::Backend>, + IntoBoxed<'static, T, ::Backend>: + LimitDsl::Backend>>, + IntoBoxed<'static, T, ::Backend>: FilterDsl< And< Eq, Or, Gt, Nullable>, Nullable, >, - Output = IntoBoxed<'static, T, C::Backend>, + Output = IntoBoxed<'static, T, ::Backend>, >, Filter>: IntoUpdateTarget, DeleteStatement< > as HasTable>::Table, > as IntoUpdateTarget>::WhereClause, - >: ExecuteDsl, + >: ExecuteDsl, Eq: BoolExpressionMethods, - for<'a> IntoBoxed<'static, T, C::Backend>: LoadQuery<'a, C, crate::Session>, - Pool>: Clone, - ConnectionManager: ManageConnection, - C: R2D2Connection, + for<'a> IntoBoxed<'static, T, ::Backend>: + LoadQuery<'a, P::Connection, crate::Session>, { type Error = DieselStoreError; async fn save(&self, session_record: &crate::SessionRecord) -> Result<(), Self::Error> { - let pool = self.pool.clone(); let record = session_record.clone(); - tokio::task::spawn_blocking(move || { - let conn: &mut diesel::r2d2::PooledConnection> = - &mut pool.get()?; - T::insert(conn, &record) - }) - .await??; + self.pool + .interact(move |conn| T::insert(conn, &record)) + .await?; Ok(()) } @@ -302,69 +373,101 @@ where session_id: &crate::session::SessionId, ) -> Result, Self::Error> { let session_id = session_id.to_string(); - let pool = self.pool.clone(); - let res = tokio::task::spawn_blocking(move || { - let conn: &mut diesel::r2d2::PooledConnection> = - &mut pool.get()?; - - let q = T::table() - .into_boxed() - .limit(1) - .filter( - T::PrimaryKey::default().eq(session_id.to_string()).and( - T::ExpirationTime::default() - .is_null() - .or(T::ExpirationTime::default().gt(diesel::dsl::now)), - ), - ) - .get_result(conn) - .optional()?; - Ok::<_, DieselStoreError>(q) - }) - .await??; - + let res = self + .pool + .interact(move |conn| { + let q = T::table() + .into_boxed() + .limit(1) + .filter( + T::PrimaryKey::default().eq(session_id.to_string()).and( + T::ExpirationTime::default() + .is_null() + .or(T::ExpirationTime::default().gt(diesel::dsl::now)), + ), + ) + .get_result(conn) + .optional()?; + Ok::<_, DieselStoreError>(q) + }) + .await?; Ok(res) } async fn delete(&self, session_id: &crate::session::SessionId) -> Result<(), Self::Error> { let session_id = session_id.to_string(); - let pool = self.pool.clone(); - tokio::task::spawn_blocking(move || { - let conn: &mut diesel::r2d2::PooledConnection> = - &mut pool.get()?; - diesel::delete(T::table().filter(T::PrimaryKey::default().eq(session_id.to_string()))) + self.pool + .interact(move |conn| { + diesel::delete( + T::table().filter(T::PrimaryKey::default().eq(session_id.to_string())), + ) .execute(conn)?; - Ok::<_, DieselStoreError>(()) - }) - .await??; + Ok::<_, DieselStoreError>(()) + }) + .await?; Ok(()) } } #[async_trait] -impl ExpiredDeletion for DieselStore +impl ExpiredDeletion for DieselStore where Self: SessionStore, - C: R2D2Connection, - T: SessionTable, - T: FilterDsl>>>, - Filter>>>: IntoUpdateTarget, + P: DieselPool, + T: SessionTable, + T: FilterDsl< + Box< + dyn BoxableExpression< + T, + ::Backend, + SqlType = Nullable, + >, + >, + >, + Filter< + T, + Box< + dyn BoxableExpression< + T, + ::Backend, + SqlType = Nullable, + >, + >, + >: IntoUpdateTarget, DeleteStatement< - >>> as HasTable>::Table, - >>> as IntoUpdateTarget>::WhereClause, - >: ExecuteDsl, - Lt: QueryFragment + SelectableExpression + Expression>, + ::Backend, + SqlType = Nullable, + >, + >, + > as HasTable>::Table, + ::Backend, + SqlType = Nullable, + >, + >, + > as IntoUpdateTarget>::WhereClause, + >: ExecuteDsl, + Lt: QueryFragment<::Backend> + + SelectableExpression + + Expression>, { async fn delete_expired(&self) -> Result<(), Self::Error> { - let pool = self.pool.clone(); - tokio::task::spawn_blocking(move || { - let mut conn = pool.get()?; - let filter: Box>> = Box::new(T::ExpirationTime::default().lt(diesel::dsl::now)) as Box<_>; - diesel::delete(T::table().filter(filter)) - .execute(&mut conn)?; - Ok::<_, DieselStoreError>(()) - }) - .await??; + self.pool + .interact(|conn| { + let filter = Box::new(T::ExpirationTime::default().lt(diesel::dsl::now)) as Box<_>; + diesel::delete(T::table().filter(filter)).execute(conn)?; + Ok::<_, DieselStoreError>(()) + }) + .await?; Ok(()) } } diff --git a/tests/integration-tests.rs b/tests/integration-tests.rs index 835784e..ea0ea6f 100644 --- a/tests/integration-tests.rs +++ b/tests/integration-tests.rs @@ -118,13 +118,19 @@ mod mysql_store_tests { route_tests!(app); } -#[cfg(all(test, feature = "axum-core", feature = "diesel-store"))] +#[cfg(all( + test, + feature = "axum-core", + feature = "diesel-store", + feature = "diesel-r2d2" +))] mod diesel_sqlite_store_tests { use axum::Router; - use diesel::prelude::*; - use diesel::r2d2::{ConnectionManager, Pool}; - use tower_sessions::diesel_store::DieselStore; - use tower_sessions::SessionManagerLayer; + use diesel::{ + prelude::*, + r2d2::{ConnectionManager, Pool}, + }; + use tower_sessions::{diesel_store::DieselStore, SessionManagerLayer}; use crate::common::build_app; @@ -147,14 +153,16 @@ mod diesel_sqlite_store_tests { test, feature = "axum-core", feature = "diesel-store", - feature = "__diesel_postgres" + feature = "__diesel_postgres", + feature = "diesel-r2d2" ))] mod diesel_pg_store_tests { use axum::Router; - use diesel::prelude::*; - use diesel::r2d2::{ConnectionManager, Pool}; - use tower_sessions::diesel_store::DieselStore; - use tower_sessions::SessionManagerLayer; + use diesel::{ + prelude::*, + r2d2::{ConnectionManager, Pool}, + }; + use tower_sessions::{diesel_store::DieselStore, SessionManagerLayer}; use crate::common::build_app; @@ -178,14 +186,16 @@ mod diesel_pg_store_tests { test, feature = "axum-core", feature = "diesel-store", - feature = "__diesel_mysql" + feature = "__diesel_mysql", + feature = "diesel-r2d2" ))] mod diesel_mysql_store_tests { use axum::Router; - use diesel::prelude::*; - use diesel::r2d2::{ConnectionManager, Pool}; - use tower_sessions::diesel_store::DieselStore; - use tower_sessions::SessionManagerLayer; + use diesel::{ + prelude::*, + r2d2::{ConnectionManager, Pool}, + }; + use tower_sessions::{diesel_store::DieselStore, SessionManagerLayer}; use crate::common::build_app; @@ -205,6 +215,34 @@ mod diesel_mysql_store_tests { route_tests!(app); } +#[cfg(all( + test, + feature = "axum-core", + feature = "diesel-store", + feature = "diesel-deadpool" +))] +mod diesel_sqlite_store_deadpool_tests { + use axum::Router; + use deadpool_diesel::{Manager, Pool}; + use diesel::prelude::*; + use tower_sessions::{diesel_store::DieselStore, SessionManagerLayer}; + + use crate::common::build_app; + + async fn app(max_age: Option) -> Router { + let manager = + Manager::::new(":memory:", deadpool_diesel::Runtime::Tokio1); + let pool = Pool::builder(manager).max_size(1).build()?; + let session_store = DieselStore::new(pool); + session_store.migrate().await.unwrap(); + let session_manager = SessionManagerLayer::new(session_store).with_secure(true); + + build_app(session_manager, max_age) + } + + route_tests!(app); +} + #[cfg(all(test, feature = "axum-core", feature = "mongodb-store"))] mod mongodb_store_tests { use axum::Router;