diff --git a/Cargo.lock b/Cargo.lock index ca0d002e..4683c3c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -514,6 +514,7 @@ dependencies = [ "tracing-test", "url", "utils", + "xet_threadpool", ] [[package]] @@ -926,6 +927,7 @@ dependencies = [ "toml", "tracing", "utils", + "xet_threadpool", ] [[package]] @@ -3765,6 +3767,7 @@ dependencies = [ "thiserror 2.0.9", "tokio", "tracing", + "xet_threadpool", ] [[package]] @@ -4181,6 +4184,16 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "xet_threadpool" +version = "0.1.0" +dependencies = [ + "lazy_static", + "thiserror 2.0.9", + "tokio", + "tracing", +] + [[package]] name = "zerocopy" version = "0.7.35" diff --git a/Cargo.toml b/Cargo.toml index 311f1562..4f256786 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ members = [ "utils", "cas_object", "cas_types", - "chunk_cache", + "chunk_cache", "xet_threadpool", ] exclude = ["hf_xet", "chunk_cache_bench"] diff --git a/cas_client/Cargo.toml b/cas_client/Cargo.toml index 4fcd8c13..b56b83b6 100644 --- a/cas_client/Cargo.toml +++ b/cas_client/Cargo.toml @@ -18,6 +18,7 @@ utils = { path = "../utils" } merkledb = { path = "../merkledb" } mdb_shard = { path = "../mdb_shard" } merklehash = { path = "../merklehash" } +xet_threadpool = { path = "../xet_threadpool" } thiserror = "2.0" tokio = { version = "1.41", features = ["full"] } async-trait = "0.1.9" diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index d362c41a..c6a33ef2 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -21,7 +21,7 @@ use tracing::{debug, error, trace}; use utils::auth::AuthConfig; use utils::progress::ProgressUpdater; use utils::singleflight::Group; -use utils::ThreadPool; +use xet_threadpool::ThreadPool; use crate::error::Result; use crate::interface::*; diff --git a/data/Cargo.toml b/data/Cargo.toml index 31786cc0..bc9abb43 100644 --- a/data/Cargo.toml +++ b/data/Cargo.toml @@ -21,6 +21,7 @@ utils = { path = "../utils" } parutils = { path = "../parutils" } file_utils = { path = "../file_utils" } error_printer = { path = "../error_printer" } +xet_threadpool = { path = "../xet_threadpool" } thiserror = "2.0" tokio = { version = "1.36", features = ["full"] } anyhow = "1" diff --git a/data/src/bin/example.rs b/data/src/bin/example.rs index 90ef47dc..7c463e04 100644 --- a/data/src/bin/example.rs +++ b/data/src/bin/example.rs @@ -10,7 +10,7 @@ use cas_client::CacheConfig; use clap::{Args, Parser, Subcommand}; use data::configurations::*; use data::{PointerFile, PointerFileTranslator, SMALL_FILE_THRESHOLD}; -use utils::ThreadPool; +use xet_threadpool::ThreadPool; #[derive(Parser)] struct XCommand { diff --git a/data/src/cas_interface.rs b/data/src/cas_interface.rs index fb279a81..73c165e1 100644 --- a/data/src/cas_interface.rs +++ b/data/src/cas_interface.rs @@ -7,7 +7,7 @@ use cas_client::{CacheConfig, RemoteClient}; use mdb_shard::ShardFileManager; use tracing::info; use utils::auth::AuthConfig; -use utils::ThreadPool; +use xet_threadpool::ThreadPool; use crate::configurations::*; use crate::errors::Result; diff --git a/data/src/chunking.rs b/data/src/chunking.rs index 6f36f338..5dd853fb 100644 --- a/data/src/chunking.rs +++ b/data/src/chunking.rs @@ -7,7 +7,7 @@ use merklehash::compute_data_hash; use tokio::sync::mpsc::{Receiver, Sender}; use tokio::sync::Mutex; use tokio::task::JoinHandle; -use utils::ThreadPool; +use xet_threadpool::ThreadPool; use super::clean::BufferItem; use crate::errors::{DataProcessingError, Result}; diff --git a/data/src/clean.rs b/data/src/clean.rs index ce39f6ac..01a91d7a 100644 --- a/data/src/clean.rs +++ b/data/src/clean.rs @@ -23,7 +23,7 @@ use tokio::sync::Mutex; use tokio::task::{JoinHandle, JoinSet}; use tracing::{debug, info, warn}; use utils::progress::ProgressUpdater; -use utils::ThreadPool; +use xet_threadpool::ThreadPool; use crate::chunking::{chunk_target_default, ChunkYieldType}; use crate::constants::MIN_SPACING_BETWEEN_GLOBAL_DEDUP_QUERIES; diff --git a/data/src/data_client.rs b/data/src/data_client.rs index e8e8b0fb..2bc42a2f 100644 --- a/data/src/data_client.rs +++ b/data/src/data_client.rs @@ -14,7 +14,7 @@ use parutils::{tokio_par_for_each, ParallelError}; use tempfile::{tempdir_in, TempDir}; use utils::auth::{AuthConfig, TokenRefresher}; use utils::progress::ProgressUpdater; -use utils::ThreadPool; +use xet_threadpool::ThreadPool; use crate::configurations::*; use crate::errors::DataProcessingError; diff --git a/data/src/data_processing.rs b/data/src/data_processing.rs index 4676f900..a2357d1b 100644 --- a/data/src/data_processing.rs +++ b/data/src/data_processing.rs @@ -13,7 +13,7 @@ use mdb_shard::ShardFileManager; use merklehash::MerkleHash; use tokio::sync::{Mutex, Semaphore}; use utils::progress::ProgressUpdater; -use utils::ThreadPool; +use xet_threadpool::ThreadPool; use crate::cas_interface::create_cas_client; use crate::clean::Cleaner; diff --git a/data/src/parallel_xorb_uploader.rs b/data/src/parallel_xorb_uploader.rs index 02ab050a..bf86860c 100644 --- a/data/src/parallel_xorb_uploader.rs +++ b/data/src/parallel_xorb_uploader.rs @@ -9,7 +9,7 @@ use merklehash::MerkleHash; use tokio::sync::{Mutex, Semaphore}; use tokio::task::JoinSet; use utils::progress::ProgressUpdater; -use utils::ThreadPool; +use xet_threadpool::ThreadPool; use crate::data_processing::CASDataAggregator; use crate::errors::DataProcessingError::*; diff --git a/data/src/remote_shard_interface.rs b/data/src/remote_shard_interface.rs index b25661fc..2fc6909b 100644 --- a/data/src/remote_shard_interface.rs +++ b/data/src/remote_shard_interface.rs @@ -15,7 +15,7 @@ use merklehash::MerkleHash; use parutils::tokio_par_for_each; use tokio::task::JoinHandle; use tracing::{debug, info}; -use utils::ThreadPool; +use xet_threadpool::ThreadPool; use super::configurations::{FileQueryPolicy, StorageConfig}; use super::errors::{DataProcessingError, Result}; diff --git a/hf_xet/Cargo.lock b/hf_xet/Cargo.lock index 2d905f20..488dab21 100644 --- a/hf_xet/Cargo.lock +++ b/hf_xet/Cargo.lock @@ -267,6 +267,7 @@ dependencies = [ "tracing", "url", "utils", + "xet_threadpool", ] [[package]] @@ -541,6 +542,7 @@ dependencies = [ "toml", "tracing", "utils", + "xet_threadpool", ] [[package]] @@ -1046,6 +1048,7 @@ dependencies = [ "tracing", "tracing-subscriber", "utils", + "xet_threadpool", ] [[package]] @@ -3423,6 +3426,7 @@ dependencies = [ "thiserror 2.0.11", "tokio", "tracing", + "xet_threadpool", ] [[package]] @@ -3845,6 +3849,16 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +[[package]] +name = "xet_threadpool" +version = "0.1.0" +dependencies = [ + "lazy_static", + "thiserror 2.0.11", + "tokio", + "tracing", +] + [[package]] name = "yoke" version = "0.7.4" diff --git a/hf_xet/Cargo.toml b/hf_xet/Cargo.toml index 5b06e927..d703e663 100644 --- a/hf_xet/Cargo.toml +++ b/hf_xet/Cargo.toml @@ -19,6 +19,7 @@ pyo3 = { version = "0.23.3", features = [ error_printer = { path = "../error_printer" } data = { path = "../data" } utils = { path = "../utils" } +xet_threadpool = { path = "../xet_threadpool" } tokio = { version = "1.36", features = ["full"] } parutils = { path = "../parutils" } tracing = "0.1.*" diff --git a/hf_xet/src/log.rs b/hf_xet/src/log.rs index c5dfda12..ff058b16 100644 --- a/hf_xet/src/log.rs +++ b/hf_xet/src/log.rs @@ -6,7 +6,7 @@ use tracing_subscriber::filter::FilterFn; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{EnvFilter, Layer}; -use utils::ThreadPool; +use xet_threadpool::ThreadPool; use crate::log_buffer::{get_telemetry_task, LogBufferLayer, TelemetryTaskInfo, TELEMETRY_PRE_ALLOC_BYTES}; diff --git a/hf_xet/src/runtime.rs b/hf_xet/src/runtime.rs index 81b215ed..93ff3cf2 100644 --- a/hf_xet/src/runtime.rs +++ b/hf_xet/src/runtime.rs @@ -5,7 +5,8 @@ use std::time::Duration; use lazy_static::lazy_static; use pyo3::exceptions::{PyKeyboardInterrupt, PyRuntimeError}; use pyo3::prelude::*; -use utils::threadpool::{MultithreadedRuntimeError, ThreadPool}; +use xet_threadpool::errors::MultithreadedRuntimeError; +use xet_threadpool::ThreadPool; use crate::log; diff --git a/utils/Cargo.toml b/utils/Cargo.toml index 3532a108..7baacee9 100644 --- a/utils/Cargo.toml +++ b/utils/Cargo.toml @@ -9,6 +9,7 @@ path = "src/lib.rs" [dependencies] merklehash = { path = "../merklehash" } +xet_threadpool = { path = "../xet_threadpool" } thiserror = "2.0" futures = "0.3.28" diff --git a/utils/src/lib.rs b/utils/src/lib.rs index 08a6d2ab..cafb5554 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -4,8 +4,6 @@ pub mod auth; pub mod errors; pub mod serialization_utils; pub mod singleflight; -pub mod threadpool; -pub use threadpool::ThreadPool; mod async_read; mod output_bytes; diff --git a/utils/src/singleflight.rs b/utils/src/singleflight.rs index 92c64cf8..3aeed804 100644 --- a/utils/src/singleflight.rs +++ b/utils/src/singleflight.rs @@ -10,6 +10,7 @@ //! //! use futures::future::join_all; //! use utils::singleflight::Group; +//! use xet_threadpool; //! //! const RES: usize = 7; //! @@ -20,7 +21,7 @@ //! //! #[tokio::main] //! async fn main() { -//! let threadpool = Arc::new(utils::ThreadPool::new().unwrap()); +//! let threadpool = Arc::new(xet_threadpool::ThreadPool::new().unwrap()); //! let g = Arc::new(Group::<_, ()>::new(threadpool.clone())); //! let mut handlers = Vec::new(); //! for _ in 0..10 { @@ -50,9 +51,9 @@ use parking_lot::RwLock; use pin_project::{pin_project, pinned_drop}; use tokio::sync::{Mutex, Notify}; use tracing::debug; +use xet_threadpool::ThreadPool; pub use crate::errors::SingleflightError; -use crate::ThreadPool; type SingleflightResult = Result>; type CallMap = HashMap>>; @@ -366,11 +367,11 @@ mod tests { use tokio::sync::{Mutex, Notify}; use tokio::task::JoinHandle; use tokio::time::timeout; + use xet_threadpool::ThreadPool; use super::Group; use crate::errors::SingleflightError; use crate::singleflight::{Call, OwnerTask}; - use crate::ThreadPool; /// A period of time for waiters to wait for a notification from the owner /// task. This is expected to be sufficient time for the test futures to diff --git a/xet_threadpool/Cargo.toml b/xet_threadpool/Cargo.toml new file mode 100644 index 00000000..663ef0ef --- /dev/null +++ b/xet_threadpool/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "xet_threadpool" +version = "0.1.0" +edition = "2024" + +[dependencies] +tokio = { version = "1.41", features = ["full"] } +thiserror = "2.0" +tracing = "0.1.31" +lazy_static = "1" diff --git a/xet_threadpool/src/errors.rs b/xet_threadpool/src/errors.rs new file mode 100644 index 00000000..e45fc302 --- /dev/null +++ b/xet_threadpool/src/errors.rs @@ -0,0 +1,18 @@ +use thiserror::Error; + +/// Define an error time for spawning external threads. +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum MultithreadedRuntimeError { + #[error("Error Initializing Multithreaded Runtime: {0:?}")] + RuntimeInitializationError(std::io::Error), + + #[error("Task Panic: {0:?}.")] + TaskPanic(tokio::task::JoinError), + + #[error("Task cancelled; possible runtime shutdown in progress ({0}).")] + TaskCanceled(String), + + #[error("Unknown task runtime error: {0}")] + Other(String), +} diff --git a/xet_threadpool/src/lib.rs b/xet_threadpool/src/lib.rs new file mode 100644 index 00000000..52d5e84b --- /dev/null +++ b/xet_threadpool/src/lib.rs @@ -0,0 +1,4 @@ +pub mod errors; +pub mod threadpool; + +pub use threadpool::ThreadPool; diff --git a/utils/src/threadpool.rs b/xet_threadpool/src/threadpool.rs similarity index 93% rename from utils/src/threadpool.rs rename to xet_threadpool/src/threadpool.rs index 861b770c..743ed325 100644 --- a/utils/src/threadpool.rs +++ b/xet_threadpool/src/threadpool.rs @@ -3,7 +3,6 @@ use std::future::Future; use std::sync::atomic::Ordering::SeqCst; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use thiserror::Error; /// This module provides a simple wrapper around Tokio's runtime to create a thread pool /// with some default settings. It is intended to be used as a singleton thread pool for /// the entire application. @@ -14,7 +13,7 @@ use thiserror::Error; /// # Example /// /// ```rust -/// use utils::ThreadPool; +/// use xet_threadpool::ThreadPool; /// /// let pool = ThreadPool::new().expect("Error initializing runtime."); /// @@ -51,31 +50,16 @@ use thiserror::Error; /// /// - `new_threadpool`: Creates a new Tokio runtime with the specified settings. use tokio; -use tokio::task::{JoinError, JoinHandle}; +use tokio::task::JoinHandle; use tracing::{debug, error}; +use crate::errors::MultithreadedRuntimeError; + const THREADPOOL_NUM_WORKER_THREADS: usize = 4; // 4 active threads const THREADPOOL_THREAD_ID_PREFIX: &str = "hf-xet"; // thread names will be hf-xet-0, hf-xet-1, etc. const THREADPOOL_STACK_SIZE: usize = 8_000_000; // 8MB stack size const THREADPOOL_MAX_BLOCKING_THREADS: usize = 100; // max 100 threads can block IO -/// Define an error time for spawning external threads. -#[derive(Debug, Error)] -#[non_exhaustive] -pub enum MultithreadedRuntimeError { - #[error("Error Initializing Multithreaded Runtime: {0:?}")] - RuntimeInitializationError(std::io::Error), - - #[error("Task Panic: {0:?}.")] - TaskPanic(JoinError), - - #[error("Task cancelled; possible runtime shutdown in progress ({0}).")] - TaskCanceled(String), - - #[error("Unknown task runtime error: {0}")] - Other(String), -} - #[derive(Debug)] pub struct ThreadPool { // This has to allow for exclusive access to enable shutdown when