diff --git a/README.md b/README.md index 021d422..39b5fc0 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ async fn main() { let a = FileAdapter::new("examples/rbac_with_pattern_policy.csv"); - let casbin_hoop = CasbinHoop::new(Enforcer::new(m, a).await.unwrap(), |_req, _depot| { + let casbin_hoop = CasbinHoop::new(Enforcer::new(m, a).await.unwrap(), false, |_req, _depot| { Ok(Some(CasbinVals { subject: String::from("alice"), domain: None, diff --git a/src/hoop.rs b/src/hoop.rs index 1328a08..da8b7c5 100644 --- a/src/hoop.rs +++ b/src/hoop.rs @@ -15,6 +15,7 @@ pub struct CasbinVals { #[derive(Clone)] pub struct CasbinHoop { enforcer: Arc>, + use_enforcer_mut: bool, get_casbin_vals: F, } impl Deref for CasbinHoop { @@ -40,14 +41,15 @@ where + Sync + 'static, { - pub fn new(enforcer: E, get_casbin_vals: F) -> Self { + pub fn new(enforcer: E, use_enforcer_mut: bool, get_casbin_vals: F) -> Self { CasbinHoop { enforcer: Arc::new(RwLock::new(enforcer)), + use_enforcer_mut, get_casbin_vals, } } - pub fn get_enforcer(&mut self) -> Arc> { + pub fn get_enforcer(&self) -> Arc> { self.enforcer.clone() } @@ -75,40 +77,29 @@ where let path = req.uri().path().to_string(); let action = req.method().as_str().to_string(); - if !vals.subject.is_empty() { - if let Some(domain) = vals.domain { - let mut lock = self.enforcer.write().await; - match lock.enforce_mut(vec![vals.subject, domain, path, action]) { - Ok(true) => { - drop(lock); - } - Ok(false) => { - drop(lock); - res.render(StatusError::forbidden()); - } - Err(_) => { - drop(lock); - res.render(StatusError::bad_gateway()); - } - } - } else { - let mut lock = self.enforcer.write().await; - match lock.enforce_mut(vec![vals.subject, path, action]) { - Ok(true) => { - drop(lock); - } - Ok(false) => { - drop(lock); - res.render(StatusError::forbidden()); - } - Err(_) => { - drop(lock); - res.render(StatusError::bad_gateway()); - } - } - } - } else { + if vals.subject.is_empty() { res.render(StatusError::unauthorized()); + return; + } + + let rvals = if let Some(domain) = vals.domain { + vec![vals.subject, domain, path, action] + } else { + vec![vals.subject, path, action] + }; + let r = if self.use_enforcer_mut { + self.enforcer.write().await.enforce_mut(rvals) + } else { + self.enforcer.read().await.enforce(rvals) + }; + match r { + Ok(true) => {} + Ok(false) => { + res.render(StatusError::forbidden()); + } + Err(_) => { + res.render(StatusError::bad_gateway()); + } } } } diff --git a/tests/test_hoop.rs b/tests/test_hoop.rs index 0db7e34..eeda3c7 100644 --- a/tests/test_hoop.rs +++ b/tests/test_hoop.rs @@ -18,7 +18,7 @@ async fn test_hoop() { let a = FileAdapter::new("examples/rbac_with_pattern_policy.csv"); - let casbin_hoop = CasbinHoop::new(Enforcer::new(m, a).await.unwrap(), |_req, _depot| { + let casbin_hoop = CasbinHoop::new(Enforcer::new(m, a).await.unwrap(), false, |_req, _depot| { Ok(Some(CasbinVals { subject: String::from("alice"), domain: None, diff --git a/tests/test_hoop_domain.rs b/tests/test_hoop_domain.rs index 9572cf7..93871e2 100644 --- a/tests/test_hoop_domain.rs +++ b/tests/test_hoop_domain.rs @@ -16,7 +16,7 @@ async fn test_hoop_domain() { .unwrap(); let a = FileAdapter::new("examples/rbac_with_domains_policy.csv"); - let casbin_hoop = CasbinHoop::new(Enforcer::new(m, a).await.unwrap(), |_req, _depot| { + let casbin_hoop = CasbinHoop::new(Enforcer::new(m, a).await.unwrap(), false, |_req, _depot| { Ok(Some(CasbinVals { subject: String::from("alice"), domain: Some(String::from("domain1")), diff --git a/tests/test_set_enforcer.rs b/tests/test_set_enforcer.rs index ab599f0..11e5676 100644 --- a/tests/test_set_enforcer.rs +++ b/tests/test_set_enforcer.rs @@ -17,7 +17,7 @@ async fn test_set_enforcer() { .unwrap(); let a = FileAdapter::new("examples/rbac_with_pattern_policy.csv"); - let casbin_hoop = CasbinHoop::new(Enforcer::new(m, a).await.unwrap(), |_req, _depot| { + let casbin_hoop = CasbinHoop::new(Enforcer::new(m, a).await.unwrap(), false,|_req, _depot| { Ok(Some(CasbinVals { subject: String::from("alice"), domain: None,