Skip to content

Commit

Permalink
Add support for MonadCatch
Browse files Browse the repository at this point in the history
- Add support for Catch in IOSim and IOSimPOR
- Add support for Catch in Test/STM.hs

Co-authored-by: Marcin Szamotulski <[email protected]>
  • Loading branch information
yogeshsajanikar and coot committed Oct 13, 2022
1 parent 6b81d7c commit d073a5b
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 42 deletions.
33 changes: 28 additions & 5 deletions io-sim/src/Control/Monad/IOSim/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -926,19 +926,42 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =

ThrowStm e ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Revert all the TVar writes
-- Rollback `TVar`s written since catch handler was installed
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted [] (toException e)
case ctl of
AtomicallyFrame -> do
k0 $ StmTxAborted (Map.elems read) (toException e)

BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Execute the left side in a new frame with an empty written set.
-- Rollback `TVar`s written since catch handler was installed,
-- but preserve ones that were set prior to it, as specified in the
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid (h e)
--
BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)

CatchStm a h k ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Execute the catch handler with an empty written set.
-- but preserve ones that were set prior to it, as specified in the
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
go ctl' read Map.empty [] [] nextVid a


Retry ->
{-# SCC "execAtomically.go.Retry" #-}
do
{-# SCC "execAtomically.go.Retry" #-} do
-- Always revert all the TVar writes for the retry
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
case ctl of
AtomicallyFrame -> do
-- Return vars read, so the thread can block on them
k0 $! StmTxBlocked $! (Map.elems read)
k0 $! StmTxBlocked $! Map.elems read

BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
Expand Down
34 changes: 31 additions & 3 deletions io-sim/src/Control/Monad/IOSim/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ runSTM (STM k) = k ReturnStm
data StmA s a where
ReturnStm :: a -> StmA s a
ThrowStm :: SomeException -> StmA s a
CatchStm :: StmA s a -> (SomeException -> StmA s a) -> (a -> StmA s b) -> StmA s b

NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b
LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b
Expand Down Expand Up @@ -322,6 +323,32 @@ instance MonadThrow (STM s) where
instance Exceptions.MonadThrow (STM s) where
throwM = MonadThrow.throwIO


instance MonadCatch (STM s) where

catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . fromHandler handler) k
where
-- Get a total handler from the given handler
fromHandler :: Exception e => (e -> STM s a) -> SomeException -> STM s a
fromHandler h e = case fromException e of
Nothing -> throwIO e -- Rethrow the exception if handler does not handle it.
Just e' -> h e'

-- STM actions are always run inside `execAtomically` and behave as if masked
-- Another point to note that the default implementation of `generalBracket` needs
-- mask, and is part of `MonadThrow`. For STM, we don't need masking because
-- async exceptions are handled outside of `execAtomically`.
generalBracket acquire release use = do
resource <- acquire
b <- use resource `catch` \e -> do
_ <- release resource (ExitCaseException e)
throwIO e
c <- release resource (ExitCaseSuccess b)
return (b, c)

instance Exceptions.MonadCatch (STM s) where
catch = MonadThrow.catch

instance MonadCatch (IOSim s) where
catch action handler =
IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k
Expand Down Expand Up @@ -857,9 +884,10 @@ data StmTxResult s a =
| StmTxAborted [SomeTVar s] SomeException


-- | OrElse/Catch give rise to an alternate right hand side branch. A right branch
-- can be a NoOp
data BranchStmA s a = OrElseStmA (StmA s a) | NoOpStmA
-- | A branch is an alternative of a `OrElse` or a `CatchStm` statement
data BranchStmA s a = OrElseStmA (StmA s a)
| CatchStmA (SomeException -> StmA s a)
| NoOpStmA

data StmStack s b a where
-- | Executing in the context of a top level 'atomically'.
Expand Down
67 changes: 43 additions & 24 deletions io-sim/src/Control/Monad/IOSimPOR/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1174,32 +1174,51 @@ execAtomically time tid tlbl nextVid0 action0 k0 =
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Revert all the TVar writes
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted (Map.elems read) (toException e)
case ctl of
AtomicallyFrame -> do
k0 $ StmTxAborted (Map.elems read) (toException e)

BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Execute the left side in a new frame with an empty written set.
-- Rollback `TVar`s written since catch handler was installed,
-- but preserve ones that were set prior to it, as specified in the
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid (h e)

BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)

