Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add inlining pragma #254

Merged
merged 3 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/Agda2Hs/Compile.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ import qualified Data.Map as M
import Agda.Compiler.Backend
import Agda.Syntax.TopLevelModuleName ( TopLevelModuleName )
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Monad.Signature ( isInlineFun )
import Agda.Utils.Null
import Agda.Utils.Monad ( whenM )

import qualified Language.Haskell.Exts.Extension as Hs

import Agda2Hs.Compile.ClassInstance ( compileInstance )
import Agda2Hs.Compile.Data ( compileData )
import Agda2Hs.Compile.Function ( compileFun, checkTransparentPragma )
import Agda2Hs.Compile.Function ( compileFun, checkTransparentPragma, checkInlinePragma )
import Agda2Hs.Compile.Postulate ( compilePostulate )
import Agda2Hs.Compile.Record ( compileRecord, checkUnboxPragma )
import Agda2Hs.Compile.Types
Expand Down Expand Up @@ -91,6 +92,8 @@ compile opts tlm _ def = withCurrentModule (qnameModule $ defName def) $ runC tl
tag <$> compileFun True def
(DefaultPragma ds, _, Record{}) ->
tag . single <$> compileRecord (ToRecord ds) def
(InlinePragma, _, Function{}) -> do
checkInlinePragma def >> return []
_ ->
genericDocError =<< do
text "Don't know how to compile" <+> prettyTCM (defName def)
Expand Down
27 changes: 25 additions & 2 deletions src/Agda2Hs/Compile/Function.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedStrings, ViewPatterns, NamedFieldPuns #-}
module Agda2Hs.Compile.Function where

import Control.Monad ( (>=>), filterM, forM_ )
Expand Down Expand Up @@ -26,7 +26,7 @@ import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Telescope ( telView )
import Agda.TypeChecking.Sort ( ifIsSort )

import Agda.Utils.Functor ( (<&>) )
import Agda.Utils.Functor ( (<&>), dget)
import Agda.Utils.Impossible ( __IMPOSSIBLE__ )
import Agda.Utils.List
import Agda.Utils.Maybe
Expand Down Expand Up @@ -294,3 +294,26 @@ checkTransparentPragma def = compileFun False def >>= \case
errNotTransparent = genericDocError =<<
"Cannot make function" <+> prettyTCM (defName def) <+> "transparent." <+>
"A transparent function must have exactly one non-erased argument and return it unchanged."

checkInlinePragma :: Definition -> C ()
checkInlinePragma def@Defn{defName = f} = do
let Function{funClauses = cs} = theDef def
case filter (isJust . clauseBody) cs of
[c] -> do
let Clause{clauseTel,namedClausePats = naps} = c
unlessM (allM (dget . dget <$> naps) allowedPat) $ genericDocError =<<
"Cannot make function" <+> prettyTCM (defName def) <+> "inlinable." <+>
"Inline functions can only use variable patterns, dot patterns, or transparent record constructor patterns."
_ ->
genericDocError =<<
"Cannot make function" <+> prettyTCM f <+> "inlinable." <+>
"An inline function must have exactly one clause."
where allowedPat :: DeBruijnPattern -> C Bool
allowedPat VarP{} = pure True
allowedPat DotP{} = pure True
-- only allow matching on (unboxed) record constructors
allowedPat (ConP ch ci cargs) =
isUnboxConstructor (conName ch) >>= \case
Just _ -> allM cargs (allowedPat . dget . dget)
Nothing -> pure False
allowedPat _ = pure False
77 changes: 60 additions & 17 deletions src/Agda2Hs/Compile/Term.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE ViewPatterns, NamedFieldPuns #-}
module Agda2Hs.Compile.Term where

import Control.Arrow ( (>>>), (&&&) )
Expand All @@ -7,6 +8,7 @@ import Control.Monad.Reader
import Data.List ( isPrefixOf )
import Data.Maybe ( fromMaybe, isJust )
import qualified Data.Text as Text ( unpack )
import qualified Data.Set as Set ( singleton )

import qualified Language.Haskell.Exts as Hs

Expand All @@ -18,8 +20,8 @@ import Agda.Syntax.Internal

