diff --git a/Cargo.toml b/Cargo.toml index ed0bb7c..7fb1206 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,6 @@ unicode-segmentation = "1.10" unicode-normalization = "0.1" num = "0.4" anyhow = "1.0" -num_cpus = "1.14" numpy = "0.23" clap = { version = "4", features = ["derive"] } tokenizers = "0.21" @@ -38,11 +37,6 @@ criterion = "0.5" [features] benchmark-utils = [] -[profile.release] -lto = true -codegen-units = 1 -strip = true - [[bench]] required-features = ["benchmark-utils"] name = "benchmark" diff --git a/python/text_utils/api/processor.py b/python/text_utils/api/processor.py index 2b6e869..d54f857 100644 --- a/python/text_utils/api/processor.py +++ b/python/text_utils/api/processor.py @@ -77,12 +77,12 @@ def from_pretrained( ): if model is None: default = cls.default_model() - assert default is not None, "no default model available" + assert default is not None, "No default model available" model = default.name assert model is not None assert any(model == m.name for m in cls.available_models()), ( - f"model {model} does not match any of the available models:\n" + f"Model {model} does not match any of the available models:\n" f"{pprint.pformat(cls.available_models())}" ) @@ -104,7 +104,7 @@ def from_pretrained( ) sub_dirs = os.listdir(zip_dir) assert len(sub_dirs) == 1, ( - f"expected extracted zip for model {model} to contain " + f"Expected extracted zip for model {model} to contain " f"one subdirectory, but got {len(sub_dirs)}:\n{pprint.pformat(sub_dirs)}" ) # mark processor as pretrained @@ -151,7 +151,7 @@ def __init__( ) -> None: self.cfg = cfg self.logger = logging.get_logger(self._task_upper()) - self.logger.debug(f"got config:\n{self.cfg}") + self.logger.debug(f"Got config:\n{self.cfg}") torch.set_num_threads(len(os.sched_getaffinity(0))) torch.use_deterministic_algorithms(False) @@ -215,6 +215,7 @@ def _process( for batch in loader: with torch.inference_mode(): outputs = inference_fn(batch) + for item, output in zip(batch.items(), outputs): if item.item_idx not in results: results[item.item_idx] = {} @@ -245,6 +246,7 @@ def _process( for batch in loader: with torch.inference_mode(): outputs = inference_fn(batch) + for item, output in zip(batch.items(), outputs): if item.item_idx == prev_item_idx: window_items.append(item) @@ -254,14 +256,13 @@ def _process( yield postprocessing_fn(window_items, window_outputs) if progress_unit == "byte": pbar.update(sum(item.window_bytes() for item in window_items)) + elif progress_unit == "it": + pbar.update(1) prev_item_idx = item.item_idx window_items = [item] window_outputs = [output] - if progress_unit == "it": - pbar.update(1) - # dont forget to yield final item yield postprocessing_fn(window_items, window_outputs) diff --git a/src/data/loading.rs b/src/data/loading.rs index 61e90a2..7a8e4b6 100644 --- a/src/data/loading.rs +++ b/src/data/loading.rs @@ -311,11 +311,16 @@ where } } +enum PipeInner { + Unthreaded(Box + Send>), + Threaded(Receiver), +} + pub struct Pipe where O: Send + 'static, { - rx: Receiver, + inner: PipeInner, } impl Pipe @@ -323,13 +328,18 @@ where O: Send + 'static, { fn new( - inner: impl Iterator + Send + 'static, + iter: impl Iterator + Send + 'static, pipeline: Pipeline, num_threads: u8, ) -> Self { - let num_threads = num_threads.clamp(1, num_cpus::get() as u8); - let inner = Arc::new(Mutex::new(inner.enumerate())); - let (tx, rx) = sync_channel(num_threads as usize); + if num_threads == 0 { + return Pipe { + inner: PipeInner::Unthreaded(Box::new(iter.map(move |item| pipeline(item)))), + }; + } + let num_threads = num_threads as usize; + let inner = Arc::new(Mutex::new(iter.enumerate())); + let (tx, rx) = sync_channel(num_threads); let sent_counter = Arc::new(AtomicUsize::new(0)); panic::set_hook(Box::new(move |info| { warn!("Thread panicked: {info}"); @@ -364,7 +374,9 @@ where }) .unwrap_or_else(|_| panic!("failed building worker thread {thread}")); } - Pipe { rx } + Pipe { + inner: PipeInner::Threaded(rx), + } } } @@ -375,7 +387,10 @@ where type Item = O; fn next(&mut self) -> Option { - self.rx.recv().ok() + match self.inner { + PipeInner::Unthreaded(ref mut iter) => iter.next(), + PipeInner::Threaded(ref rx) => rx.recv().ok(), + } } } diff --git a/src/data/mod.rs b/src/data/mod.rs index 80017e9..1613d9f 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -794,13 +794,14 @@ impl Iterator for InferenceIterator { ))) } }; - match item.extract(py) { - Ok(None) => None, - Ok(Some(item)) => Some(Ok(item)), - Err(e) => Some(Err(anyhow!( + let item = match item.extract(py) { + Ok(None) => return None, + Ok(Some(item)) => Ok(item), + Err(e) => Err(anyhow!( "error extracting item from inference iterator: {e}" - ))), - } + )), + }; + Some(item) }) } } @@ -814,8 +815,8 @@ impl InferenceLoader { tokenizer, window, ignore_special_tokens = false, - num_threads = num_cpus::get() as u8, - buffer_size = 128, + num_threads = 0, + buffer_size = 16, batch_limit = 16, batch_limit_type = BatchLimitType::BatchSize, prefetch_factor = 1, @@ -851,28 +852,35 @@ impl InferenceLoader { slf } - fn __next__(&self) -> anyhow::Result> { - let mut iter = self - .iter - .lock() - .map_err(|e| anyhow!("error locking inference iterator: {e}"))?; - if let Some(batch) = iter.next() { - Ok(Some(InferenceBatch { - len: batch.len(), - batch: Some(batch), - })) - } else { - // check if batch is None because iterator is stopped, - // or because an error was encountered - let err = self - .iter_err + fn __next__(&self, py: Python<'_>) -> anyhow::Result> { + // allow threads is required here because the underlying iterator + // is actually a python itetrator that also requests the GIL when calling + // next on it. so we need to release the GIL here, call next on the python + // iterator, and then re-acquire the GIL afterwards + py.allow_threads(move || { + let mut iter = self + .iter .lock() - .map_err(|e| anyhow!("error locking inference iterator error: {e}"))?; - match err.as_ref() { - Some(e) => Err(anyhow!("error in inference iterator: {e}")), - None => Ok(None), + .map_err(|e| anyhow!("error locking inference iterator: {e}"))?; + + if let Some(batch) = iter.next() { + Ok(Some(InferenceBatch { + len: batch.len(), + batch: Some(batch), + })) + } else { + // check if batch is None because iterator is stopped, + // or because an error was encountered + let err = self + .iter_err + .lock() + .map_err(|e| anyhow!("error locking inference iterator error: {e}"))?; + match err.as_ref() { + Some(e) => Err(anyhow!("error in inference iterator: {e}")), + None => Ok(None), + } } - } + }) } } @@ -1018,8 +1026,8 @@ impl TrainLoader { files, pipeline, strategy = GenerationStrategy::Sequential, - num_threads = num_cpus::get() as u8, - buffer_size = 128, + num_threads = 0, + buffer_size = 16, batch_limit = 16, batch_limit_type = BatchLimitType::BatchSize, max_length = 512, diff --git a/src/dictionary.rs b/src/dictionary.rs index b7e6acc..e7bfcee 100644 --- a/src/dictionary.rs +++ b/src/dictionary.rs @@ -229,7 +229,7 @@ impl Dictionary { #[staticmethod] #[pyo3( name = "create", - signature = (files, max_size = None, max_sequences = None, num_threads=num_cpus::get() as u8, use_characters = false, char_grams=1, show_progress=false), + signature = (files, max_size = None, max_sequences = None, num_threads=0, use_characters = false, char_grams=1, show_progress=false), )] fn create_py( files: Vec, diff --git a/src/tokenization.rs b/src/tokenization.rs index 0e837aa..f1f341b 100644 --- a/src/tokenization.rs +++ b/src/tokenization.rs @@ -1774,7 +1774,7 @@ fn update_stats( out_file, max_lines_per_file = None, normalization = Normalization::NFKC, - num_threads = num_cpus::get() as u8, + num_threads = 0, progress = true, ) )]