Skip to content

Commit

Permalink
Add field element type (#2659)
Browse files Browse the repository at this point in the history
* Closes #2571
* It is reasonable to finish this PR before tackling #2562, because the
field element type is the primary data type in Cairo.
* Depends on #2653

Checklist
---------

- [x] Add field type and operations to intermediate representations
(JuvixCore, JuvixTree, JuvixAsm, JuvixReg).
- [x] Add CLI option to choose field size.
- [x] Add frontend field builtins.
- [x] Automatic conversion of integer literals to field elements.
- [x] Juvix standard library support for fields.
- [x] Check if field size matches when loading a stored module.
- [x] Update the Cairo Assembly (CASM) interpreter to use the field type
instead of integer type.
- [x] Add field type to VampIR backend.
- [x] Tests

---------

Co-authored-by: Jan Mas Rovira <[email protected]>
  • Loading branch information
lukaszcz and janmasrovira authored Feb 27, 2024
1 parent a091a7f commit dcea0bb
Show file tree
Hide file tree
Showing 108 changed files with 1,196 additions and 160 deletions.
3 changes: 2 additions & 1 deletion app/Commands/Dev/Core/Eval.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ runCommand :: forall r. (Members '[EmbedIO, App] r) => CoreEvalOptions -> Sem r
runCommand opts = do
f :: Path Abs File <- fromAppPathFile b
s <- readFile f
gopts <- askGlobalOptions
case Core.runParser f defaultModuleId mempty s of
Left err -> exitJuvixError (JuvixError err)
Right (tab, Just node) -> do evalAndPrint opts tab node
Right (tab, Just node) -> do evalAndPrint gopts opts tab node
Right (_, Nothing) -> return ()
where
b :: AppPath File
Expand Down
4 changes: 2 additions & 2 deletions app/Commands/Dev/Core/FromConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ runCommand localOpts = do
newline

goEval :: Sem r ()
goEval = evalAndPrint localOpts tab' evalNode
goEval = evalAndPrint gopts localOpts tab' evalNode
where
evalNode :: Core.Node
| isJust (localOpts ^. coreFromConcreteSymbolName) = getNode' selInfo
| otherwise = getNode' mainInfo

goNormalize :: Sem r ()
goNormalize = normalizeAndPrint localOpts tab' evalNode
goNormalize = normalizeAndPrint gopts localOpts tab' evalNode
where
evalNode :: Core.Node
| isJust (localOpts ^. coreFromConcreteSymbolName) = getNode' selInfo
Expand Down
3 changes: 2 additions & 1 deletion app/Commands/Dev/Core/Normalize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ runCommand :: forall r. (Members '[EmbedIO, App] r) => CoreNormalizeOptions -> S
runCommand opts = do
f :: Path Abs File <- fromAppPathFile b
s <- readFile f
gopts <- askGlobalOptions
case Core.runParser f defaultModuleId mempty s of
Left err -> exitJuvixError (JuvixError err)
Right (tab, Just node) -> do normalizeAndPrint opts tab node
Right (tab, Just node) -> do normalizeAndPrint gopts opts tab node
Right (_, Nothing) -> return ()
where
b :: AppPath File
Expand Down
10 changes: 5 additions & 5 deletions app/Commands/Dev/Core/Read.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@ runCommand opts = do
embed (Scoper.scopeTrace tab')
unless (project opts ^. coreReadNoPrint) $ do
renderStdOut (Pretty.ppOut opts tab')
whenJust (tab' ^. Core.infoMain) $ \sym -> doEval tab' (fromJust $ tab' ^. Core.identContext . at sym)
whenJust (tab' ^. Core.infoMain) $ \sym -> doEval gopts tab' (fromJust $ tab' ^. Core.identContext . at sym)
where
doEval :: Core.InfoTable -> Core.Node -> Sem r ()
doEval tab' node =
doEval :: GlobalOptions -> Core.InfoTable -> Core.Node -> Sem r ()
doEval gopts tab' node =
if
| project opts ^. coreReadEval -> do
putStrLn "--------------------------------"
putStrLn "| Eval |"
putStrLn "--------------------------------"
Eval.evalAndPrint opts tab' node
Eval.evalAndPrint gopts opts tab' node
| project opts ^. coreReadNormalize -> do
putStrLn "--------------------------------"
putStrLn "| Normalize |"
putStrLn "--------------------------------"
Eval.normalizeAndPrint opts tab' node
Eval.normalizeAndPrint gopts opts tab' node
| otherwise -> return ()
sinputFile :: AppPath File
sinputFile = project opts ^. coreReadInputFile
5 changes: 3 additions & 2 deletions app/Commands/Dev/Core/Repl.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import Juvix.Compiler.Core.Pretty qualified as Core
import Juvix.Compiler.Core.Transformation.ComputeTypeInfo qualified as Core
import Juvix.Compiler.Core.Transformation.DisambiguateNames qualified as Core
import Juvix.Compiler.Core.Translation.FromSource qualified as Core
import Juvix.Data.Field
import Juvix.Extra.Paths

runCommand :: forall r. (Members '[EmbedIO, App] r) => CoreReplOptions -> Sem r ()
Expand Down Expand Up @@ -98,7 +99,7 @@ runRepl opts tab = do
where
replEval :: Bool -> Core.InfoTable -> Core.Node -> Sem r ()
replEval noIO tab' node = do
r <- Core.doEval noIO defaultLoc tab' node
r <- Core.doEval Nothing noIO defaultLoc tab' node
case r of
Left err -> do
printJuvixError (JuvixError err)
Expand All @@ -115,7 +116,7 @@ runRepl opts tab = do
replNormalize :: Core.InfoTable -> Core.Node -> Sem r ()
replNormalize tab' node =
let md' = Core.moduleFromInfoTable tab'
node' = normalize md' node
node' = normalize (maximum allowedFieldSizes) md' node
in if
| Info.member Info.kNoDisplayInfo (Core.getInfo node') ->
runRepl opts tab'
Expand Down
2 changes: 1 addition & 1 deletion app/Commands/Dev/Core/Strip.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ runCommand opts = do
run $
runReader (project gopts) $
runError @JuvixError (Core.toStripped' Core.Identity (Core.moduleFromInfoTable tab) :: Sem '[Error JuvixError, Reader Core.CoreOptions] Core.Module)
tab' <- getRight $ mapLeft JuvixError $ mapRight (Stripped.fromCore . Core.computeCombinedInfoTable) r
tab' <- getRight $ mapLeft JuvixError $ mapRight (Stripped.fromCore (project gopts ^. Core.optFieldSize) . Core.computeCombinedInfoTable) r
unless (project opts ^. coreStripNoPrint) $ do
renderStdOut (Core.ppOut opts tab')
where
Expand Down
2 changes: 1 addition & 1 deletion app/Commands/Eval.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ runCommand opts@EvalOptions {..} = do
| otherwise -> getNode tab (mainInfo tab)
case mevalNode of
Just evalNode ->
Eval.evalAndPrint opts tab evalNode
Eval.evalAndPrint gopts opts tab evalNode
Nothing -> do
let name = fromMaybe Str.main _evalSymbolName
printFailureExit ("function not found: " <> name)
Expand Down
9 changes: 5 additions & 4 deletions app/Commands/Repl.hs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ replCommand opts input_ = catchAll $ do

eval :: Core.Node -> Repl Core.Node
eval n = do
gopts <- State.gets (^. replStateGlobalOptions)
ep <- getReplEntryPointFromPrepath (mkPrepath (toFilePath P.replPath))
let shouldDisambiguate :: Bool
shouldDisambiguate = not (opts ^. replNoDisambiguate)
Expand All @@ -182,12 +183,12 @@ replCommand opts input_ = catchAll $ do
. runState artif
. runTransformations shouldDisambiguate (opts ^. replTransformations)
$ n
liftIO (doEvalIO' artif' n') >>= replFromEither
liftIO (doEvalIO' (project gopts ^. Core.optFieldSize) artif' n') >>= replFromEither

doEvalIO' :: Artifacts -> Core.Node -> IO (Either JuvixError Core.Node)
doEvalIO' artif' n =
doEvalIO' :: Natural -> Artifacts -> Core.Node -> IO (Either JuvixError Core.Node)
doEvalIO' fsize artif' n =
mapLeft (JuvixError @Core.CoreError)
<$> Core.doEvalIO False replDefaultLoc (Core.computeCombinedInfoTable $ artif' ^. artifactCoreModule) n
<$> Core.doEvalIO (Just fsize) False replDefaultLoc (Core.computeCombinedInfoTable $ artif' ^. artifactCoreModule) n

compileString :: Repl (Maybe Core.Node)
compileString = do
Expand Down
18 changes: 18 additions & 0 deletions app/CommonOptions.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Data.List.NonEmpty qualified as NonEmpty
import Juvix.Compiler.Core.Data.TransformationId.Parser qualified as Core
import Juvix.Compiler.Reg.Data.TransformationId.Parser qualified as Reg
import Juvix.Compiler.Tree.Data.TransformationId.Parser qualified as Tree
import Juvix.Data.Field
import Juvix.Data.FileExt
import Juvix.Prelude
import Options.Applicative
Expand Down Expand Up @@ -111,6 +112,23 @@ naturalNumberOpt = eitherReader aux
aux :: String -> Either String Word
aux s = maybe (Left $ s <> " is not a nonnegative number") Right (readMaybe s :: Maybe Word)

fieldSizeOpt :: ReadM (Maybe Natural)
fieldSizeOpt = eitherReader aux
where
aux :: String -> Either String (Maybe Natural)
aux s = case s of
"cairo" -> Right $ Just cairoFieldSize
"small" -> Right $ Just smallFieldSize
_ ->
mapRight Just $
either Left checkAllowed $
maybe (Left $ s <> " is not a valid field size") Right (readMaybe s :: Maybe Natural)

checkAllowed :: Natural -> Either String Natural
checkAllowed n
| n `elem` allowedFieldSizes = Right n
| otherwise = Left $ Prelude.show n <> " is not a recognized field size"

extCompleter :: FileExt -> Completer
extCompleter ext = mkCompleter $ \word -> do
let cmd = unwords ["compgen", "-o", "plusdirs", "-f", "-X", "!*" <> Prelude.show ext, "--", requote word]
Expand Down
18 changes: 10 additions & 8 deletions app/Evaluator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ data EvalOptions = EvalOptions
makeLenses ''EvalOptions

evalAndPrint ::
forall r a.
(Members '[EmbedIO, App] r, CanonicalProjection a EvalOptions, CanonicalProjection a Core.Options) =>
forall r a b.
(Members '[EmbedIO, App] r, CanonicalProjection a EvalOptions, CanonicalProjection b Core.CoreOptions, CanonicalProjection a Core.Options) =>
b ->
a ->
Core.InfoTable ->
Core.Node ->
Sem r ()
evalAndPrint opts tab node = do
evalAndPrint gopts opts tab node = do
loc <- defaultLoc
r <- Core.doEval (project opts ^. evalNoIO) loc tab node
r <- Core.doEval (Just $ project gopts ^. Core.optFieldSize) (project opts ^. evalNoIO) loc tab node
case r of
Left err -> exitJuvixError (JuvixError err)
Right node'
Expand All @@ -50,14 +51,15 @@ evalAndPrint opts tab node = do
f = project opts ^. evalInputFile

normalizeAndPrint ::
forall r a.
(Members '[EmbedIO, App] r, CanonicalProjection a EvalOptions, CanonicalProjection a Core.Options) =>
forall r a b.
(Members '[EmbedIO, App] r, CanonicalProjection a EvalOptions, CanonicalProjection b Core.CoreOptions, CanonicalProjection a Core.Options) =>
b ->
a ->
Core.InfoTable ->
Core.Node ->
Sem r ()
normalizeAndPrint opts tab node =
let node' = normalize (Core.moduleFromInfoTable tab) node
normalizeAndPrint gopts opts tab node =
let node' = normalize (project gopts ^. Core.optFieldSize) (Core.moduleFromInfoTable tab) node
in if
| Info.member Info.kNoDisplayInfo (Core.getInfo node') ->
return ()
Expand Down
17 changes: 15 additions & 2 deletions app/GlobalOptions.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Juvix.Compiler.Pipeline
import Juvix.Compiler.Pipeline.Root
import Juvix.Data.Effect.TaggedLock
import Juvix.Data.Error.GenericError qualified as E
import Juvix.Data.Field

data GlobalOptions = GlobalOptions
{ _globalNoColors :: Bool,
Expand All @@ -23,6 +24,7 @@ data GlobalOptions = GlobalOptions
_globalNoCoverage :: Bool,
_globalNoStdlib :: Bool,
_globalUnrollLimit :: Int,
_globalFieldSize :: Maybe Natural,
_globalOffline :: Bool
}
deriving stock (Eq, Show)
Expand All @@ -46,6 +48,7 @@ instance CanonicalProjection GlobalOptions Core.CoreOptions where
Core.CoreOptions
{ Core._optCheckCoverage = not _globalNoCoverage,
Core._optUnrollLimit = _globalUnrollLimit,
Core._optFieldSize = fromMaybe defaultFieldSize _globalFieldSize,
Core._optOptimizationLevel = defaultOptimizationLevel,
Core._optInliningDepth = defaultInliningDepth
}
Expand All @@ -63,6 +66,7 @@ defaultGlobalOptions =
_globalNoCoverage = False,
_globalNoStdlib = False,
_globalUnrollLimit = defaultUnrollLimit,
_globalFieldSize = Nothing,
_globalOffline = False
}

Expand Down Expand Up @@ -112,6 +116,13 @@ parseGlobalFlags = do
( long "no-stdlib"
<> help "Do not use the standard library"
)
_globalFieldSize <-
option
fieldSizeOpt
( long "field-size"
<> value Nothing
<> help "Field type size [cairo,small,11] (default: small)"
)
_globalUnrollLimit <-
option
(fromIntegral <$> naturalNumberOpt)
Expand Down Expand Up @@ -162,7 +173,8 @@ entryPointFromGlobalOptions root mainFile opts = do
_entryPointUnrollLimit = opts ^. globalUnrollLimit,
_entryPointGenericOptions = project opts,
_entryPointBuildDir = maybe (def ^. entryPointBuildDir) (CustomBuildDir . Abs) mabsBuildDir,
_entryPointOffline = opts ^. globalOffline
_entryPointOffline = opts ^. globalOffline,
_entryPointFieldSize = fromMaybe defaultFieldSize $ opts ^. globalFieldSize
}
where
optBuildDir :: Maybe (Prepath Dir)
Expand All @@ -184,7 +196,8 @@ entryPointFromGlobalOptionsNoFile root opts = do
_entryPointUnrollLimit = opts ^. globalUnrollLimit,
_entryPointGenericOptions = project opts,
_entryPointBuildDir = maybe (def ^. entryPointBuildDir) (CustomBuildDir . Abs) mabsBuildDir,
_entryPointOffline = opts ^. globalOffline
_entryPointOffline = opts ^. globalOffline,
_entryPointFieldSize = fromMaybe defaultFieldSize $ opts ^. globalFieldSize
}
where
optBuildDir :: Maybe (Prepath Dir)
Expand Down
10 changes: 2 additions & 8 deletions examples/milestone/Bank/Bank.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ Address : Type := Nat;

bankAddress : Address := 1234;

--- Some field type.
axiom Field : Type;

--- Equality test for ;Field;.
axiom eqField : Field -> Field -> Bool;

module Token;
type Token :=
--- Arguments are: owner, gates, amount.
Expand Down Expand Up @@ -47,7 +41,7 @@ module Balances;
| f n nil := (f, n) :: nil
| f n ((b, bn) :: bs) :=
if
(eqField f b)
(f == b)
((b, bn + n) :: bs)
((b, bn) :: increment f n bs);

Expand All @@ -58,7 +52,7 @@ module Balances;
| _ _ nil := nil
| f n ((b, bn) :: bs) :=
if
(eqField f b)
(f == b)
((b, sub bn n) :: bs)
((b, bn) :: decrement f n bs);

Expand Down
2 changes: 0 additions & 2 deletions include/package-base/Juvix/Builtin/V1/Nat.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,5 @@ naturalNatI : Natural Nat :=
mkNatural@{
+ := (Nat.+);
* := (Nat.*);
div := Nat.div;
mod := Nat.mod;
fromNat (x : Nat) : Nat := x
};
2 changes: 0 additions & 2 deletions include/package-base/Juvix/Builtin/V1/Trait/Natural.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ type Natural A :=
+ : A -> A -> A;
syntax operator * multiplicative;
* : A -> A -> A;
div : A -> A -> A;
mod : A -> A -> A;
builtin from-nat
fromNat : Nat -> A
};
Expand Down
13 changes: 13 additions & 0 deletions runtime/src/vampir/stdlib.pir
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ def mul (x, e1) (y, e2) = {
(x * y, e1 * e2 * range_check (x * y))
};

def fadd (x, e1) (y, e2) = {
(x + y, e1 * e2)
};
def fsub (x, e1) (y, e2) = {
(x - y, e1 * e2)
};
def fmul (x, e1) (y, e2) = {
(x * y, e1 * e2)
};
def fdiv (x, e1) (y, e2) = {
(x / y, e1 * e2)
};

def equal (x, e1) (y, e2) = (isZero (x - y), e1 * e2);

def if (b, e) (x, e1) (y, e2) = (b * x + (1 - b) * y, e * (b * e1 + (1 - b) * e2));
Expand Down
5 changes: 5 additions & 0 deletions runtime/src/vampir/stdlib_unsafe.pir
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def add x y = x + y;
def sub x y = x - y;
def mul x y = x * y;

def fadd x y = x + y;
def fsub x y = x - y;
def fmul x y = x * y;
def fdiv x y = x / y;

def equal x y = isZero (x - y);

def if b x y = b * x + (1 - b) * y;
Expand Down
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Asm/Extra/Memory.hs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ getDirectRefType dr mem = case dr of
getConstantType :: Constant -> Type
getConstantType = \case
ConstInt {} -> mkTypeInteger
ConstField {} -> TyField
ConstBool {} -> mkTypeBool
ConstString {} -> TyString
ConstUnit -> TyUnit
Expand Down
Loading

0 comments on commit dcea0bb

Please sign in to comment.