Skip to content

Commit

Permalink
Fix #2216.
Browse files Browse the repository at this point in the history
This is by introducing a new builtin function for conditional values
(essentially the C ternary operator), which we then use in the
definition of the partial derivative of 2216.
  • Loading branch information
athas committed Feb 3, 2025
1 parent f00faf0 commit 6036a75
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

* An overzealous floating-point simplification rule.

* Corrected AD of `x**y` where `x==0` (#2216).

## [0.25.26]

### Fixed
Expand Down
12 changes: 12 additions & 0 deletions rts/c/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -3048,4 +3048,16 @@ SCALAR_FUN_ATTR double fpconv_f64_f64(double x) {

#endif

#define futrts_cond_f16(x,y,z) ((x) ? (y) : (z))
#define futrts_cond_f32(x,y,z) ((x) ? (y) : (z))
#define futrts_cond_f64(x,y,z) ((x) ? (y) : (z))

#define futrts_cond_i8(x,y,z) ((x) ? (y) : (z))
#define futrts_cond_i16(x,y,z) ((x) ? (y) : (z))
#define futrts_cond_i32(x,y,z) ((x) ? (y) : (z))
#define futrts_cond_i64(x,y,z) ((x) ? (y) : (z))

#define futrts_cond_bool(x,y,z) ((x) ? (y) : (z))
#define futrts_cond_unit(x,y,z) ((x) ? (y) : (z))

// End of scalar.h.
26 changes: 24 additions & 2 deletions src/Futhark/AD/Derivatives.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ where

import Data.Bifunctor (bimap)
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Syntax.Core (Name, VName)
import Futhark.IR.Syntax.Core (Name, VName, nameToText)
import Futhark.Util.IntegralExp
import Prelude hiding (quot)

Expand Down Expand Up @@ -134,7 +134,14 @@ pdBinOp (FDiv ft) a b =
pdBinOp (FPow ft) a b =
floatBinOp derivs derivs derivs ft a b
where
derivs x y = (y * (x ** (y - 1)), (x ** y) * log x)
derivs x y =
( y * (x ** (y - 1)),
TPrimExp $
FunExp
(condFun (FloatType ft))
[untyped (x .==. 0), fConst ft 0, untyped $ (x ** y) * log x]
(FloatType ft)
)
pdBinOp (FMax ft) a b =
floatBinOp derivs derivs derivs ft a b
where
Expand Down Expand Up @@ -375,6 +382,21 @@ pdBuiltin "copysign32" [_x, y] =
Just [untyped $ 1 * isF32 (UnOpExp (FSignum Float32) y), fConst Float32 0]
pdBuiltin "copysign64" [_x, y] =
Just [untyped $ 1 * isF64 (UnOpExp (FSignum Float64) y), fConst Float64 0]
pdBuiltin h [x, _y, _z]
| Just t <- isCondFun $ nameToText h =
Just
[ boolToT t false,
boolToT t $ isBool x,
boolToT t $ bNot $ isBool x
]
where
boolToT t = case t of
IntType it ->
ConvOpExp (BToI it) . untyped
FloatType ft ->
ConvOpExp (SIToFP Int32 ft) . ConvOpExp (BToI Int32) . untyped
Bool -> untyped
Unit -> const $ ValueExp UnitValue
-- More problematic derivatives follow below.
pdBuiltin "umul_hi8" [x, y] = Just [y, x]
pdBuiltin "umul_hi16" [x, y] = Just [y, x]
Expand Down
12 changes: 12 additions & 0 deletions src/Futhark/Analysis/PrimExp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ module Futhark.Analysis.PrimExp
fMax16,
fMax32,
fMax64,
condExp,

-- * Untyped construction
(~*~),
Expand Down Expand Up @@ -701,6 +702,17 @@ fMax32 x y = isF32 $ BinOpExp (FMax Float32) (untyped x) (untyped y)
fMax64 :: TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v
fMax64 x y = isF64 $ BinOpExp (FMax Float64) (untyped x) (untyped y)

-- | Conditional expression.
condExp :: TPrimExp Bool v -> TPrimExp t v -> TPrimExp t v -> TPrimExp t v
condExp x y z =
TPrimExp $
FunExp
(condFun t)
[untyped x, untyped y, untyped z]
t
where
t = primExpType $ untyped y

-- | Convert result of some integer expression to have the same type
-- as another, using sign extension.
sExtAs ::
Expand Down
27 changes: 26 additions & 1 deletion src/Language/Futhark/Primitive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ module Language.Futhark.Primitive

-- * Primitive functions
primFuns,
condFun,
isCondFun,

-- * Utility
zeroIsh,
Expand Down Expand Up @@ -128,6 +130,7 @@ import Data.Bits
)
import Data.Fixed (mod') -- Weird location.
import Data.Int (Int16, Int32, Int64, Int8)
import Data.List qualified as L
import Data.Map qualified as M
import Data.Text qualified as T
import Data.Word (Word16, Word32, Word64, Word8)
Expand Down Expand Up @@ -1181,6 +1184,17 @@ doubleToWord = G.runGet G.getWord64le . P.runPut . P.putDoublele
wordToDouble :: Word64 -> Double
wordToDouble = G.runGet G.getDoublele . P.runPut . P.putWord64le

-- | @condFun t@ is the name of the ternary conditional function that
-- accepts operands of type @[Bool, t, t]@, and returns either the
-- first or second @t@ based on the truth value of the @Bool@.
condFun :: PrimType -> T.Text
condFun t = "cond_" <> prettyText t

-- | Is this the name of a condition function as per 'condFun', and
-- for which type?
isCondFun :: T.Text -> Maybe PrimType
isCondFun v = L.find (\t -> condFun t == v) allPrimTypes

-- | A mapping from names of primitive functions to their parameter
-- types, their result type, and a function for evaluating them.
primFuns ::
Expand All @@ -1191,7 +1205,7 @@ primFuns ::
[PrimValue] -> Maybe PrimValue
)
primFuns =
M.fromList
M.fromList $
[ f16 "sqrt16" sqrt,
f32 "sqrt32" sqrt,
f64 "sqrt64" sqrt,
Expand Down Expand Up @@ -1529,6 +1543,17 @@ primFuns =
f32_3 "fma32" (\a b c -> a * b + c),
f64_3 "fma64" (\a b c -> a * b + c)
]
<> [ ( condFun t,
( [Bool, t, t],
t,
\case
[BoolValue b, tv, fv] ->
Just $ if b then tv else fv
_ -> Nothing
)
)
| t <- allPrimTypes
]
where
i8 s f = (s, ([IntType Int8], IntType Int32, i8PrimFun f))
i16 s f = (s, ([IntType Int16], IntType Int32, i16PrimFun f))
Expand Down
13 changes: 13 additions & 0 deletions tests/ad/pow0.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- The power function has a dangerous kink for x==0.

-- ==
-- entry: fwd
-- input { 0.0 1.0 } output { 1.0 }

-- ==
-- entry: rev
-- input { 0.0 1.0 } output { 1.0 0.0 }

entry fwd x y : f64 = jvp (\(x, y) -> x ** y) (x, y) (1, 1)

entry rev x y = vjp (\(x, y) -> x ** y) (x, y) 1f64
15 changes: 15 additions & 0 deletions tests/issue2216.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
-- ==
-- input { [0.0, 0.0, 0.0] }
-- output { [[0.0, 0.0, 0.0],
-- [0.0, 2.0, 0.0],
-- [0.0, 0.0, 2.0]] }

def identity_mat n = tabulate_2d n n (\i j -> f64.bool (i == j))

def Jacobi [n] (f: [n]f64 -> [n]f64) (x: [n]f64) : [n][n]f64 =
map (\i -> jvp f x i) (identity_mat n)

def Hessian [n] (f: [n]f64 -> f64) (x: [n]f64) : [n][n]f64 =
Jacobi (\x -> vjp f x 1) x

entry main (x: [3]f64) = Hessian (\x -> x[1] ** 2 + x[2] ** 2) x

0 comments on commit 6036a75

Please sign in to comment.