Skip to content

Commit

Permalink
Merge pull request #11 from lemastero/more-modular-tested
Browse files Browse the repository at this point in the history
More modular tested
  • Loading branch information
lemastero authored May 3, 2024
2 parents 520d17f + 4b76a54 commit 93523db
Show file tree
Hide file tree
Showing 12 changed files with 201 additions and 130 deletions.
15 changes: 7 additions & 8 deletions .github/workflows/haskell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,17 @@ jobs:
- name: Build
run: cabal build all

- name: Run tests
- name: Run Haskell tests
run: cabal test all

# compile example Agda files into Scala
- name: Run tests
run: cabal run -- agda2scala ./test/adts.agda
- name: Compile example Agda code to Scala
run: cabal run -- agda2scala ./examples/adts.agda

# compile result Scala code
- uses: actions/setup-java@v4
- name: Set up JVM including SBT
uses: actions/setup-java@v4
with:
distribution: 'adopt'
java-version: '21'

- name: Run tests
run: sbt clean compile
- name: compile result Scala code
run: cd examples && sbt clean compile
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,10 @@
* Save result in scala file
* Generate package declaration
* Generate ADT with only case objects
* CI compile example Agda code

## 0.1.0.2

* Refactor: split grammar, pretty printing, extracting grammar
* Generate ADT with case classes and case objects
* CI compile generated Scala code
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ cabal build all

```sh
cabal run -- agda2scala --help
cabal run -- agda2scala ./test/adts.agda
cabal run -- agda2scala ./examples/adts.agda
```

* Run tests
Expand Down
7 changes: 5 additions & 2 deletions agda2scala.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ name: agda2scala
-- PVP summary: +-+------- breaking API changes
-- | | +----- non-breaking API additions
-- | | | +--- code changes with no API change
version: 0.1.0.1
version: 0.1.0.2
description: Allows to export Scala source files from formal specification in Agda
license: MIT
license-file: LICENSE
Expand All @@ -28,6 +28,9 @@ common warnings
library
hs-source-dirs: src
exposed-modules: Agda.Compiler.Scala.Backend
Agda.Compiler.Scala.ScalaExpr
Agda.Compiler.Scala.AgdaToScalaExpr
Agda.Compiler.Scala.PrintScalaExpr
Paths_agda2scala
autogen-modules: Paths_agda2scala
build-depends: base >= 4.10 && < 4.20,
Expand All @@ -54,7 +57,7 @@ test-suite agda2scala-test
default-language: Haskell2010
type: exitcode-stdio-1.0
hs-source-dirs: test
main-is: ScalaBackendTest.hs
main-is: Main.hs
build-depends: base >=4.10 && < 4.20,
Agda >= 2.6.4 && < 2.6.5,
HUnit >= 1.6.2.0,
Expand Down
4 changes: 2 additions & 2 deletions test/adts.agda → examples/adts.agda
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module test.adts where
module examples.adts where

