Skip to content

Commit

Permalink
Better tracing of AD values in interpreter.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Nov 20, 2024
1 parent 10fcc49 commit f346bf4
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
* Sizes that go out of scope due to use of higher order functions will
now work in more cases by adding existentials. (#2193)

* Tracing inside AD operators with the interpreter now prints values
properly.

### Changed

## [0.25.24]
Expand Down
12 changes: 6 additions & 6 deletions src/Language/Futhark/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ asInteger :: Value -> Integer
asInteger (ValuePrim (SignedValue v)) = P.valueIntegral v
asInteger (ValuePrim (UnsignedValue v)) =
toInteger (P.valueIntegral (P.doZExt v Int64) :: Word64)
asInteger (ValueAD d v)
| P.IntValue v' <- AD.primitive $ AD.primal $ AD.Variable d v =
asInteger (ValueAD _ v)
| P.IntValue v' <- AD.varPrimal v =
P.valueIntegral v'
asInteger v = error $ "Unexpectedly not an integer: " <> show v

Expand All @@ -274,17 +274,17 @@ asInt = fromIntegral . asInteger

asSigned :: Value -> IntValue
asSigned (ValuePrim (SignedValue v)) = v
asSigned (ValueAD d v)
| P.IntValue v' <- AD.primitive $ AD.primal $ AD.Variable d v = v'
asSigned (ValueAD _ v)
| P.IntValue v' <- AD.varPrimal v = v'
asSigned v = error $ "Unexpectedly not a signed integer: " <> show v

asInt64 :: Value -> Int64
asInt64 = fromIntegral . asInteger

asBool :: Value -> Bool
asBool (ValuePrim (BoolValue x)) = x
asBool (ValueAD d v)
| P.BoolValue v' <- AD.primitive $ AD.primal $ AD.Variable d v = v'
asBool (ValueAD _ v)
| P.BoolValue v' <- AD.varPrimal v = v'
asBool v = error $ "Unexpectedly not a boolean: " <> show v

lookupInEnv ::
Expand Down
8 changes: 6 additions & 2 deletions src/Language/Futhark/Interpreter/AD.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ module Language.Futhark.Interpreter.AD
JVPValue (..),
doOp,
addFor,
primal,
tapePrimal,
primitive,
varPrimal,
deriveTape,
)
where
Expand Down Expand Up @@ -97,9 +97,13 @@ primal (Variable _ (JVP (JVPValue v _))) = primal v
primal (Constant v) = Constant v

primitive :: ADValue -> PrimValue
primitive v@(Variable _ _) = primitive $ primal v
primitive (Variable _ v) = varPrimal v
primitive (Constant v) = v

varPrimal :: ADVariable -> PrimValue
varPrimal (VJP (VJPValue t)) = primitive $ tapePrimal t
varPrimal (JVP (JVPValue v _)) = primitive $ primal v

-- Evaluates a PrimExp using doOp
evalPrimExp :: M.Map VName ADValue -> PrimExp VName -> Maybe ADValue
evalPrimExp m (LeafExp n _) = M.lookup n m
Expand Down
8 changes: 6 additions & 2 deletions src/Language/Futhark/Interpreter/Values.hs
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,15 @@ prettyValueWith pprPrim = pprPrec 0
pprPrec _ ValueAcc {} = "#<acc>"
pprPrec p (ValueSum _ n vs) =
parensIf (p > (0 :: Int)) $ "#" <> sep (pretty n : map (pprPrec 1) vs)
-- TODO: This could be prettier. Perhaps add pretty printing for ADVariable / ADValues
pprPrec _ (ValueAD d v) = pretty $ "d[" ++ show d ++ "]" ++ show v
pprPrec _ (ValueAD _ v) = pprPrim $ putV $ AD.varPrimal v
pprElem v@ValueArray {} = pprPrec 0 v
pprElem v = group $ pprPrec 0 v

putV (P.IntValue x) = SignedValue x
putV (P.FloatValue x) = FloatValue x
putV (P.BoolValue x) = BoolValue x
putV P.UnitValue = BoolValue True

-- | Prettyprint value.
prettyValue :: Value m -> Doc a
prettyValue = prettyValueWith pprPrim
Expand Down

0 comments on commit f346bf4

Please sign in to comment.