diff --git a/lightning/Cargo.toml b/lightning/Cargo.toml index 20888261b31..f85c18d89b4 100644 --- a/lightning/Cargo.toml +++ b/lightning/Cargo.toml @@ -49,6 +49,7 @@ regex = { version = "1.5.6", optional = true } backtrace = { version = "0.3", optional = true } libm = { version = "0.2", default-features = false } +delegate = "0.12.0" [dev-dependencies] regex = "1.5.6" diff --git a/lightning/src/util/mod.rs b/lightning/src/util/mod.rs index bb138dd69d1..a1d8d53195c 100644 --- a/lightning/src/util/mod.rs +++ b/lightning/src/util/mod.rs @@ -15,6 +15,9 @@ pub(crate) mod fuzz_wrappers; #[macro_use] pub mod ser_macros; +#[cfg(feature = "std")] +pub mod mut_global; + pub mod errors; pub mod ser; pub mod message_signing; diff --git a/lightning/src/util/mut_global.rs b/lightning/src/util/mut_global.rs new file mode 100644 index 00000000000..f9ba1fe7011 --- /dev/null +++ b/lightning/src/util/mut_global.rs @@ -0,0 +1,67 @@ +//! A settable global variable. +//! +//! Used for testing purposes only. + +use std::sync::Mutex; + +/// A global variable that can be set exactly once. +pub struct MutGlobal { + value: Mutex>, + default_fn: fn() -> T, +} + +impl MutGlobal { + /// Create a new `MutGlobal` with no value set. + pub const fn new(default_fn: fn() -> T) -> Self { + Self { value: Mutex::new(None), default_fn } + } + + /// Set the value of the global variable. + /// + /// Ignores any attempt to set the value more than once. + pub fn set(&self, value: T) { + let mut lock = self.value.lock().unwrap(); + *lock = Some(value); + } + + /// Get the value of the global variable. + /// + /// # Panics + /// + /// Panics if the value has not been set. + pub fn get(&self) -> T { + let mut lock = self.value.lock().unwrap(); + if let Some(value) = &*lock { + value.clone() + } else { + let value = (self.default_fn)(); + *lock = Some(value.clone()); + value + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test() { + let v = MutGlobal::::new(|| 0); + assert_eq!(v.get(), 0); + v.set(42); + assert_eq!(v.get(), 42); + v.set(43); + assert_eq!(v.get(), 43); + } + + static G: MutGlobal = MutGlobal::new(|| 0); + + #[test] + fn test_global() { + G.set(42); + assert_eq!(G.get(), 42); + G.set(43); + assert_eq!(G.get(), 43); + } +}