Skip to content

Commit

Permalink
add inlining pragma
Browse files Browse the repository at this point in the history
  • Loading branch information
flupe committed Dec 15, 2023
1 parent e4bca58 commit fa97dca
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 24 deletions.
5 changes: 4 additions & 1 deletion src/Agda2Hs/Compile.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ import qualified Data.Map as M
import Agda.Compiler.Backend
import Agda.Syntax.TopLevelModuleName ( TopLevelModuleName )
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Monad.Signature ( isInlineFun )
import Agda.Utils.Null
import Agda.Utils.Monad ( whenM )

import qualified Language.Haskell.Exts.Extension as Hs

import Agda2Hs.Compile.ClassInstance ( compileInstance )
import Agda2Hs.Compile.Data ( compileData )
import Agda2Hs.Compile.Function ( compileFun, checkTransparentPragma )
import Agda2Hs.Compile.Function ( compileFun, checkTransparentPragma, checkInlinePragma )
import Agda2Hs.Compile.Postulate ( compilePostulate )
import Agda2Hs.Compile.Record ( compileRecord, checkUnboxPragma )
import Agda2Hs.Compile.Types
Expand Down Expand Up @@ -91,6 +92,8 @@ compile opts tlm _ def = withCurrentModule (qnameModule $ defName def) $ runC tl
tag <$> compileFun True def
(DefaultPragma ds, _, Record{}) ->
tag . single <$> compileRecord (ToRecord ds) def
(InlinePragma, _, Function{}) -> do
checkInlinePragma def >> return []
_ ->
genericDocError =<< do
text "Don't know how to compile" <+> prettyTCM (defName def)
Expand Down
17 changes: 16 additions & 1 deletion src/Agda2Hs/Compile/Function.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedStrings, ViewPatterns, NamedFieldPuns #-}
module Agda2Hs.Compile.Function where

import Control.Monad ( (>=>), filterM, forM_ )
Expand Down Expand Up @@ -287,3 +287,18 @@ checkTransparentPragma def = compileFun False def >>= \case
errNotTransparent = genericDocError =<<
"Cannot make function" <+> prettyTCM (defName def) <+> "transparent." <+>
"A transparent function must have exactly one non-erased argument and return it unchanged."

checkInlinePragma :: Definition -> C ()
checkInlinePragma def@(theDef -> Function{funClauses = [c]}) = do
let Clause{clauseTel,namedClausePats = pats} = c
cpats <- addContext (KeepNames clauseTel) $ compilePats pats
unless (all allowedPat cpats) $ genericDocError =<<
"Cannot make function" <+> prettyTCM (defName def) <+> "inlinable." <+>
"Inline functions can only use variable patterns, wildcard patterns, or transparent record constructor patterns."
where allowedPat :: Hs.Pat () -> Bool
allowedPat (Hs.PWildCard ()) = True
allowedPat (Hs.PVar () _) = True
allowedPat _ = False
checkInlinePragma (defName -> f) = genericDocError =<<
"Cannot make function" <+> prettyTCM f <+> "inlinable." <+>
"An inline function must have exactly one clause."
65 changes: 48 additions & 17 deletions src/Agda2Hs/Compile/Term.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE ViewPatterns #-}
module Agda2Hs.Compile.Term where

import Control.Arrow ( (>>>), (&&&) )
Expand All @@ -7,6 +8,7 @@ import Control.Monad.Reader
import Data.List ( isPrefixOf )
import Data.Maybe ( fromMaybe, isJust )
import qualified Data.Text as Text ( unpack )
import qualified Data.Set as Set (singleton)

import qualified Language.Haskell.Exts as Hs

Expand All @@ -18,8 +20,8 @@ import Agda.Syntax.Internal

import Agda.TypeChecking.Monad
import Agda.TypeChecking.Pretty
import Agda.TypeChecking.Reduce ( instantiate )
import Agda.TypeChecking.Substitute ( Apply(applyE) )
import Agda.TypeChecking.Reduce ( instantiate, reduce, reduceHead )
import Agda.TypeChecking.Substitute ( Apply(applyE), raise, mkAbs )

import Agda.Utils.Lens

