diff --git a/.github/workflows/haskell.yml b/.github/workflows/haskell.yml index 71047fd..cddf334 100644 --- a/.github/workflows/haskell.yml +++ b/.github/workflows/haskell.yml @@ -49,18 +49,17 @@ jobs: - name: Build run: cabal build all - - name: Run tests + - name: Run Haskell tests run: cabal test all - # compile example Agda files into Scala - - name: Run tests - run: cabal run -- agda2scala ./test/adts.agda + - name: Compile example Agda code to Scala + run: cabal run -- agda2scala ./examples/adts.agda - # compile result Scala code - - uses: actions/setup-java@v4 + - name: Set up JVM including SBT + uses: actions/setup-java@v4 with: distribution: 'adopt' java-version: '21' - - name: Run tests - run: sbt clean compile + - name: compile result Scala code + run: cd examples && sbt clean compile diff --git a/CHANGELOG.md b/CHANGELOG.md index 6971422..982205c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,3 +5,10 @@ * Save result in scala file * Generate package declaration * Generate ADT with only case objects +* CI compile example Agda code + +## 0.1.0.2 + +* Refactor: split grammar, pretty printing, extracting grammar +* Generate ADT with case classes and case objects +* CI compile generated Scala code diff --git a/README.md b/README.md index a2a6c86..a89e45e 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ cabal build all ```sh cabal run -- agda2scala --help -cabal run -- agda2scala ./test/adts.agda +cabal run -- agda2scala ./examples/adts.agda ``` * Run tests diff --git a/agda2scala.cabal b/agda2scala.cabal index b039f07..46ebd9a 100644 --- a/agda2scala.cabal +++ b/agda2scala.cabal @@ -6,7 +6,7 @@ name: agda2scala -- PVP summary: +-+------- breaking API changes -- | | +----- non-breaking API additions -- | | | +--- code changes with no API change -version: 0.1.0.1 +version: 0.1.0.2 description: Allows to export Scala source files from formal specification in Agda license: MIT license-file: LICENSE @@ -28,6 +28,9 @@ common warnings library hs-source-dirs: src exposed-modules: Agda.Compiler.Scala.Backend + Agda.Compiler.Scala.ScalaExpr + Agda.Compiler.Scala.AgdaToScalaExpr + Agda.Compiler.Scala.PrintScalaExpr Paths_agda2scala autogen-modules: Paths_agda2scala build-depends: base >= 4.10 && < 4.20, @@ -54,7 +57,7 @@ test-suite agda2scala-test default-language: Haskell2010 type: exitcode-stdio-1.0 hs-source-dirs: test - main-is: ScalaBackendTest.hs + main-is: Main.hs build-depends: base >=4.10 && < 4.20, Agda >= 2.6.4 && < 2.6.5, HUnit >= 1.6.2.0, diff --git a/test/adts.agda b/examples/adts.agda similarity index 74% rename from test/adts.agda rename to examples/adts.agda index d910f0d..5cc78bf 100644 --- a/test/adts.agda +++ b/examples/adts.agda @@ -1,4 +1,4 @@ -module test.adts where +module examples.adts where -- simple sum type no arguments - sealed trait + case objects data Rgb : Set where @@ -7,7 +7,7 @@ data Rgb : Set where Blue : Rgb {-# COMPILE AGDA2SCALA Rgb #-} --- TODO simple sum type with arguments +-- simple sum type with arguments - sealed trait + case class data Color : Set where Light : Rgb -> Color diff --git a/src/Agda/Compiler/Scala/AgdaToScalaExpr.hs b/src/Agda/Compiler/Scala/AgdaToScalaExpr.hs new file mode 100644 index 0000000..285e015 --- /dev/null +++ b/src/Agda/Compiler/Scala/AgdaToScalaExpr.hs @@ -0,0 +1,33 @@ +module Agda.Compiler.Scala.AgdaToScalaExpr ( + compileDefn + ) where + +import Agda.Compiler.Backend ( funCompiled, funClauses, Defn(..), RecordData(..)) +import Agda.Syntax.Abstract.Name ( QName ) +import Agda.Syntax.Common.Pretty ( prettyShow ) +import Agda.Syntax.Common ( Arg(..), ArgName, Named(..) ) +import Agda.Syntax.Internal ( + Clause(..), DeBruijnPattern, DBPatVar(..), Dom(..), unDom, PatternInfo(..), Pattern'(..), + qnameName, qnameModule, Telescope, Tele(..), Term(..), Type, Type''(..) ) +import Agda.TypeChecking.Monad.Base ( Definition(..) ) +import Agda.TypeChecking.Monad +import Agda.TypeChecking.CompiledClause ( CompiledClauses(..), CompiledClauses'(..) ) + +import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaExpr(..) ) + +compileDefn :: QName -> Defn -> ScalaExpr +compileDefn defName theDef = case theDef of + Datatype{dataCons = dataCons} -> + compileDataType defName dataCons + Function{funCompiled = funDef, funClauses = fc} -> + Unhandled "compileDefn Function" (show defName ++ "\n = \n" ++ show theDef) + RecordDefn(RecordData{_recFields = recFields, _recTel = recTel}) -> + Unhandled "compileDefn RecordDefn" (show defName ++ "\n = \n" ++ show theDef) + other -> + Unhandled "compileDefn other" (show defName ++ "\n = \n" ++ show theDef) + +compileDataType :: QName -> [QName] -> ScalaExpr +compileDataType defName fields = SeAdt (showName defName) (map showName fields) + +showName :: QName -> ScalaName +showName = prettyShow . qnameName diff --git a/src/Agda/Compiler/Scala/Backend.hs b/src/Agda/Compiler/Scala/Backend.hs index 12b02f2..2799831 100644 --- a/src/Agda/Compiler/Scala/Backend.hs +++ b/src/Agda/Compiler/Scala/Backend.hs @@ -1,5 +1,3 @@ -{-# LANGUAGE LambdaCase, RecordWildCards #-} - module Agda.Compiler.Scala.Backend ( runScalaBackend , scalaBackend @@ -8,8 +6,16 @@ module Agda.Compiler.Scala.Backend ( ) where import Control.DeepSeq ( NFData(..) ) +import Control.Monad ( unless ) +import Control.Monad.IO.Class ( MonadIO(liftIO) ) +import qualified Data.List.NonEmpty as Nel +import Data.Maybe ( fromMaybe ) import Data.Map ( Map ) import qualified Data.Text.IO as T +import Data.Version ( showVersion ) +import System.Console.GetOpt ( OptDescr(Option), ArgDescr(ReqArg) ) + +import Paths_agda2scala ( version ) import Agda.Main ( runAgda ) import Agda.Compiler.Backend ( @@ -21,29 +27,18 @@ import Agda.Compiler.Backend ( , Recompile(..) , TCM ) import Agda.Interaction.Options ( OptDescr ) -import Agda.Syntax.TopLevelModuleName ( TopLevelModuleName ) - -import Control.Monad ( unless ) -import Control.Monad.IO.Class ( MonadIO(liftIO) ) -import qualified Data.List.NonEmpty as Nel -import Data.Maybe ( fromMaybe ) -import Data.Version ( showVersion ) -import System.Console.GetOpt ( OptDescr(Option), ArgDescr(ReqArg) ) - -import Paths_agda2scala ( version ) - import Agda.Compiler.Common ( curIF, compileDir ) -import Agda.Compiler.Backend ( IsMain, Defn(..) ) +import Agda.Compiler.Backend ( IsMain ) import Agda.Syntax.Abstract.Name ( QName ) import Agda.Syntax.Common.Pretty ( prettyShow ) -import Agda.Syntax.Common ( Arg(..), ArgName, Named(..), moduleNameParts ) -import Agda.Syntax.Internal ( - Clause(..), DeBruijnPattern, DBPatVar(..), Dom(..), unDom, PatternInfo(..), Pattern'(..), - qnameName, qnameModule, Telescope, Tele(..), Term(..), Type, Type''(..) ) +import Agda.Syntax.Common ( moduleNameParts ) +import Agda.Syntax.Internal ( qnameModule ) import Agda.Syntax.TopLevelModuleName ( TopLevelModuleName, moduleNameToFileName ) -import Agda.TypeChecking.Monad.Base ( Definition(..) ) import Agda.TypeChecking.Monad -import Agda.TypeChecking.CompiledClause ( CompiledClauses(..), CompiledClauses'(..) ) + +import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaExpr(..), unHandled ) +import Agda.Compiler.Scala.AgdaToScalaExpr ( compileDefn ) +import Agda.Compiler.Scala.PrintScalaExpr ( printScalaExpr ) runScalaBackend :: IO () runScalaBackend = runAgda [scalaBackend] @@ -62,18 +57,6 @@ type ScalaModuleEnv = () type ScalaModule = () type ScalaDefinition = ScalaExpr -type ScalaName = String - -data ScalaExpr - = SePackage ScalaName [ScalaExpr] - | SeAdt ScalaName [ScalaName] - | Unhandled ScalaName String - deriving ( Show ) - -unHandled :: ScalaExpr -> Bool -unHandled (Unhandled _ _) = True -unHandled _ = False - {- Backend contains implementations of hooks called around compilation of Agda code -} scalaBackend' :: Backend' ScalaFlags ScalaEnv ScalaModuleEnv ScalaModule ScalaDefinition scalaBackend' = Backend' @@ -82,7 +65,7 @@ scalaBackend' = Backend' , options = defaultOptions , commandLineFlags = scalaCmdLineFlags , isEnabled = const True - , preCompile = scalaPreCompile + , preCompile = return , compileDef = scalaCompileDef , postCompile = scalaPostCompile , preModule = scalaPreModule @@ -91,7 +74,6 @@ scalaBackend' = Backend' , mayEraseType = const $ return True } --- TODO get version from cabal definition, perhaps git hash too scalaBackendVersion :: Maybe String scalaBackendVersion = Just (showVersion version) @@ -103,48 +85,27 @@ defaultOptions = Options{ optOutDir = Nothing } -- TODO perhaps add option to use annotations from siddhartha-gadgil/ProvingGround library scalaCmdLineFlags :: [OptDescr (Flag ScalaFlags)] scalaCmdLineFlags = [ - Option ['o'] ["out-dir"] (ReqArg outdirOpt "DIR") + Option ['o'] ["out-dir"] (ReqArg outDirOpt "DIR") "Write output files to DIR. (default: project root)" ] -outdirOpt :: Monad m => FilePath -> Options -> m Options -outdirOpt dir opts = return opts{ optOutDir = Just dir } - -scalaPreCompile :: ScalaFlags -> TCM ScalaEnv -scalaPreCompile = return +outDirOpt :: Monad m => FilePath -> Options -> m Options +outDirOpt dir opts = return opts{ optOutDir = Just dir } --- TODO perhaps transform definitions here, ATM just pass it with extra information is it main --- Rust backend perform transformation to Higher IR --- Scheme pass as is (like here) scalaCompileDef :: ScalaEnv -> ScalaModuleEnv -> IsMain -> Definition -> TCM ScalaDefinition -scalaCompileDef _ _ isMain Defn{..} +scalaCompileDef _ _ isMain Defn{theDef = theDef, defName = defName} = withCurrentModule (qnameModule defName) - $ getUniqueCompilerPragma "AGDA2SCALA" defName >>= \case - Nothing -> return $ Unhandled "compile" "" - Just (CompilerPragma _ _) -> - return $ compileDefn defName theDef - -compileDefn :: QName -> Defn -> ScalaDefinition -compileDefn defName theDef = case theDef of - Datatype{dataCons = fields} -> - compileDataType defName fields - Function{funCompiled = funDef, funClauses = fc} -> - Unhandled "compileDefn Function" (show defName ++ "\n = \n" ++ show theDef) - RecordDefn(RecordData{_recFields = recFields, _recTel = recTel}) -> - Unhandled "compileDefn RecordDefn" (show defName ++ "\n = \n" ++ show theDef) - other -> - Unhandled "compileDefn other" (show defName ++ "\n = \n" ++ show theDef) - -compileDataType :: QName -> [QName] -> ScalaDefinition -compileDataType defName fields = SeAdt (showName defName) (map showName fields) --- Unhandled "compileDefn Datatype" (show defName ++ "\n = \n" ++ show theDef) - -showName :: QName -> ScalaName -showName = prettyShow . qnameName + $ getUniqueCompilerPragma "AGDA2SCALA" defName >>= handlePragma defName theDef + +handlePragma :: QName -> Defn -> Maybe CompilerPragma -> TCMT IO ScalaDefinition +handlePragma defName theDef pragma = case pragma of + Nothing -> return $ Unhandled "" "" + Just (CompilerPragma _ _) -> + return $ compileDefn defName theDef scalaPostCompile :: ScalaEnv -> IsMain @@ -161,7 +122,6 @@ scalaPreModule _ _ _ _ = do setScope . iInsideScope =<< curIF return $ Recompile () --- TODO implement translation here scalaPostModule :: ScalaEnv -> ScalaModuleEnv -> IsMain @@ -173,7 +133,7 @@ scalaPostModule env modEnv isMain mName cdefs = do compileLog $ "compiling " <> (outFile outDir) unless (all unHandled cdefs) $ liftIO $ writeFile (outFile outDir) - $ prettyPrintScalaExpr (compileModule mName cdefs) + $ printScalaExpr (compileModule mName cdefs) where fileName = scalaFileName mName dirName outDir = fromMaybe outDir (optOutDir env) @@ -191,37 +151,3 @@ moduleName n = prettyShow (Nel.last (moduleNameParts n)) compileLog :: String -> TCMT IO () compileLog msg = liftIO $ putStrLn msg - -prettyPrintScalaExpr :: ScalaDefinition -> String -prettyPrintScalaExpr def = case def of - (SePackage mName defs) -> - moduleHeader mName - <> defsSeparator <> ( - defsSeparator -- empty line before first definition in package - <> combineLines (map prettyPrintScalaExpr defs)) - <> defsSeparator - (SeAdt adtName adtCases) -> "sealed trait" <> exprSeparator <> adtName <> defsSeparator <> unlines (map (prettyPrintCaseObject adtName) adtCases) - -- TODO not sure why I get this - -- (Unhandled name payload) -> "TODO " ++ (show name) ++ " " ++ (show payload) - (Unhandled name payload) -> "" - -- XXX at the end there should be no Unhandled expression - -- other -> "unsupported prettyPrintScalaExpr " ++ (show other) - - -prettyPrintCaseObject :: ScalaName -> ScalaName -> String -prettyPrintCaseObject superName xs = "case object" <> exprSeparator <> xs <> exprSeparator <> "extends" <> exprSeparator <> superName - -moduleHeader :: String -> String -moduleHeader mName = "package" <> exprSeparator <> mName <> exprSeparator - -bracket :: String -> String -bracket str = "{\n" <> str <> "\n}" - -defsSeparator :: String -defsSeparator = "\n" - -exprSeparator :: String -exprSeparator = " " - -combineLines :: [String] -> String -combineLines xs = unlines (filter (not . null) xs) diff --git a/src/Agda/Compiler/Scala/PrintScalaExpr.hs b/src/Agda/Compiler/Scala/PrintScalaExpr.hs new file mode 100644 index 0000000..f30ab38 --- /dev/null +++ b/src/Agda/Compiler/Scala/PrintScalaExpr.hs @@ -0,0 +1,49 @@ +module Agda.Compiler.Scala.PrintScalaExpr ( printScalaExpr + , printCaseObject + , printSealedTrait + , printPackage + ) where + +import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaExpr(..) ) + +printScalaExpr :: ScalaExpr -> String +printScalaExpr def = case def of + (SePackage pName defs) -> + printPackage pName <> defsSeparator + <> ( + blankLine -- between package declaration and first definition + <> combineLines (map printScalaExpr defs) + ) + <> defsSeparator + (SeAdt adtName adtCases) -> + printSealedTrait adtName + <> defsSeparator + <> unlines (map (printCaseObject adtName) adtCases) + (Unhandled name payload) -> "" -- for development comment out this and uncomment below + -- (Unhandled name payload) -> "TODO " ++ (show name) ++ " " ++ (show payload) + -- other -> "unsupported printScalaExpr " ++ (show other) + +printSealedTrait :: ScalaName -> String +printSealedTrait adtName = "sealed trait" <> exprSeparator <> adtName + +printCaseObject :: ScalaName -> ScalaName -> String +printCaseObject superName caseName = + "case object" <> exprSeparator <> caseName <> exprSeparator <> "extends" <> exprSeparator <> superName + +printPackage :: ScalaName -> String +printPackage pName = "package" <> exprSeparator <> pName + +bracket :: String -> String +bracket str = "{\n" <> str <> "\n}" + +defsSeparator :: String +defsSeparator = "\n" + +blankLine :: String +blankLine = "\n" + +exprSeparator :: String +exprSeparator = " " + +combineLines :: [String] -> String +combineLines xs = unlines (filter (not . null) xs) diff --git a/src/Agda/Compiler/Scala/ScalaExpr.hs b/src/Agda/Compiler/Scala/ScalaExpr.hs new file mode 100644 index 0000000..fe96d09 --- /dev/null +++ b/src/Agda/Compiler/Scala/ScalaExpr.hs @@ -0,0 +1,18 @@ +module Agda.Compiler.Scala.ScalaExpr ( + ScalaName, + ScalaExpr(..), + unHandled + ) where + +type ScalaName = String + +{- Represent Scala language extracted from Agda compiler representation -} +data ScalaExpr + = SePackage ScalaName [ScalaExpr] + | SeAdt ScalaName [ScalaName] + | Unhandled ScalaName String + deriving ( Show ) + +unHandled :: ScalaExpr -> Bool +unHandled (Unhandled _ _) = True +unHandled _ = False diff --git a/test/Main.hs b/test/Main.hs new file mode 100644 index 0000000..710475d --- /dev/null +++ b/test/Main.hs @@ -0,0 +1,19 @@ +module Main (main) where + +import Test.HUnit ( + Test(..) + , failures + , runTestTT + ) +import System.Exit ( exitFailure , exitSuccess ) + +import PrintScalaExprTest ( printScalaTests ) +import ScalaBackendTest ( backendTests ) + +allTests :: Test +allTests = TestList [ backendTests , printScalaTests ] + +main :: IO () +main = do + result <- runTestTT allTests + if (failures result) > 0 then exitFailure else exitSuccess diff --git a/test/PrintScalaExprTest.hs b/test/PrintScalaExprTest.hs new file mode 100644 index 0000000..30acfc7 --- /dev/null +++ b/test/PrintScalaExprTest.hs @@ -0,0 +1,27 @@ +module PrintScalaExprTest ( printScalaTests ) where + +import Test.HUnit ( Test(..), assertEqual ) +import Agda.Compiler.Scala.PrintScalaExpr ( + printSealedTrait + , printCaseObject + , printPackage + ) + +testPrintCaseObject :: Test +testPrintCaseObject = TestCase + (assertEqual "printCaseObject" (printCaseObject "Color" "Light") "case object Light extends Color") + +testPrintSealedTrait :: Test +testPrintSealedTrait = TestCase + (assertEqual "printSealedTrait" (printSealedTrait "Color") "sealed trait Color") + +testPrintPackage :: Test +testPrintPackage = TestCase + (assertEqual "printPackage" (printPackage "adts") "package adts") + +printScalaTests :: Test +printScalaTests = TestList [ + TestLabel "printCaseObject" testPrintCaseObject + , TestLabel "printSealedTrait" testPrintSealedTrait + , TestLabel "printPackage" testPrintPackage + ] diff --git a/test/ScalaBackendTest.hs b/test/ScalaBackendTest.hs index 80d110b..28ff0be 100644 --- a/test/ScalaBackendTest.hs +++ b/test/ScalaBackendTest.hs @@ -1,22 +1,12 @@ -module Main (main) where +module ScalaBackendTest ( backendTests ) where -import Agda.Compiler.Scala.Backend ( scalaBackend', defaultOptions ) -import Test.HUnit ( - Test(..) - , assertEqual - , failures - , runTestTT) -import System.Exit ( exitFailure , exitSuccess ) +import Test.HUnit ( Test(..), assertEqual ) import Agda.Compiler.Backend ( isEnabled ) +import Agda.Compiler.Scala.Backend ( scalaBackend', defaultOptions ) testIsEnabled :: Test testIsEnabled = TestCase (assertEqual "isEnabled" (isEnabled scalaBackend' defaultOptions) True) -tests :: Test -tests = TestList [TestLabel "scalaBackend" testIsEnabled] - -main :: IO () -main = do - result <- runTestTT tests - if failures result > 0 then exitFailure else exitSuccess +backendTests :: Test +backendTests = TestList [TestLabel "scalaBackend" testIsEnabled]