Skip to content

Commit

Permalink
Ensure that chunks are no larger than chunk size (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
axman6 authored Jan 10, 2023
1 parent cf09e0f commit 4d1aa55
Showing 1 changed file with 79 additions and 63 deletions.
142 changes: 79 additions & 63 deletions src/Amazonka/S3/StreamingUpload.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,27 @@ import Data.Conduit ( ConduitT, Void, await, handleC, yield, (.|) )
import Data.Conduit.Combinators ( sinkList )
import Data.Conduit.Combinators qualified as CC

import Data.ByteString ( ByteString )
import Data.ByteString qualified as BS
import Data.ByteString.Builder ( Builder, stringUtf8 )
import Data.ByteString.Builder.Extra ( Next(..), byteStringCopy, runBuilder )
import Data.List ( unfoldr )
import Data.List.NonEmpty ( fromList, nonEmpty )
import Data.Text ( Text )
import Data.ByteString.Builder ( stringUtf8 )
import Data.ByteString.Builder.Extra ( byteStringCopy, runBuilder )
import Data.ByteString.Internal ( ByteString(PS) )

import Data.List ( unfoldr )
import Data.List.NonEmpty ( fromList, nonEmpty )
import Data.Text ( Text )

import Control.Concurrent ( newQSem, signalQSem, waitQSem )
import Control.Concurrent.Async ( forConcurrently )
import Control.Exception.Base ( SomeException, bracket_ )

import Data.ByteString qualified as B
import Data.ByteString.Internal ( ByteString(PS), toForeignPtr )
import Foreign.ForeignPtr ( mallocForeignPtrBytes )
import Foreign.ForeignPtr ( ForeignPtr, mallocForeignPtrBytes, plusForeignPtr )
import Foreign.ForeignPtr.Unsafe ( unsafeForeignPtrToPtr )
import GHC.ForeignPtr ( finalizeForeignPtr )

import Control.DeepSeq ( rwhnf, (<$!!>) )
import Control.DeepSeq ( rwhnf )
import Data.Foldable ( for_, traverse_ )
import Data.Typeable ( Typeable )
import Data.Word ( Word8 )
import Control.Monad ((>=>))


type ChunkSize = Int
Expand Down Expand Up @@ -102,17 +102,21 @@ uploads - it is important to abort multipart uploads because you will
be charged for storage of the parts until it is completed or aborted.
See the AWS documentation for more details.
Internally, a single @chunkSize@d buffer will be allocated and reused between
requests to avoid holding onto incoming @ByteString@s.
May throw 'Amazonka.Error'
-}
streamUpload :: forall m. (MonadUnliftIO m, MonadResource m)
=> Env
-> Maybe ChunkSize -- ^ Optional chunk size
-> CreateMultipartUpload -- ^ Upload location
-> ConduitT ByteString Void m (Either (AbortMultipartUploadResponse, SomeException) CompleteMultipartUploadResponse)
streamUpload env mChunkSize multiPartUploadDesc@CreateMultipartUpload'{bucket = buck, key = k} =
processAndChunkOutputRaw chunkSize
.| enumerateConduit
.| startUpload
streamUpload env mChunkSize multiPartUploadDesc@CreateMultipartUpload'{bucket = buck, key = k} = do
buffer <- liftIO $ allocBuffer chunkSize
unsafeWriteChunksToBuffer buffer
.| enumerateConduit
.| startUpload buffer
where
chunkSize :: ChunkSize
chunkSize = maybe minimumChunkSize (max minimumChunkSize) mChunkSize
Expand All @@ -121,26 +125,26 @@ streamUpload env mChunkSize multiPartUploadDesc@CreateMultipartUpload'{bucket =
logStr msg = do
liftIO $ logger env Debug $ stringUtf8 msg

startUpload :: ConduitT (Int, S) Void m
startUpload :: Buffer
-> ConduitT (Int, BufferResult) Void m
(Either (AbortMultipartUploadResponse, SomeException)
CompleteMultipartUploadResponse)
startUpload = do
startUpload buffer = do
CreateMultipartUploadResponse'{uploadId = upId} <- lift $ send env multiPartUploadDesc
lift $ logStr "\n**** Created upload\n"

handleC (cancelMultiUploadConduit upId) $
CC.mapM (multiUpload upId)
CC.mapM (multiUpload buffer upId)
.| finishMultiUploadConduit upId

multiUpload :: Text -> (Int, S) -> m (Maybe CompletedPart)
multiUpload upId (partnum, s) = do
buffer <- liftIO $ finaliseS s
let (fptr,_,_) = toForeignPtr buffer
UploadPartResponse'{eTag} <- send env $! newUploadPart buck k partnum upId $! toBody $! (HashedBytes $! hash buffer) buffer
multiUpload :: Buffer -> Text -> (Int, BufferResult) -> m (Maybe CompletedPart)
multiUpload buffer upId (partnum, result) = do
let !bs = bufferToByteString buffer result
!bsHash = hash bs
UploadPartResponse'{eTag} <- send env $! newUploadPart buck k partnum upId $! toBody $! HashedBytes bsHash bs
let !_ = rwhnf eTag
liftIO $ finalizeForeignPtr fptr
logStr $ "\n**** Uploaded part " <> show partnum
return $! newCompletedPart partnum <$!!> eTag
return $! newCompletedPart partnum <$> eTag

