Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define partitionKeys: fused version of restrictKeys and withoutKeys #975

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions containers-tests/benchmarks/Map.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ main = do
let m = M.fromAscList elems :: M.Map Int Int
m_even = M.fromAscList elems_even :: M.Map Int Int
m_odd = M.fromAscList elems_odd :: M.Map Int Int
m_odd_keys = M.keysSet m_odd
evaluate $ rnf [m, m_even, m_odd]
evaluate $ rnf elems_rev
evaluate $ rnf m_odd_keys
defaultMain
[ bench "lookup absent" $ whnf (lookup evens) m_odd
, bench "lookup present" $ whnf (lookup evens) m_even
Expand Down Expand Up @@ -95,8 +97,13 @@ main = do
, bench "fromDistinctDescList" $ whnf M.fromDistinctDescList elems_rev
, bench "fromDistinctDescList:fusion" $ whnf (\n -> M.fromDistinctDescList [(i,i) | i <- [n,n-1..1]]) bound
, bench "minView" $ whnf (\m' -> case M.minViewWithKey m' of {Nothing -> 0; Just ((k,v),m'') -> k+v+M.size m''}) (M.fromAscList $ zip [1..10::Int] [100..110::Int])

, bench "eq" $ whnf (\m' -> m' == m') m -- worst case, compares everything
, bench "compare" $ whnf (\m' -> compare m' m') m -- worst case, compares everything

, bench "restrictKeys" $ whnf (M.restrictKeys m) m_odd_keys
, bench "withoutKeys" $ whnf (M.withoutKeys m) m_odd_keys
, bench "partitionKeys" $ whnf (M.partitionKeys m) m_odd_keys
]
where
bound = 2^12
Expand Down
1 change: 1 addition & 0 deletions containers-tests/containers-tests.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ library
Utils.Containers.Internal.State
Utils.Containers.Internal.StrictMaybe
Utils.Containers.Internal.EqOrdUtil
Utils.Containers.Internal.StrictTriple

if impl(ghc)
other-modules:
Expand Down
7 changes: 7 additions & 0 deletions containers-tests/tests/map-properties.hs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ main = defaultMain $ testGroup "map-properties"
, testProperty "withoutKeys" prop_withoutKeys
, testProperty "intersection" prop_intersection
, testProperty "restrictKeys" prop_restrictKeys
, testProperty "partitionKeys" prop_partitionKeys
, testProperty "intersection model" prop_intersectionModel
, testProperty "intersectionWith" prop_intersectionWith
, testProperty "intersectionWithModel" prop_intersectionWithModel
Expand Down Expand Up @@ -1140,6 +1141,12 @@ prop_withoutKeys m s0 = valid reduced .&&. (m `withoutKeys` s === filterWithKey
s = keysSet s0
reduced = withoutKeys m s

prop_partitionKeys :: IMap -> IMap -> Property
prop_partitionKeys m s0 = valid with .&&. valid without .&&. (m `partitionKeys` s === (m `restrictKeys` s, m `withoutKeys` s))
where
s = keysSet s0
(with, without) = partitionKeys m s

prop_intersection :: IMap -> IMap -> Bool
prop_intersection t1 t2 = valid (intersection t1 t2)

Expand Down
1 change: 1 addition & 0 deletions containers/containers.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Library
Utils.Containers.Internal.PtrEquality
Utils.Containers.Internal.Coercions
Utils.Containers.Internal.EqOrdUtil
Utils.Containers.Internal.StrictTriple
if impl(ghc)
other-modules:
Utils.Containers.Internal.TypeError
Expand Down
49 changes: 47 additions & 2 deletions containers/src/Data/Map/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ module Data.Map.Internal (

, restrictKeys
, withoutKeys
, partitionKeys
, partition
, partitionWithKey

Expand Down Expand Up @@ -398,6 +399,7 @@ import qualified Data.Set.Internal as Set
import Data.Set.Internal (Set)
import Utils.Containers.Internal.PtrEquality (ptrEq)
import Utils.Containers.Internal.StrictPair
import Utils.Containers.Internal.StrictTriple
import Utils.Containers.Internal.StrictMaybe
import Utils.Containers.Internal.BitQueue
import Utils.Containers.Internal.EqOrdUtil (EqM(..), OrdM(..))
Expand Down Expand Up @@ -1966,6 +1968,51 @@ withoutKeys m (Set.Bin _ k ls rs) = case splitMember k m of
{-# INLINABLE withoutKeys #-}
#endif

-- | \(O\bigl(m \log\bigl(\frac{n}{m}+1\bigr)\bigr), \; 0 < m \leq n\). Partition the map according to a set.
-- The first map contains the input 'Map' restricted to those keys found in the 'Set',
-- the second map contains the input 'Map' without all keys in the 'Set'.
-- This is more efficient than using 'restrictKeys' and 'withoutKeys' together.
--
-- @
-- m \`partitionKeys\` s = (m ``restrictKeys`` s, m ``withoutKeys`` s)
-- @
partitionKeys :: Ord k => Map k a -> Set k -> (Map k a, Map k a)
partitionKeys xs ys =
case partitionKeysWorker xs ys of
xs' :*: ys' -> (xs', ys')
#if __GLASGOW_HASKELL__
{-# INLINABLE partitionKeys #-}
#endif

partitionKeysWorker :: Ord k => Map k a -> Set k -> StrictPair (Map k a) (Map k a)
partitionKeysWorker Tip _ = Tip :*: Tip
partitionKeysWorker m Set.Tip = Tip :*: m
partitionKeysWorker m@(Bin _ k x lm rm) [email protected]{} =
case b of
True -> with :*: without
where
with =
if lmWith `ptrEq` lm && rmWith `ptrEq` rm
then m
else link k x lmWith rmWith
without =
link2 lmWithout rmWithout
False -> with :*: without
where
with = link2 lmWith rmWith
without =
if lmWithout `ptrEq` lm && rmWithout `ptrEq` rm
then m
else link k x lmWithout rmWithout
where
!(lmWith :*: lmWithout) = partitionKeysWorker lm ls'
!(rmWith :*: rmWithout) = partitionKeysWorker rm rs'

!(!ls', b, !rs') = Set.splitMember k s
#if __GLASGOW_HASKELL__
{-# INLINABLE partitionKeysWorker #-}
#endif

-- | \(O(n+m)\). Difference with a combining function.
-- When two equal keys are
-- encountered, the combining function is applied to the values of these keys.
Expand Down Expand Up @@ -4004,8 +4051,6 @@ splitMember k0 m = case go k0 m of
{-# INLINABLE splitMember #-}
#endif

data StrictTriple a b c = StrictTriple !a !b !c

{--------------------------------------------------------------------
Utility functions that maintain the balance properties of the tree.
All constructors assume that all values in [l] < [k] and all values
Expand Down
1 change: 1 addition & 0 deletions containers/src/Data/Map/Lazy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ module Data.Map.Lazy (
, filterWithKey
, restrictKeys
, withoutKeys
, partitionKeys
, partition
, partitionWithKey
, takeWhileAntitone
Expand Down
1 change: 1 addition & 0 deletions containers/src/Data/Map/Strict.hs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ module Data.Map.Strict
, filterWithKey
, restrictKeys
, withoutKeys
, partitionKeys
, partition
, partitionWithKey

Expand Down
4 changes: 3 additions & 1 deletion containers/src/Data/Map/Strict/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ module Data.Map.Strict.Internal
, filterWithKey
, restrictKeys
, withoutKeys
, partitionKeys
, partition
, partitionWithKey
, takeWhileAntitone
Expand Down Expand Up @@ -418,7 +419,8 @@ import Data.Map.Internal
, toDescList
, union
, unions
, withoutKeys )
, withoutKeys
, partitionKeys )

import Data.Map.Internal.Debug (valid)

Expand Down
27 changes: 17 additions & 10 deletions containers/src/Data/Set/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ import Data.List.NonEmpty (NonEmpty(..))
#endif

import Utils.Containers.Internal.StrictPair
import Utils.Containers.Internal.StrictTriple
import Utils.Containers.Internal.PtrEquality
import Utils.Containers.Internal.EqOrdUtil (EqM(..), OrdM(..))

Expand Down Expand Up @@ -1430,19 +1431,25 @@ splitS x (Bin _ y l r)
EQ -> (l :*: r)
{-# INLINABLE splitS #-}

splitMemberS :: Ord a => a -> Set a -> StrictTriple (Set a) Bool (Set a)
splitMemberS x = go
where
go Tip = StrictTriple Tip False Tip
go (Bin _ y l r) = case compare x y of
LT -> let StrictTriple lt found gt = splitMemberS x l
in StrictTriple lt found (link y gt r)
GT -> let StrictTriple lt found gt = splitMemberS x r
in StrictTriple (link y l lt) found gt
EQ -> StrictTriple l True r
#if __GLASGOW_HASKELL__
{-# INLINABLE splitMemberS #-}
#endif

-- | \(O(\log n)\). Performs a 'split' but also returns whether the pivot
-- element was found in the original set.
splitMember :: Ord a => a -> Set a -> (Set a,Bool,Set a)
splitMember _ Tip = (Tip, False, Tip)
splitMember x (Bin _ y l r)
= case compare x y of
LT -> let (lt, found, gt) = splitMember x l
!gt' = link y gt r
in (lt, found, gt')
GT -> let (lt, found, gt) = splitMember x r
!lt' = link y l lt
in (lt', found, gt)
EQ -> (l, True, r)
splitMember k0 s = case splitMemberS k0 s of
StrictTriple l b r -> (l, b, r)
#if __GLASGOW_HASKELL__
{-# INLINABLE splitMember #-}
#endif
Expand Down
15 changes: 15 additions & 0 deletions containers/src/Utils/Containers/Internal/StrictTriple.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{-# LANGUAGE CPP #-}
#if !defined(TESTING) && defined(__GLASGOW_HASKELL__)
{-# LANGUAGE Safe #-}
#endif

-- | A strict triple

module Utils.Containers.Internal.StrictTriple (StrictTriple(..)) where

-- | The same as a regular Haskell tuple, but
--
-- @
-- StrictTriple x y _|_ = StrictTriple x _|_ z = StrictTriple _|_ y z = _|_
-- @
data StrictTriple a b c = StrictTriple !a !b !c