import Agda.TypeChecking.Monad
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Reduce ( instantiate )
import Agda.TypeChecking.Substitute ( Apply(applyE) )
import Agda.TypeChecking.Reduce ( unfoldDefinitionStep )
import Agda.TypeChecking.Substitute ( Apply(applyE), raise, mkAbs )

import Agda.Utils.Lens

Expand Down Expand Up @@ -228,16 +230,15 @@ compileTerm v = do
| Just semantics <- isSpecialTerm f -> do
reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of special function"
semantics f es
| otherwise -> isClassFunction f >>= \case
True -> compileClassFunApp f es
False -> (isJust <$> isUnboxProjection f) `or2M` isTransparentFunction f >>= \case
True -> compileErasedApp es
False -> do
reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of regular function"
-- Drop module parameters of local `where` functions
moduleArgs <- getDefFreeVars f
reportSDoc "agda2hs.compile.term" 15 $ text "Module arguments for" <+> (prettyTCM f <> text ":") <+> prettyTCM moduleArgs
(`app` drop moduleArgs es) . Hs.Var () =<< compileQName f
| otherwise ->
ifM (isClassFunction f) (compileClassFunApp f es) $
ifM ((isJust <$> isUnboxProjection f) `or2M` isTransparentFunction f) (compileErasedApp es) $
ifM (isInlinedFunction f) (compileInlineFunctionApp f es) $ do
reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of regular function"
-- Drop module parameters of local `where` functions
moduleArgs <- getDefFreeVars f
reportSDoc "agda2hs.compile.term" 15 $ text "Module arguments for" <+> (prettyTCM f <> text ":") <+> prettyTCM moduleArgs
(`app` drop moduleArgs es) . Hs.Var () =<< compileQName f
Con h i es -> do
reportSDoc "agda2hs.compile" 8 $ text "reached constructor:" <+> prettyTCM (conName h)
-- the constructor may be a copy introduced by module application,
Expand Down Expand Up @@ -281,14 +282,56 @@ compileTerm v = do
Just _ -> compileErasedApp es
Nothing -> (`app` es) . Hs.Con () =<< compileQName (conName h)

-- `compileErasedApp` compiles an application of an erased constructor
-- or projection.
-- `compileErasedApp` compiles an application of an unboxed constructor
-- or unboxed projection or transparent function.
-- Precondition is that at most one elim is preserved.
compileErasedApp :: Elims -> C (Hs.Exp ())
compileErasedApp es = do
reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of erased function"
reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of transparent function or erased unboxed constructor"
compileElims es >>= \case
[] -> return $ hsVar "id"
(v:vs) -> return $ v `eApp` vs
[] -> return $ hsVar "id"
[v] -> return v
_ -> __IMPOSSIBLE__

-- | Compile the application of a function definition marked as inlinable.
-- The provided arguments will get substituted in the function body, and the missing arguments
-- will get quantified with lambdas.
compileInlineFunctionApp :: QName -> Elims -> C (Hs.Exp ())
compileInlineFunctionApp f es = do
reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of inline function"
Function { funClauses = cs } <- theDef <$> getConstInfo f
let [ Clause { namedClausePats = pats
, clauseBody = Just body
, clauseTel
} ] = filter (isJust . clauseBody) cs
etaExpand (drop (length es) pats) es >>= compileTerm
where
-- inline functions can only have transparent constructor patterns and variable patterns
extractPatName :: DeBruijnPattern -> ArgName
extractPatName (VarP _ v) = dbPatVarName v
extractPatName (ConP _ _ args) =
let arg = namedThing $ unArg $ head $ filter (usableModality `and2M` visible) args
in extractPatName arg
extractPatName _ = __IMPOSSIBLE__

extractName :: NamedArg DeBruijnPattern -> ArgName
extractName (unArg -> np)
| Just n <- nameOf np = rangedThing (woThing n)
| otherwise = extractPatName (namedThing np)

etaExpand :: NAPs -> Elims -> C Term
etaExpand [] es = do
r <- liftReduce
$ locallyReduceDefs (OnlyReduceDefs $ Set.singleton f)
$ unfoldDefinitionStep (Def f es) f es
case r of
YesReduction _ t -> pure t
_ -> genericDocError =<< text "Could not reduce inline function" <+> prettyTCM f

etaExpand (p:ps) es =
let ai = argInfo p in
Lam ai . mkAbs (extractName p)
<$> etaExpand ps (raise 1 es ++ [ Apply $ Arg ai $ var 0 ])

