diff --git a/oak_functions_containers_app/src/native_handler.rs b/oak_functions_containers_app/src/native_handler.rs index 376ab04a201..5c5be127df7 100644 --- a/oak_functions_containers_app/src/native_handler.rs +++ b/oak_functions_containers_app/src/native_handler.rs @@ -30,7 +30,7 @@ use tempfile::{tempdir, TempDir}; struct RequestContext { request: Vec, response: Vec, - lookup_data: LookupData, + lookup_data: LookupData<16>, } thread_local! { @@ -129,7 +129,7 @@ struct SharedLibrary { /// Variant of a Handler that dynamically loads a `.so` file and invokes native /// code to handle requests from there. pub struct NativeHandler { - lookup_data_manager: Arc, + lookup_data_manager: Arc>, #[allow(dead_code)] observer: Option>, @@ -198,7 +198,7 @@ impl Handler for NativeHandler { fn new_handler( _config: (), module_bytes: &[u8], - lookup_data_manager: Arc, + lookup_data_manager: Arc>, observer: Option>, ) -> anyhow::Result { let directory = tempdir().context("could not create temporary directory")?; diff --git a/oak_functions_containers_app/tests/native_test.rs b/oak_functions_containers_app/tests/native_test.rs index b736583fd1e..b955ed66399 100644 --- a/oak_functions_containers_app/tests/native_test.rs +++ b/oak_functions_containers_app/tests/native_test.rs @@ -41,7 +41,7 @@ async fn test_native_handler() { .expect("failed to read test library"); let logger = Arc::new(StandaloneLogger); - let lookup_data_manager = Arc::new(LookupDataManager::new_empty(logger)); + let lookup_data_manager = Arc::new(LookupDataManager::<1>::new_empty(logger)); lookup_data_manager .extend_next_lookup_data([("key_0".as_bytes(), "value_0".as_bytes())].into_iter()); diff --git a/oak_functions_sdk/tests/integration_test.rs b/oak_functions_sdk/tests/integration_test.rs index 3d2b9d3b2ca..92704ccd11b 100644 --- a/oak_functions_sdk/tests/integration_test.rs +++ b/oak_functions_sdk/tests/integration_test.rs @@ -51,7 +51,8 @@ lazy_static! { #[tokio::test] async fn test_read_write() { let logger = Arc::new(StandaloneLogger); - let lookup_data_manager = Arc::new(LookupDataManager::for_test(Vec::default(), logger.clone())); + let lookup_data_manager = + Arc::new(LookupDataManager::<1>::for_test(Vec::default(), logger.clone())); let api_factory = StdWasmApiFactory { lookup_data_manager }; let wasm_handler = @@ -66,7 +67,8 @@ async fn test_read_write() { #[tokio::test] async fn test_double_read() { let logger = Arc::new(StandaloneLogger); - let lookup_data_manager = Arc::new(LookupDataManager::for_test(Vec::default(), logger.clone())); + let lookup_data_manager = + Arc::new(LookupDataManager::<1>::for_test(Vec::default(), logger.clone())); let api_factory = StdWasmApiFactory { lookup_data_manager }; let wasm_handler = @@ -81,7 +83,8 @@ async fn test_double_read() { #[tokio::test] async fn test_double_write() { let logger = Arc::new(StandaloneLogger); - let lookup_data_manager = Arc::new(LookupDataManager::for_test(Vec::default(), logger.clone())); + let lookup_data_manager = + Arc::new(LookupDataManager::<1>::for_test(Vec::default(), logger.clone())); let api_factory = StdWasmApiFactory { lookup_data_manager }; let wasm_handler = @@ -96,7 +99,8 @@ async fn test_double_write() { #[tokio::test] async fn test_write_log() { let logger = Arc::new(StandaloneLogger); - let lookup_data_manager = Arc::new(LookupDataManager::for_test(Vec::default(), logger.clone())); + let lookup_data_manager = + Arc::new(LookupDataManager::<1>::for_test(Vec::default(), logger.clone())); let api_factory = StdWasmApiFactory { lookup_data_manager }; let wasm_handler = @@ -113,7 +117,7 @@ async fn test_storage_get_item() { let entries = Vec::from_iter([(b"StorageGet".to_vec(), b"StorageGetResponse".to_vec())]); let logger = Arc::new(StandaloneLogger); - let lookup_data_manager = Arc::new(LookupDataManager::for_test(entries, logger.clone())); + let lookup_data_manager = Arc::new(LookupDataManager::<1>::for_test(entries, logger.clone())); let api_factory = StdWasmApiFactory { lookup_data_manager }; let wasm_handler = @@ -131,7 +135,7 @@ async fn test_storage_get_item_not_found() { let entries = Vec::default(); let logger = Arc::new(StandaloneLogger); - let lookup_data_manager = Arc::new(LookupDataManager::for_test(entries, logger.clone())); + let lookup_data_manager = Arc::new(LookupDataManager::<1>::for_test(entries, logger.clone())); let api_factory = StdWasmApiFactory { lookup_data_manager }; let wasm_handler = @@ -150,7 +154,7 @@ async fn test_storage_get_item_huge_key() { let entries = Vec::from_iter([(bytes.clone(), bytes.clone())]); let logger = Arc::new(StandaloneLogger); - let lookup_data_manager = Arc::new(LookupDataManager::for_test(entries, logger.clone())); + let lookup_data_manager = Arc::new(LookupDataManager::<1>::for_test(entries, logger.clone())); let api_factory = StdWasmApiFactory { lookup_data_manager }; let wasm_handler = @@ -168,7 +172,8 @@ async fn test_echo() { let logger = Arc::new(StandaloneLogger); let message_to_echo = "ECHO"; - let lookup_data_manager = Arc::new(LookupDataManager::for_test(Vec::default(), logger.clone())); + let lookup_data_manager = + Arc::new(LookupDataManager::<1>::for_test(Vec::default(), logger.clone())); let api_factory = StdWasmApiFactory { lookup_data_manager }; let wasm_handler = @@ -189,7 +194,8 @@ async fn test_blackhole() { let logger = Arc::new(StandaloneLogger); let message_to_blackhole = "BLACKHOLE"; - let lookup_data_manager = Arc::new(LookupDataManager::for_test(Vec::default(), logger.clone())); + let lookup_data_manager = + Arc::new(LookupDataManager::<1>::for_test(Vec::default(), logger.clone())); let api_factory = StdWasmApiFactory { lookup_data_manager }; let wasm_handler = @@ -210,7 +216,8 @@ async fn test_huge_response() { let logger = Arc::new(StandaloneLogger); - let lookup_data_manager = Arc::new(LookupDataManager::for_test(Vec::default(), logger.clone())); + let lookup_data_manager = + Arc::new(LookupDataManager::<1>::for_test(Vec::default(), logger.clone())); let api_factory = StdWasmApiFactory { lookup_data_manager }; let wasm_handler = diff --git a/oak_functions_service/benches/wasm_benchmark.rs b/oak_functions_service/benches/wasm_benchmark.rs index 917a93c1987..c74999b65da 100644 --- a/oak_functions_service/benches/wasm_benchmark.rs +++ b/oak_functions_service/benches/wasm_benchmark.rs @@ -223,7 +223,7 @@ fn create_test_data(start: i32, end: i32) -> HashMap, Vec> { struct TestState { wasm_handler: H::HandlerType, - lookup_data_manager: Arc, + lookup_data_manager: Arc>, } fn create_test_state_with_wasm_module_name(wasm_module_name: &str) -> TestState { diff --git a/oak_functions_service/src/instance.rs b/oak_functions_service/src/instance.rs index 94b2628aae1..1a3c8d99a30 100644 --- a/oak_functions_service/src/instance.rs +++ b/oak_functions_service/src/instance.rs @@ -33,7 +33,7 @@ use crate::{ }; pub struct OakFunctionsInstance { - lookup_data_manager: Arc, + lookup_data_manager: Arc>, wasm_handler: H::HandlerType, } diff --git a/oak_functions_service/src/lib.rs b/oak_functions_service/src/lib.rs index 54bb6826daf..e74fb8f702a 100644 --- a/oak_functions_service/src/lib.rs +++ b/oak_functions_service/src/lib.rs @@ -63,7 +63,7 @@ pub trait Handler { fn new_handler( config: Self::HandlerConfig, wasm_module_bytes: &[u8], - lookup_data_manager: Arc, + lookup_data_manager: Arc>, observer: Option>, ) -> anyhow::Result; diff --git a/oak_functions_service/src/lookup.rs b/oak_functions_service/src/lookup.rs index 9c5032987fb..105f0c44754 100644 --- a/oak_functions_service/src/lookup.rs +++ b/oak_functions_service/src/lookup.rs @@ -55,6 +55,7 @@ impl DataBuilder { /// /// Note, if new data contains a key already present in the existing data, /// calling extend overwrites the value. + #[allow(unused)] fn extend<'a, T: IntoIterator>(&mut self, new_data: T) { self.state = BuilderState::Extending; self.data.extend(new_data) @@ -72,25 +73,12 @@ impl DataBuilder { #[cfg(feature = "std")] mod mutexes { - pub use parking_lot::{Mutex, MutexGuard, RwLock}; + pub use parking_lot::{Mutex, RwLock}; } #[cfg(not(feature = "std"))] mod mutexes { - pub use spinning_top::{ - guard::SpinlockGuard as MutexGuard, RwSpinlock as RwLock, Spinlock as Mutex, - }; -} - -/// RAII data structure that holds an exclusive lock for multiple insertions. -pub struct LookupDataInserter<'a> { - lock: mutexes::MutexGuard<'a, DataBuilder>, -} - -impl<'a> LookupDataInserter<'a> { - pub fn insert(&mut self, key: &[u8], val: &[u8]) { - self.lock.insert(key, val); - } + pub use spinning_top::{RwSpinlock as RwLock, Spinlock as Mutex}; } /// Utility for managing lookup data. @@ -109,22 +97,29 @@ impl<'a> LookupDataInserter<'a> { /// /// In the future we may replace both the mutex and the hash map with something /// like RCU. -pub struct LookupDataManager { - data: mutexes::RwLock>, - // Behind a lock, because we have multiple references to LookupDataManager and need to mutate - // data builder. - data_builder: mutexes::Mutex, +pub struct LookupDataManager { + data: mutexes::RwLock<[Arc; S]>, + // The outer RwLock guards the DataBuilder-s themselves; while inserting data you need a read + // lock on the outer RwLock, but when finalizing lookup data you need to grab a write lock. + // The inner lock guards the contents of the DataBuilder, ensuring that we add data from only + // one thread at a time. + data_builder: mutexes::RwLock<[mutexes::Mutex; S]>, logger: Arc, } -impl LookupDataManager { +impl LookupDataManager { /// Creates a new instance with empty backing data. pub fn new_empty(logger: Arc) -> Self { + if S > 1 { + info!("Splitting lookup data hashmap into {}.", S); + } Self { - data: mutexes::RwLock::new(Arc::new(Data::default())), + data: mutexes::RwLock::new([(); S].map(|()| Arc::new(Data::default()))), // Incrementally builds the backing data that will be used by new `LookupData` // instances when finished. - data_builder: mutexes::Mutex::new(DataBuilder::default()), + data_builder: mutexes::RwLock::new( + [(); S].map(|()| mutexes::Mutex::new(DataBuilder::default())), + ), logger, } } @@ -139,17 +134,15 @@ impl LookupDataManager { } pub fn reserve(&self, additional_entries: u64) -> anyhow::Result<()> { - let mut data_builder = self.data_builder.lock(); - data_builder.reserve(additional_entries as usize); + // We're assuming uniform distribution here. + let entries_per_shard = additional_entries as usize / S; + self.data_builder.read().iter().for_each(|db| db.lock().reserve(entries_per_shard)); Ok(()) } - pub fn inserter(&self) -> LookupDataInserter<'_> { - LookupDataInserter { lock: self.data_builder.lock() } - } - pub fn insert(&self, key: &[u8], val: &[u8]) { - self.data_builder.lock().insert(key, val); + let index = crate::lookup_htbl::hash(key, 0) as usize % S; + self.data_builder.read()[index].lock().insert(key, val); } pub fn extend_next_lookup_data<'a, T: IntoIterator>( @@ -157,9 +150,10 @@ impl LookupDataManager { new_data: T, ) { info!("Start extending next lookup data"); - { - let mut data_builder = self.data_builder.lock(); - data_builder.extend(new_data); + let builder = self.data_builder.read(); + for (k, v) in new_data { + let index = crate::lookup_htbl::hash(k, 0) as usize % S; + builder[index].lock().insert(k, v); } info!("Finish extending next lookup data"); } @@ -167,16 +161,16 @@ impl LookupDataManager { // Finish building the next lookup data and replace the current lookup data in // place. pub fn finish_next_lookup_data(&self) { - let data_len; - let next_data_len; + let data_len: usize; + let next_data_len: usize; info!("Start replacing lookup data by next lookup data"); { - let mut data_builder = self.data_builder.lock(); - let next_data = data_builder.build(); - next_data_len = next_data.len(); + let mut data_builder = self.data_builder.write(); + let next_data = data_builder.each_mut().map(|builder| builder.lock().build()); + next_data_len = next_data.iter().map(|htbl| htbl.len()).sum(); let mut data = self.data.write(); - data_len = data.len(); - *data = Arc::new(next_data); + data_len = data.iter().map(|htbl| htbl.len()).sum(); + *data = next_data.map(Arc::new); } info!( "Finished replacing lookup data with len {} by next lookup data with len {}", @@ -187,20 +181,20 @@ impl LookupDataManager { pub fn abort_next_lookup_data(&self) { info!("Start aborting next lookup data"); { - let mut data_builder = self.data_builder.lock(); + let mut data_builder = self.data_builder.write(); // Clear the builder throwing away the intermediate result. - let _ = data_builder.build(); + let _ = data_builder.each_mut().map(|builder| builder.lock().build()); } info!("Finish aborting next lookup data"); } /// Creates a new `LookupData` instance with a reference to the current /// backing data. - pub fn create_lookup_data(&self) -> LookupData { - let keys; + pub fn create_lookup_data(&self) -> LookupData { + let keys: usize; let data = { let data = self.data.read().clone(); - keys = data.len(); + keys = data.iter().map(|data| data.len()).sum(); LookupData::new(data, self.logger.clone()) }; info!("Created lookup data with len: {}", keys); @@ -210,29 +204,30 @@ impl LookupDataManager { /// Provides access to shared lookup data. #[derive(Clone)] -pub struct LookupData { - data: Arc, +pub struct LookupData { + data: [Arc; S], logger: Arc, } -impl LookupData { - fn new(data: Arc, logger: Arc) -> Self { +impl LookupData { + fn new(data: [Arc; S], logger: Arc) -> Self { Self { data, logger } } /// Gets an individual entry from the backing data. pub fn get(&self, key: &[u8]) -> Option<&[u8]> { - self.data.get(key) + let index = crate::lookup_htbl::hash(key, 0) as usize % S; + self.data[index].get(key) } /// Gets the number of entries in the backing data. pub fn len(&self) -> usize { - self.data.len() + self.data.iter().map(|data| data.len()).sum() } /// Whether the backing data is empty. pub fn is_empty(&self) -> bool { - self.data.is_empty() + self.data.iter().all(|data| data.is_empty()) } /// Logs an error message. @@ -278,7 +273,7 @@ mod tests { fn test_lookup_data_instance_consistency() { // Ensure that the data for a specific lookup data instance remains consistent // even if the data in the manager has been updated. - let manager = LookupDataManager::new_empty(Arc::new(TestLogger)); + let manager = LookupDataManager::<1>::new_empty(Arc::new(TestLogger)); let lookup_data_0 = manager.create_lookup_data(); assert_eq!(lookup_data_0.len(), 0); @@ -298,7 +293,7 @@ mod tests { #[test] fn test_update_lookup_data_one_chunk() { - let manager = LookupDataManager::new_empty(Arc::new(TestLogger)); + let manager = LookupDataManager::<1>::new_empty(Arc::new(TestLogger)); reserve_and_extend_test_data(&manager, 0, 2); let lookup_data = manager.create_lookup_data(); assert_eq!(lookup_data.len(), 2); @@ -306,7 +301,7 @@ mod tests { #[test] fn test_update_lookup_data_two_chunks() { - let manager = LookupDataManager::new_empty(Arc::new(TestLogger)); + let manager = LookupDataManager::<1>::new_empty(Arc::new(TestLogger)); let lookup_data_0 = manager.create_lookup_data(); manager.reserve(4).unwrap(); @@ -328,7 +323,7 @@ mod tests { #[test] fn test_update_lookup_four_chunks() { - let manager = LookupDataManager::new_empty(Arc::new(TestLogger)); + let manager = LookupDataManager::<1>::new_empty(Arc::new(TestLogger)); manager.reserve(7).unwrap(); manager.extend_next_lookup_data( @@ -353,7 +348,7 @@ mod tests { #[test] fn test_update_lookup_data_abort_by_sender() { - let manager = LookupDataManager::new_empty(Arc::new(TestLogger)); + let manager = LookupDataManager::<1>::new_empty(Arc::new(TestLogger)); let lookup_data_0 = manager.create_lookup_data(); manager.reserve(2).unwrap(); @@ -394,7 +389,11 @@ mod tests { vec } - fn reserve_and_extend_test_data(manager: &LookupDataManager, start: i32, end: i32) { + fn reserve_and_extend_test_data( + manager: &LookupDataManager, + start: i32, + end: i32, + ) { manager.reserve((end - start) as u64).unwrap(); manager.extend_next_lookup_data( create_test_data(start, end).iter().map(|(k, v)| (k.as_ref(), v.as_ref())), diff --git a/oak_functions_service/src/lookup_htbl.rs b/oak_functions_service/src/lookup_htbl.rs index 2c15a6648c6..91b06e99dd6 100644 --- a/oak_functions_service/src/lookup_htbl.rs +++ b/oak_functions_service/src/lookup_htbl.rs @@ -454,7 +454,7 @@ fn hash_u64(v: u64, hash_secret: u64) -> u64 { } #[inline] -fn hash(v: &[u8], hash_secret: u64) -> u64 { +pub fn hash(v: &[u8], hash_secret: u64) -> u64 { let mut i = 0usize; let mut val = 0u64; let mut bytes = [0u8; 8]; diff --git a/oak_functions_service/src/wasm/api.rs b/oak_functions_service/src/wasm/api.rs index 457be2ae552..ef7bec2c8ee 100644 --- a/oak_functions_service/src/wasm/api.rs +++ b/oak_functions_service/src/wasm/api.rs @@ -33,11 +33,11 @@ use crate::{ /// The main purpose of this factory is to allow creating a new instance of the /// [`StdWasmApiImpl`] for each incoming gRPC request, with an immutable /// snapshot of the current lookup data. -pub struct StdWasmApiFactory { - pub lookup_data_manager: Arc, +pub struct StdWasmApiFactory { + pub lookup_data_manager: Arc>, } -impl WasmApiFactory for StdWasmApiFactory { +impl WasmApiFactory for StdWasmApiFactory { fn create_wasm_api(&self, request: Vec, response: Rc>>) -> Box { Box::new(StdWasmApiImpl { lookup_data: self.lookup_data_manager.create_lookup_data(), @@ -53,8 +53,8 @@ impl WasmApiFactory for StdWasmApiFactory { /// There are probably more locks than necessary here, it should be possible to /// reduce them in the future. #[derive(Clone)] -pub struct StdWasmApiImpl { - lookup_data: LookupData, +pub struct StdWasmApiImpl { + lookup_data: LookupData, logger: Rc, /// Current request, as received from the client. request: Vec, @@ -62,7 +62,7 @@ pub struct StdWasmApiImpl { response: Rc>>, } -impl StdWasmApi for StdWasmApiImpl { +impl StdWasmApi for StdWasmApiImpl { fn read_request( &mut self, _: ReadRequestRequest, @@ -157,7 +157,7 @@ impl StdWasmApi for StdWasmApiImpl { } } -impl WasmApi for StdWasmApiImpl { +impl WasmApi for StdWasmApiImpl { fn transport(&mut self) -> Box> { Box::new(StdWasmApiServer::new(self.clone())) } diff --git a/oak_functions_service/src/wasm/mod.rs b/oak_functions_service/src/wasm/mod.rs index ea711c8aa74..59e31ebf1f3 100644 --- a/oak_functions_service/src/wasm/mod.rs +++ b/oak_functions_service/src/wasm/mod.rs @@ -443,7 +443,7 @@ impl Handler for WasmHandler { fn new_handler( _config: WasmConfig, wasm_module_bytes: &[u8], - lookup_data_manager: Arc, + lookup_data_manager: Arc>, observer: Option>, ) -> anyhow::Result { let logger = Arc::new(StandaloneLogger); diff --git a/oak_functions_service/src/wasm/tests.rs b/oak_functions_service/src/wasm/tests.rs index e21b951a988..73b83de99f7 100644 --- a/oak_functions_service/src/wasm/tests.rs +++ b/oak_functions_service/src/wasm/tests.rs @@ -144,7 +144,8 @@ struct TestState { fn create_test_state() -> TestState { let logger = Arc::new(StandaloneLogger); - let lookup_data_manager = Arc::new(LookupDataManager::for_test(Vec::default(), logger.clone())); + let lookup_data_manager = + Arc::new(LookupDataManager::<1>::for_test(Vec::default(), logger.clone())); let api_factory = Arc::new(StdWasmApiFactory { lookup_data_manager: lookup_data_manager.clone() }); diff --git a/oak_functions_service/src/wasm/wasmtime.rs b/oak_functions_service/src/wasm/wasmtime.rs index 6c1abf21ed0..2908f37ed51 100644 --- a/oak_functions_service/src/wasm/wasmtime.rs +++ b/oak_functions_service/src/wasm/wasmtime.rs @@ -484,7 +484,7 @@ impl Handler for WasmtimeHandler { fn new_handler( config: WasmtimeConfig, wasm_module_bytes: &[u8], - lookup_data_manager: Arc, + lookup_data_manager: Arc>, observer: Option>, ) -> anyhow::Result { let logger = Box::new(StandaloneLogger);