CatchStm a h k ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Execute the left side in a new frame with an empty written set
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
go ctl' read Map.empty [] [] nextVid a

Retry ->
{-# SCC "execAtomically.go.Retry" #-}
do
-- Always revert all the TVar writes for the retry
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
case ctl of
AtomicallyFrame -> do
-- Return vars read, so the thread can block on them
k0 $! StmTxBlocked $! Map.elems read

BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
-- Execute the orElse right hand with an empty written set
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid b

BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Retry makes sense only within a OrElse context. If it is a branch other than
-- OrElse left side, then bubble up the `retry` to the frame above.
-- Skip the continuation and propagate the retry into the outer frame
-- using the written set for the outer frame
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry
{-# SCC "execAtomically.go.Retry" #-} do
-- Always revert all the TVar writes for the retry
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
case ctl of
AtomicallyFrame -> do
-- Return vars read, so the thread can block on them
k0 $! StmTxBlocked $! Map.elems read

BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
-- Execute the orElse right hand with an empty written set
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid b

BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Retry makes sense only within a OrElse context. If it is a branch other than
-- OrElse left side, then bubble up the `retry` to the frame above.
-- Skip the continuation and propagate the retry into the outer frame
-- using the written set for the outer frame
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

OrElse a b k ->
{-# SCC "execAtomically.go.OrElse" #-} do
Expand Down
2 changes: 1 addition & 1 deletion io-sim/test/Test/IOSim.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1221,7 +1221,7 @@ prop_stm_referenceSim t =
-- | Compare the behaviour of the STM reference operational semantics with
-- the behaviour of any 'MonadSTM' STM implementation.
--
prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m)
prop_stm_referenceM :: (MonadSTM m, MonadCatch (STM m), MonadCatch m)
=> SomeTerm -> m Property
prop_stm_referenceM (SomeTerm _tyrep t) = do
let (r1, _heap) = evalAtomically t
Expand Down
61 changes: 52 additions & 9 deletions io-sim/test/Test/STM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ data Term (t :: Type) where

Return :: Expr t -> Term t
Throw :: Expr a -> Term t
Catch :: Term t -> Term t -> Term t
Retry :: Term t

ReadTVar :: Name (TyVar t) -> Term t
Expand Down Expand Up @@ -297,7 +298,7 @@ deriving instance Show (NfTerm t)
-- | The STM transition rules. They reduce a 'Term' to a normal-form 'NfTerm'.
--
-- Compare the implementation of this against the operational semantics in
-- Figure 4 in the paper. Note that @catch@ is not included.
-- Figure 4 in the paper including the `Catch` semantics from the Appendix A.
--
evalTerm :: Env -> Heap -> Allocs -> Term t -> (NfTerm t, Heap, Allocs)
evalTerm !env !heap !allocs term = case term of
Expand All @@ -310,6 +311,30 @@ evalTerm !env !heap !allocs term = case term of
where
e' = evalExpr env e

-- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of
-- <https://research.microsoft.com/en-us/um/people/simonpj/papers/stm/stm.pdf>
Catch t1 t2 ->
let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of

-- Rule XSTM1
-- M; heap, {} => return P; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[return P]; heap', allocs'
NfReturn v -> (NfReturn v, heap', allocs <> allocs')

-- Rule XSTM2
-- M; heap, {} => throw P; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[N P]; heap U allocs', allocs U allocs'
NfThrow _ -> evalTerm env (heap <> allocs') (allocs <> allocs') t2

-- Rule XSTM3
-- M; heap, {} => retry; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[retry]; heap, allocs
NfRetry -> (NfRetry, heap, allocs)


Retry -> (NfRetry, heap, allocs)

-- Rule READ
Expand Down Expand Up @@ -438,7 +463,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) =

-- | Execute an STM 'Term' in the 'STM' monad.
--
execTerm :: (MonadSTM m, MonadThrow (STM m))
execTerm :: (MonadSTM m, MonadCatch (STM m))
=> ExecEnv m
-> Term t
-> STM m (ExecValue m t)
Expand All @@ -452,6 +477,8 @@ execTerm env t =
let e' = execExpr env e
throwSTM =<< snapshotExecValue e'

Catch t1 t2 -> execTerm env t1 `catch` \(_ :: ImmValue) -> execTerm env t2

Retry -> retry

ReadTVar n -> do
Expand Down Expand Up @@ -492,7 +519,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x)
snapshotExecValue (ExecValVar v _) = fmap ImmValVar
(snapshotExecValue =<< readTVar v)

execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m)
execAtomically :: forall m t. (MonadSTM m, MonadCatch (STM m), MonadCatch m)
=> Term t -> m TxResult
execAtomically t =
toTxResult <$> try (atomically action')
Expand Down Expand Up @@ -658,7 +685,7 @@ genTerm env tyrep =
Nothing)
]

binTerm = frequency [ (2, bindTerm), (1, orElseTerm)]
binTerm = frequency [ (2, bindTerm), (1, orElseTerm), (1, catchTerm)]

bindTerm =
sized $ \sz -> do
Expand All @@ -672,10 +699,15 @@ genTerm env tyrep =
return (Bind t1 name t2)

orElseTerm =
sized $ \sz -> resize (sz `div` 2) $
scale (`div` 2) $
OrElse <$> genTerm env tyrep
<*> genTerm env tyrep

catchTerm =
scale (`div` 2) $
Catch <$> genTerm env tyrep
<*> genTerm env tyrep

genSomeExpr :: GenEnv -> Gen SomeExpr
genSomeExpr env =
oneof'
Expand Down Expand Up @@ -714,6 +746,8 @@ shrinkTerm t =
case t of
Return e -> [Return e' | e' <- shrinkExpr e]
Throw e -> [Throw e' | e' <- shrinkExpr e]
Catch t1 t2 -> [t1, t2]
++ [Catch t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2)]
Retry -> []
ReadTVar _ -> []

Expand All @@ -722,12 +756,10 @@ shrinkTerm t =
NewTVar e -> [NewTVar e' | e' <- shrinkExpr e]

Bind t1 n t2 -> [ t2 | nameId n `Set.notMember` freeNamesTerm t2 ]
++ [ Bind t1' n t2 | t1' <- shrinkTerm t1 ]
++ [ Bind t1 n t2' | t2' <- shrinkTerm t2 ]
++ [ Bind t1' n t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]

OrElse t1 t2 -> [t1, t2]
++ [ OrElse t1' t2 | t1' <- shrinkTerm t1 ]
++ [ OrElse t1 t2' | t2' <- shrinkTerm t2 ]
++ [ OrElse t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]

shrinkExpr :: Expr t -> [Expr t]
shrinkExpr ExprUnit = []
Expand All @@ -739,6 +771,12 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = []
freeNamesTerm :: Term t -> Set NameId
freeNamesTerm (Return e) = freeNamesExpr e
freeNamesTerm (Throw e) = freeNamesExpr e
-- A catch handler should actually have an argument, and then the implementation
-- should handle it. But since current implementation of catch never binds the
-- variable, the following implementation is correct as of now. It needs to be
-- tackled once nested exceptions are handled.
-- TODO: Correctly handle free names when the handler also binds a variable.
freeNamesTerm (Catch t1 t2) = freeNamesTerm t1 <> freeNamesTerm t2
freeNamesTerm Retry = Set.empty
freeNamesTerm (ReadTVar n) = Set.singleton (nameId n)
freeNamesTerm (WriteTVar n e) = Set.singleton (nameId n) <> freeNamesExpr e
Expand Down Expand Up @@ -769,6 +807,7 @@ prop_genSomeTerm (SomeTerm tyrep term) =
termSize :: Term a -> Int
termSize Return{} = 1
termSize Throw{} = 1
termSize (Catch a b) = 1 + termSize a + termSize b
termSize Retry{} = 1
termSize ReadTVar{} = 1
termSize WriteTVar{} = 1
Expand All @@ -779,6 +818,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b
termDepth :: Term a -> Int
termDepth Return{} = 1
termDepth Throw{} = 1
termDepth (Catch a b) = 1 + max (termDepth a) (termDepth b)
termDepth Retry{} = 1
termDepth ReadTVar{} = 1
termDepth WriteTVar{} = 1
Expand All @@ -791,6 +831,9 @@ showTerm p (Return e) = showParen (p > 10) $
showString "return " . showExpr 11 e
showTerm p (Throw e) = showParen (p > 10) $
showString "throwSTM " . showExpr 11 e
showTerm p (Catch t1 t2) = showParen (p > 9) $
showTerm 10 t1 . showString " `catch` "
. showTerm 10 t2
showTerm _ Retry = showString "retry"
showTerm p (ReadTVar n) = showParen (p > 10) $
showString "readTVar " . showName n
Expand Down

0 comments on commit d073a5b

Please sign in to comment.