Skip to content

Commit

Permalink
fix threading issues in inference iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Dec 17, 2024
1 parent a6291cb commit a1c8872
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 52 deletions.
6 changes: 0 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ unicode-segmentation = "1.10"
unicode-normalization = "0.1"
num = "0.4"
anyhow = "1.0"
num_cpus = "1.14"
numpy = "0.23"
clap = { version = "4", features = ["derive"] }
tokenizers = "0.21"
Expand All @@ -38,11 +37,6 @@ criterion = "0.5"
[features]
benchmark-utils = []

[profile.release]
lto = true
codegen-units = 1
strip = true

[[bench]]
required-features = ["benchmark-utils"]
name = "benchmark"
Expand Down
15 changes: 8 additions & 7 deletions python/text_utils/api/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ def from_pretrained(
):
if model is None:
default = cls.default_model()
assert default is not None, "no default model available"
assert default is not None, "No default model available"
model = default.name

assert model is not None
assert any(model == m.name for m in cls.available_models()), (
f"model {model} does not match any of the available models:\n"
f"Model {model} does not match any of the available models:\n"
f"{pprint.pformat(cls.available_models())}"
)

Expand All @@ -104,7 +104,7 @@ def from_pretrained(
)
sub_dirs = os.listdir(zip_dir)
assert len(sub_dirs) == 1, (
f"expected extracted zip for model {model} to contain "
f"Expected extracted zip for model {model} to contain "
f"one subdirectory, but got {len(sub_dirs)}:\n{pprint.pformat(sub_dirs)}"
)
# mark processor as pretrained
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
) -> None:
self.cfg = cfg
self.logger = logging.get_logger(self._task_upper())
self.logger.debug(f"got config:\n{self.cfg}")
self.logger.debug(f"Got config:\n{self.cfg}")

torch.set_num_threads(len(os.sched_getaffinity(0)))
torch.use_deterministic_algorithms(False)
Expand Down Expand Up @@ -215,6 +215,7 @@ def _process(
for batch in loader:
with torch.inference_mode():
outputs = inference_fn(batch)

for item, output in zip(batch.items(), outputs):
if item.item_idx not in results:
results[item.item_idx] = {}
Expand Down Expand Up @@ -245,6 +246,7 @@ def _process(
for batch in loader:
with torch.inference_mode():
outputs = inference_fn(batch)

for item, output in zip(batch.items(), outputs):
if item.item_idx == prev_item_idx:
window_items.append(item)
Expand All @@ -254,14 +256,13 @@ def _process(
yield postprocessing_fn(window_items, window_outputs)
if progress_unit == "byte":
pbar.update(sum(item.window_bytes() for item in window_items))
elif progress_unit == "it":
pbar.update(1)

prev_item_idx = item.item_idx
window_items = [item]
window_outputs = [output]

if progress_unit == "it":
pbar.update(1)

# dont forget to yield final item
yield postprocessing_fn(window_items, window_outputs)

Expand Down
29 changes: 22 additions & 7 deletions src/data/loading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,25 +311,35 @@ where
}
}

enum PipeInner<O> {
Unthreaded(Box<dyn Iterator<Item = O> + Send>),
Threaded(Receiver<O>),
}

pub struct Pipe<O>
where
O: Send + 'static,
{
rx: Receiver<O>,
inner: PipeInner<O>,
}

impl<O> Pipe<O>
where
O: Send + 'static,
{
fn new<I: Send + 'static>(
inner: impl Iterator<Item = I> + Send + 'static,
iter: impl Iterator<Item = I> + Send + 'static,
pipeline: Pipeline<I, O>,
num_threads: u8,
) -> Self {
let num_threads = num_threads.clamp(1, num_cpus::get() as u8);
let inner = Arc::new(Mutex::new(inner.enumerate()));
let (tx, rx) = sync_channel(num_threads as usize);
if num_threads == 0 {
return Pipe {
inner: PipeInner::Unthreaded(Box::new(iter.map(move |item| pipeline(item)))),
};
}
let num_threads = num_threads as usize;
let inner = Arc::new(Mutex::new(iter.enumerate()));
let (tx, rx) = sync_channel(num_threads);
let sent_counter = Arc::new(AtomicUsize::new(0));
panic::set_hook(Box::new(move |info| {
warn!("Thread panicked: {info}");
Expand Down Expand Up @@ -364,7 +374,9 @@ where
})
.unwrap_or_else(|_| panic!("failed building worker thread {thread}"));
}
Pipe { rx }
Pipe {
inner: PipeInner::Threaded(rx),
}
}
}

