From c9e64f043a141e38e1fa9a01e6846f91b9088af5 Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 17 Sep 2024 19:17:39 +0800 Subject: [PATCH] WIP --- src/buffer.rs | 11 +- src/error.rs | 57 ++++++++- src/lib.rs | 39 +++--- src/swmr/generic.rs | 8 +- src/tests.rs | 97 ++++++++++++++- src/utils.rs | 16 ++- src/wal.rs | 26 +++- src/wal/sealed.rs | 291 ++++++++++++++++++++++++++++++++++---------- 8 files changed, 436 insertions(+), 109 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index 946a532d..2a3e2021 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -14,13 +14,22 @@ macro_rules! builder { impl [< $name Builder >] { #[doc = "Creates a new `" [<$name Builder>] "` with the given size and builder closure."] #[inline] - pub const fn new(size: $size, f: F) -> Self + pub const fn once(size: $size, f: F) -> Self where F: for<'a> FnOnce(&mut VacantBuffer<'a>) -> Result<(), E>, { Self { size, f } } + #[doc = "Creates a new `" [<$name Builder>] "` with the given size and builder closure."] + #[inline] + pub const fn new(size: $size, f: F) -> Self + where + F: for<'a> Fn(&mut VacantBuffer<'a>) -> Result<(), E>, + { + Self { size, f } + } + #[doc = "Returns the required" [< $name: snake>] "size."] #[inline] pub const fn size(&self) -> $size { diff --git a/src/error.rs b/src/error.rs index 0c04f85e..3169d91b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,19 @@ +/// The batch error type. +#[derive(Debug, thiserror::Error)] +pub enum BatchError { + /// Returned when the expected batch encoding size does not match the actual size. + #[error("the expected batch encoding size ({expected}) does not match the actual size {actual}")] + EncodedSizeMismatch { + /// The expected size. + expected: u32, + /// The actual size. + actual: u32, + }, + /// Larger encoding size than the expected batch encoding size. + #[error("larger encoding size than the expected batch encoding size {0}")] + LargerEncodedSize(u32), +} + /// The error type. #[derive(Debug, thiserror::Error)] pub enum Error { @@ -13,7 +29,7 @@ pub enum Error { #[error("the key size is {size} larger than the maximum key size {maximum_key_size}")] KeyTooLarge { /// The size of the key. - size: u32, + size: u64, /// The maximum key size. maximum_key_size: u32, }, @@ -21,7 +37,7 @@ pub enum Error { #[error("the value size is {size} larger than the maximum value size {maximum_value_size}")] ValueTooLarge { /// The size of the value. - size: u32, + size: u64, /// The maximum value size. maximum_value_size: u32, }, @@ -33,6 +49,9 @@ pub enum Error { /// The maximum entry size. maximum_entry_size: u64, }, + /// Returned when the expected batch encoding size does not match the actual size. + #[error(transparent)] + Batch(#[from] BatchError), /// I/O error. #[error("{0}")] IO(#[from] std::io::Error), @@ -51,7 +70,7 @@ impl Error { } /// Create a new `Error::KeyTooLarge` instance. - pub(crate) const fn key_too_large(size: u32, maximum_key_size: u32) -> Self { + pub(crate) const fn key_too_large(size: u64, maximum_key_size: u32) -> Self { Self::KeyTooLarge { size, maximum_key_size, @@ -59,7 +78,7 @@ impl Error { } /// Create a new `Error::ValueTooLarge` instance. - pub(crate) const fn value_too_large(size: u32, maximum_value_size: u32) -> Self { + pub(crate) const fn value_too_large(size: u64, maximum_value_size: u32) -> Self { Self::ValueTooLarge { size, maximum_value_size, @@ -87,13 +106,39 @@ impl Error { /// Create a new corrupted error. #[inline] - pub(crate) fn corrupted() -> Error { + pub(crate) fn corrupted(e: E) -> Error + where + E: Into>, + { + #[derive(Debug)] + struct Corrupted(Box); + + impl std::fmt::Display for Corrupted { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "corrupted write-ahead log: {}", self.0) + } + } + + impl std::error::Error for Corrupted {} + Self::IO(std::io::Error::new( std::io::ErrorKind::InvalidData, - "corrupted write-ahead log", + Corrupted(e.into()), )) } + /// Create a new batch size mismatch error. + #[inline] + pub(crate) const fn batch_size_mismatch(expected: u32, actual: u32) -> Self { + Self::Batch(BatchError::EncodedSizeMismatch { expected, actual }) + } + + /// Create a new larger batch size error. + #[inline] + pub(crate) const fn larger_batch_size(size: u32) -> Self { + Self::Batch(BatchError::LargerEncodedSize(size)) + } + /// Create a read-only error. pub(crate) const fn read_only() -> Self { Self::ReadOnly diff --git a/src/lib.rs b/src/lib.rs index 4a23386f..f3ecc463 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,24 +43,25 @@ const MAGIC_TEXT_SIZE: usize = MAGIC_TEXT.len(); const MAGIC_VERSION_SIZE: usize = mem::size_of::(); const HEADER_SIZE: usize = MAGIC_TEXT_SIZE + MAGIC_VERSION_SIZE; -#[cfg(all( - test, - any( - all_tests, - test_unsync_constructor, - test_unsync_insert, - test_unsync_get, - test_unsync_iters, - test_swmr_constructor, - test_swmr_insert, - test_swmr_get, - test_swmr_iters, - test_swmr_generic_constructor, - test_swmr_generic_insert, - test_swmr_generic_get, - test_swmr_generic_iters, - ) -))] +// #[cfg(all( +// test, +// any( +// all_tests, +// test_unsync_constructor, +// test_unsync_insert, +// test_unsync_get, +// test_unsync_iters, +// test_swmr_constructor, +// test_swmr_insert, +// test_swmr_get, +// test_swmr_iters, +// test_swmr_generic_constructor, +// test_swmr_generic_insert, +// test_swmr_generic_get, +// test_swmr_generic_iters, +// ) +// ))] +#[cfg(test)] #[macro_use] mod tests; @@ -76,7 +77,7 @@ use utils::*; mod wal; pub use wal::{ - Batch, BatchWithBuilders, BatchWithKeyBuilder, BatchWithValueBuilder, ImmutableWal, Wal, + Batch, BatchWithBuilders, BatchWithKeyBuilder, BatchWithValueBuilder, Builder, ImmutableWal, Wal, }; mod options; diff --git a/src/swmr/generic.rs b/src/swmr/generic.rs index 4f1a073d..c862f6d9 100644 --- a/src/swmr/generic.rs +++ b/src/swmr/generic.rs @@ -428,11 +428,11 @@ where let header = arena.get_u8(cursor).unwrap(); let flag = Flags::from_bits_unchecked(header); - let (kvsize, encoded_len) = arena.get_u64_varint(cursor + STATUS_SIZE).map_err(|_e| { + let (kvsize, encoded_len) = arena.get_u64_varint(cursor + STATUS_SIZE).map_err(|e| { #[cfg(feature = "tracing")] - tracing::error!(err=%_e); + tracing::error!(err=%e); - Error::corrupted() + Error::corrupted(e) })?; let (key_len, value_len) = split_lengths(encoded_len); let key_len = key_len as usize; @@ -452,7 +452,7 @@ where let cks = arena.get_u64_le(cursor + cks_offset).unwrap(); if cks != checksumer.checksum_one(arena.get_bytes(cursor, cks_offset)) { - return Err(Error::corrupted()); + return Err(Error::corrupted("checksum mismatch")); } // If the entry is not committed, we should not rewind diff --git a/src/tests.rs b/src/tests.rs index 8aa7e91a..37fc0d60 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -669,7 +669,7 @@ pub(crate) fn insert_with_key_builder>(wal: &mut W) { for i in 0..100u32 { wal .insert_with_key_builder::<()>( - KeyBuilder::<_>::new(4, |buf| { + KeyBuilder::<_>::once(4, |buf| { let _ = buf.put_u32_be(i); Ok(()) }), @@ -693,7 +693,7 @@ pub(crate) fn insert_with_value_builder>(wal: &mut W) { wal .insert_with_value_builder::<()>( &i.to_be_bytes(), - ValueBuilder::<_>::new(4, |buf| { + ValueBuilder::<_>::once(4, |buf| { let _ = buf.put_u32_be(i); Ok(()) }), @@ -715,11 +715,11 @@ pub(crate) fn insert_with_builders>(wal: &mut W) { for i in 0..100u32 { wal .insert_with_builders::<(), ()>( - KeyBuilder::<_>::new(4, |buf| { + KeyBuilder::<_>::once(4, |buf| { let _ = buf.put_u32_be(i); Ok(()) }), - ValueBuilder::<_>::new(4, |buf| { + ValueBuilder::<_>::once(4, |buf| { let _ = buf.put_u32_be(i); Ok(()) }), @@ -1023,7 +1023,7 @@ pub(crate) fn get_or_insert_with_value_builder>(wal: &mut wal .get_or_insert_with_value_builder::<()>( &i.to_be_bytes(), - ValueBuilder::<_>::new(4, |buf| { + ValueBuilder::<_>::once(4, |buf| { let _ = buf.put_u32_be(i); Ok(()) }), @@ -1035,7 +1035,7 @@ pub(crate) fn get_or_insert_with_value_builder>(wal: &mut wal .get_or_insert_with_value_builder::<()>( &i.to_be_bytes(), - ValueBuilder::<_>::new(4, |buf| { + ValueBuilder::<_>::once(4, |buf| { let _ = buf.put_u32_be(i * 2); Ok(()) }), @@ -1053,6 +1053,91 @@ pub(crate) fn get_or_insert_with_value_builder>(wal: &mut } } +pub(crate) fn insert_batch>(wal: &mut W) { + let mut batch = vec![]; + + for i in 0..100u32 { + batch.push(Entry::new(i.to_be_bytes(), i.to_be_bytes())); + } + + wal.insert_batch(&mut batch).unwrap(); + + for i in 0..100u32 { + assert_eq!(wal.get(&i.to_be_bytes()).unwrap(), i.to_be_bytes()); + } + + let wal = wal.reader(); + for i in 0..100u32 { + assert_eq!(wal.get(&i.to_be_bytes()).unwrap(), i.to_be_bytes()); + } +} + +pub(crate) fn insert_batch_with_key_builder>(wal: &mut W) { + let mut batch = vec![]; + + for i in 0..100u32 { + batch.push(EntryWithKeyBuilder::new( + KeyBuilder::new(4, move |buf| buf.put_u32_le(i)), + i.to_be_bytes(), + )); + } + + wal.insert_batch_with_key_builder(&mut batch).unwrap(); + + for i in 0..100u32 { + assert_eq!(wal.get(&i.to_be_bytes()).unwrap(), i.to_be_bytes()); + } + + let wal = wal.reader(); + for i in 0..100u32 { + assert_eq!(wal.get(&i.to_be_bytes()).unwrap(), i.to_be_bytes()); + } +} + +pub(crate) fn insert_batch_with_value_builder>(wal: &mut W) { + let mut batch = vec![]; + + for i in 0..100u32 { + batch.push(EntryWithValueBuilder::new( + i.to_be_bytes(), + ValueBuilder::new(4, move |buf| buf.put_u32_le(i)), + )); + } + + wal.insert_batch_with_value_builder(&mut batch).unwrap(); + + for i in 0..100u32 { + assert_eq!(wal.get(&i.to_be_bytes()).unwrap(), i.to_be_bytes()); + } + + let wal = wal.reader(); + for i in 0..100u32 { + assert_eq!(wal.get(&i.to_be_bytes()).unwrap(), i.to_be_bytes()); + } +} + +pub(crate) fn insert_batch_with_builders>(wal: &mut W) { + let mut batch = vec![]; + + for i in 0..100u32 { + batch.push(EntryWithBuilders::new( + KeyBuilder::new(4, move |buf| buf.put_u32_le(i)), + ValueBuilder::new(4, move |buf| buf.put_u32_le(i)), + )); + } + + wal.insert_batch_with_builders(&mut batch).unwrap(); + + for i in 0..100u32 { + assert_eq!(wal.get(&i.to_be_bytes()).unwrap(), i.to_be_bytes()); + } + + let wal = wal.reader(); + for i in 0..100u32 { + assert_eq!(wal.get(&i.to_be_bytes()).unwrap(), i.to_be_bytes()); + } +} + pub(crate) fn zero_reserved>(wal: &mut W) { unsafe { assert_eq!(wal.reserved_slice(), &[]); diff --git a/src/utils.rs b/src/utils.rs index 10f6f572..5258794d 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -2,11 +2,19 @@ pub use dbutils::leb128::*; use super::*; +/// Merge two `u32` into a `u64`. +/// +/// high 32 bits: `a` +/// low 32 bits: `b` #[inline] -pub(crate) const fn merge_lengths(klen: u32, vlen: u32) -> u64 { - (klen as u64) << 32 | vlen as u64 +pub(crate) const fn merge_lengths(a: u32, b: u32) -> u64 { + (a as u64) << 32 | b as u64 } +/// Split a `u64` into two `u32`. +/// +/// high 32 bits: the first `u32` +/// low 32 bits: the second `u32` #[inline] pub(crate) const fn split_lengths(len: u64) -> (u32, u32) { ((len >> 32) as u32, len as u32) @@ -60,11 +68,11 @@ pub(crate) const fn check( let max_vsize = min_u64(max_value_size as u64, u32::MAX as u64); if max_ksize < klen as u64 { - return Err(error::Error::key_too_large(klen as u32, max_key_size)); + return Err(error::Error::key_too_large(klen as u64, max_key_size)); } if max_vsize < vlen as u64 { - return Err(error::Error::value_too_large(vlen as u32, max_value_size)); + return Err(error::Error::value_too_large(vlen as u64, max_value_size)); } let (_, _, elen) = entry_size(klen as u32, vlen as u32); diff --git a/src/wal.rs b/src/wal.rs index 2e16b4d9..d403fbd9 100644 --- a/src/wal.rs +++ b/src/wal.rs @@ -573,7 +573,7 @@ pub trait Wal: sealed::Sealed + ImmutableWal { self .get_or_insert_with_value_builder::<()>( key, - ValueBuilder::new(value.len() as u32, |buf| { + ValueBuilder::once(value.len() as u32, |buf| { buf.put_slice(value).unwrap(); Ok(()) }), @@ -617,7 +617,7 @@ pub trait Wal: sealed::Sealed + ImmutableWal { self .insert_with_in::( kb, - ValueBuilder::new(value.len() as u32, |buf| { + ValueBuilder::once(value.len() as u32, |buf| { buf.put_slice(value).unwrap(); Ok(()) }), @@ -651,7 +651,7 @@ pub trait Wal: sealed::Sealed + ImmutableWal { self .insert_with_in::<(), E>( - KeyBuilder::new(key.len() as u32, |buf| { + KeyBuilder::once(key.len() as u32, |buf| { buf.put_slice(key).unwrap(); Ok(()) }), @@ -696,6 +696,10 @@ pub trait Wal: sealed::Sealed + ImmutableWal { C: Comparator + CheapClone, S: BuildChecksumer, { + if self.read_only() { + return Err(Either::Right(Error::read_only())); + } + self .insert_batch_with_key_builder_in(batch) .map(|_| self.insert_pointers(batch.iter_mut().map(|ent| ent.pointer.take().unwrap()))) @@ -710,6 +714,10 @@ pub trait Wal: sealed::Sealed + ImmutableWal { C: Comparator + CheapClone, S: BuildChecksumer, { + if self.read_only() { + return Err(Either::Right(Error::read_only())); + } + self .insert_batch_with_value_builder_in(batch) .map(|_| self.insert_pointers(batch.iter_mut().map(|ent| ent.pointer.take().unwrap()))) @@ -724,6 +732,10 @@ pub trait Wal: sealed::Sealed + ImmutableWal { C: Comparator + CheapClone, S: BuildChecksumer, { + if self.read_only() { + return Err(Among::Right(Error::read_only())); + } + self .insert_batch_with_builders_in(batch) .map(|_| self.insert_pointers(batch.iter_mut().map(|ent| ent.pointer.take().unwrap()))) @@ -735,6 +747,10 @@ pub trait Wal: sealed::Sealed + ImmutableWal { C: Comparator + CheapClone, S: BuildChecksumer, { + if self.read_only() { + return Err(Error::read_only()); + } + self .insert_batch_in(batch) .map(|_| self.insert_pointers(batch.iter_mut().map(|ent| ent.pointer.take().unwrap()))) @@ -756,11 +772,11 @@ pub trait Wal: sealed::Sealed + ImmutableWal { self .insert_with_in::<(), ()>( - KeyBuilder::new(key.len() as u32, |buf| { + KeyBuilder::once(key.len() as u32, |buf: &mut VacantBuffer<'_>| { buf.put_slice(key).unwrap(); Ok(()) }), - ValueBuilder::new(value.len() as u32, |buf| { + ValueBuilder::once(value.len() as u32, |buf: &mut VacantBuffer<'_>| { buf.put_slice(value).unwrap(); Ok(()) }), diff --git a/src/wal/sealed.rs b/src/wal/sealed.rs index e8478ad3..e90a9aa9 100644 --- a/src/wal/sealed.rs +++ b/src/wal/sealed.rs @@ -30,6 +30,23 @@ pub trait Sealed: Constructor { crate::check(klen, vlen, max_key_size, max_value_size, ro) } + #[inline] + fn check_batch_entry(&self, klen: usize, vlen: usize) -> Result<(), Error> { + let opts = self.options(); + let max_key_size = opts.maximum_key_size(); + let max_value_size = opts.maximum_value_size(); + + if klen > max_key_size as usize { + return Err(Error::key_too_large(klen as u64, max_key_size)); + } + + if vlen > max_value_size as usize { + return Err(Error::value_too_large(vlen as u64, max_value_size)); + } + + Ok(()) + } + fn hasher(&self) -> &S; fn options(&self) -> &Options; @@ -52,11 +69,32 @@ pub trait Sealed: Constructor { C: Comparator + CheapClone, S: BuildChecksumer, { - let batch_encoded_size = batch.iter_mut().fold(0u64, |acc, ent| { - acc + ent.kb.size() as u64 + ent.value.borrow().len() as u64 - }); - let total_size = STATUS_SIZE as u64 + batch_encoded_size + CHECKSUM_SIZE as u64; - if total_size > self.options().capacity() as u64 { + let mut batch_encoded_size = 0u64; + + let mut num_entries = 0u32; + for ent in batch.iter_mut() { + let klen = ent.kb.size() as usize; + let vlen = ent.value.borrow().len(); + self.check_batch_entry(klen, vlen).map_err(Either::Right)?; + let merged_len = merge_lengths(klen as u32, vlen as u32); + batch_encoded_size += klen as u64 + vlen as u64 + encoded_u64_varint_len(merged_len) as u64; + num_entries += 1; + } + + let cap = self.options().capacity() as u64; + if batch_encoded_size > cap { + return Err(Either::Right(Error::insufficient_space( + batch_encoded_size, + self.allocator().remaining() as u32, + ))); + } + + // safe to cast batch_encoded_size to u32 here, we already checked it's less than capacity (less than u32::MAX). + let batch_meta = merge_lengths(num_entries, batch_encoded_size as u32); + let batch_meta_size = encoded_u64_varint_len(batch_meta); + let total_size = + STATUS_SIZE as u64 + batch_meta_size as u64 + batch_encoded_size + CHECKSUM_SIZE as u64; + if total_size > cap { return Err(Either::Right(Error::insufficient_space( total_size, self.allocator().remaining() as u32, @@ -74,24 +112,41 @@ pub trait Sealed: Constructor { let mut cks = self.hasher().build_checksumer(); let flag = Flags::BATCHING; buf.put_u8_unchecked(flag.bits); + buf.put_u64_varint_unchecked(batch_meta); let cmp = self.comparator(); - let mut cursor = 1; + let mut cursor = 1 + batch_meta_size; for ent in batch.iter_mut() { - let ptr = buf.as_mut_ptr().add(cursor as usize); let klen = ent.kb.size() as usize; - buf.set_len(cursor as usize + klen); - let f = ent.kb.builder(); - f(&mut VacantBuffer::new(klen, NonNull::new_unchecked(ptr))).map_err(Either::Left)?; let value = ent.value.borrow(); let vlen = value.len(); - cursor += klen as u64; - cursor += vlen as u64; + let merged_kv_len = merge_lengths(klen as u32, vlen as u32); + let merged_kv_len_size = encoded_u64_varint_len(merged_kv_len); + let remaining = buf.remaining(); + if remaining < merged_kv_len_size + klen + vlen { + return Err(Either::Right(Error::larger_batch_size(total_size as u32))); + } + + let ent_len_size = buf.put_u64_varint_unchecked(merged_kv_len); + let ptr = buf.as_mut_ptr().add(cursor as usize + ent_len_size); + buf.set_len(cursor as usize + ent_len_size + klen); + let f = ent.kb.builder(); + f(&mut VacantBuffer::new(klen, NonNull::new_unchecked(ptr))).map_err(Either::Left)?; + + cursor += ent_len_size + klen; + cursor += vlen; buf.put_slice_unchecked(value); ent.pointer = Some(Pointer::new(klen, vlen, ptr, cmp.cheap_clone())); } + if cursor as u64 != total_size { + return Err(Either::Right(Error::batch_size_mismatch( + total_size as u32, + cursor as u32, + ))); + } + cks.update(&[committed_flag.bits]); cks.update(&buf[1..]); buf.put_u64_le_unchecked(cks.digest()); @@ -118,11 +173,32 @@ pub trait Sealed: Constructor { C: Comparator + CheapClone, S: BuildChecksumer, { - let batch_encoded_size = batch.iter_mut().fold(0u64, |acc, ent| { - acc + ent.key.borrow().len() as u64 + ent.vb.size() as u64 - }); - let total_size = STATUS_SIZE as u64 + batch_encoded_size + CHECKSUM_SIZE as u64; - if total_size > self.options().capacity() as u64 { + let mut batch_encoded_size = 0u64; + + let mut num_entries = 0u32; + for ent in batch.iter_mut() { + let klen = ent.key.borrow().len(); + let vlen = ent.vb.size() as usize; + self.check_batch_entry(klen, vlen).map_err(Either::Right)?; + let merged_len = merge_lengths(klen as u32, vlen as u32); + batch_encoded_size += klen as u64 + vlen as u64 + encoded_u64_varint_len(merged_len) as u64; + num_entries += 1; + } + + let cap = self.options().capacity() as u64; + if batch_encoded_size > cap { + return Err(Either::Right(Error::insufficient_space( + batch_encoded_size, + self.allocator().remaining() as u32, + ))); + } + + // safe to cast batch_encoded_size to u32 here, we already checked it's less than capacity (less than u32::MAX). + let batch_meta = merge_lengths(num_entries, batch_encoded_size as u32); + let batch_meta_size = encoded_u64_varint_len(batch_meta); + let total_size = + STATUS_SIZE as u64 + batch_meta_size as u64 + batch_encoded_size + CHECKSUM_SIZE as u64; + if total_size > cap { return Err(Either::Right(Error::insufficient_space( total_size, self.allocator().remaining() as u32, @@ -140,27 +216,41 @@ pub trait Sealed: Constructor { let mut cks = self.hasher().build_checksumer(); let flag = Flags::BATCHING; buf.put_u8_unchecked(flag.bits); + buf.put_u64_varint_unchecked(batch_meta); let cmp = self.comparator(); - let mut cursor = 1; + let mut cursor = 1 + batch_meta_size; for ent in batch.iter_mut() { - let ptr = buf.as_mut_ptr().add(cursor as usize); let key = ent.key.borrow(); let klen = key.len(); let vlen = ent.vb.size() as usize; - cursor += klen as u64; + let merged_kv_len = merge_lengths(klen as u32, vlen as u32); + let merged_kv_len_size = encoded_u64_varint_len(merged_kv_len); + let remaining = buf.remaining(); + if remaining < merged_kv_len_size + klen + vlen { + return Err(Either::Right(Error::larger_batch_size(total_size as u32))); + } + + let ent_len_size = buf.put_u64_varint_unchecked(merged_kv_len); + let ptr = buf.as_mut_ptr().add(cursor as usize + ent_len_size); + cursor += klen + ent_len_size; buf.put_slice_unchecked(key); - buf.set_len(cursor as usize + vlen); + buf.set_len(cursor + vlen); let f = ent.vb.builder(); - f(&mut VacantBuffer::new( - klen, - NonNull::new_unchecked(ptr.add(klen)), - )) - .map_err(Either::Left)?; - cursor += vlen as u64; + let mut vacant_buffer = VacantBuffer::new(klen, NonNull::new_unchecked(ptr.add(klen))); + f(&mut vacant_buffer).map_err(Either::Left)?; + + cursor += vlen; ent.pointer = Some(Pointer::new(klen, vlen, ptr, cmp.cheap_clone())); } + if cursor as u64 != total_size { + return Err(Either::Right(Error::batch_size_mismatch( + total_size as u32, + cursor as u32, + ))); + } + cks.update(&[committed_flag.bits]); cks.update(&buf[1..]); buf.put_u64_le_unchecked(cks.digest()); @@ -187,9 +277,15 @@ pub trait Sealed: Constructor { C: Comparator + CheapClone, S: BuildChecksumer, { - let batch_encoded_size = batch.iter_mut().fold(0u64, |acc, ent| { - acc + ent.kb.size() as u64 + ent.vb.size() as u64 - }); + let mut batch_encoded_size = 0u64; + + for ent in batch.iter_mut() { + let klen = ent.kb.size() as usize; + let vlen = ent.vb.size() as usize; + self.check_batch_entry(klen, vlen).map_err(Either::Right)?; + batch_encoded_size += klen as u64 + vlen as u64; + } + let total_size = STATUS_SIZE as u64 + batch_encoded_size + CHECKSUM_SIZE as u64; if total_size > self.options().capacity() as u64 { return Err(Among::Right(Error::insufficient_space( @@ -446,50 +542,117 @@ pub trait Constructor: Sized { let header = arena.get_u8(cursor).unwrap(); let flag = Flags::from_bits_unchecked(header); - let (kvsize, encoded_len) = arena.get_u64_varint(cursor + STATUS_SIZE).map_err(|_e| { - #[cfg(feature = "tracing")] - tracing::error!(err=%_e); + if !flag.contains(Flags::BATCHING) { + let (readed, encoded_len) = arena.get_u64_varint(cursor + STATUS_SIZE).map_err(|e| { + #[cfg(feature = "tracing")] + tracing::error!(err=%e); + + Error::corrupted(e) + })?; + let (key_len, value_len) = split_lengths(encoded_len); + let key_len = key_len as usize; + let value_len = value_len as usize; + // Same as above, if we reached the end of the arena, we should discard the remaining. + let cks_offset = STATUS_SIZE + readed + key_len + value_len; + if cks_offset + CHECKSUM_SIZE > allocated { + // If the entry is committed, then it means our file is truncated, so we should report corrupted. + if flag.contains(Flags::COMMITTED) { + return Err(Error::corrupted("file is truncated")); + } + + if !ro { + arena.rewind(ArenaPosition::Start(cursor as u32)); + arena.flush()?; + } + + break; + } - Error::corrupted() - })?; + let cks = arena.get_u64_le(cursor + cks_offset).unwrap(); - let (key_len, value_len) = split_lengths(encoded_len); - let key_len = key_len as usize; - let value_len = value_len as usize; - // Same as above, if we reached the end of the arena, we should discard the remaining. - let cks_offset = STATUS_SIZE + kvsize + key_len + value_len; - if cks_offset + CHECKSUM_SIZE > allocated { - if !ro { - arena.rewind(ArenaPosition::Start(cursor as u32)); - arena.flush()?; + if cks != checksumer.checksum_one(arena.get_bytes(cursor, cks_offset)) { + return Err(Error::corrupted("checksum mismatch")); } - break; - } + // If the entry is not committed, we should not rewind + if !flag.contains(Flags::COMMITTED) { + if !ro { + arena.rewind(ArenaPosition::Start(cursor as u32)); + arena.flush()?; + } - let cks = arena.get_u64_le(cursor + cks_offset).unwrap(); + break; + } - if cks != checksumer.checksum_one(arena.get_bytes(cursor, cks_offset)) { - return Err(Error::corrupted()); - } + set.insert(Pointer::new( + key_len, + value_len, + arena.get_pointer(cursor + STATUS_SIZE + readed), + cmp.cheap_clone(), + )); + cursor += cks_offset + CHECKSUM_SIZE; + } else { + let (readed, encoded_len) = arena.get_u64_varint(cursor + STATUS_SIZE).map_err(|e| { + #[cfg(feature = "tracing")] + tracing::error!(err=%e); + + Error::corrupted(e) + })?; + + let (num_entries, encoded_data_len) = split_lengths(encoded_len); + + // Same as above, if we reached the end of the arena, we should discard the remaining. + let cks_offset = STATUS_SIZE + readed + encoded_data_len as usize; + if cks_offset + CHECKSUM_SIZE > allocated { + // If the entry is committed, then it means our file is truncated, so we should report corrupted. + if flag.contains(Flags::COMMITTED) { + return Err(Error::corrupted("file is truncated")); + } + + if !ro { + arena.rewind(ArenaPosition::Start(cursor as u32)); + arena.flush()?; + } + + break; + } - // If the entry is not committed, we should not rewind - if !flag.contains(Flags::COMMITTED) { - if !ro { - arena.rewind(ArenaPosition::Start(cursor as u32)); - arena.flush()?; + let cks = arena.get_u64_le(cursor + cks_offset).unwrap(); + let mut batch_data_buf = arena.get_bytes(cursor, cks_offset); + if cks != checksumer.checksum_one(batch_data_buf) { + return Err(Error::corrupted("checksum mismatch")); } - break; - } + let mut sub_cursor = 0; + + for _ in 0..num_entries { + let (kvlen, ent_len) = + dbutils::leb128::decode_u64_varint(batch_data_buf).map_err(|e| { + #[cfg(feature = "tracing")] + tracing::error!(err=%e); + + Error::corrupted(e) + })?; + + let (klen, vlen) = split_lengths(ent_len); + let klen = klen as usize; + let vlen = vlen as usize; + + set.insert(Pointer::new( + klen, + vlen, + arena.get_pointer(cursor + STATUS_SIZE + readed + sub_cursor + kvlen), + cmp.cheap_clone(), + )); + sub_cursor += kvlen + klen + vlen; + batch_data_buf = &batch_data_buf[sub_cursor..]; + } - set.insert(Pointer::new( - key_len, - value_len, - arena.get_pointer(cursor + STATUS_SIZE + kvsize), - cmp.cheap_clone(), - )); - cursor += cks_offset + CHECKSUM_SIZE; + debug_assert_eq!( + sub_cursor, encoded_data_len as usize, + "expected encoded batch data size is not equal to the actual size" + ); + } } }