Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hack: try removing error normalization #4859

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 13 additions & 176 deletions src/err/err_state.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
use std::{
cell::UnsafeCell,
sync::{Mutex, Once},
thread::ThreadId,
};

use crate::{
exceptions::{PyBaseException, PyTypeError},
ffi,
Expand All @@ -12,14 +6,7 @@ use crate::{
Bound, Py, PyAny, PyErrArguments, PyObject, PyTypeInfo, Python,
};

pub(crate) struct PyErrState {
// Safety: can only hand out references when in the "normalized" state. Will never change
// after normalization.
normalized: Once,
// Guard against re-entrancy when normalizing the exception state.
normalizing_thread: Mutex<Option<ThreadId>>,
inner: UnsafeCell<Option<PyErrStateInner>>,
}
pub(crate) struct PyErrState(PyErrStateNormalized);

// Safety: The inner value is protected by locking to ensure that only the normalized state is
// handed out as a reference.
Expand All @@ -30,105 +17,32 @@ unsafe impl crate::marker::Ungil for PyErrState {}

impl PyErrState {
pub(crate) fn lazy(f: Box<PyErrStateLazyFn>) -> Self {
Self::from_inner(PyErrStateInner::Lazy(f))
Self(Python::with_gil(|py| {
PyErrStateInner::Lazy(f).normalize(py)
}))
}

pub(crate) fn lazy_arguments(ptype: Py<PyAny>, args: impl PyErrArguments + 'static) -> Self {
Self::from_inner(PyErrStateInner::Lazy(Box::new(move |py| {
PyErrStateLazyFnOutput {
Self(Python::with_gil(|py| {
PyErrStateInner::Lazy(Box::new(move |py| PyErrStateLazyFnOutput {
ptype,
pvalue: args.arguments(py),
}
})))
}))
.normalize(py)
}))
}

pub(crate) fn normalized(normalized: PyErrStateNormalized) -> Self {
let state = Self::from_inner(PyErrStateInner::Normalized(normalized));
// This state is already normalized, by completing the Once immediately we avoid
// reaching the `py.allow_threads` in `make_normalized` which is less efficient
// and introduces a GIL switch which could deadlock.
// See https://github.com/PyO3/pyo3/issues/4764
state.normalized.call_once(|| {});
state
Self(normalized)
}

pub(crate) fn restore(self, py: Python<'_>) {
self.inner
.into_inner()
.expect("PyErr state should never be invalid outside of normalization")
.restore(py)
}

fn from_inner(inner: PyErrStateInner) -> Self {
Self {
normalized: Once::new(),
normalizing_thread: Mutex::new(None),
inner: UnsafeCell::new(Some(inner)),
}
PyErrStateInner::Normalized(self.0).restore(py)
}

#[inline]
pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
if self.normalized.is_completed() {
match unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => return n,
_ => unreachable!(),
}
}

self.make_normalized(py)
}

#[cold]
fn make_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
// This process is safe because:
// - Access is guaranteed not to be concurrent thanks to `Python` GIL token
// - Write happens only once, and then never will change again.

// Guard against re-entrant normalization, because `Once` does not provide
// re-entrancy guarantees.
if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
assert!(
!(*thread == std::thread::current().id()),
"Re-entrant normalization of PyErrState detected"
);
}

// avoid deadlock of `.call_once` with the GIL
py.allow_threads(|| {
self.normalized.call_once(|| {
self.normalizing_thread
.lock()
.unwrap()
.replace(std::thread::current().id());

// Safety: no other thread can access the inner value while we are normalizing it.
let state = unsafe {
(*self.inner.get())
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};

let normalized_state =
Python::with_gil(|py| PyErrStateInner::Normalized(state.normalize(py)));

// Safety: no other thread can access the inner value while we are normalizing it.
unsafe {
*self.inner.get() = Some(normalized_state);
}
})
});

match unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => n,
_ => unreachable!(),
}
pub(crate) fn as_normalized(&self, _py: Python<'_>) -> &PyErrStateNormalized {
&self.0
}
}

Expand Down Expand Up @@ -361,80 +275,3 @@ fn raise_lazy(py: Python<'_>, lazy: Box<PyErrStateLazyFn>) {
}
}
}

#[cfg(test)]
mod tests {

use crate::{
exceptions::PyValueError, sync::GILOnceCell, PyErr, PyErrArguments, PyObject, Python,
};

#[test]
#[should_panic(expected = "Re-entrant normalization of PyErrState detected")]
fn test_reentrant_normalization() {
static ERR: GILOnceCell<PyErr> = GILOnceCell::new();

struct RecursiveArgs;

impl PyErrArguments for RecursiveArgs {
fn arguments(self, py: Python<'_>) -> PyObject {
// .value(py) triggers normalization
ERR.get(py)
.expect("is set just below")
.value(py)
.clone()
.into()
}
}

Python::with_gil(|py| {
ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap();
ERR.get(py).expect("is set just above").value(py);
})
}

#[test]
#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
fn test_no_deadlock_thread_switch() {
static ERR: GILOnceCell<PyErr> = GILOnceCell::new();

struct GILSwitchArgs;

impl PyErrArguments for GILSwitchArgs {
fn arguments(self, py: Python<'_>) -> PyObject {
// releasing the GIL potentially allows for other threads to deadlock
// with the normalization going on here
py.allow_threads(|| {
std::thread::sleep(std::time::Duration::from_millis(10));
});
py.None()
}
}

Python::with_gil(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap());

// Let many threads attempt to read the normalized value at the same time
let handles = (0..10)
.map(|_| {
std::thread::spawn(|| {
Python::with_gil(|py| {
ERR.get(py).expect("is set just above").value(py);
});
})
})
.collect::<Vec<_>>();

for handle in handles {
handle.join().unwrap();
}

// We should never have deadlocked, and should be able to run
// this assertion
Python::with_gil(|py| {
assert!(ERR
.get(py)
.expect("is set above")
.is_instance_of::<PyValueError>(py))
});
}
}
Loading