-- collect all the parts
finishMultiUploadConduit :: Text
Expand Down Expand Up @@ -173,6 +177,56 @@ streamUpload env mChunkSize multiPartUploadDesc@CreateMultipartUpload'{bucket =
loop (i + 1)
{-# INLINE enumerateConduit #-}

-- The number of bytes remaining in a buffer, and the pointer that backs it.
data Buffer = Buffer {remaining :: !Int, _fptr :: !(ForeignPtr Word8)}

data PutResult
= Ok Buffer -- Didn't fill the buffer, updated buffer.
| Full ByteString -- Buffer is full, the unwritten remaining string.

data BufferResult = FullBuffer | Incomplete Int

-- Accepts @ByteString@s and writes them into @Buffer@. When the buffer is full,
-- @FullBuffer@ is emitted. If there is no more input, @Incomplete@ is emitted with
-- the number of bytes remaining in the buffer.
unsafeWriteChunksToBuffer :: MonadIO m => Buffer -> ConduitT ByteString BufferResult m ()
unsafeWriteChunksToBuffer buffer0 = awaitLoop buffer0 where
awaitLoop buf =
await >>= maybe (yield $ Incomplete $ remaining buf)
(liftIO . putBuffer buf >=> \case
Full next -> yield FullBuffer *> chunkLoop buffer0 next
Ok buf' -> awaitLoop buf'
)
-- Handle inputs which are larger than the chunkSize
chunkLoop buf = liftIO . putBuffer buf >=> \case
Full next -> yield FullBuffer *> chunkLoop buffer0 next
Ok buf' -> awaitLoop buf'

bufferToByteString :: Buffer -> BufferResult -> ByteString
bufferToByteString (Buffer bufSize fptr) FullBuffer = PS fptr 0 bufSize
bufferToByteString (Buffer bufSize fptr) (Incomplete remaining) = PS fptr 0 (bufSize - remaining)

allocBuffer :: Int -> IO Buffer
allocBuffer chunkSize = Buffer chunkSize <$> mallocForeignPtrBytes chunkSize

putBuffer :: Buffer -> ByteString -> IO PutResult
putBuffer buffer bs
| BS.length bs <= remaining buffer =
Ok <$> unsafeWriteBuffer buffer bs
| otherwise = do
let (remainder,rest) = BS.splitAt (remaining buffer) bs
Full rest <$ unsafeWriteBuffer buffer remainder

-- The length of the bytestring must be less than or equal to the number
-- of bytes remaining.
unsafeWriteBuffer :: Buffer -> ByteString -> IO Buffer
unsafeWriteBuffer (Buffer remaining fptr) bs = do
let ptr = unsafeForeignPtrToPtr fptr
len = BS.length bs
_ <- runBuilder (byteStringCopy bs) ptr remaining
pure $ Buffer (remaining - len) (plusForeignPtr fptr len)


-- | Specifies whether to upload a file or 'ByteString'.
data UploadLocation
= FP FilePath -- ^ A file to be uploaded
Expand Down Expand Up @@ -276,41 +330,3 @@ nothingWhen f = justWhen (not . f)

chunksOf :: Int -> BS.ByteString -> [BS.ByteString]
chunksOf x = unfoldr (nothingWhen BS.null (BS.splitAt x))

-- | A bytestring `Builder` stored with the size of buffer it needs to be fully evaluated.
data S = S !Builder {-# UNPACK #-} !Int

newS :: S
newS = S mempty 0

newSFrom :: ByteString -> S
newSFrom bs = S (byteStringCopy bs) (B.length bs)

appendS :: S -> ByteString -> S
appendS (S builder len) bs = S (builder <> byteStringCopy bs) (len + B.length bs)

finaliseS :: S -> IO ByteString
finaliseS (S builder builderLen) = do
fptr <- mallocForeignPtrBytes builderLen
let ptr = unsafeForeignPtrToPtr fptr
bufWriter = runBuilder builder
bufWriter ptr builderLen >>= \case
(written, Done)
| written == builderLen -> pure $! PS fptr 0 builderLen
| otherwise ->
error $ "finaliseS: bytes written didn't match, expected: " <> show builderLen <> " got: " <> show written
(_written, _) -> error "Something went very wrong"

-- Right means the buffer needs more data to fill it
-- Left means the buffer is full
processChunk :: ChunkSize -> ByteString -> S -> IO (Either S S)
processChunk chunkSize input s@(S _ builderLen)
| builderLen >= chunkSize = pure $! Left $! s
| otherwise = pure $! Right $! appendS s input

processAndChunkOutputRaw :: MonadIO m => ChunkSize -> ConduitT ByteString S m ()
processAndChunkOutputRaw chunkSize = loop newS where
loop !s = await >>=
maybe (yield s)
(\bs -> liftIO (processChunk chunkSize bs s) >>= either (\s' -> yield s' >> loop (newSFrom bs)) loop)

0 comments on commit 4d1aa55

Please sign in to comment.