Skip to content

Commit

Permalink
Use more convenient types (#8)
Browse files Browse the repository at this point in the history
A move away from the bare GHC.Exts style towards the primitive style.
This should make the functions a little more convenient to use.
  • Loading branch information
meooow25 authored May 3, 2024
1 parent f81f8c7 commit 4f670e6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 63 deletions.
15 changes: 7 additions & 8 deletions compare/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{-# LANGUAGE MagicHash #-}

import Control.DeepSeq (NFData(..), rwhnf)
import Control.Monad.Primitive (PrimMonad(..), RealWorld, primitive_)
import Control.Monad.Primitive (PrimMonad(..), RealWorld, stToPrim)
import Control.Monad.ST (stToIO)
import Data.Bits ((.&.))
import Data.Primitive.Array
Expand All @@ -26,7 +26,6 @@ import qualified Data.Vector.Algorithms.Merge as Merge
import qualified Data.Vector.Algorithms.Tim as Tim
import qualified Data.Vector.Primitive.Mutable as MPV
import qualified Data.Vector.Mutable as MV
import GHC.Exts (Int(..), sizeofMutableArray#)

import Criterion.Main (Benchmark, defaultMain, bench, bgroup, perRunEnv)

Expand Down Expand Up @@ -61,7 +60,7 @@ bgroupN f n = bgroup (show n) $

bgroupIOA :: Ord a => String -> IO (IOArray a) -> Benchmark
bgroupIOA name mkma = bgroup name
[ bench "samsort sortArrayBy#" $
[ bench "samsort sortArrayBy" $
perRunEnv (fmap WHNF mkma) $ \(WHNF ma) -> samSort ma
, bench "vector-algorithms Intro" $
perRunEnv mkmv $ \(WHNF mv) -> Intro.sort mv
Expand Down Expand Up @@ -92,7 +91,7 @@ bgroupPN n = bgroup (show n) $

bgroupIOPA :: String -> IO (IOPrimArray Int) -> Benchmark
bgroupIOPA name mkma = bgroup name
[ bench "samsort sortIntArrayBy#" $
[ bench "samsort sortIntArrayBy" $
perRunEnv (fmap WHNF mkma) $ \(WHNF ma) -> samSortInts ma
, bench "vector-algorithms Intro" $
perRunEnv mkmv $ \(WHNF mv) -> Intro.sort mv
Expand All @@ -111,13 +110,13 @@ bgroupIOPA name mkma = bgroup name
pure (WHNF (MPV.MVector 0 sz (MutableByteArray ma#)))

samSort :: (PrimMonad m, Ord a) => MutableArray (PrimState m) a -> m ()
samSort (MutableArray ma#) =
primitive_ $ Sam.sortArrayBy# compare ma# 0# (sizeofMutableArray# ma#)
samSort ma@(MutableArray ma#) =
stToPrim $ Sam.sortArrayBy compare ma# 0 (sizeofMutableArray ma)

samSortInts :: PrimMonad m => MutablePrimArray (PrimState m) Int -> m ()
samSortInts ma@(MutablePrimArray ma#) = do
I# sz# <- getSizeofMutablePrimArray ma
primitive_ $ Sam.sortIntArrayBy# (\x# y# -> compare (I# x#) (I# y#)) ma# 0# sz#
sz <- getSizeofMutablePrimArray ma
stToPrim $ Sam.sortIntArrayBy compare ma# 0 sz

---------
-- Data
Expand Down
65 changes: 29 additions & 36 deletions src/Data/SamSort.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
-- https://arxiv.org/abs/1801.04641
--
module Data.SamSort
( sortArrayBy#
, sortIntArrayBy#
( sortArrayBy
, sortIntArrayBy
) where

import Control.Monad (when)
Expand All @@ -29,7 +29,6 @@ import GHC.Exts
, Int(..)
, MutableArray#
, MutableByteArray#
, State#
, (*#)
, copyMutableArray#
, copyMutableByteArray#
Expand All @@ -56,27 +55,24 @@ import GHC.Exts
-- known comparison functions. To avoid code duplication, create a wrapping
-- definition and reuse it as necessary.
--
sortArrayBy#
sortArrayBy
:: (a -> a -> Ordering) -- ^ comparison
-> MutableArray# s a
-> Int# -- ^ offset
-> Int# -- ^ length
-> State# s
-> State# s
sortArrayBy# cmp = -- Inline with 1 arg
\ma# off# len# s ->
case sortArrayByST cmp (MA ma#) (I# off#) (I# len#) of
ST f -> case f s of (# s1, _ #) -> s1
{-# INLINE sortArrayBy# #-}

sortArrayByST
-> Int -- ^ offset
-> Int -- ^ length
-> ST s ()
sortArrayBy cmp = -- Inline with 1 arg
\ma# !off !len -> sortArrayBy' cmp (MA ma#) off len
{-# INLINE sortArrayBy #-}

sortArrayBy'
:: (a -> a -> Ordering)
-> MA s a
-> Int
-> Int
-> ST s ()
sortArrayByST _ !_ !_ len | len < 2 = pure ()
sortArrayByST cmp ma off len = do
sortArrayBy' _ !_ !_ len | len < 2 = pure ()
sortArrayBy' cmp ma off len = do
-- See Note [Algorithm overview]

!swp <- newA (len `shiftR` 1) errorElement
Expand Down Expand Up @@ -191,7 +187,7 @@ sortArrayByST cmp ma off len = do
!end = off + len

getRun = mkGetRun lt (readA ma) (writeA ma) (reverseA ma) end
{-# INLINE sortArrayByST #-}
{-# INLINE sortArrayBy' #-}

-- | \(O(n \log n)\). Sort a slice of a @MutableByteArray#@ interpreted as an
-- array of @Int#@s using a comparison function.
Expand All @@ -208,27 +204,24 @@ sortArrayByST cmp ma off len = do
-- known comparison functions. To avoid code duplication, create a wrapping
-- definition and reuse it as necessary.
--
sortIntArrayBy#
:: (Int# -> Int# -> Ordering) -- ^ comparison
sortIntArrayBy
:: (Int -> Int -> Ordering) -- ^ comparison
-> MutableByteArray# s
-> Int# -- ^ offset in @Int#@s
-> Int# -- ^ length in @Int#@s
-> State# s
-> State# s
sortIntArrayBy# cmp = -- Inline with 1 arg
\ma# off# len# s ->
case sortIntArrayByST cmp (MIA ma#) (I# off#) (I# len#) of
ST f -> case f s of (# s1, _ #) -> s1
{-# INLINE sortIntArrayBy# #-}

sortIntArrayByST
:: (Int# -> Int# -> Ordering)
-> Int -- ^ offset in @Int#@s
-> Int -- ^ length in @Int#@s
-> ST s ()
sortIntArrayBy cmp = -- Inline with 1 arg
\ma# !off !len -> sortIntArrayBy' cmp (MIA ma#) off len
{-# INLINE sortIntArrayBy #-}

sortIntArrayBy'
:: (Int -> Int -> Ordering)
-> MIA s
-> Int
-> Int
-> ST s ()
sortIntArrayByST _ !_ !_ len | len < 2 = pure ()
sortIntArrayByST cmp ma off len = do
sortIntArrayBy' _ !_ !_ len | len < 2 = pure ()
sortIntArrayBy' cmp ma off len = do
-- See Note [Algorithm overview]

!swp <- newI (len `shiftR` 1)
Expand Down Expand Up @@ -306,7 +299,7 @@ sortIntArrayByST cmp ma off len = do
mergeStrategy merge getRun stk off end

where
lt (I# x#) (I# y#) = case cmp x# y# of LT -> True; _ -> False
lt !x !y = case cmp x y of LT -> True; _ -> False
{-# INLINE lt #-}
-- Note: Use lt instead of gt. Why? Because `compare` for types like Int and
-- Word are defined in a way that needs one `<` op for LT but two (`<`,`==`)
Expand All @@ -315,7 +308,7 @@ sortIntArrayByST cmp ma off len = do
!end = off + len

getRun = mkGetRun lt (readI ma) (writeI ma) (reverseI ma) end
{-# INLINE sortIntArrayByST #-}
{-# INLINE sortIntArrayBy' #-}

mkGetRun
:: (a -> a -> Bool) -- comparison
Expand Down
26 changes: 7 additions & 19 deletions test/Main.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE UnboxedTuples #-}

import qualified Data.Foldable as F
import qualified Data.List as L
Expand All @@ -15,22 +13,20 @@ import Data.Primitive.PrimArray
, sizeofPrimArray
, thawPrimArray
)
import GHC.ST (ST(..))
import GHC.Exts (Int(..), Int#)

import Test.Tasty (defaultMain, localOption, testGroup)
import Test.Tasty.QuickCheck (QuickCheckTests(..), Fun, applyFun, testProperty, (===))
import Test.QuickCheck.Poly (A, OrdA)

import Data.SamSort (sortArrayBy#, sortIntArrayBy#)
import Data.SamSort (sortArrayBy, sortIntArrayBy)

main :: IO ()
main = defaultMain $ localOption (QuickCheckTests 5000) $ testGroup "Tests"
[ testProperty "sortArrayBy#" $ \xs ys zs ->
[ testProperty "sortArrayBy" $ \xs ys zs ->
sortViaMutableArray (comparing fst) (xs,ys,zs)
===
((xs :: [(OrdA, A)]) ++ L.sortBy (comparing fst) ys ++ zs)
, testProperty "sortIntArrayBy#" $ \f xs ys zs ->
, testProperty "sortIntArrayBy" $ \f xs ys zs ->
sortViaMutableIntArray
(comparing (applyFun (f :: Fun Int OrdA)))
(xs,ys,zs)
Expand All @@ -45,7 +41,8 @@ sortViaMutableArray
sortViaMutableArray cmp (xs,ys,zs) = F.toList $ runArray $ do
let a = arrayFromList (xs ++ ys ++ zs)
ma@(MutableArray ma#) <- thawArray a 0 (sizeofArray a)
ST $ \s -> case sortArrayBy# cmp ma# (len# xs) (len# ys) s of s1 -> (# s1, ma #)
sortArrayBy cmp ma# (length xs) (length ys)
pure ma

sortViaMutableIntArray
:: (Int -> Int -> Ordering)
Expand All @@ -54,14 +51,5 @@ sortViaMutableIntArray
sortViaMutableIntArray cmp (xs,ys,zs) = primArrayToList $ runPrimArray $ do
let a = primArrayFromList (xs ++ ys ++ zs)
ma@(MutablePrimArray ma#) <- thawPrimArray a 0 (sizeofPrimArray a)
ST $ \s ->
case sortIntArrayBy#
(\x# y# -> cmp (I# x#) (I# y#))
ma#
(len# xs)
(len# ys)
s of
s1 -> (# s1, ma #)

len# :: [a] -> Int#
len# as = case length as of I# n# -> n#
sortIntArrayBy cmp ma# (length xs) (length ys)
pure ma

0 comments on commit 4f670e6

Please sign in to comment.