-- `compileClassFunApp` is used when we have a record projection and we want to
-- drop the first visible arg (the record)
Expand Down
27 changes: 24 additions & 3 deletions src/Agda2Hs/Compile/Type.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeApplications, NamedFieldPuns #-}

module Agda2Hs.Compile.Type where

Expand All @@ -8,6 +8,7 @@ import Control.Monad.Trans ( lift )
import Control.Monad.Reader ( asks )
import Data.List ( find )
import Data.Maybe ( mapMaybe, isJust )
import qualified Data.Set as Set ( singleton )

import qualified Language.Haskell.Exts.Syntax as Hs
import qualified Language.Haskell.Exts.Extension as Hs
Expand All @@ -20,7 +21,7 @@ import Agda.Syntax.Internal
import Agda.Syntax.Common.Pretty ( prettyShow )

import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Reduce ( reduce )
import Agda.TypeChecking.Reduce ( reduce, unfoldDefinitionStep )
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Telescope

Expand Down Expand Up @@ -150,7 +151,8 @@ compileType t = do
| Just semantics <- isSpecialType f -> setCurrentRange f $ semantics f es
| Just args <- allApplyElims es ->
ifJustM (isUnboxRecord f) (\_ -> compileUnboxType f args) $
ifM (isTransparentFunction f) (compileTransparentType args) $ do
ifM (isTransparentFunction f) (compileTransparentType args) $
ifM (isInlinedFunction f) (compileInlineType f es) $ do
vs <- compileTypeArgs args
f <- compileQName f
return $ tApp (Hs.TyCon () f) vs
Expand Down Expand Up @@ -181,6 +183,25 @@ compileTransparentType args = compileTypeArgs args >>= \case
[] -> __IMPOSSIBLE__
(v:vs) -> return $ v `tApp` vs

compileInlineType :: QName -> Elims -> C (Hs.Type ())
compileInlineType f args = do
Function { funClauses = cs } <- theDef <$> getConstInfo f

let [ Clause { namedClausePats = pats
, clauseBody = Just body
, clauseTel
} ] = filter (isJust . clauseBody) cs

when (length args < length pats) $ genericDocError =<<
text "Cannot compile inlinable type alias" <+> prettyTCM f <+> text "as it must be fully applied."

r <- liftReduce $ locallyReduceDefs (OnlyReduceDefs $ Set.singleton f)
$ unfoldDefinitionStep (Def f args) f args

case r of
YesReduction _ t -> compileType t
_ -> genericDocError =<< text "Could not reduce inline function" <+> prettyTCM f

compileDom :: ArgName -> Dom Type -> C CompiledDom
compileDom x a
| usableModality a = case getHiding a of
Expand Down
11 changes: 8 additions & 3 deletions src/Agda2Hs/Compile/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,14 @@ isTransparentFunction :: QName -> C Bool
isTransparentFunction q = do
getConstInfo q >>= \case
Defn{defName = r, theDef = Function{}} ->
processPragma r <&> \case
TransparentPragma -> True
_ -> False
(TransparentPragma ==) <$> processPragma r
_ -> return False

isInlinedFunction :: QName -> C Bool
isInlinedFunction q = do
getConstInfo q >>= \case
Defn{defName = r, theDef = Function{}} ->
(InlinePragma ==) <$> processPragma r
_ -> return False

checkInstance :: Term -> C ()
Expand Down
2 changes: 1 addition & 1 deletion src/Agda2Hs/HsUtils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -267,4 +267,4 @@ patToExp = \case
_ -> Nothing

data Strictness = Lazy | Strict
deriving Show
deriving (Eq, Show)
4 changes: 3 additions & 1 deletion src/Agda2Hs/Pragma.hs
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ getForeignPragmas exts = do

data ParsedPragma
= NoPragma
| InlinePragma
| DefaultPragma [Hs.Deriving ()]
| ClassPragma [String]
| ExistingClassPragma
| UnboxPragma Strictness
| TransparentPragma
| NewTypePragma [Hs.Deriving ()]
| DerivePragma (Maybe (Hs.DerivStrategy ()))
deriving Show
deriving (Eq, Show)

