diff --git a/inferno-vc/CHANGELOG.md b/inferno-vc/CHANGELOG.md index dc98625f..0e8a00f4 100644 --- a/inferno-vc/CHANGELOG.md +++ b/inferno-vc/CHANGELOG.md @@ -1,6 +1,11 @@ # Revision History for inferno-vc *Note*: we use https://pvp.haskell.org/ (MAJOR.MAJOR.MINOR.PATCH) +## 0.3.7.0 -- 2024-08-19 +* Cached client now serializes requests to server for the same script ids in + order to avoid DOSing the server when the same script is requested many times + simultaneously + ## 0.3.6.0 -- 2024-03-18 * HLint everything diff --git a/inferno-vc/inferno-vc.cabal b/inferno-vc/inferno-vc.cabal index 174922a1..6f6cc87e 100644 --- a/inferno-vc/inferno-vc.cabal +++ b/inferno-vc/inferno-vc.cabal @@ -1,6 +1,6 @@ cabal-version: >=1.10 name: inferno-vc -version: 0.3.6.0 +version: 0.3.7.0 synopsis: Version control server for Inferno description: A version control server for Inferno scripts category: DSL,Scripting @@ -67,6 +67,7 @@ library , atomic-write >= 0.2 && < 0.3 , hspec , QuickCheck + , stm default-language: Haskell2010 default-extensions: diff --git a/inferno-vc/src/Inferno/VersionControl/Client/Cached.hs b/inferno-vc/src/Inferno/VersionControl/Client/Cached.hs index 429013be..7f049a8c 100644 --- a/inferno-vc/src/Inferno/VersionControl/Client/Cached.hs +++ b/inferno-vc/src/Inferno/VersionControl/Client/Cached.hs @@ -1,18 +1,29 @@ +{-# LANGUAGE NamedFieldPuns #-} + module Inferno.VersionControl.Client.Cached - ( VCCachePath (..), + ( VCCacheEnv, CachedVCClientError (..), fetchVCObjectClosure, initVCCachedClient, ) where +import Control.Concurrent.STM + ( TVar, + atomically, + newTVarIO, + readTVar, + retry, + writeTVar, + ) import Control.Monad (forM, forM_) +import Control.Monad.Catch (MonadMask, bracket_) import Control.Monad.Error.Lens (throwing) import Control.Monad.Except (MonadError (..)) import Control.Monad.IO.Class (MonadIO (..)) import Control.Monad.Reader (MonadReader (..), asks) import Crypto.Hash (digestFromByteString) -import Data.Aeson (FromJSON, ToJSON, eitherDecode, encode) +import Data.Aeson (FromJSON, ToJSON, eitherDecodeStrict, encode) import qualified Data.ByteString as B import qualified Data.ByteString.Base64.URL as Base64 import qualified Data.ByteString.Char8 as Char8 @@ -20,7 +31,9 @@ import qualified Data.ByteString.Lazy as BL import Data.Either (partitionEithers) import Data.Generics.Product (HasType, getTyped) import Data.Generics.Sum (AsType (..)) +import Data.List (foldl') import qualified Data.Map as Map +import qualified Data.Set as Set import GHC.Generics (Generic) import qualified Inferno.VersionControl.Client as VCClient import Inferno.VersionControl.Operations.Error (VCStoreError (..)) @@ -37,7 +50,26 @@ import System.AtomicWrite.Writer.LazyByteString (atomicWriteFile) import System.Directory (createDirectoryIfMissing, doesFileExist) import System.FilePath.Posix (()) -newtype VCCachePath = VCCachePath FilePath deriving (Generic) +data VCCacheEnv = VCCacheEnv + { cachePath :: FilePath, + cacheInFlight :: TVar (Set.Set VCObjectHash) + } + deriving (Generic) + +-- | Makes sure only one thread at a time fetches the closure for certain +-- VCObjectHashes +withInFlight :: (MonadMask m, MonadIO m) => VCCacheEnv -> [VCObjectHash] -> m a -> m a +withInFlight VCCacheEnv {cacheInFlight} hashes = bracket_ acquire release + where + acquire = liftIO $ atomically $ do + inFlight <- readTVar cacheInFlight + if any (`Set.member` inFlight) hashes + then retry + else do + writeTVar cacheInFlight $! foldl' (flip Set.insert) inFlight hashes + release = liftIO $ atomically $ do + inFlight <- readTVar cacheInFlight + writeTVar cacheInFlight $! foldl' (flip Set.delete) inFlight hashes data CachedVCClientError = ClientVCStoreError VCServerError @@ -45,19 +77,22 @@ data CachedVCClientError | LocalVCStoreError VCStoreError deriving (Show, Generic) -initVCCachedClient :: VCCachePath -> IO () -initVCCachedClient (VCCachePath storePath) = - createDirectoryIfMissing True $ storePath "deps" +initVCCachedClient :: FilePath -> IO VCCacheEnv +initVCCachedClient cachePath = do + createDirectoryIfMissing True $ cachePath "deps" + cacheInFlight <- newTVarIO mempty + pure VCCacheEnv {cachePath, cacheInFlight} fetchVCObjectClosure :: ( MonadError err m, - HasType VCCachePath env, + HasType VCCacheEnv env, HasType ClientEnv env, AsType VCServerError err, AsType ClientError err, AsType VCStoreError err, MonadReader env m, MonadIO m, + MonadMask m, FromJSON a, FromJSON g, ToJSON a, @@ -68,51 +103,53 @@ fetchVCObjectClosure :: VCObjectHash -> m (Map.Map VCObjectHash (VCMeta a g VCObject)) fetchVCObjectClosure fetchVCObjects remoteFetchVCObjectClosureHashes objHash = do - VCCachePath storePath <- asks getTyped + env@VCCacheEnv {cachePath} <- asks getTyped deps <- - liftIO (doesFileExist $ storePath "deps" show objHash) >>= \case - False -> do - deps <- liftServantClient $ remoteFetchVCObjectClosureHashes objHash - liftIO - $ atomicWriteFile - (storePath "deps" show objHash) - $ BL.concat [BL.fromStrict (vcObjectHashToByteString h) <> "\n" | h <- deps] - pure deps - True -> fetchVCObjectClosureHashes objHash - (nonLocalHashes, localHashes) <- - partitionEithers - <$> forM - (objHash : deps) - ( \depHash -> do - liftIO (doesFileExist $ storePath show depHash) >>= \case - True -> pure $ Right depHash - False -> pure $ Left depHash - ) - localObjs <- - Map.fromList - <$> forM - localHashes - ( \h -> - (h,) <$> fetchVCObjectUnsafe h - ) + withInFlight env [objHash] $ + liftIO (doesFileExist $ cachePath "deps" show objHash) >>= \case + False -> do + deps <- liftServantClient $ remoteFetchVCObjectClosureHashes objHash + liftIO + $ atomicWriteFile + (cachePath "deps" show objHash) + $ BL.concat [BL.fromStrict (vcObjectHashToByteString h) <> "\n" | h <- deps] + pure deps + True -> fetchVCObjectClosureHashes objHash + withInFlight env deps $ do + (nonLocalHashes, localHashes) <- + partitionEithers + <$> forM + (objHash : deps) + ( \depHash -> do + liftIO (doesFileExist $ cachePath show depHash) >>= \case + True -> pure $ Right depHash + False -> pure $ Left depHash + ) + localObjs <- + Map.fromList + <$> forM + localHashes + ( \h -> + (h,) <$> fetchVCObjectUnsafe h + ) - nonLocalObjs <- liftServantClient $ fetchVCObjects nonLocalHashes - forM_ (Map.toList nonLocalObjs) $ \(h, o) -> - liftIO $ atomicWriteFile (storePath show h) $ encode o - pure $ localObjs `Map.union` nonLocalObjs + nonLocalObjs <- liftServantClient $ fetchVCObjects nonLocalHashes + forM_ (Map.toList nonLocalObjs) $ \(h, o) -> + liftIO $ atomicWriteFile (cachePath show h) $ encode o + pure $ localObjs `Map.union` nonLocalObjs fetchVCObjectClosureHashes :: ( MonadError err m, MonadIO m, MonadReader env m, AsType VCStoreError err, - HasType VCCachePath env + HasType VCCacheEnv env ) => VCObjectHash -> m [VCObjectHash] fetchVCObjectClosureHashes h = do - VCCachePath storePath <- asks getTyped - let fp = storePath "deps" show h + VCCacheEnv {cachePath} <- asks getTyped + let fp = cachePath "deps" show h readVCObjectHashTxt fp readVCObjectHashTxt :: @@ -130,7 +167,7 @@ readVCObjectHashTxt fp = do fetchVCObjectUnsafe :: ( MonadReader r m, - HasType VCCachePath r, + HasType VCCacheEnv r, MonadError e m, AsType VCStoreError e, MonadIO m, @@ -139,9 +176,9 @@ fetchVCObjectUnsafe :: VCObjectHash -> m b fetchVCObjectUnsafe h = do - VCCachePath storePath <- asks getTyped - let fp = storePath show h - either (throwing _Typed . CouldNotDecodeObject h) pure =<< liftIO (eitherDecode <$> BL.readFile fp) + VCCacheEnv {cachePath} <- asks getTyped + let fp = cachePath show h + either (throwing _Typed . CouldNotDecodeObject h) pure =<< liftIO (eitherDecodeStrict <$> Char8.readFile fp) liftServantClient :: ( MonadError e m,