Expand All @@ -375,7 +387,10 @@ where
type Item = O;

fn next(&mut self) -> Option<Self::Item> {
self.rx.recv().ok()
match self.inner {
PipeInner::Unthreaded(ref mut iter) => iter.next(),
PipeInner::Threaded(ref rx) => rx.recv().ok(),
}
}
}

Expand Down
68 changes: 38 additions & 30 deletions src/data/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -794,13 +794,14 @@ impl Iterator for InferenceIterator {
)))
}
};
match item.extract(py) {
Ok(None) => None,
Ok(Some(item)) => Some(Ok(item)),
Err(e) => Some(Err(anyhow!(
let item = match item.extract(py) {
Ok(None) => return None,
Ok(Some(item)) => Ok(item),
Err(e) => Err(anyhow!(
"error extracting item from inference iterator: {e}"
))),
}
)),
};
Some(item)
})
}
}
Expand All @@ -814,8 +815,8 @@ impl InferenceLoader {
tokenizer,
window,
ignore_special_tokens = false,
num_threads = num_cpus::get() as u8,
buffer_size = 128,
num_threads = 0,
buffer_size = 16,
batch_limit = 16,
batch_limit_type = BatchLimitType::BatchSize,
prefetch_factor = 1,
Expand Down Expand Up @@ -851,28 +852,35 @@ impl InferenceLoader {
slf
}

fn __next__(&self) -> anyhow::Result<Option<InferenceBatch>> {
let mut iter = self
.iter
.lock()
.map_err(|e| anyhow!("error locking inference iterator: {e}"))?;
if let Some(batch) = iter.next() {
Ok(Some(InferenceBatch {
len: batch.len(),
batch: Some(batch),
}))
} else {
// check if batch is None because iterator is stopped,
// or because an error was encountered
let err = self
.iter_err
fn __next__(&self, py: Python<'_>) -> anyhow::Result<Option<InferenceBatch>> {
// allow threads is required here because the underlying iterator
// is actually a python itetrator that also requests the GIL when calling
// next on it. so we need to release the GIL here, call next on the python
// iterator, and then re-acquire the GIL afterwards
py.allow_threads(move || {
let mut iter = self
.iter
.lock()
.map_err(|e| anyhow!("error locking inference iterator error: {e}"))?;
match err.as_ref() {
Some(e) => Err(anyhow!("error in inference iterator: {e}")),
None => Ok(None),
.map_err(|e| anyhow!("error locking inference iterator: {e}"))?;

if let Some(batch) = iter.next() {
Ok(Some(InferenceBatch {
len: batch.len(),
batch: Some(batch),
}))
} else {
// check if batch is None because iterator is stopped,
// or because an error was encountered
let err = self
.iter_err
.lock()
.map_err(|e| anyhow!("error locking inference iterator error: {e}"))?;
match err.as_ref() {
Some(e) => Err(anyhow!("error in inference iterator: {e}")),
None => Ok(None),
}
}
}
})
}
}

Expand Down Expand Up @@ -1018,8 +1026,8 @@ impl TrainLoader {
files,
pipeline,
strategy = GenerationStrategy::Sequential,
num_threads = num_cpus::get() as u8,
buffer_size = 128,
num_threads = 0,
buffer_size = 16,
batch_limit = 16,
batch_limit_type = BatchLimitType::BatchSize,
max_length = 512,
Expand Down
2 changes: 1 addition & 1 deletion src/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ impl Dictionary {
#[staticmethod]
#[pyo3(
name = "create",
signature = (files, max_size = None, max_sequences = None, num_threads=num_cpus::get() as u8, use_characters = false, char_grams=1, show_progress=false),
signature = (files, max_size = None, max_sequences = None, num_threads=0, use_characters = false, char_grams=1, show_progress=false),
)]
fn create_py(
files: Vec<PyBackedStr>,
Expand Down
2 changes: 1 addition & 1 deletion src/tokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1774,7 +1774,7 @@ fn update_stats(
out_file,
max_lines_per_file = None,
normalization = Normalization::NFKC,
num_threads = num_cpus::get() as u8,
num_threads = 0,
progress = true,
)
)]
Expand Down

0 comments on commit a1c8872

Please sign in to comment.