Skip to content

Commit

Permalink
hack: try removing error normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jan 15, 2025
1 parent ad5f6d4 commit 4d2bf73
Showing 1 changed file with 13 additions and 176 deletions.
189 changes: 13 additions & 176 deletions src/err/err_state.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
use std::{
cell::UnsafeCell,
sync::{Mutex, Once},
thread::ThreadId,
};

use crate::{
exceptions::{PyBaseException, PyTypeError},
ffi,
Expand All @@ -12,14 +6,7 @@ use crate::{
Bound, Py, PyAny, PyErrArguments, PyObject, PyTypeInfo, Python,
};

pub(crate) struct PyErrState {
// Safety: can only hand out references when in the "normalized" state. Will never change
// after normalization.
normalized: Once,
// Guard against re-entrancy when normalizing the exception state.
normalizing_thread: Mutex<Option<ThreadId>>,
inner: UnsafeCell<Option<PyErrStateInner>>,
}
pub(crate) struct PyErrState(PyErrStateNormalized);

// Safety: The inner value is protected by locking to ensure that only the normalized state is
// handed out as a reference.
Expand All @@ -30,105 +17,32 @@ unsafe impl crate::marker::Ungil for PyErrState {}

impl PyErrState {
pub(crate) fn lazy(f: Box<PyErrStateLazyFn>) -> Self {
Self::from_inner(PyErrStateInner::Lazy(f))
Self(Python::with_gil(|py| {
PyErrStateInner::Lazy(f).normalize(py)
}))
}

pub(crate) fn lazy_arguments(ptype: Py<PyAny>, args: impl PyErrArguments + 'static) -> Self {
Self::from_inner(PyErrStateInner::Lazy(Box::new(move |py| {
PyErrStateLazyFnOutput {
Self(Python::with_gil(|py| {
PyErrStateInner::Lazy(Box::new(move |py| PyErrStateLazyFnOutput {
ptype,
pvalue: args.arguments(py),
}
})))
}))
.normalize(py)
}))
}

pub(crate) fn normalized(normalized: PyErrStateNormalized) -> Self {
let state = Self::from_inner(PyErrStateInner::Normalized(normalized));
// This state is already normalized, by completing the Once immediately we avoid
// reaching the `py.allow_threads` in `make_normalized` which is less efficient
// and introduces a GIL switch which could deadlock.
// See https://github.com/PyO3/pyo3/issues/4764
state.normalized.call_once(|| {});
state
Self(normalized)
}

pub(crate) fn restore(self, py: Python<'_>) {
self.inner
.into_inner()
.expect("PyErr state should never be invalid outside of normalization")
.restore(py)
}

fn from_inner(inner: PyErrStateInner) -> Self {
Self {
normalized: Once::new(),
normalizing_thread: Mutex::new(None),
inner: UnsafeCell::new(Some(inner)),
}
PyErrStateInner::Normalized(self.0).restore(py)
}

#[inline]
pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
if self.normalized.is_completed() {
match unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => return n,
_ => unreachable!(),
}
}

self.make_normalized(py)
}

#[cold]
fn make_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
// This process is safe because:
// - Access is guaranteed not to be concurrent thanks to `Python` GIL token
// - Write happens only once, and then never will change again.

// Guard against re-entrant normalization, because `Once` does not provide
// re-entrancy guarantees.
if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
assert!(
!(*thread == std::thread::current().id()),
"Re-entrant normalization of PyErrState detected"
);
}

// avoid deadlock of `.call_once` with the GIL
py.allow_threads(|| {
self.normalized.call_once(|| {
self.normalizing_thread
.lock()
.unwrap()
.replace(std::thread::current().id());

// Safety: no other thread can access the inner value while we are normalizing it.
let state = unsafe {
(*self.inner.get())
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};

let normalized_state =
Python::with_gil(|py| PyErrStateInner::Normalized(state.normalize(py)));

// Safety: no other thread can access the inner value while we are normalizing it.
unsafe {
*self.inner.get() = Some(normalized_state);
}
})
});

match unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => n,
_ => unreachable!(),
}
pub(crate) fn as_normalized(&self, _py: Python<'_>) -> &PyErrStateNormalized {
&self.0
}
}

Expand Down Expand Up @@ -361,80 +275,3 @@ fn raise_lazy(py: Python<'_>, lazy: Box<PyErrStateLazyFn>) {
}
}
}

#[cfg(test)]
mod tests {

use crate::{
exceptions::PyValueError, sync::GILOnceCell, PyErr, PyErrArguments, PyObject, Python,
};

#[test]
#[should_panic(expected = "Re-entrant normalization of PyErrState detected")]
fn test_reentrant_normalization() {
static ERR: GILOnceCell<PyErr> = GILOnceCell::new();

struct RecursiveArgs;

impl PyErrArguments for RecursiveArgs {
fn arguments(self, py: Python<'_>) -> PyObject {
// .value(py) triggers normalization
ERR.get(py)
.expect("is set just below")
.value(py)
.clone()
.into()
}
}

Python::with_gil(|py| {
ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap();
ERR.get(py).expect("is set just above").value(py);
})
}

#[test]
#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
fn test_no_deadlock_thread_switch() {
static ERR: GILOnceCell<PyErr> = GILOnceCell::new();

struct GILSwitchArgs;

impl PyErrArguments for GILSwitchArgs {
fn arguments(self, py: Python<'_>) -> PyObject {
// releasing the GIL potentially allows for other threads to deadlock
// with the normalization going on here
py.allow_threads(|| {
std::thread::sleep(std::time::Duration::from_millis(10));
});
py.None()
}
}

Python::with_gil(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap());

// Let many threads attempt to read the normalized value at the same time
let handles = (0..10)
.map(|_| {
std::thread::spawn(|| {
Python::with_gil(|py| {
ERR.get(py).expect("is set just above").value(py);
});
})
})
.collect::<Vec<_>>();

for handle in handles {
handle.join().unwrap();
}

// We should never have deadlocked, and should be able to run
// this assertion
Python::with_gil(|py| {
assert!(ERR
.get(py)
.expect("is set above")
.is_instance_of::<PyValueError>(py))
});
}
}

0 comments on commit 4d2bf73

Please sign in to comment.