diff --git a/Cargo.lock b/Cargo.lock index 85a692f..301282a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -170,12 +170,102 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.61", +] + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "gcra" version = "0.6.0" dependencies = [ "chrono", "dashmap", + "futures", "rustc-hash", "thiserror", "tokio", @@ -358,6 +448,12 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "proc-macro2" version = "1.0.82" @@ -418,6 +514,15 @@ dependencies = [ "libc", ] +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + [[package]] name = "smallvec" version = "1.10.0" diff --git a/Cargo.toml b/Cargo.toml index 4a62026..848f7d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ thiserror = "1.0.60" [dev-dependencies] chrono = "0.4.38" tokio = { version = "1.37.0", features = ["full"] } +futures = "0.3.30" [[example]] name = "rate_limiter" diff --git a/README.md b/README.md index 9dfa8ab..fdd1fef 100644 --- a/README.md +++ b/README.md @@ -34,17 +34,18 @@ fn check_rate_limit() { ### With `rate-limiter` ```rust +use std::sync::Arc; use gcra::{GcraError, RateLimit, RateLimiter}; #[tokio::main] async fn main() -> Result<(), GcraError> { let rate_limit = RateLimit::per_sec(2); - let rl = RateLimiter::new(4); + let rate_limiter = Arc::new(RateLimiter::new(4)); - rl.check("key", rate_limit.clone(), 1).await?; - rl.check("key", rate_limit.clone(), 1).await?; + rate_limiter.check("key", &rate_limit, 1).await?; + rate_limiter.check("key", &rate_limit, 1).await?; - match rl.check("key", rate_limit.clone(), 1).await { + match rate_limiter.check("key", rate_limit.clone(), 1).await { Err(GcraError::DeniedUntil { next_allowed_at }) => { print!("Denied: Request next at {:?}", next_allowed_at); Ok(()) diff --git a/examples/rate_limiter.rs b/examples/rate_limiter.rs index 63fe612..f9b7526 100644 --- a/examples/rate_limiter.rs +++ b/examples/rate_limiter.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use gcra::{GcraError, RateLimit, RateLimiter}; const CACHE_CAPACITY: usize = 4; @@ -6,12 +8,12 @@ const WORKER_SHARD_COUNT: usize = 2; #[tokio::main] async fn main() -> Result<(), GcraError> { let rate_limit = RateLimit::per_sec(2); - let rl = RateLimiter::with_shards(CACHE_CAPACITY, WORKER_SHARD_COUNT); + let rate_limiter = Arc::new(RateLimiter::with_shards(CACHE_CAPACITY, WORKER_SHARD_COUNT)); - rl.check("key", &rate_limit, 1).await?; - rl.check("key", &rate_limit, 1).await?; + rate_limiter.check("key", &rate_limit, 1).await?; + rate_limiter.check("key", &rate_limit, 1).await?; - match rl.check("key", &rate_limit, 1).await { + match rate_limiter.check("key", &rate_limit, 1).await { Err(GcraError::DeniedUntil { next_allowed_at }) => { print!("Denied: Request next at {:?}", next_allowed_at); Ok(()) diff --git a/src/lib.rs b/src/lib.rs index cd64abb..d15d1ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,17 +29,18 @@ //! ## With `rate-limiter` //! //! ```rust +//! use std::sync::Arc; //! use gcra::{GcraError, RateLimit, RateLimiter}; //! //! #[tokio::main] //! async fn main() -> Result<(), GcraError> { //! let rate_limit = RateLimit::per_sec(2); -//! let rl = RateLimiter::new(4); +//! let rate_limiter = Arc::new(RateLimiter::new(4)); //! -//! rl.check("key", &rate_limit, 1).await?; -//! rl.check("key", &rate_limit, 1).await?; +//! rate_limiter.check("key", &rate_limit, 1).await?; +//! rate_limiter.check("key", &rate_limit, 1).await?; //! -//! match rl.check("key", &rate_limit, 1).await { +//! match rate_limiter.check("key", &rate_limit, 1).await { //! Err(GcraError::DeniedUntil { next_allowed_at }) => { //! print!("Denied: Request next at {:?}", next_allowed_at); //! Ok(()) diff --git a/src/rate_limiter/rate_limiter.rs b/src/rate_limiter/rate_limiter.rs index ba566df..3ce5c67 100644 --- a/src/rate_limiter/rate_limiter.rs +++ b/src/rate_limiter/rate_limiter.rs @@ -134,9 +134,14 @@ where #[cfg(test)] mod tests { + use futures::stream::{self, StreamExt}; + use crate::clock::tests::FakeClock; use core::panic; - use std::time::{Duration, Instant}; + use std::{ + sync::Arc, + time::{Duration, Instant}, + }; use super::*; @@ -161,6 +166,33 @@ mod tests { } } + #[tokio::test] + async fn rate_limiter_run_until_denied_concurrent_access() { + let rate_limit = RateLimit::new(3, Duration::from_secs(3)); + let rate_limiter = Arc::new(RateLimiter::with_shards(4, 2)); + + let all_checked = stream::iter(0..rate_limit.resource_limit) + .then(|_| async { + let rate_limiter = rate_limiter.clone(); + rate_limiter.check("key", &rate_limit, 1).await + }) + .all(|result| async move { result.is_ok() }) + .await; + + assert!( + all_checked, + "All checks should have passed and not rate limited" + ); + + match rate_limiter.check("key", &rate_limit, 1).await { + Ok(_) => panic!("We should be rate limited"), + Err(GcraError::DeniedUntil { next_allowed_at }) => { + assert!(next_allowed_at > Instant::now()) + } + Err(_) => panic!("Unexpected error"), + } + } + #[tokio::test] async fn rate_limiter_indefinitly_denied() { let rate_limit = RateLimit::new(3, Duration::from_secs(3));