-- simple sum type no arguments - sealed trait + case objects
data Rgb : Set where
Expand All @@ -7,7 +7,7 @@ data Rgb : Set where
Blue : Rgb
{-# COMPILE AGDA2SCALA Rgb #-}

-- TODO simple sum type with arguments
-- simple sum type with arguments - sealed trait + case class

data Color : Set where
Light : Rgb -> Color
Expand Down
33 changes: 33 additions & 0 deletions src/Agda/Compiler/Scala/AgdaToScalaExpr.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
module Agda.Compiler.Scala.AgdaToScalaExpr (
compileDefn
) where

import Agda.Compiler.Backend ( funCompiled, funClauses, Defn(..), RecordData(..))
import Agda.Syntax.Abstract.Name ( QName )
import Agda.Syntax.Common.Pretty ( prettyShow )
import Agda.Syntax.Common ( Arg(..), ArgName, Named(..) )
import Agda.Syntax.Internal (
Clause(..), DeBruijnPattern, DBPatVar(..), Dom(..), unDom, PatternInfo(..), Pattern'(..),
qnameName, qnameModule, Telescope, Tele(..), Term(..), Type, Type''(..) )
import Agda.TypeChecking.Monad.Base ( Definition(..) )
import Agda.TypeChecking.Monad
import Agda.TypeChecking.CompiledClause ( CompiledClauses(..), CompiledClauses'(..) )

import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaExpr(..) )

compileDefn :: QName -> Defn -> ScalaExpr
compileDefn defName theDef = case theDef of
Datatype{dataCons = dataCons} ->
compileDataType defName dataCons
Function{funCompiled = funDef, funClauses = fc} ->
Unhandled "compileDefn Function" (show defName ++ "\n = \n" ++ show theDef)
RecordDefn(RecordData{_recFields = recFields, _recTel = recTel}) ->
Unhandled "compileDefn RecordDefn" (show defName ++ "\n = \n" ++ show theDef)
other ->
Unhandled "compileDefn other" (show defName ++ "\n = \n" ++ show theDef)

compileDataType :: QName -> [QName] -> ScalaExpr
compileDataType defName fields = SeAdt (showName defName) (map showName fields)

showName :: QName -> ScalaName
showName = prettyShow . qnameName
130 changes: 28 additions & 102 deletions src/Agda/Compiler/Scala/Backend.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
{-# LANGUAGE LambdaCase, RecordWildCards #-}

module Agda.Compiler.Scala.Backend (
runScalaBackend
, scalaBackend
Expand All @@ -8,8 +6,16 @@ module Agda.Compiler.Scala.Backend (
) where

import Control.DeepSeq ( NFData(..) )
import Control.Monad ( unless )
import Control.Monad.IO.Class ( MonadIO(liftIO) )
import qualified Data.List.NonEmpty as Nel
import Data.Maybe ( fromMaybe )
import Data.Map ( Map )
import qualified Data.Text.IO as T
import Data.Version ( showVersion )
import System.Console.GetOpt ( OptDescr(Option), ArgDescr(ReqArg) )

import Paths_agda2scala ( version )

import Agda.Main ( runAgda )
import Agda.Compiler.Backend (
Expand All @@ -21,29 +27,18 @@ import Agda.Compiler.Backend (
, Recompile(..)
, TCM )
import Agda.Interaction.Options ( OptDescr )
import Agda.Syntax.TopLevelModuleName ( TopLevelModuleName )

import Control.Monad ( unless )
import Control.Monad.IO.Class ( MonadIO(liftIO) )
import qualified Data.List.NonEmpty as Nel
import Data.Maybe ( fromMaybe )
import Data.Version ( showVersion )
import System.Console.GetOpt ( OptDescr(Option), ArgDescr(ReqArg) )

import Paths_agda2scala ( version )

import Agda.Compiler.Common ( curIF, compileDir )
import Agda.Compiler.Backend ( IsMain, Defn(..) )
import Agda.Compiler.Backend ( IsMain )
import Agda.Syntax.Abstract.Name ( QName )
import Agda.Syntax.Common.Pretty ( prettyShow )
import Agda.Syntax.Common ( Arg(..), ArgName, Named(..), moduleNameParts )
import Agda.Syntax.Internal (
Clause(..), DeBruijnPattern, DBPatVar(..), Dom(..), unDom, PatternInfo(..), Pattern'(..),
qnameName, qnameModule, Telescope, Tele(..), Term(..), Type, Type''(..) )
import Agda.Syntax.Common ( moduleNameParts )
import Agda.Syntax.Internal ( qnameModule )
import Agda.Syntax.TopLevelModuleName ( TopLevelModuleName, moduleNameToFileName )
import Agda.TypeChecking.Monad.Base ( Definition(..) )
import Agda.TypeChecking.Monad
import Agda.TypeChecking.CompiledClause ( CompiledClauses(..), CompiledClauses'(..) )

import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaExpr(..), unHandled )
import Agda.Compiler.Scala.AgdaToScalaExpr ( compileDefn )
import Agda.Compiler.Scala.PrintScalaExpr ( printScalaExpr )

runScalaBackend :: IO ()
runScalaBackend = runAgda [scalaBackend]
Expand All @@ -62,18 +57,6 @@ type ScalaModuleEnv = ()
type ScalaModule = ()
type ScalaDefinition = ScalaExpr

type ScalaName = String

data ScalaExpr
= SePackage ScalaName [ScalaExpr]
| SeAdt ScalaName [ScalaName]
| Unhandled ScalaName String
deriving ( Show )

unHandled :: ScalaExpr -> Bool
unHandled (Unhandled _ _) = True
unHandled _ = False

{- Backend contains implementations of hooks called around compilation of Agda code -}
scalaBackend' :: Backend' ScalaFlags ScalaEnv ScalaModuleEnv ScalaModule ScalaDefinition
scalaBackend' = Backend'
Expand All @@ -82,7 +65,7 @@ scalaBackend' = Backend'
, options = defaultOptions
, commandLineFlags = scalaCmdLineFlags
, isEnabled = const True
, preCompile = scalaPreCompile
, preCompile = return
, compileDef = scalaCompileDef
, postCompile = scalaPostCompile
, preModule = scalaPreModule
Expand All @@ -91,7 +74,6 @@ scalaBackend' = Backend'
, mayEraseType = const $ return True
}

-- TODO get version from cabal definition, perhaps git hash too
scalaBackendVersion :: Maybe String
scalaBackendVersion = Just (showVersion version)

Expand All @@ -103,48 +85,27 @@ defaultOptions = Options{ optOutDir = Nothing }
-- TODO perhaps add option to use annotations from siddhartha-gadgil/ProvingGround library
scalaCmdLineFlags :: [OptDescr (Flag ScalaFlags)]
scalaCmdLineFlags = [
Option ['o'] ["out-dir"] (ReqArg outdirOpt "DIR")
Option ['o'] ["out-dir"] (ReqArg outDirOpt "DIR")
"Write output files to DIR. (default: project root)"
]

outdirOpt :: Monad m => FilePath -> Options -> m Options
outdirOpt dir opts = return opts{ optOutDir = Just dir }

scalaPreCompile :: ScalaFlags -> TCM ScalaEnv
scalaPreCompile = return
outDirOpt :: Monad m => FilePath -> Options -> m Options
outDirOpt dir opts = return opts{ optOutDir = Just dir }

-- TODO perhaps transform definitions here, ATM just pass it with extra information is it main
-- Rust backend perform transformation to Higher IR
-- Scheme pass as is (like here)
scalaCompileDef :: ScalaEnv
-> ScalaModuleEnv
-> IsMain
-> Definition
-> TCM ScalaDefinition
scalaCompileDef _ _ isMain Defn{..}
scalaCompileDef _ _ isMain Defn{theDef = theDef, defName = defName}
= withCurrentModule (qnameModule defName)
$ getUniqueCompilerPragma "AGDA2SCALA" defName >>= \case
Nothing -> return $ Unhandled "compile" ""
Just (CompilerPragma _ _) ->
return $ compileDefn defName theDef

compileDefn :: QName -> Defn -> ScalaDefinition
compileDefn defName theDef = case theDef of
Datatype{dataCons = fields} ->
compileDataType defName fields
Function{funCompiled = funDef, funClauses = fc} ->
Unhandled "compileDefn Function" (show defName ++ "\n = \n" ++ show theDef)
RecordDefn(RecordData{_recFields = recFields, _recTel = recTel}) ->
Unhandled "compileDefn RecordDefn" (show defName ++ "\n = \n" ++ show theDef)
other ->
Unhandled "compileDefn other" (show defName ++ "\n = \n" ++ show theDef)

compileDataType :: QName -> [QName] -> ScalaDefinition
compileDataType defName fields = SeAdt (showName defName) (map showName fields)
-- Unhandled "compileDefn Datatype" (show defName ++ "\n = \n" ++ show theDef)

showName :: QName -> ScalaName
showName = prettyShow . qnameName
$ getUniqueCompilerPragma "AGDA2SCALA" defName >>= handlePragma defName theDef

handlePragma :: QName -> Defn -> Maybe CompilerPragma -> TCMT IO ScalaDefinition
handlePragma defName theDef pragma = case pragma of
Nothing -> return $ Unhandled "" ""
Just (CompilerPragma _ _) ->
return $ compileDefn defName theDef

scalaPostCompile :: ScalaEnv
-> IsMain
Expand All @@ -161,7 +122,6 @@ scalaPreModule _ _ _ _ = do
setScope . iInsideScope =<< curIF
return $ Recompile ()

-- TODO implement translation here
scalaPostModule :: ScalaEnv
-> ScalaModuleEnv
-> IsMain
Expand All @@ -173,7 +133,7 @@ scalaPostModule env modEnv isMain mName cdefs = do
compileLog $ "compiling " <> (outFile outDir)
unless (all unHandled cdefs) $ liftIO
$ writeFile (outFile outDir)
$ prettyPrintScalaExpr (compileModule mName cdefs)
$ printScalaExpr (compileModule mName cdefs)
where
fileName = scalaFileName mName
dirName outDir = fromMaybe outDir (optOutDir env)
Expand All @@ -191,37 +151,3 @@ moduleName n = prettyShow (Nel.last (moduleNameParts n))

compileLog :: String -> TCMT IO ()
compileLog msg = liftIO $ putStrLn msg

prettyPrintScalaExpr :: ScalaDefinition -> String
prettyPrintScalaExpr def = case def of
(SePackage mName defs) ->
moduleHeader mName
<> defsSeparator <> (
defsSeparator -- empty line before first definition in package
<> combineLines (map prettyPrintScalaExpr defs))
<> defsSeparator
(SeAdt adtName adtCases) -> "sealed trait" <> exprSeparator <> adtName <> defsSeparator <> unlines (map (prettyPrintCaseObject adtName) adtCases)
-- TODO not sure why I get this
-- (Unhandled name payload) -> "TODO " ++ (show name) ++ " " ++ (show payload)
(Unhandled name payload) -> ""
-- XXX at the end there should be no Unhandled expression
-- other -> "unsupported prettyPrintScalaExpr " ++ (show other)


prettyPrintCaseObject :: ScalaName -> ScalaName -> String
prettyPrintCaseObject superName xs = "case object" <> exprSeparator <> xs <> exprSeparator <> "extends" <> exprSeparator <> superName

moduleHeader :: String -> String
moduleHeader mName = "package" <> exprSeparator <> mName <> exprSeparator

bracket :: String -> String
bracket str = "{\n" <> str <> "\n}"

defsSeparator :: String
defsSeparator = "\n"

exprSeparator :: String
exprSeparator = " "

combineLines :: [String] -> String
combineLines xs = unlines (filter (not . null) xs)
49 changes: 49 additions & 0 deletions src/Agda/Compiler/Scala/PrintScalaExpr.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
module Agda.Compiler.Scala.PrintScalaExpr ( printScalaExpr
, printCaseObject
, printSealedTrait
, printPackage
) where

import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaExpr(..) )

printScalaExpr :: ScalaExpr -> String
printScalaExpr def = case def of
(SePackage pName defs) ->
printPackage pName <> defsSeparator
<> (
blankLine -- between package declaration and first definition
<> combineLines (map printScalaExpr defs)
)
<> defsSeparator
(SeAdt adtName adtCases) ->
printSealedTrait adtName
<> defsSeparator
<> unlines (map (printCaseObject adtName) adtCases)
(Unhandled name payload) -> "" -- for development comment out this and uncomment below
-- (Unhandled name payload) -> "TODO " ++ (show name) ++ " " ++ (show payload)
-- other -> "unsupported printScalaExpr " ++ (show other)

printSealedTrait :: ScalaName -> String
printSealedTrait adtName = "sealed trait" <> exprSeparator <> adtName

printCaseObject :: ScalaName -> ScalaName -> String
printCaseObject superName caseName =
"case object" <> exprSeparator <> caseName <> exprSeparator <> "extends" <> exprSeparator <> superName

printPackage :: ScalaName -> String
printPackage pName = "package" <> exprSeparator <> pName

bracket :: String -> String
bracket str = "{\n" <> str <> "\n}"

defsSeparator :: String
defsSeparator = "\n"

blankLine :: String
blankLine = "\n"

exprSeparator :: String
exprSeparator = " "

combineLines :: [String] -> String
combineLines xs = unlines (filter (not . null) xs)
Loading

0 comments on commit 93523db

Please sign in to comment.