Expand Down Expand Up @@ -233,16 +235,14 @@ compileTerm v = do
| Just semantics <- isSpecialTerm f -> do
reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of special function"
semantics f es
| otherwise -> isClassFunction f >>= \case
True -> compileClassFunApp f es
False -> (isJust <$> isUnboxProjection f) `or2M` isTransparentFunction f >>= \case
True -> compileErasedApp es
False -> do
reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of regular function"
-- Drop module parameters of local `where` functions
moduleArgs <- getDefFreeVars f
reportSDoc "agda2hs.compile.term" 15 $ text "Module arguments for" <+> (prettyTCM f <> text ":") <+> prettyTCM moduleArgs
(`app` drop moduleArgs es) . Hs.Var () =<< compileQName f
| otherwise -> ifM (isClassFunction f) (compileClassFunApp f es) $ do
ifM ((isJust <$> isUnboxProjection f) `or2M` isTransparentFunction f) (compileErasedApp es) $ do
ifM (isInlinedFunction f) (compileInlineFunctionApp f es) $ do
reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of regular function"
-- Drop module parameters of local `where` functions
moduleArgs <- getDefFreeVars f
reportSDoc "agda2hs.compile.term" 15 $ text "Module arguments for" <+> (prettyTCM f <> text ":") <+> prettyTCM moduleArgs
(`app` drop moduleArgs es) . Hs.Var () =<< compileQName f
Con h i es
| Just semantics <- isSpecialCon (conName h) -> semantics h i es
Con h i es -> isUnboxConstructor (conName h) >>= \case
Expand Down Expand Up @@ -273,14 +273,45 @@ compileTerm v = do
app :: Hs.Exp () -> Elims -> C (Hs.Exp ())
app hd es = eApp <$> pure hd <*> compileElims es

-- `compileErasedApp` compiles an application of an erased constructor
-- or projection.
-- `compileErasedApp` compiles an application of an unboxed constructor
-- or unboxed projection or transparent function.
-- Precondition is that at most one elim is preserved.
compileErasedApp :: Elims -> C (Hs.Exp ())
compileErasedApp es = do
reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of erased function"
reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of transparent function or erased unboxed constructor"
compileElims es >>= \case
[] -> return $ hsVar "id"
(v:vs) -> return $ v `eApp` vs
[] -> return $ hsVar "id"
[v] -> return v
_ -> __IMPOSSIBLE__

-- | Compile the application of a function definition marked as inlinable.
-- The provided arguments will get substituted in the function body, and the missing arguments
-- will get quantified with lambdas.
compileInlineFunctionApp :: QName -> Elims -> C (Hs.Exp ())
compileInlineFunctionApp f es = do
reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of inline function"
Function{funClauses = [namedClausePats -> pats]} <- theDef <$> getConstInfo f
etaExpand (drop (length es) pats) es >>= compileTerm
where
-- inline functions can only have transparent constructor patterns and variable patterns
extractPatName :: DeBruijnPattern -> ArgName
extractPatName (VarP _ v) = dbPatVarName v
extractPatName (ConP _ _ args) =
let arg = namedThing $ unArg $ head $ filter (usableModality `and2M` visible) args
in extractPatName arg
extractPatName _ = __IMPOSSIBLE__

extractName :: NamedArg DeBruijnPattern -> ArgName
extractName (unArg -> np)
| Just n <- nameOf np = rangedThing (woThing n)
| otherwise = extractPatName (namedThing np)

etaExpand :: NAPs -> Elims -> C Term
etaExpand [] es = locallyReduceDefs (OnlyReduceDefs $ Set.singleton f) $ reduce (Def f es)
etaExpand (p:ps) es =
let ai = argInfo p in
Lam ai . mkAbs (extractName p)
<$> etaExpand ps (raise 1 es ++ [ Apply $ Arg ai $ var 0 ])

-- `compileClassFunApp` is used when we have a record projection and we want to
-- drop the first visible arg (the record)
Expand Down
11 changes: 8 additions & 3 deletions src/Agda2Hs/Compile/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,14 @@ isTransparentFunction :: QName -> C Bool
isTransparentFunction q = do
getConstInfo q >>= \case
Defn{defName = r, theDef = Function{}} ->
processPragma r <&> \case
TransparentPragma -> True
_ -> False
(TransparentPragma ==) <$> processPragma r
_ -> return False

