Skip to content

Commit

Permalink
handle multiple named function arguments - named arguments, single cl…
Browse files Browse the repository at this point in the history
…ause
  • Loading branch information
lemastero committed May 5, 2024
1 parent bed63b6 commit c868641
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 49 deletions.
30 changes: 24 additions & 6 deletions examples/adts.agda
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module examples.adts where

-- simple product type no arguments - sealed trait + case objects

-- simple sum type no arguments - sealed trait + case objects
data Rgb : Set where
Red : Rgb
Green : Rgb
Expand All @@ -13,11 +12,12 @@ data Bool : Set where
False : Bool
{-# COMPILE AGDA2SCALA Bool #-}

-- trivial function with single argument
-- simple sum type with arguments - sealed trait + case class

idRgb : Rgb -> Rgb
idRgb x = x
{-# COMPILE AGDA2SCALA idRgb #-}
data Color : Set where
Light : Rgb -> Color
Dark : Rgb -> Color
-- TODO {-# COMPILE AGDA2SCALA Color #-}

-- simple sum type - case class

Expand All @@ -27,3 +27,21 @@ record RgbPair : Set where
fst : Rgb
snd : Bool
{-# COMPILE AGDA2SCALA RgbPair #-}

-- trivial function with single argument

idRgb : Rgb -> Rgb
idRgb theArg = theArg
{-# COMPILE AGDA2SCALA idRgb #-}

-- const function with one named argument

rgbConstTrue1 : (rgb : Rgb) Bool
rgbConstTrue1 foo = True
{-# COMPILE AGDA2SCALA rgbConstTrue1 #-}

-- function with multiple named arguments

and0 : (rgbPairArg : RgbPair) -> (rgbArg : Rgb) -> Bool
and0 a b = False
{-# COMPILE AGDA2SCALA and0 #-}
8 changes: 6 additions & 2 deletions examples/adts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ sealed trait Bool
case object True extends Bool
case object False extends Bool

def idRgb(x: Rgb): Rgb = x

final case class RgbPair(snd: Bool, fst: Rgb)

def idRgb(theArg: Rgb): Rgb = theArg

def rgbConstTrue1(rgb: Rgb): Bool = foo

def and0(rgbArg: Rgb, rgbPairArg: RgbPair): Bool = a
}
95 changes: 56 additions & 39 deletions src/Agda/Compiler/Scala/AgdaToScalaExpr.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
module Agda.Compiler.Scala.AgdaToScalaExpr (
compileDefn
) where
module Agda.Compiler.Scala.AgdaToScalaExpr ( compileDefn ) where

import Agda.Compiler.Backend ( funCompiled, funClauses, Defn(..), RecordData(..))
import Agda.Syntax.Abstract.Name ( QName )
Expand All @@ -22,8 +20,8 @@ compileDefn :: QName -> Defn -> ScalaExpr
compileDefn defName theDef = case theDef of
Datatype{dataCons = dataCons} ->
compileDataType defName dataCons
Function{funCompiled = funDef, funClauses = fc} ->
compileFunction defName funDef fc
Function{funCompiled = funCompiled, funClauses = funClauses} ->
compileFunction defName funCompiled funClauses
RecordDefn(RecordData{_recFields = recFields, _recTel = recTel}) ->
compileRecord defName recFields recTel
other ->
Expand All @@ -42,32 +40,52 @@ compileFunction :: QName
-> Maybe CompiledClauses
-> [Clause]
-> ScalaExpr
compileFunction defName funDef fc =
compileFunction defName funCompiled funClauses =
SeFun
(fromQName defName)
[SeVar (compileFunctionArgument fc) (compileFunctionArgType fc)] -- TODO many function arguments
(compileFunctionResultType fc)
(compileFunctionBody funDef)

compileFunctionArgument :: [Clause] -> ScalaName
compileFunctionArgument [] = ""
compileFunctionArgument [fc] = fromDeBruijnPattern (namedThing (unArg (head (namedClausePats fc))))
compileFunctionArgument xs = error "unsupported compileFunctionArgument" ++ (show xs) -- show xs

compileFunctionArgType :: [Clause] -> ScalaType
compileFunctionArgType [ Clause{clauseTel = ct} ] = fromTelescope ct
compileFunctionArgType xs = error "unsupported compileFunctionArgType" ++ (show xs)

fromTelescope :: Telescope -> ScalaName -- TODO PP probably parent should be different, use fold on telescope above
fromTelescope tel = case tel of
ExtendTel a _ -> fromDom a
other -> error ("unhandled fromType" ++ show other)
(fromQName defName) -- ++ "\n FULL FUNCTION DEFINITION \n[\n" ++ (show theDef) ++ "\n]\n")
(funArgs funClauses)
(compileFunctionResultType funClauses)
-- you can get body of the function using:
-- - FunctionData _funCompiled
-- - FunctionData _funClauses Clause clauseBody
-- see:
-- https://hackage.haskell.org/package/Agda-2.6.4.3/docs/Agda-TypeChecking-Monad-Base.html#t:FunctionData
-- https://hackage.haskell.org/package/Agda/docs/Agda-Syntax-Internal.html#t:Clause
-- at this point both contain the same info (at least in simple cases)
(compileFunctionBody funCompiled)

funArgs :: [Clause] -> [SeVar]
funArgs [] = []
funArgs (c : cs) = funArgsFromClause c

funArgsFromClause :: Clause -> [SeVar]
funArgsFromClause c@Clause{clauseTel = clauseTel} = case parsedArgs of
[(SeVar "" varType)] -> [SeVar (hackyFunArgNameFromClause c) varType]
args -> args
where
parsedArgs = foldl varsFromTelescope [] clauseTel

-- this is extremely hacky way to get function argument name
-- for identity function
-- I apparently do not understand enough how this works
-- or perhaps this is bug in Agda compiler :)
hackyFunArgNameFromClause :: Clause -> ScalaName
hackyFunArgNameFromClause fc = hackyFunArgNameFromDeBruijnPattern (namedThing (unArg
(head -- TODO perhaps iterate here

Check warning on line 74 in src/Agda/Compiler/Scala/AgdaToScalaExpr.hs

View workflow job for this annotation

GitHub Actions / agda2scala

In the use of ‘head’
(namedClausePats fc))))

hackyFunArgNameFromDeBruijnPattern :: DeBruijnPattern -> ScalaName
hackyFunArgNameFromDeBruijnPattern d = case d of
VarP a b -> (dbPatVarName b)
a@(ConP x y z) -> "\n hackyFunArgNameFromDeBruijnPattern \n[\n" ++ show a ++ "\n]\n"
other -> error ("hackyFunArgNameFromDeBruijnPattern " ++ show other)

nameFromDom :: Dom Type -> ScalaName
nameFromDom dt = case (domName dt) of
Nothing -> error ("nameFromDom" ++ show dt)
Nothing -> ""
Just a -> namedNameToStr a

-- https://hackage.haskell.org/package/Agda-2.6.4.3/docs/Agda-Syntax-Common.html#t:NamedName
namedNameToStr :: NamedName -> ScalaName
namedNameToStr n = rangedThing (woThing n)

Expand All @@ -76,39 +94,38 @@ fromDom x = fromType (unDom x)

compileFunctionResultType :: [Clause] -> ScalaType
compileFunctionResultType [Clause{clauseType = ct}] = fromMaybeType ct
compileFunctionResultType other = error ("unhandled compileFunctionResultType" ++ show other)
compileFunctionResultType (Clause{clauseType = ct} : xs) = fromMaybeType ct
compileFunctionResultType other = error "Fatal error - function has not clause."

fromMaybeType :: Maybe (Arg Type) -> ScalaName
fromMaybeType (Just argType) = fromArgType argType
fromMaybeType other = error ("unhandled fromMaybeType" ++ show other)
fromMaybeType other = error ("\nunhandled fromMaybeType \n[" ++ show other ++ "]\n")

fromArgType :: Arg Type -> ScalaName
fromArgType arg = fromType (unArg arg)

fromType :: Type -> ScalaName
fromType t = case t of
a@(El _ ue) -> fromTerm ue
other -> error ("unhandled fromType" ++ show other)
El _ ue -> fromTerm ue
other -> error ("unhandled fromType [" ++ show other ++ "]")

Check warning on line 110 in src/Agda/Compiler/Scala/AgdaToScalaExpr.hs

View workflow job for this annotation

GitHub Actions / agda2scala

Pattern match is redundant

-- https://hackage.haskell.org/package/Agda-2.6.4.3/docs/Agda-Syntax-Internal.html#t:Term
fromTerm :: Term -> ScalaName
fromTerm t = case t of
Def qname el -> fromQName qname
other -> error ("unhandled fromTerm" ++ show other)

fromDeBruijnPattern :: DeBruijnPattern -> ScalaName
fromDeBruijnPattern d = case d of
VarP a b -> (dbPatVarName b)
a@(ConP x y z) -> show a
other -> error ("unhandled fromDeBruijnPattern" ++ show other)
Def qName elims -> fromQName qName
Var n elims -> "\nunhandled fromTerm Var \n[" ++ show t ++ "]\n"
other -> error ("\nunhandled fromTerm [" ++ show other ++ "]\n")

compileFunctionBody :: Maybe CompiledClauses -> FunBody
compileFunctionBody (Just funDef) = fromCompiledClauses funDef
compileFunctionBody funDef = error ("unhandled compileFunctionBody " ++ show funDef)
compileFunctionBody funDef = error "Fatal error - function body is not compiled."

-- https://hackage.haskell.org/package/Agda/docs/Agda-TypeChecking-CompiledClause.html#t:CompiledClauses
fromCompiledClauses :: CompiledClauses -> FunBody
fromCompiledClauses cc = case cc of
(Case argInt caseCompiledClauseTerm) -> "WIP" --"\nCase fromCompiledClauses\n[\n" ++ (show cc) ++ "\n]\n"
(Done (x:xs) term) -> fromArgName x
other -> error ("unhandled fromCompiledClauses " ++ show other)
other -> "\nunhandled fromCompiledClauses \n\n[" ++ show other ++ "]\n"

fromArgName :: Arg ArgName -> FunBody
fromArgName = unArg
Expand Down
4 changes: 2 additions & 2 deletions src/Agda/Compiler/Scala/PrintScalaExpr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ printScalaExpr def = case def of
<> defsSeparator
(SeFun fName args resType funBody) ->
"def" <> exprSeparator <> fName
<> "(" <> combineLines (map printVar args) <> ")"
<> "(" <> combineThem (map printVar args) <> ")"
<> ":" <> exprSeparator <> resType <> exprSeparator
<> "=" <> exprSeparator <> funBody
<> defsSeparator
(SeProd name args) -> printCaseClass name args
(SeProd name args) -> printCaseClass name args <> defsSeparator
(Unhandled "" payload) -> ""
(Unhandled name payload) -> "TODO " ++ (show name) ++ " " ++ (show payload)
other -> "unsupported printScalaExpr " ++ (show other)

Check warning on line 35 in src/Agda/Compiler/Scala/PrintScalaExpr.hs

View workflow job for this annotation

GitHub Actions / agda2scala

Pattern match is redundant
Expand Down

0 comments on commit c868641

Please sign in to comment.