derivePragma :: String
derivePragma = "derive"
Expand Down Expand Up @@ -85,6 +86,7 @@ processPragma qn = liftTCM (getUniqueCompilerPragma pragmaName qn) >>= \case
Nothing -> return NoPragma
Just (CompilerPragma _ s)
| "class" `isPrefixOf` s -> return $ ClassPragma (words $ drop 5 s)
| s == "inline" -> return InlinePragma
| s == "existing-class" -> return ExistingClassPragma
| s == "unboxed" -> return $ UnboxPragma Lazy
| s == "unboxed-strict" -> return $ UnboxPragma Strict
Expand Down
2 changes: 2 additions & 0 deletions test/AllTests.agda
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ import Issue210
import ModuleParameters
import ModuleParametersImports
import Coerce
import Inlining

{-# FOREIGN AGDA2HS
import Issue14
Expand Down Expand Up @@ -126,4 +127,5 @@ import Issue210
import ModuleParameters
import ModuleParametersImports
import Coerce
import Inlining
#-}
8 changes: 8 additions & 0 deletions test/Fail/Inline.agda
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module Fail.Inline where

open import Haskell.Prelude

tail' : List a → List a
tail' (x ∷ xs) = xs
tail' [] = []
{-# COMPILE AGDA2HS tail' inline #-}
7 changes: 7 additions & 0 deletions test/Fail/Inline2.agda
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module Fail.Inline2 where

open import Haskell.Prelude

tail' : (xs : List a) → @0 {{ NonEmpty xs }} → List a
tail' (x ∷ xs) = xs
{-# COMPILE AGDA2HS tail' inline #-}
43 changes: 43 additions & 0 deletions test/Inlining.agda
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
module Inlining where

open import Haskell.Prelude

Alias : Set
Alias = Bool
{-# COMPILE AGDA2HS Alias inline #-}

aliased : Alias
aliased = True
{-# COMPILE AGDA2HS aliased #-}

record Wrap (a : Set) : Set where
constructor Wrapped
field
unwrap : a
open Wrap public
{-# COMPILE AGDA2HS Wrap unboxed #-}

mapWrap : (f : a → b) → Wrap a → Wrap b
mapWrap f (Wrapped x) = Wrapped (f x)
{-# COMPILE AGDA2HS mapWrap inline #-}

mapWrap2 : (f : a → b → c) → Wrap a → Wrap b → Wrap c
mapWrap2 f (Wrapped x) (Wrapped y) = Wrapped (f x y)
{-# COMPILE AGDA2HS mapWrap2 inline #-}

test1 : Wrap Int → Wrap Int
test1 x = mapWrap (1 +_) x
{-# COMPILE AGDA2HS test1 #-}

test2 : Wrap Int → Wrap Int → Wrap Int
test2 x y = mapWrap2 _+_ x y
{-# COMPILE AGDA2HS test2 #-}

-- partial application of inline function
test3 : Wrap Int → Wrap Int → Wrap Int
test3 x = mapWrap2 _+_ x
{-# COMPILE AGDA2HS test3 #-}

test4 : Wrap Int → Wrap Int → Wrap Int
test4 = mapWrap2 _+_
{-# COMPILE AGDA2HS test4 #-}
1 change: 1 addition & 0 deletions test/golden/AllTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@ import Issue210
import ModuleParameters
import ModuleParametersImports
import Coerce
import Inlining

4 changes: 0 additions & 4 deletions test/golden/Haskell/Extra/Dec.hs

This file was deleted.

2 changes: 2 additions & 0 deletions test/golden/Inline.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test/Fail/Inline.agda:5,1-6
Cannot make function tail' inlinable. An inline function must have exactly one clause.
2 changes: 2 additions & 0 deletions test/golden/Inline2.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test/Fail/Inline2.agda:5,1-6
Cannot make function tail' inlinable. Inline functions can only use variable patterns, dot patterns, or transparent record constructor patterns.
flupe marked this conversation as resolved.
Show resolved Hide resolved
17 changes: 17 additions & 0 deletions test/golden/Inlining.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module Inlining where

aliased :: Bool
aliased = True

test1 :: Int -> Int
test1 x = 1 + x

test2 :: Int -> Int -> Int
test2 x y = x + y

test3 :: Int -> Int -> Int
test3 x = \ y -> x + y

test4 :: Int -> Int -> Int
test4 = \ x y -> x + y

Loading