Skip to content

Commit

Permalink
Used Thread.interrupt() to implement cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
rcardin committed May 23, 2024
1 parent f2a9631 commit ec87580
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 62 deletions.
72 changes: 12 additions & 60 deletions core/src/main/scala/in/rcard/sus4s/sus4s.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package in.rcard.sus4s

import java.util.concurrent.StructuredTaskScope.Subtask
import java.util.concurrent.{CompletableFuture, ExecutionException, StructuredTaskScope}
import scala.compiletime.uninitialized
import java.util.concurrent.{CompletableFuture, StructuredTaskScope}

object sus4s {

Expand All @@ -22,18 +20,14 @@ object sus4s {
* @tparam A
* The type of the value returned by the job
*/
class Job[A] private[sus4s] (private val cf: CompletableFuture[A]) {
class Job[A] private[sus4s] (
private val cf: CompletableFuture[A],
private val executingThread: CompletableFuture[Thread]
) {
def value: A = cf.get()
}

class CancellableJob[A] private[sus4s] (val cf: CompletableFuture[A]) extends Job[A](cf) {
var scope: StructuredTaskScope[Any] = uninitialized
def cancel(): Unit = {
executingThread.get().interrupt()
cf.completeExceptionally(new InterruptedException("Job cancelled"))
try
cf.get()
catch
case e =>
}
}

Expand Down Expand Up @@ -122,60 +116,18 @@ object sus4s {
* [[structured]]
*/
def fork[A](block: Suspend ?=> A): Suspend ?=> Job[A] = {
val result = new CompletableFuture[A]()
val result = new CompletableFuture[A]()
val executingThread = new CompletableFuture[Thread]()
summon[Suspend].scope.fork(() => {
executingThread.complete(Thread.currentThread())
try result.complete(block)
catch
case _: InterruptedException =>
result.completeExceptionally(new InterruptedException("Job cancelled"))
case throwable: Throwable =>
result.completeExceptionally(throwable)
throw throwable;
})
Job(result)
}

def forkCancellable[A](block: Suspend ?=> A): Suspend ?=> CancellableJob[A] = {
val result = new CompletableFuture[A]()
val innerResult = new CompletableFuture[A]()
val cancellableJob = CancellableJob(innerResult)
summon[Suspend].scope.fork(() => {
val innerScope = new StructuredTaskScope.ShutdownOnFailure()
cancellableJob.scope = innerScope

given innerSuspended: Suspend = new Suspend {
override val scope: StructuredTaskScope[Any] = innerScope
}
try {
val subtask = innerScope.fork(() => {
val result = block(using innerSuspended)
innerResult.complete(result)
result
})
try
innerResult.get()
catch
case e: Throwable =>
innerScope.shutdown()
innerScope.join().throwIfFailed(identity)
if (subtask.state() == Subtask.State.UNAVAILABLE) {
// TODO Handle all cases
result.completeExceptionally(new InterruptedException("Job cancelled"))
} else {
result.complete(subtask.get())
}
} catch
case exex: ExecutionException =>
exex.getCause match
case ie: InterruptedException => innerScope.shutdown()
case e: Throwable =>
result.completeExceptionally(e)
throw e
case e: Throwable =>
result.completeExceptionally(e)
throw e
finally {
innerScope.close()
}
})
cancellableJob
Job(result, executingThread)
}
}
25 changes: 23 additions & 2 deletions core/src/test/scala/StructuredSpec.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import in.rcard.sus4s.sus4s
import in.rcard.sus4s.sus4s.{fork, forkCancellable, structured}
import in.rcard.sus4s.sus4s.{fork, structured}
import org.scalatest.TryValues.*
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
Expand Down Expand Up @@ -100,17 +100,38 @@ class StructuredSpec extends AnyFlatSpec with Matchers {
result shouldBe 85
}

it should "wait for children jobs to finish" in {
val results = structured {
val queue = new ConcurrentLinkedQueue[String]()
val job1 = fork {
fork {
Thread.sleep(1000)
queue.add("1")
}
fork {
Thread.sleep(500)
queue.add("2")
}
queue.add("3")
}
queue
}

results.toArray should contain theSameElementsInOrderAs List("3", "2", "1")
}

it should "cancel at the first suspending point" in {
val queue = new ConcurrentLinkedQueue[String]()
val result = structured {
val cancellable = forkCancellable {
val cancellable = fork {
while (true) {
Thread.sleep(2000)
println("cancellable job")
queue.add("cancellable")
}
}
val job = fork {
Thread.sleep(500)
cancellable.cancel()
queue.add("job2")
43
Expand Down

0 comments on commit ec87580

Please sign in to comment.