Skip to content

Commit

Permalink
async
Browse files Browse the repository at this point in the history
  • Loading branch information
joshieDo committed Jan 14, 2025
1 parent e7cdd92 commit 8d67276
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 88 deletions.
1 change: 1 addition & 0 deletions crates/stages/stages/src/stages/s3/downloader/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use alloy_primitives::B256;
use reth_fs_util::FsPathError;

/// Possible downloader error variants.
#[derive(Debug, thiserror::Error)]
pub enum DownloaderError {
/// Requires a valid `total_size` {0}
Expand Down
49 changes: 20 additions & 29 deletions crates/stages/stages/src/stages/s3/downloader/fetch.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
use crate::stages::s3::downloader::worker::spawn_workers;

use super::{
error::DownloaderError,
meta::Metadata,
worker::{worker_fetch, WorkerRequest, WorkerResponse},
};
use alloy_primitives::B256;
use reqwest::{blocking::Client, header::CONTENT_LENGTH};
use reqwest::{header::CONTENT_LENGTH, Client};
use std::{
collections::HashMap,
fs::{File, OpenOptions},
io::BufReader,
path::Path,
sync::mpsc::channel,
};
use tracing::{debug, error};
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};
use tracing::{debug, error, info};

/// Downloads file from url to data file path.
///
Expand All @@ -32,7 +34,7 @@ use tracing::{debug, error};
/// * If `file_hash` is `Some`, verifies its blake3 hash.
/// * Deletes the metadata file
/// * Moves downloaded file to target directory.
pub fn fetch(
pub async fn fetch(
filename: &str,
target_dir: &Path,
url: &str,
Expand All @@ -44,7 +46,7 @@ pub fn fetch(
reth_fs_util::create_dir_all(&download_dir)?;

let data_file = download_dir.join(filename);
let mut metadata = metadata(&data_file, url)?;
let mut metadata = metadata(&data_file, url).await?;
if metadata.is_done() {
return Ok(())
}
Expand All @@ -59,45 +61,32 @@ pub fn fetch(
.open(&data_file)?;

if file.metadata()?.len() as usize != metadata.total_size {
debug!(target: "sync::stages::s3::downloader", ?data_file, length = metadata.total_size, "Preallocating space.");
info!(target: "sync::stages::s3::downloader", ?filename, length = metadata.total_size, "Preallocating space.");
file.set_len(metadata.total_size as u64)?;
}
}

while !metadata.is_done() {
info!(target: "sync::stages::s3::downloader", ?filename, "Downloading.");

// Find the missing file chunks and the minimum number of workers required
let missing_chunks = metadata.needed_ranges();
concurrent = concurrent
.min(std::thread::available_parallelism()?.get() as u64)
.min(missing_chunks.len() as u64);

// Create channels for communication between workers and orchestrator
let (orchestrator_tx, orchestrator_rx) = channel();

// Initiate workers
for worker_id in 0..concurrent {
let orchestrator_tx = orchestrator_tx.clone();
let data_file = data_file.clone();
let url = url.to_string();
std::thread::spawn(move || {
if let Err(error) = worker_fetch(worker_id, &orchestrator_tx, data_file, url) {
let _ = orchestrator_tx.send(WorkerResponse::Err { worker_id, error });
}
});
}

// Drop the sender to allow the loop processing to exit once all workers are done
drop(orchestrator_tx);
let mut orchestrator_rx = spawn_workers(url, concurrent, &data_file);

let mut workers = HashMap::new();
let mut missing_chunks = missing_chunks.into_iter();

// Distribute chunk ranges to workers when they free up
while let Ok(worker_msg) = orchestrator_rx.recv() {
while let Some(worker_msg) = orchestrator_rx.recv().await {
debug!(target: "sync::stages::s3::downloader", ?worker_msg, "received message from worker");

let available_worker = match worker_msg {
WorkerResponse::Ready { worker_id, tx } => {
debug!(target: "sync::stages::s3::downloader", ?worker_id, "Worker ready.");
workers.insert(worker_id, tx);
worker_id
}
Expand All @@ -114,6 +103,7 @@ pub fn fetch(
let worker = workers.get(&available_worker).expect("should exist");
match missing_chunks.next() {
Some((chunk_index, (start, end))) => {
debug!(target: "sync::stages::s3::downloader", ?available_worker, start, end, "Worker download request.");
let _ = worker.send(WorkerRequest::Download { chunk_index, start, end });
}
None => {
Expand All @@ -124,6 +114,7 @@ pub fn fetch(
}

if let Some(file_hash) = file_hash {
info!(target: "sync::stages::s3::downloader", ?filename, "Checking file integrity.");
check_file_hash(&data_file, &file_hash)?;
}

Expand All @@ -136,14 +127,14 @@ pub fn fetch(

/// Creates a metadata file used to keep track of the downloaded chunks. Useful on resuming after a
/// shutdown.
fn metadata(data_file: &Path, url: &str) -> Result<Metadata, DownloaderError> {
async fn metadata(data_file: &Path, url: &str) -> Result<Metadata, DownloaderError> {
if Metadata::file_path(data_file).exists() {
debug!(target: "sync::stages::s3::downloader", ?data_file, "Loading metadata ");
return Metadata::load(data_file)
}

let client = Client::new();
let resp = client.head(url).send()?;
let resp = client.head(url).send().await?;
let total_length: usize = resp
.headers()
.get(CONTENT_LENGTH)
Expand Down Expand Up @@ -175,8 +166,8 @@ mod tests {
use super::*;
use alloy_primitives::b256;

#[test]
fn test_download() {
#[tokio::test]
async fn test_download() {
reth_tracing::init_test_tracing();

let b3sum = b256!("81a7318f69fc1d6bb0a58a24af302f3b978bc75a435e4ae5d075f999cd060cfd");
Expand All @@ -185,6 +176,6 @@ mod tests {
let file = tempfile::NamedTempFile::new().unwrap();
let filename = file.path().file_name().unwrap().to_str().unwrap();
let target_dir = file.path().parent().unwrap();
fetch(filename, target_dir, url, 4, Some(b3sum)).unwrap();
fetch(filename, target_dir, url, 4, Some(b3sum)).await.unwrap();
}
}
4 changes: 3 additions & 1 deletion crates/stages/stages/src/stages/s3/downloader/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ pub struct Metadata {
pub downloaded: usize,
/// Download chunk size. Default 150MB.
pub chunk_size: usize,
/// Each chunk remaining download range.
/// Remaining download ranges for each chunk.
/// - `Some((start, end))`: Start and end indices to be downloaded.
/// - `None`: Chunk fully downloaded.
chunks: Vec<Option<(usize, usize)>>,
/// Path with the stored metadata.
#[serde(skip)]
Expand Down
10 changes: 10 additions & 0 deletions crates/stages/stages/src/stages/s3/downloader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,15 @@ mod fetch;
mod meta;
mod worker;

pub(crate) use error::DownloaderError;
pub use fetch::fetch;
pub use meta::Metadata;

/// Response sent by the fetch task to [`S3Stage`] once it has downloaded all files of a block
/// range.
pub(crate) struct S3DownloaderResponse {
/// Downloaded range.
pub range: std::ops::RangeInclusive<u64>,
/// Whether the fetch task has downloaded the last requested block range.
pub is_done: bool,
}
79 changes: 60 additions & 19 deletions crates/stages/stages/src/stages/s3/downloader/worker.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
use reqwest::{blocking::Client, header::RANGE};
use std::{
use super::error::DownloaderError;
use reqwest::{header::RANGE, Client};
use std::path::{Path, PathBuf};
use tokio::{
fs::OpenOptions,
io::{BufWriter, Read, Seek, SeekFrom},
path::PathBuf,
sync::mpsc::{channel, Sender},
io::{AsyncSeekExt, AsyncWriteExt, BufWriter},
sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
};
use tracing::debug;

use super::error::DownloaderError;

/// Responses sent by a worker.
#[derive(Debug)]
pub(crate) enum WorkerResponse {
/// Worker has been spawned and awaiting work.
Ready { worker_id: u64, tx: Sender<WorkerRequest> },
Ready { worker_id: u64, tx: UnboundedSender<WorkerRequest> },
/// Worker has downloaded
DownloadedChunk { worker_id: u64, chunk_index: usize, written_bytes: usize },
/// Worker has encountered an error.
Expand All @@ -30,30 +29,45 @@ pub(crate) enum WorkerRequest {
}

/// Downloads requested chunk ranges to the data file.
pub(crate) fn worker_fetch(
pub(crate) async fn worker_fetch(
worker_id: u64,
orchestrator_tx: &Sender<WorkerResponse>,
orchestrator_tx: &UnboundedSender<WorkerResponse>,
data_file: PathBuf,
url: String,
) -> Result<(), DownloaderError> {
let client = Client::new();
let mut data_file = BufWriter::new(OpenOptions::new().write(true).open(data_file)?);
let mut data_file = BufWriter::new(OpenOptions::new().write(true).open(data_file).await?);

// Signals readiness to download
let (tx, rx) = channel::<WorkerRequest>();
let _ = orchestrator_tx.send(WorkerResponse::Ready { worker_id, tx });
let (tx, mut rx) = unbounded_channel::<WorkerRequest>();
orchestrator_tx.send(WorkerResponse::Ready { worker_id, tx }).unwrap_or_else(|_| {
debug!("Failed to notify orchestrator of readiness");
});

while let Ok(req) = rx.recv() {
debug!(target: "sync::stages::s3::downloader", worker_id, ?req, "received from orchestrator");
while let Some(req) = rx.recv().await {
debug!(
target: "sync::stages::s3::downloader",
worker_id,
?req,
"received from orchestrator"
);

match req {
WorkerRequest::Download { chunk_index, start, end } => {
data_file.seek(SeekFrom::Start(start as u64))?;
data_file.seek(tokio::io::SeekFrom::Start(start as u64)).await?;

let mut response =
client.get(&url).header(RANGE, format!("bytes={}-{}", start, end)).send()?;
let mut response = client
.get(&url)
.header(RANGE, format!("bytes={}-{}", start, end))
.send()
.await?;

let written_bytes = std::io::copy(response.by_ref(), &mut data_file)? as usize;
let mut written_bytes = 0;
while let Some(chunk) = response.chunk().await? {
written_bytes += chunk.len();
data_file.write_all(&chunk).await?;
}
data_file.flush().await?;

let _ = orchestrator_tx.send(WorkerResponse::DownloadedChunk {
worker_id,
Expand All @@ -67,3 +81,30 @@ pub(crate) fn worker_fetch(

Ok(())
}

/// Spawns the requested number of workers and returns a [UnboundedReceiver] that all of them will
/// respond to.
pub(crate) fn spawn_workers(
url: &str,
number: u64,
data_file: &Path,
) -> UnboundedReceiver<WorkerResponse> {
// Create channels for communication between workers and orchestrator
let (orchestrator_tx, orchestrator_rx) = unbounded_channel();

// Initiate workers
for worker_id in 0..number {
let orchestrator_tx = orchestrator_tx.clone();
let data_file = data_file.to_path_buf();
let url = url.to_string();
debug!(target: "sync::stages::s3::downloader", ?worker_id, "Spawning.");

tokio::spawn(async move {
if let Err(error) = worker_fetch(worker_id, &orchestrator_tx, data_file, url).await {
let _ = orchestrator_tx.send(WorkerResponse::Err { worker_id, error });
}
});
}

orchestrator_rx
}
Loading

0 comments on commit 8d67276

Please sign in to comment.