Skip to content

Commit

Permalink
Stream chunks without olding them in memory when downloading via the …
Browse files Browse the repository at this point in the history
…bridge
  • Loading branch information
co42 committed Dec 20, 2024
1 parent 286e663 commit 973c95a
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 23 deletions.
5 changes: 2 additions & 3 deletions cas_object/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,5 @@ bytes = "1.7.2"
rand = "0.8.5"
blake3 = "1.5.4"
futures = { version = "0.3.31" }
tokio-util = {version = "0.7.12", features = ["io"]}
tokio = {version = "1.41.1" }

tokio-util = { version = "0.7.12", features = ["io", "io-util"] }
tokio = { version = "1.41.1" }
18 changes: 7 additions & 11 deletions cas_object/src/cas_chunk_format.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::io::{copy, Cursor, Read, Write};
use std::io::{copy, Read, Write};
use std::mem::size_of;

use anyhow::anyhow;
Expand Down Expand Up @@ -145,10 +145,9 @@ pub fn deserialize_chunk_to_writer<R: Read, W: Write>(
writer: &mut W,
) -> Result<(usize, u32), CasObjectError> {
let header = deserialize_chunk_header(reader)?;
let mut compressed_buf = vec![0u8; header.get_compressed_length() as usize];
reader.read_exact(&mut compressed_buf)?;
let mut reader = reader.take(header.get_compressed_length() as u64);

let uncompressed_len = decompress_chunk_to_writer(header, &mut compressed_buf, writer)?;
let uncompressed_len = decompress_chunk_to_writer(header, &mut reader, writer)?;

if uncompressed_len != header.get_uncompressed_length() {
return Err(CasObjectError::FormatError(anyhow!(
Expand All @@ -159,18 +158,15 @@ pub fn deserialize_chunk_to_writer<R: Read, W: Write>(
Ok((header.get_compressed_length() as usize + CAS_CHUNK_HEADER_LENGTH, uncompressed_len))
}

pub(crate) fn decompress_chunk_to_writer<W: Write>(
pub(crate) fn decompress_chunk_to_writer<R: Read, W: Write>(
header: CASChunkHeader,
compressed_buf: &mut Vec<u8>,
reader: &mut R,
writer: &mut W,
) -> Result<u32, CasObjectError> {
let result = match header.get_compression_scheme()? {
CompressionScheme::None => {
writer.write_all(compressed_buf)?;
compressed_buf.len() as u32
},
CompressionScheme::None => copy(reader, writer)? as u32,
CompressionScheme::LZ4 => {
let mut dec = FrameDecoder::new(Cursor::new(compressed_buf));
let mut dec = FrameDecoder::new(reader);
copy(&mut dec, writer)? as u32
},
};
Expand Down
74 changes: 66 additions & 8 deletions cas_object/src/cas_chunk_format/deserialize_async.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::io::Write;
use std::io::{Cursor, Read, Write};
use std::mem::size_of;
use std::slice;

use anyhow::anyhow;
use bytes::Buf;
use futures::Stream;
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio_util::io::StreamReader;
use tokio::task::{spawn_blocking, JoinHandle};
use tokio_util::io::{StreamReader, SyncIoBridge};

use crate::error::CasObjectError;
use crate::{CASChunkHeader, CAS_CHUNK_HEADER_LENGTH};
Expand All @@ -30,7 +31,7 @@ pub async fn deserialize_chunk_to_writer<R: AsyncRead + Unpin, W: Write>(
let mut compressed_buf = vec![0u8; header.get_compressed_length() as usize];
reader.read_exact(&mut compressed_buf).await?;

let uncompressed_len = super::decompress_chunk_to_writer(header, &mut compressed_buf, writer)?;
let uncompressed_len = super::decompress_chunk_to_writer(header, &mut Cursor::new(compressed_buf), writer)?;

if uncompressed_len != header.get_uncompressed_length() {
return Err(CasObjectError::FormatError(anyhow!(
Expand Down Expand Up @@ -73,12 +74,69 @@ pub async fn deserialize_chunks_to_writer_from_async_read<R: AsyncRead + Unpin,
Ok((num_compressed_written, chunk_byte_indices))
}

pub async fn deserialize_chunks_from_async_read<R: AsyncRead + Unpin>(
pub fn deserialize_chunk_header_stream<R: Read>(reader: &mut R) -> Result<CASChunkHeader, CasObjectError> {
let mut result = CASChunkHeader::default();
unsafe {
let buf = slice::from_raw_parts_mut(&mut result as *mut _ as *mut u8, size_of::<CASChunkHeader>());
reader.read_exact(buf)?;
}
result.validate()?;
Ok(result)
}

/// Returns the compressed chunk size along with the uncompressed chunk size as a tuple, (compressed, uncompressed)
pub fn deserialize_chunk_to_writer_stream<R: Read, W: Write>(
reader: &mut R,
) -> Result<(Vec<u8>, Vec<u32>), CasObjectError> {
let mut buf = Vec::new();
let (_, chunk_byte_indices) = deserialize_chunks_to_writer_from_async_read(reader, &mut buf).await?;
Ok((buf, chunk_byte_indices))
writer: &mut W,
) -> Result<(usize, u32), CasObjectError> {
let header = deserialize_chunk_header_stream(reader)?;
let mut compressed_buf = vec![0u8; header.get_compressed_length() as usize];
reader.read_exact(&mut compressed_buf)?;

let mut uncompressed_buf = Vec::new();
let uncompressed_len = super::decompress_chunk_to_writer(
header,
&mut Cursor::new(compressed_buf),
&mut Cursor::new(&mut uncompressed_buf),
)?;
writer.write_all(&uncompressed_buf)?;

if uncompressed_len != header.get_uncompressed_length() {
return Err(CasObjectError::FormatError(anyhow!(
"chunk is corrupted, uncompressed bytes len doesn't agree with chunk header"
)));
}

Ok((header.get_compressed_length() as usize + CAS_CHUNK_HEADER_LENGTH, uncompressed_len))
}

pub fn deserialize_chunks_to_writer_from_async_read_stream<
R: AsyncRead + Unpin + Send + 'static,
W: Write + Send + 'static,
>(
reader: R,
mut writer: W,
) -> JoinHandle<Result<(), CasObjectError>> {
let mut reader = SyncIoBridge::new(reader);
// The deserialization of the chunks is done synchronously and thus needs to happen on a blocking thread
// Moreover we expect to return from this function right away to be able to read the other end of the stream
spawn_blocking(move || -> Result<(), CasObjectError> {
loop {
match deserialize_chunk_to_writer_stream(&mut reader, &mut writer) {
Ok(_) => {},
Err(CasObjectError::InternalIOError(e)) => {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
break;
}
return Err(CasObjectError::InternalIOError(e));
},
Err(e) => {
return Err(e);
},
}
}
Ok(())
})
}

pub async fn deserialize_chunks_to_writer_from_stream<B, E, S, W>(
Expand Down
3 changes: 2 additions & 1 deletion cas_object/src/validate_xorb_stream.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::io::Cursor;
use std::mem::size_of;

use anyhow::anyhow;
Expand Down Expand Up @@ -64,7 +65,7 @@ async fn _validate_cas_object_from_async_read<R: AsyncRead + Unpin>(

let chunk_uncompressed_expected_len = chunk_header.get_uncompressed_length() as usize;
let mut uncompressed_chunk_data = Vec::with_capacity(chunk_uncompressed_expected_len);
decompress_chunk_to_writer(chunk_header, &mut compressed_chunk_data, &mut uncompressed_chunk_data)
decompress_chunk_to_writer(chunk_header, &mut Cursor::new(compressed_chunk_data), &mut uncompressed_chunk_data)
.log_error(format!("failed to decompress chunk at index {}, xorb {hash}", hash_chunks.len()))?;

if chunk_uncompressed_expected_len != uncompressed_chunk_data.len() {
Expand Down

0 comments on commit 973c95a

Please sign in to comment.