diff --git a/io-sim/src/Control/Monad/IOSim/Internal.hs b/io-sim/src/Control/Monad/IOSim/Internal.hs index 71f8f87d..2ae72dd6 100644 --- a/io-sim/src/Control/Monad/IOSim/Internal.hs +++ b/io-sim/src/Control/Monad/IOSim/Internal.hs @@ -926,13 +926,33 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 = ThrowStm e -> {-# SCC "execAtomically.go.ThrowStm" #-} do - -- Revert all the TVar writes + -- Rollback `TVar`s written since catch handler was installed !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written - k0 $ StmTxAborted [] (toException e) + case ctl of + AtomicallyFrame -> do + k0 $ StmTxAborted (Map.elems read) (toException e) + + BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame" #-} do + -- Execute the left side in a new frame with an empty written set + let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl' + go ctl'' read Map.empty [] [] nextVid (h e) + -- + BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame" #-} do + go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e) + + CatchStm a h k -> + {-# SCC "execAtomically.go.ThrowStm" #-} do + -- Execute the catch handler with an empty written set. + -- but preserve ones that were set prior to it, as specified in the + -- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package. + let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl + go ctl' read Map.empty [] [] nextVid a + Retry -> - {-# SCC "execAtomically.go.Retry" #-} - do + {-# SCC "execAtomically.go.Retry" #-} do -- Always revert all the TVar writes for the retry !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written case ctl of diff --git a/io-sim/src/Control/Monad/IOSim/Types.hs b/io-sim/src/Control/Monad/IOSim/Types.hs index f56539db..916ce0cb 100644 --- a/io-sim/src/Control/Monad/IOSim/Types.hs +++ b/io-sim/src/Control/Monad/IOSim/Types.hs @@ -184,6 +184,7 @@ runSTM (STM k) = k ReturnStm data StmA s a where ReturnStm :: a -> StmA s a ThrowStm :: SomeException -> StmA s a + CatchStm :: StmA s a -> (SomeException -> StmA s a) -> (a -> StmA s b) -> StmA s b NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b @@ -322,6 +323,29 @@ instance MonadThrow (STM s) where instance Exceptions.MonadThrow (STM s) where throwM = MonadThrow.throwIO + +instance MonadCatch (STM s) where + + catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . fromHandler handler) k + where + -- Get a total handler from the given handler + fromHandler :: Exception e => (e -> STM s a) -> SomeException -> STM s a + fromHandler h e = case fromException e of + Nothing -> throwIO e -- Rethrow the exception if handler does not handle it. + Just e' -> h e' + + -- No need to consider masking for STM + generalBracket acquire release use = do + resource <- acquire + b <- use resource `catch` \e -> do + _ <- release resource (ExitCaseException e) + throwIO e + c <- release resource (ExitCaseSuccess b) + return (b, c) + +instance Exceptions.MonadCatch (STM s) where + catch = MonadThrow.catch + instance MonadCatch (IOSim s) where catch action handler = IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k @@ -857,9 +881,11 @@ data StmTxResult s a = | StmTxAborted [SomeTVar s] SomeException --- | OrElse/Catch give rise to an alternate right hand side branch. A right branch --- can be a NoOp -data BranchStmA s a = OrElseStmA (StmA s a) | NoOpStmA +-- | OrElse/Catch give rise to an alternate branch. +-- A branch of a branch is an empty one. +data BranchStmA s a = OrElseStmA (StmA s a) + | CatchStmA (SomeException -> StmA s a) + | NoOpStmA data StmStack s b a where -- | Executing in the context of a top level 'atomically'. diff --git a/io-sim/src/Control/Monad/IOSimPOR/Internal.hs b/io-sim/src/Control/Monad/IOSimPOR/Internal.hs index 7f244371..c913bda1 100644 --- a/io-sim/src/Control/Monad/IOSimPOR/Internal.hs +++ b/io-sim/src/Control/Monad/IOSimPOR/Internal.hs @@ -1174,32 +1174,47 @@ execAtomically time tid tlbl nextVid0 action0 k0 = {-# SCC "execAtomically.go.ThrowStm" #-} do -- Revert all the TVar writes !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written - k0 $ StmTxAborted (Map.elems read) (toException e) + case ctl of + AtomicallyFrame -> do + k0 $ StmTxAborted (Map.elems read) (toException e) + + BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame" #-} do + let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl' + go ctl'' read Map.empty [] [] nextVid (h e) + -- + BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame" #-} do + go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e) + + CatchStm a h k -> + {-# SCC "execAtomically.go.ThrowStm" #-} do + -- Execute the left side in a new frame with an empty written set + let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl + go ctl' read Map.empty [] [] nextVid a Retry -> - {-# SCC "execAtomically.go.Retry" #-} - do - -- Always revert all the TVar writes for the retry - !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written - case ctl of - AtomicallyFrame -> do - -- Return vars read, so the thread can block on them - k0 $! StmTxBlocked $! Map.elems read - - BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> - {-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do - !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written - -- Execute the orElse right hand with an empty written set - let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl' - go ctl'' read Map.empty [] [] nextVid b - - BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> - {-# SCC "execAtomically.go.BranchFrame" #-} do - -- Retry makes sense only within a OrElse context. If it is a branch other than - -- OrElse left side, then bubble up the `retry` to the frame above. - -- Skip the continuation and propagate the retry into the outer frame - -- using the written set for the outer frame - go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry + {-# SCC "execAtomically.go.Retry" #-} do + -- Always revert all the TVar writes for the retry + !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written + case ctl of + AtomicallyFrame -> do + -- Return vars read, so the thread can block on them + k0 $! StmTxBlocked $! Map.elems read + + BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do + -- Execute the orElse right hand with an empty written set + let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl' + go ctl'' read Map.empty [] [] nextVid b + + BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame" #-} do + -- Retry makes sense only within a OrElse context. If it is a branch other than + -- OrElse left side, then bubble up the `retry` to the frame above. + -- Skip the continuation and propagate the retry into the outer frame + -- using the written set for the outer frame + go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry OrElse a b k -> {-# SCC "execAtomically.go.OrElse" #-} do diff --git a/io-sim/test/Test/IOSim.hs b/io-sim/test/Test/IOSim.hs index 4ae7658a..b617868f 100644 --- a/io-sim/test/Test/IOSim.hs +++ b/io-sim/test/Test/IOSim.hs @@ -1221,7 +1221,7 @@ prop_stm_referenceSim t = -- | Compare the behaviour of the STM reference operational semantics with -- the behaviour of any 'MonadSTM' STM implementation. -- -prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m) +prop_stm_referenceM :: (MonadSTM m, MonadCatch (STM m), MonadCatch m) => SomeTerm -> m Property prop_stm_referenceM (SomeTerm _tyrep t) = do let (r1, _heap) = evalAtomically t diff --git a/io-sim/test/Test/STM.hs b/io-sim/test/Test/STM.hs index d16ce71b..b9afd6df 100644 --- a/io-sim/test/Test/STM.hs +++ b/io-sim/test/Test/STM.hs @@ -24,7 +24,7 @@ module Test.STM where import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map -import Data.Maybe (fromMaybe, maybeToList) +import Data.Maybe import Data.Set (Set) import qualified Data.Set as Set import Data.Type.Equality @@ -68,6 +68,7 @@ data Term (t :: Type) where Return :: Expr t -> Term t Throw :: Expr a -> Term t + Catch :: Term t -> SomeException -> Term t -> Term t Retry :: Term t ReadTVar :: Name (TyVar t) -> Term t @@ -267,6 +268,7 @@ data ImmValue where ImmValVar :: ImmValue -> ImmValue deriving (Eq, Show) + -- | In the execution in real STM transactions are aborted by throwing an -- exception. -- @@ -297,7 +299,7 @@ deriving instance Show (NfTerm t) -- | The STM transition rules. They reduce a 'Term' to a normal-form 'NfTerm'. -- -- Compare the implementation of this against the operational semantics in --- Figure 4 in the paper. Note that @catch@ is not included. +-- Figure 4 in the paper including the `Catch` semantics from the Appendix A. -- evalTerm :: Env -> Heap -> Allocs -> Term t -> (NfTerm t, Heap, Allocs) evalTerm !env !heap !allocs term = case term of @@ -310,6 +312,37 @@ evalTerm !env !heap !allocs term = case term of where e' = evalExpr env e + -- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of + -- + Catch t1 exc t2 -> + let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of + + -- Rule XSTM1 + -- M; heap, {} => return P; heap', allocs' + -- -------------------------------------------------------- + -- S[catch M N]; heap, allocs => S[return P]; heap', allocs' + NfReturn v -> (NfReturn v, heap', allocs <> allocs') + + -- Rule XSTM2 + -- M; heap, {} => throw P; heap', allocs' + -- -------------------------------------------------------- + -- S[catch M N]; heap, allocs => S[N P]; heap U allocs', allocs U allocs' + NfThrow v -> + -- v should be compared to exception + case fromException exc of + -- TODO: Add eqValue for value + Just (ImmValInt 0) -> + evalTerm env (heap <> allocs') (allocs <> allocs') t2 + -- Exception is not handled, bubble it up + _otherwise -> (NfThrow v, heap, allocs) + + -- Rule XSTM3 + -- M; heap, {} => retry; heap', allocs' + -- -------------------------------------------------------- + -- S[catch M N]; heap, allocs => S[retry]; heap, allocs + NfRetry -> (NfRetry, heap, allocs) + + Retry -> (NfRetry, heap, allocs) -- Rule READ @@ -438,7 +471,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) = -- | Execute an STM 'Term' in the 'STM' monad. -- -execTerm :: (MonadSTM m, MonadThrow (STM m)) +execTerm :: (MonadSTM m, MonadCatch (STM m)) => ExecEnv m -> Term t -> STM m (ExecValue m t) @@ -452,6 +485,8 @@ execTerm env t = let e' = execExpr env e throwSTM =<< snapshotExecValue e' + Catch t1 _ t2 -> execTerm env t1 `catch` \(_ :: ImmValue) -> execTerm env t2 + Retry -> retry ReadTVar n -> do @@ -492,7 +527,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x) snapshotExecValue (ExecValVar v _) = fmap ImmValVar (snapshotExecValue =<< readTVar v) -execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m) +execAtomically :: forall m t. (MonadSTM m, MonadCatch (STM m), MonadCatch m) => Term t -> m TxResult execAtomically t = toTxResult <$> try (atomically action') @@ -658,7 +693,7 @@ genTerm env tyrep = Nothing) ] - binTerm = frequency [ (2, bindTerm), (1, orElseTerm)] + binTerm = frequency [ (2, bindTerm), (1, orElseTerm), (1, catchTerm)] bindTerm = sized $ \sz -> do @@ -672,10 +707,16 @@ genTerm env tyrep = return (Bind t1 name t2) orElseTerm = - sized $ \sz -> resize (sz `div` 2) $ + scale (`div` 2) $ OrElse <$> genTerm env tyrep <*> genTerm env tyrep + catchTerm = + scale (`div` 2) $ + Catch <$> genTerm env tyrep + <*> pure (toException $ ImmValInt 0) -- TODO: 0 is treated as an exception value, generalize it later + <*> genTerm env tyrep + genSomeExpr :: GenEnv -> Gen SomeExpr genSomeExpr env = oneof' @@ -714,6 +755,8 @@ shrinkTerm t = case t of Return e -> [Return e' | e' <- shrinkExpr e] Throw e -> [Throw e' | e' <- shrinkExpr e] + Catch t1 exc t2 -> [t1, t2] + ++ [Catch t1' exc t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2)] Retry -> [] ReadTVar _ -> [] @@ -722,12 +765,10 @@ shrinkTerm t = NewTVar e -> [NewTVar e' | e' <- shrinkExpr e] Bind t1 n t2 -> [ t2 | nameId n `Set.notMember` freeNamesTerm t2 ] - ++ [ Bind t1' n t2 | t1' <- shrinkTerm t1 ] - ++ [ Bind t1 n t2' | t2' <- shrinkTerm t2 ] + ++ [ Bind t1' n t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ] OrElse t1 t2 -> [t1, t2] - ++ [ OrElse t1' t2 | t1' <- shrinkTerm t1 ] - ++ [ OrElse t1 t2' | t2' <- shrinkTerm t2 ] + ++ [ OrElse t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ] shrinkExpr :: Expr t -> [Expr t] shrinkExpr ExprUnit = [] @@ -739,6 +780,7 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = [] freeNamesTerm :: Term t -> Set NameId freeNamesTerm (Return e) = freeNamesExpr e freeNamesTerm (Throw e) = freeNamesExpr e +freeNamesTerm (Catch t1 exc t2) = freeNamesTerm t1 <> freeNamesTerm t2 freeNamesTerm Retry = Set.empty freeNamesTerm (ReadTVar n) = Set.singleton (nameId n) freeNamesTerm (WriteTVar n e) = Set.singleton (nameId n) <> freeNamesExpr e @@ -769,6 +811,7 @@ prop_genSomeTerm (SomeTerm tyrep term) = termSize :: Term a -> Int termSize Return{} = 1 termSize Throw{} = 1 +termSize (Catch a _ b) = 1 + termSize a + termSize b termSize Retry{} = 1 termSize ReadTVar{} = 1 termSize WriteTVar{} = 1 @@ -779,6 +822,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b termDepth :: Term a -> Int termDepth Return{} = 1 termDepth Throw{} = 1 +termDepth (Catch a _ b) = 1 + max (termDepth a) (termDepth b) termDepth Retry{} = 1 termDepth ReadTVar{} = 1 termDepth WriteTVar{} = 1 @@ -791,6 +835,9 @@ showTerm p (Return e) = showParen (p > 10) $ showString "return " . showExpr 11 e showTerm p (Throw e) = showParen (p > 10) $ showString "throwSTM " . showExpr 11 e +showTerm p (Catch t1 _ t2) = showParen (p > 9) $ + showTerm 10 t1 . showString " `catch` " + . showTerm 10 t2 showTerm _ Retry = showString "retry" showTerm p (ReadTVar n) = showParen (p > 10) $ showString "readTVar " . showName n