diff --git a/Changelog.md b/Changelog.md index 2d62e02..7926ddb 100644 --- a/Changelog.md +++ b/Changelog.md @@ -2,6 +2,12 @@ ## next -- *TBA* +## 2.1.6.0.100 + + * Updates `Data.Parameterized.TH.GADT.structuralEquality` to add type + assertions to cover all type parameters. This change may require the + addition of the `ScopedTypeVariables` pragma to modules importing this code. + ## 2.1.6.0 -- *2022 Dec 18* * Added `FinMap`: an integer map with a statically-known maximum size. diff --git a/parameterized-utils.cabal b/parameterized-utils.cabal index cd1be6e..886c3af 100644 --- a/parameterized-utils.cabal +++ b/parameterized-utils.cabal @@ -1,6 +1,6 @@ Cabal-version: 2.2 Name: parameterized-utils -Version: 2.1.6.0.99 +Version: 2.1.6.0.100 Author: Galois Inc. Maintainer: kquick@galois.com stability: stable diff --git a/src/Data/Parameterized/Classes.hs b/src/Data/Parameterized/Classes.hs index df1072a..45a387a 100644 --- a/src/Data/Parameterized/Classes.hs +++ b/src/Data/Parameterized/Classes.hs @@ -69,7 +69,7 @@ import Data.Type.Equality as Equality import Data.Parameterized.Compose () -- We define these type alias here to avoid importing Control.Lens --- modules, as this apparently causes problems with the safe Hasekll +-- modules, as this apparently causes problems with the safe Haskell -- checking. type Lens' s a = forall f. Functor f => (a -> f a) -> s -> f s type Traversal' s a = forall f. Applicative f => (a -> f a) -> s -> f s diff --git a/src/Data/Parameterized/TH/GADT.hs b/src/Data/Parameterized/TH/GADT.hs index 56573f6..c3c00b6 100644 --- a/src/Data/Parameterized/TH/GADT.hs +++ b/src/Data/Parameterized/TH/GADT.hs @@ -10,11 +10,12 @@ ------------------------------------------------------------------------ {-# LANGUAGE CPP #-} {-# LANGUAGE DoAndIfThenElse #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE EmptyCase #-} module Data.Parameterized.TH.GADT ( -- * Instance generators -- $typePatterns @@ -40,15 +41,16 @@ module Data.Parameterized.TH.GADT , assocTypePats ) where -import Control.Monad -import Data.Maybe -import Data.Set (Set) +import Control.Monad +import Data.Function ( on ) +import Data.Maybe +import Data.Set (Set) import qualified Data.Set as Set -import Language.Haskell.TH -import Language.Haskell.TH.Datatype +import Language.Haskell.TH +import Language.Haskell.TH.Datatype -import Data.Parameterized.Classes +import Data.Parameterized.Classes ------------------------------------------------------------------------ -- Template Haskell utilities @@ -133,10 +135,72 @@ typeVars :: TypeSubstitution a => a -> Set Name typeVars = Set.fromList . freeVariables --- | @structuralEquality@ declares a structural equality predicate. +-- | @structuralEquality@ declares a structural equality predicate for a GADT. structuralEquality :: TypeQ -> [(TypePat,ExpQ)] -> ExpQ -structuralEquality tpq pats = - [| \x y -> isJust ($(structuralTypeEquality tpq pats) x y) |] +structuralEquality tpq pats = do + d <- reifyDatatype =<< asTypeCon "structuralEquality" =<< tpq + + -- tpq is some type of GADT: data X p1 p2 ... where ... + -- + -- The general approach is to generate a structural type equality such that the + -- result is a Maybe (e :+: f) is Just Refl and then verify it is a Just value + -- to assert equality by generating (via template haskell): + -- + -- \ x y -> isJust $(structuralTypeEquality ... x y) + -- + -- However, that result presumes a `TestEquality f where testEquality :: f a -> + -- f b -> Maybe (a :~: b)`. If the GADT has a single type parameter, those + -- types align and there is no problem. If the GADT has multiple type + -- variables, GHC is unsure of which we are making the TestEquality assertion + -- about and we need to help. We actually want to make that assertion over + -- _all_ of the parameters, so given: + -- + -- data D p1 p2 p3 where ... + -- + -- the template haskell here should generate: + -- + -- \ (x :: D xt1 xt2 xt3) (y :: D yt1 yt2 yt3) -> + -- isJust ( ($(structuralTypeEquality ... x y)) + -- :: Maybe ( '(xt1, xt2, xt3) :~: '(yt1, yt2, yt3) ) + -- ) + -- + -- This will perform the equality check in a way that obtains proof of equality + -- for all of the type parameters. This will require the ScopedTypeVariables + -- pragma, but GHC will happily suggest that if it's missing. + -- + -- This is also useful for the equality test on the single parameter case: + -- + -- data D p1 where ... + -- + -- instance Eq (D a) where + -- (==) = $(structuralEquality [t|D|] [] + -- + -- Again, this will fail without the template haskell assertion of the target + -- types matching the argument types. + + gadtParams <- return $ datatypeInstTypes d + arg1Params <- fmap varT <$> newNames "xTy" (length gadtParams) + arg2Params <- fmap varT <$> newNames "yTy" (length gadtParams) + let arg1Ty = foldl appT (conT $ datatypeName d) arg1Params + let arg2Ty = foldl appT (conT $ datatypeName d) arg2Params +#if MIN_VERSION_base(4,14,0) + let mkSuperTy tyList = foldl appT (promotedTupleT (length tyList)) tyList +#else + let mkSuperTy tyList = + if length tyList < 2 + then if length tyList == 0 + then error "Expected at least one type in structuralEquality" + else head tyList + else foldl appT (promotedTupleT (length tyList)) tyList +#endif + let arg1AllParamTy = mkSuperTy arg1Params + let arg2AllParamTy = mkSuperTy arg2Params + + [| \(x :: $(arg1Ty)) (y :: $(arg2Ty)) -> + isJust ($(structuralTypeEquality_ True tpq pats) x y + :: Maybe ($(arg1AllParamTy) :~: $(arg2AllParamTy)) + ) + |] joinEqMaybe :: Name -> Name -> ExpQ -> ExpQ joinEqMaybe x y r = do @@ -181,26 +245,44 @@ matchEqArguments _ _ _ _ _ _ [] = error "Unexpected end of names." mkSimpleEqF :: [Type] -- ^ Data declaration types -> Set Name -> [(TypePat,ExpQ)] -- ^ Patterns for matching arguments - -> ConstructorInfo + -> ConstructorInfo -- ^ The constructor we are concerned with -> [Name] -> ExpQ - -> Bool -- ^ wildcard case required - -> ExpQ -mkSimpleEqF dTypes bnd pats con xv yQ multipleCases = do + -> [ConstructorInfo] -- ^ All constructors (for determining if wildcard case required) + -> Bool -- ^ True if the equality arguments are the same type + -> ExpQ +mkSimpleEqF dTypes bnd pats con xv yQ multipleCases argsSameType = do -- Get argument types for constructor. let nm = constructorName con (yp,yv) <- conPat con "y" let rv = matchEqArguments dTypes pats nm bnd (constructorFields con) xv yv + let otherMatchingCons = + -- Determine the other constructors that should be matched relative to + -- `con`. If this is supplying code for `testEquality`, the input + -- signature is `f a -> f b -> ...` and will admit different types, so + -- all constructors should be checked, but if this is supplying code for + -- `Eq` or similar where the input signature is `a -> a -> ...` + -- (i.e. `argsSameType` is `True`), then only constructors that have the + -- same resulting type should be checked, otherwise GHC will emit + -- warnings/errors about "pattern not reached" for the case statement + -- being generated here. + let sameContext = (==) `on` constructorContext + in if argsSameType + then filter (sameContext con) multipleCases + else multipleCases caseE yQ $ match (pure yp) (normalB rv) [] - : [ match wildP (normalB [| Nothing |]) [] | multipleCases ] + : [ match wildP (normalB [| Nothing |]) [] + | 1 < length otherMatchingCons + ] -- | Match equational form. mkEqF :: DatatypeInfo -- ^ Data declaration. -> [(TypePat,ExpQ)] - -> ConstructorInfo + -> ConstructorInfo -- ^ Constructor for which equality is to be determined -> [Name] -> ExpQ - -> Bool -- ^ wildcard case required + -> [ConstructorInfo] -- ^ All constructors (for determining if wildcard case required) + -> Bool -- ^ True if the equality arguments are the same type -> ExpQ mkEqF d pats con = let dVars = dataParamTypes d -- the type arguments for the constructor @@ -216,12 +298,18 @@ mkEqF d pats con = -- forall x y . f x -> f y -> Maybe (x :~: y) -- @ structuralTypeEquality :: TypeQ -> [(TypePat,ExpQ)] -> ExpQ -structuralTypeEquality tpq pats = do +structuralTypeEquality = structuralTypeEquality_ False + +structuralTypeEquality_ :: Bool -> TypeQ -> [(TypePat,ExpQ)] -> ExpQ +structuralTypeEquality_ argsSameType tpq pats = do d <- reifyDatatype =<< asTypeCon "structuralTypeEquality" =<< tpq - let multipleCons = not (null (drop 1 (datatypeCons d))) + let multipleCons = datatypeCons d trueEqs yQ = [ do (xp,xv) <- conPat con "x" - match (pure xp) (normalB (mkEqF d pats con xv yQ multipleCons)) [] + match (pure xp) + (normalB + (mkEqF d pats con xv yQ multipleCons argsSameType)) + [] | con <- datatypeCons d ] diff --git a/test/Test/TH.hs b/test/Test/TH.hs index 9f4f12a..da51aa2 100644 --- a/test/Test/TH.hs +++ b/test/Test/TH.hs @@ -4,6 +4,7 @@ {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} @@ -13,14 +14,15 @@ module Test.TH ) where -import Test.Tasty -import Test.Tasty.HUnit +import Test.Tasty +import Test.Tasty.HUnit -import Control.Monad (when) -import Data.Parameterized.Classes -import Data.Parameterized.NatRepr -import Data.Parameterized.TH.GADT -import GHC.TypeNats +import Control.Monad (when) +import Data.Parameterized.Classes +import Data.Parameterized.NatRepr +import Data.Parameterized.SymbolRepr +import Data.Parameterized.TH.GADT +import GHC.TypeNats data T1 = A | B | C $(mkRepr ''T1) @@ -39,13 +41,41 @@ instance TestEquality T2Repr where [ (AnyType, [|testEquality|]) ]) deriving instance Show (T2Repr t) +data T3 (is_a :: Symbol) where + T3_Int :: Int -> T3 "int" + T3_Bool :: Bool -> T3 "bool" +$(return []) +instance TestEquality T3 where + testEquality = $(structuralTypeEquality [t|T3|] []) +instance Eq (T3 s) where + (==) = $(structuralEquality [t|T3|] []) +deriving instance Show (T3 s) + +data T4 b (is_a :: Symbol) where + T4_Int :: Int -> T4 b "int" + T4_Bool :: Bool -> T4 b "bool" +$(return []) +instance TestEquality (T4 b) where + testEquality = $(structuralTypeEquality [t|T4|] []) +instance Eq (T4 b s) where + (==) = $(structuralEquality [t|T4|] []) +deriving instance Show (T4 b s) + eqTest :: (TestEquality f, Show (f a), Show (f b)) => f a -> f b -> IO () eqTest a b = when (not (isJust (testEquality a b))) $ assertFailure $ show a ++ " /= " ++ show b neqTest :: (TestEquality f, Show (f a), Show (f b)) => f a -> f b -> IO () neqTest a b = - when (isJust (testEquality a b)) $ assertFailure $ show a ++ " == " ++ show b + when (isJust (testEquality a b)) + $ assertFailure + $ show a <> " == " <> show b <> " but should not be!" + +assertNotEqual :: (Eq a, Show a) => String -> a -> a -> IO () +assertNotEqual msg a b = + when (a == b) + $ assertFailure + $ msg <> " " <> show a <> " == " <> show b <> " but should not be!" thTests :: IO TestTree thTests = testGroup "TH" <$> return @@ -62,6 +92,28 @@ thTests = testGroup "TH" <$> return T2_2Repr (knownNat @5) `neqTest` T2_2Repr (knownNat @9) T2_1Repr BRepr `neqTest` T2_2Repr (knownNat @4) + , testCase "Instance tests" $ do + assertEqual "T3_Int values" (T3_Int 5) (T3_Int 5) + assertNotEqual "T3_Int values" (T3_Int 5) (T3_Int 54) + assertEqual "T3_Bool values" (T3_Bool True) (T3_Bool True) + assertNotEqual "T3_Bool values" (T3_Bool True) (T3_Bool False) + + -- n.b. the following is not possible: 'T3 "int"' is not a 'T3 "bool"' + -- assertEqual "T3_Int/T3_Bool values" (T3_Int 1) (T3_Bool True) + + T3_Int 1 `eqTest` T3_Int 1 + T3_Int 1 `neqTest` T3_Int 3 + T3_Int 1 `neqTest` T3_Bool True + T3_Bool False `neqTest` T3_Bool True + T3_Bool True `eqTest` T3_Bool True + + assertEqual "T4_Int values" (T4_Int @String 5) (T4_Int @String 5) + assertNotEqual "T4_Int values" (T4_Int @String 5) (T4_Int @String 54) + + T4_Int @String 1 `eqTest` T4_Int @String 1 + T4_Int @String 1 `neqTest` T4_Int @String 2 + + , testCase "KnownRepr test" $ do -- T1 let aRepr = knownRepr :: T1Repr 'A