isInlinedFunction :: QName -> C Bool
isInlinedFunction q = do
getConstInfo q >>= \case
Defn{defName = r, theDef = Function{}} ->
(InlinePragma ==) <$> processPragma r
_ -> return False

checkInstance :: Term -> C ()
Expand Down
2 changes: 1 addition & 1 deletion src/Agda2Hs/HsUtils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -267,4 +267,4 @@ patToExp = \case
_ -> Nothing

data Strictness = Lazy | Strict
deriving Show
deriving (Eq, Show)
4 changes: 3 additions & 1 deletion src/Agda2Hs/Pragma.hs
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ getForeignPragmas exts = do

data ParsedPragma
= NoPragma
| InlinePragma
| DefaultPragma [Hs.Deriving ()]
| ClassPragma [String]
| ExistingClassPragma
| UnboxPragma Strictness
| TransparentPragma
| NewTypePragma [Hs.Deriving ()]
| DerivePragma (Maybe (Hs.DerivStrategy ()))
deriving Show
deriving (Eq, Show)

derivePragma :: String
derivePragma = "derive"
Expand Down Expand Up @@ -85,6 +86,7 @@ processPragma qn = liftTCM (getUniqueCompilerPragma pragmaName qn) >>= \case
Nothing -> return NoPragma
Just (CompilerPragma _ s)
| "class" `isPrefixOf` s -> return $ ClassPragma (words $ drop 5 s)
| s == "inline" -> return InlinePragma
| s == "existing-class" -> return ExistingClassPragma
| s == "unboxed" -> return $ UnboxPragma Lazy
| s == "unboxed-strict" -> return $ UnboxPragma Strict
Expand Down
2 changes: 2 additions & 0 deletions test/AllTests.agda
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import IOInput
import Issue200
import Issue169
import Issue210
import Inlining

{-# FOREIGN AGDA2HS
import Issue14
Expand Down Expand Up @@ -120,4 +121,5 @@ import IOInput
import Issue200
import Issue169
import Issue210
import Inlining
#-}
35 changes: 35 additions & 0 deletions test/Inlining.agda
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
module Inlining where

open import Haskell.Prelude

record Wrap (a : Set) : Set where
constructor Wrapped
field
unwrap : a
open Wrap public
{-# COMPILE AGDA2HS Wrap unboxed #-}

mapWrap : (f : a b) Wrap a Wrap b
mapWrap f (Wrapped x) = Wrapped (f x)
{-# COMPILE AGDA2HS mapWrap inline #-}

mapWrap2 : (f : a b c) Wrap a Wrap b Wrap c
mapWrap2 f (Wrapped x) (Wrapped y) = Wrapped (f x y)
{-# COMPILE AGDA2HS mapWrap2 inline #-}

test1 : Wrap Int Wrap Int
test1 x = mapWrap (1 +_) x
{-# COMPILE AGDA2HS test1 #-}

test2 : Wrap Int Wrap Int Wrap Int
test2 x y = mapWrap2 _+_ x y
{-# COMPILE AGDA2HS test2 #-}

-- partial application of inline function
test3 : Wrap Int Wrap Int Wrap Int
test3 x = mapWrap2 _+_ x
{-# COMPILE AGDA2HS test3 #-}

test4 : Wrap Int Wrap Int Wrap Int
test4 = mapWrap2 _+_
{-# COMPILE AGDA2HS test4 #-}
1 change: 1 addition & 0 deletions test/golden/AllTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,5 @@ import IOInput
import Issue200
import Issue169
import Issue210
import Inlining

14 changes: 14 additions & 0 deletions test/golden/Inlining.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module Inlining where

test1 :: Int -> Int
test1 x = 1 + x

test2 :: Int -> Int -> Int
test2 x y = x + y

test3 :: Int -> Int -> Int
test3 x = \ y -> x + y

test4 :: Int -> Int -> Int
test4 = \ x y -> x + y

0 comments on commit fa97dca

Please sign in to comment.