diff --git a/src/future.rs b/src/future.rs index 08d3c61..c61ffad 100644 --- a/src/future.rs +++ b/src/future.rs @@ -6,6 +6,7 @@ use std::{ }; use futures::Stream; +use tracing::error; /// Used to track query futures #[derive(Debug, Clone)] @@ -23,6 +24,7 @@ pub struct SubscriptionStream { ndb: Ndb, sub_id: Subscription, max_notes: u32, + unsubscribe_on_drop: bool, } impl SubscriptionStream { @@ -30,9 +32,11 @@ impl SubscriptionStream { // Most of the time we only want to fetch a few things. If expecting // lots of data, use `set_max_notes_per_await` let max_notes = 32; + let unsubscribe_on_drop = true; SubscriptionStream { ndb, sub_id, + unsubscribe_on_drop, max_notes, } } @@ -42,6 +46,13 @@ impl SubscriptionStream { self } + /// Unsubscribe the subscription when this stream goes out of scope. On + /// by default. Recommended unless you want subscription leaks. + pub fn unsubscribe_on_drop(mut self, yes: bool) -> Self { + self.unsubscribe_on_drop = yes; + self + } + pub fn sub_id(&self) -> Subscription { self.sub_id } @@ -50,8 +61,17 @@ impl SubscriptionStream { impl Drop for SubscriptionStream { fn drop(&mut self) { // Perform cleanup here, like removing the subscription from the global map - let mut map = self.ndb.subs.lock().unwrap(); - map.remove(&self.sub_id); + { + let mut map = self.ndb.subs.lock().unwrap(); + map.remove(&self.sub_id); + } + // unsubscribe + if let Err(err) = self.ndb.unsubscribe(self.sub_id) { + error!( + "Error unsubscribing from {} in SubscriptionStream Drop: {err}", + self.sub_id.id() + ); + } } } diff --git a/src/ndb.rs b/src/ndb.rs index 5e3f462..eba8173 100644 --- a/src/ndb.rs +++ b/src/ndb.rs @@ -555,9 +555,42 @@ mod tests { test_util::cleanup_db(&db); } + #[tokio::test] + async fn test_unsub_on_drop() { + let db = "target/testdbs/test_unsub_on_drop"; + test_util::cleanup_db(&db); + + { + let mut ndb = Ndb::new(db, &Config::new()).expect("ndb"); + let sub_id = { + let filter = Filter::new().kinds(vec![1]).build(); + let filters = vec![filter]; + + let sub_id = ndb.subscribe(&filters).expect("sub_id"); + let mut sub = sub_id.stream(&ndb).notes_per_await(1); + + let res = sub.next(); + + ndb.process_event(r#"["EVENT","b",{"id": "702555e52e82cc24ad517ba78c21879f6e47a7c0692b9b20df147916ae8731a3","pubkey": "32bf915904bfde2d136ba45dde32c88f4aca863783999faea2e847a8fafd2f15","created_at": 1702675561,"kind": 1,"tags": [],"content": "hello, world","sig": "2275c5f5417abfd644b7bc74f0388d70feb5d08b6f90fa18655dda5c95d013bfbc5258ea77c05b7e40e0ee51d8a2efa931dc7a0ec1db4c0a94519762c6625675"}]"#).expect("process ok"); + + let res = res.await.expect("await ok"); + assert_eq!(res, vec![NoteKey::new(1)]); + + assert!(ndb.subs.lock().unwrap().contains_key(&sub_id)); + sub_id + }; + + // ensure subscription state is removed after stream is dropped + assert!(!ndb.subs.lock().unwrap().contains_key(&sub_id)); + assert_eq!(ndb.subscription_count(), 0); + } + + test_util::cleanup_db(&db); + } + #[tokio::test] async fn test_stream() { - let db = "target/testdbs/test_callback"; + let db = "target/testdbs/test_stream"; test_util::cleanup_db(&db); {