Skip to content

Commit

Permalink
Change AsyncThread<A, R> to AsyncThread<R>.
Browse files Browse the repository at this point in the history
Push arguments in `Thread::into_async()` to the thread during the call instead of first poll.
The pushed arguments will be automatically used on resume.
Fixes #508 and relates to #500.
  • Loading branch information
khvzak committed Jan 11, 2025
1 parent cd4091f commit 10b9e37
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 83 deletions.
6 changes: 3 additions & 3 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ impl Function {
{
let lua = self.0.lua.lock();
let thread_res = unsafe {
lua.create_recycled_thread(self).map(|th| {
let mut th = th.into_async(args);
lua.create_recycled_thread(self).and_then(|th| {
let mut th = th.into_async(args)?;
th.set_recyclable(true);
th
Ok(th)
})
};
async move { thread_res?.await }
Expand Down
197 changes: 120 additions & 77 deletions src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,33 @@ pub enum ThreadStatus {
Error,
}

/// Internal representation of a Lua thread status.
///
/// The number in `New` and `Yielded` variants is the number of arguments pushed
/// to the thread stack.
#[derive(Clone, Copy)]
enum ThreadStatusInner {
New(c_int),
Running,
Yielded(c_int),
Finished,
Error,
}

impl ThreadStatusInner {
#[cfg(feature = "async")]
#[inline(always)]
fn is_resumable(self) -> bool {
matches!(self, ThreadStatusInner::New(_) | ThreadStatusInner::Yielded(_))
}

#[cfg(feature = "async")]
#[inline(always)]
fn is_yielded(self) -> bool {
matches!(self, ThreadStatusInner::Yielded(_))
}
}

/// Handle to an internal Lua thread (coroutine).
#[derive(Clone)]
pub struct Thread(pub(crate) ValueRef, pub(crate) *mut ffi::lua_State);
Expand All @@ -60,9 +87,8 @@ unsafe impl Sync for Thread {}
#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct AsyncThread<A, R> {
pub struct AsyncThread<R> {
thread: Thread,
init_args: Option<A>,
ret: PhantomData<R>,
recycle: bool,
}
Expand Down Expand Up @@ -122,17 +148,25 @@ impl Thread {
R: FromLuaMulti,
{
let lua = self.0.lua.lock();
if self.status_inner(&lua) != ThreadStatus::Resumable {
return Err(Error::CoroutineUnresumable);
}
let mut pushed_nargs = match self.status_inner(&lua) {
ThreadStatusInner::New(nargs) | ThreadStatusInner::Yielded(nargs) => nargs,
_ => return Err(Error::CoroutineUnresumable),
};

let state = lua.state();
let thread_state = self.state();
unsafe {
let _sg = StackGuard::new(state);
let _thread_sg = StackGuard::with_top(thread_state, 0);

let nresults = self.resume_inner(&lua, args)?;
let nargs = args.push_into_stack_multi(&lua)?;
if nargs > 0 {
check_stack(thread_state, nargs)?;
ffi::lua_xmove(state, thread_state, nargs);
pushed_nargs += nargs;
}

let (_, nresults) = self.resume_inner(&lua, pushed_nargs)?;
check_stack(state, nresults + 1)?;
ffi::lua_xmove(thread_state, state, nresults);

Expand All @@ -143,50 +177,50 @@ impl Thread {
/// Resumes execution of this thread.
///
/// It's similar to `resume()` but leaves `nresults` values on the thread stack.
unsafe fn resume_inner(&self, lua: &RawLua, args: impl IntoLuaMulti) -> Result<c_int> {
unsafe fn resume_inner(&self, lua: &RawLua, nargs: c_int) -> Result<(ThreadStatusInner, c_int)> {
let state = lua.state();
let thread_state = self.state();

let nargs = args.push_into_stack_multi(lua)?;
if nargs > 0 {
check_stack(thread_state, nargs)?;
ffi::lua_xmove(state, thread_state, nargs);
}

let mut nresults = 0;
let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int);
if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD {
if ret == ffi::LUA_ERRMEM {
match ret {
ffi::LUA_OK => Ok((ThreadStatusInner::Finished, nresults)),
ffi::LUA_YIELD => Ok((ThreadStatusInner::Yielded(0), nresults)),
ffi::LUA_ERRMEM => {
// Don't call error handler for memory errors
return Err(pop_error(thread_state, ret));
Err(pop_error(thread_state, ret))
}
_ => {
check_stack(state, 3)?;
protect_lua!(state, 0, 1, |state| error_traceback_thread(state, thread_state))?;
Err(pop_error(state, ret))
}
check_stack(state, 3)?;
protect_lua!(state, 0, 1, |state| error_traceback_thread(state, thread_state))?;
return Err(pop_error(state, ret));
}

Ok(nresults)
}

/// Gets the status of the thread.
pub fn status(&self) -> ThreadStatus {
self.status_inner(&self.0.lua.lock())
match self.status_inner(&self.0.lua.lock()) {
ThreadStatusInner::New(_) | ThreadStatusInner::Yielded(_) => ThreadStatus::Resumable,
ThreadStatusInner::Running => ThreadStatus::Running,
ThreadStatusInner::Finished => ThreadStatus::Finished,
ThreadStatusInner::Error => ThreadStatus::Error,
}
}

/// Gets the status of the thread (internal implementation).
pub(crate) fn status_inner(&self, lua: &RawLua) -> ThreadStatus {
fn status_inner(&self, lua: &RawLua) -> ThreadStatusInner {
let thread_state = self.state();
if thread_state == lua.state() {
// The thread is currently running
return ThreadStatus::Running;
return ThreadStatusInner::Running;
}
let status = unsafe { ffi::lua_status(thread_state) };
if status != ffi::LUA_OK && status != ffi::LUA_YIELD {
ThreadStatus::Error
} else if status == ffi::LUA_YIELD || unsafe { ffi::lua_gettop(thread_state) > 0 } {
ThreadStatus::Resumable
} else {
ThreadStatus::Finished
let top = unsafe { ffi::lua_gettop(thread_state) };
match status {
ffi::LUA_YIELD => ThreadStatusInner::Yielded(top),
ffi::LUA_OK if top > 0 => ThreadStatusInner::New(top - 1),
ffi::LUA_OK => ThreadStatusInner::Finished,
_ => ThreadStatusInner::Error,
}
}

Expand Down Expand Up @@ -224,7 +258,7 @@ impl Thread {
#[cfg_attr(docsrs, doc(cfg(any(feature = "lua54", feature = "luau"))))]
pub fn reset(&self, func: crate::function::Function) -> Result<()> {
let lua = self.0.lua.lock();
if self.status_inner(&lua) == ThreadStatus::Running {
if matches!(self.status_inner(&lua), ThreadStatusInner::Running) {
return Err(Error::runtime("cannot reset a running thread"));
}

Expand Down Expand Up @@ -257,7 +291,9 @@ impl Thread {

/// Converts [`Thread`] to an [`AsyncThread`] which implements [`Future`] and [`Stream`] traits.
///
/// `args` are passed as arguments to the thread function for first call.
/// Only resumable threads can be converted to [`AsyncThread`].
///
/// `args` are pushed to the thread stack and will be used when the thread is resumed.
/// The object calls [`resume`] while polling and also allow to run Rust futures
/// to completion using an executor.
///
Expand Down Expand Up @@ -290,7 +326,7 @@ impl Thread {
/// end)
/// "#).eval()?;
///
/// let mut stream = thread.into_async::<i64>(1);
/// let mut stream = thread.into_async::<i64>(1)?;
/// let mut sum = 0;
/// while let Some(n) = stream.try_next().await? {
/// sum += n;
Expand All @@ -303,15 +339,31 @@ impl Thread {
/// ```
#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
pub fn into_async<R>(self, args: impl IntoLuaMulti) -> AsyncThread<impl IntoLuaMulti, R>
pub fn into_async<R>(self, args: impl IntoLuaMulti) -> Result<AsyncThread<R>>
where
R: FromLuaMulti,
{
AsyncThread {
thread: self,
init_args: Some(args),
ret: PhantomData,
recycle: false,
let lua = self.0.lua.lock();
if !self.status_inner(&lua).is_resumable() {
return Err(Error::CoroutineUnresumable);
}

let state = lua.state();
let thread_state = self.state();
unsafe {
let _sg = StackGuard::new(state);

let nargs = args.push_into_stack_multi(&lua)?;
if nargs > 0 {
check_stack(thread_state, nargs)?;
ffi::lua_xmove(state, thread_state, nargs);
}

Ok(AsyncThread {
thread: self,
ret: PhantomData,
recycle: false,
})
}
}

Expand Down Expand Up @@ -392,7 +444,7 @@ impl LuaType for Thread {
}

#[cfg(feature = "async")]
impl<A, R> AsyncThread<A, R> {
impl<R> AsyncThread<R> {
#[inline]
pub(crate) fn set_recyclable(&mut self, recyclable: bool) {
self.recycle = recyclable;
Expand All @@ -401,15 +453,15 @@ impl<A, R> AsyncThread<A, R> {

#[cfg(feature = "async")]
#[cfg(any(feature = "lua54", feature = "luau"))]
impl<A, R> Drop for AsyncThread<A, R> {
impl<R> Drop for AsyncThread<R> {
fn drop(&mut self) {
if self.recycle {
if let Some(lua) = self.thread.0.lua.try_lock() {
unsafe {
// For Lua 5.4 this also closes all pending to-be-closed variables
if !lua.recycle_thread(&mut self.thread) {
#[cfg(feature = "lua54")]
if self.thread.status_inner(&lua) == ThreadStatus::Error {
if matches!(self.thread.status_inner(&lua), ThreadStatusInner::Error) {
#[cfg(not(feature = "vendored"))]
ffi::lua_resetthread(self.thread.state());
#[cfg(feature = "vendored")]
Expand All @@ -423,14 +475,15 @@ impl<A, R> Drop for AsyncThread<A, R> {
}

#[cfg(feature = "async")]
impl<A: IntoLuaMulti, R: FromLuaMulti> Stream for AsyncThread<A, R> {
impl<R: FromLuaMulti> Stream for AsyncThread<R> {
type Item = Result<R>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let lua = self.thread.0.lua.lock();
if self.thread.status_inner(&lua) != ThreadStatus::Resumable {
return Poll::Ready(None);
}
let nargs = match self.thread.status_inner(&lua) {
ThreadStatusInner::New(nargs) | ThreadStatusInner::Yielded(nargs) => nargs,
_ => return Poll::Ready(None),
};

let state = lua.state();
let thread_state = self.thread.state();
Expand All @@ -439,36 +492,34 @@ impl<A: IntoLuaMulti, R: FromLuaMulti> Stream for AsyncThread<A, R> {
let _thread_sg = StackGuard::with_top(thread_state, 0);
let _wg = WakerGuard::new(&lua, cx.waker());

// This is safe as we are not moving the whole struct
let this = self.get_unchecked_mut();
let nresults = if let Some(args) = this.init_args.take() {
this.thread.resume_inner(&lua, args)?
} else {
this.thread.resume_inner(&lua, ())?
};
let (status, nresults) = (self.thread).resume_inner(&lua, nargs)?;

if nresults == 1 && is_poll_pending(thread_state) {
return Poll::Pending;
if status.is_yielded() {
if nresults == 1 && is_poll_pending(thread_state) {
return Poll::Pending;
}
// Continue polling
cx.waker().wake_by_ref();
}

check_stack(state, nresults + 1)?;
ffi::lua_xmove(thread_state, state, nresults);

cx.waker().wake_by_ref();
Poll::Ready(Some(R::from_stack_multi(nresults, &lua)))
}
}
}

#[cfg(feature = "async")]
impl<A: IntoLuaMulti, R: FromLuaMulti> Future for AsyncThread<A, R> {
impl<R: FromLuaMulti> Future for AsyncThread<R> {
type Output = Result<R>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let lua = self.thread.0.lua.lock();
if self.thread.status_inner(&lua) != ThreadStatus::Resumable {
return Poll::Ready(Err(Error::CoroutineUnresumable));
}
let nargs = match self.thread.status_inner(&lua) {
ThreadStatusInner::New(nargs) | ThreadStatusInner::Yielded(nargs) => nargs,
_ => return Poll::Ready(Err(Error::CoroutineUnresumable)),
};

let state = lua.state();
let thread_state = self.thread.state();
Expand All @@ -477,21 +528,13 @@ impl<A: IntoLuaMulti, R: FromLuaMulti> Future for AsyncThread<A, R> {
let _thread_sg = StackGuard::with_top(thread_state, 0);
let _wg = WakerGuard::new(&lua, cx.waker());

// This is safe as we are not moving the whole struct
let this = self.get_unchecked_mut();
let nresults = if let Some(args) = this.init_args.take() {
this.thread.resume_inner(&lua, args)?
} else {
this.thread.resume_inner(&lua, ())?
};

if nresults == 1 && is_poll_pending(thread_state) {
return Poll::Pending;
}
let (status, nresults) = self.thread.resume_inner(&lua, nargs)?;

if ffi::lua_status(thread_state) == ffi::LUA_YIELD {
// Ignore value returned via yield()
cx.waker().wake_by_ref();
if status.is_yielded() {
if !(nresults == 1 && is_poll_pending(thread_state)) {
// Ignore value returned via yield()
cx.waker().wake_by_ref();
}
return Poll::Pending;
}

Expand Down Expand Up @@ -545,7 +588,7 @@ mod assertions {
#[cfg(feature = "send")]
static_assertions::assert_impl_all!(Thread: Send, Sync);
#[cfg(all(feature = "async", not(feature = "send")))]
static_assertions::assert_not_impl_any!(AsyncThread<(), ()>: Send);
static_assertions::assert_not_impl_any!(AsyncThread<()>: Send);
#[cfg(all(feature = "async", feature = "send"))]
static_assertions::assert_impl_all!(AsyncThread<(), ()>: Send, Sync);
static_assertions::assert_impl_all!(AsyncThread<()>: Send, Sync);
}
Loading

0 comments on commit 10b9e37

Please sign in to comment.