diff --git a/CHANGELOG.md b/CHANGELOG.md index 853d1dfbc7..341a305249 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * An overzealous floating-point simplification rule. +* Corrected AD of `x**y` where `x==0` (#2216). + ## [0.25.26] ### Fixed diff --git a/rts/c/scalar.h b/rts/c/scalar.h index 3ac9e0ae34..7435374a97 100644 --- a/rts/c/scalar.h +++ b/rts/c/scalar.h @@ -3048,4 +3048,16 @@ SCALAR_FUN_ATTR double fpconv_f64_f64(double x) { #endif +#define futrts_cond_f16(x,y,z) ((x) ? (y) : (z)) +#define futrts_cond_f32(x,y,z) ((x) ? (y) : (z)) +#define futrts_cond_f64(x,y,z) ((x) ? (y) : (z)) + +#define futrts_cond_i8(x,y,z) ((x) ? (y) : (z)) +#define futrts_cond_i16(x,y,z) ((x) ? (y) : (z)) +#define futrts_cond_i32(x,y,z) ((x) ? (y) : (z)) +#define futrts_cond_i64(x,y,z) ((x) ? (y) : (z)) + +#define futrts_cond_bool(x,y,z) ((x) ? (y) : (z)) +#define futrts_cond_unit(x,y,z) ((x) ? (y) : (z)) + // End of scalar.h. diff --git a/src/Futhark/AD/Derivatives.hs b/src/Futhark/AD/Derivatives.hs index 3596c60a96..c035f1ca11 100644 --- a/src/Futhark/AD/Derivatives.hs +++ b/src/Futhark/AD/Derivatives.hs @@ -8,7 +8,7 @@ where import Data.Bifunctor (bimap) import Futhark.Analysis.PrimExp.Convert -import Futhark.IR.Syntax.Core (Name, VName) +import Futhark.IR.Syntax.Core (Name, VName, nameToText) import Futhark.Util.IntegralExp import Prelude hiding (quot) @@ -134,7 +134,14 @@ pdBinOp (FDiv ft) a b = pdBinOp (FPow ft) a b = floatBinOp derivs derivs derivs ft a b where - derivs x y = (y * (x ** (y - 1)), (x ** y) * log x) + derivs x y = + ( y * (x ** (y - 1)), + TPrimExp $ + FunExp + (condFun (FloatType ft)) + [untyped (x .==. 0), fConst ft 0, untyped $ (x ** y) * log x] + (FloatType ft) + ) pdBinOp (FMax ft) a b = floatBinOp derivs derivs derivs ft a b where @@ -375,6 +382,21 @@ pdBuiltin "copysign32" [_x, y] = Just [untyped $ 1 * isF32 (UnOpExp (FSignum Float32) y), fConst Float32 0] pdBuiltin "copysign64" [_x, y] = Just [untyped $ 1 * isF64 (UnOpExp (FSignum Float64) y), fConst Float64 0] +pdBuiltin h [x, _y, _z] + | Just t <- isCondFun $ nameToText h = + Just + [ boolToT t false, + boolToT t $ isBool x, + boolToT t $ bNot $ isBool x + ] + where + boolToT t = case t of + IntType it -> + ConvOpExp (BToI it) . untyped + FloatType ft -> + ConvOpExp (SIToFP Int32 ft) . ConvOpExp (BToI Int32) . untyped + Bool -> untyped + Unit -> const $ ValueExp UnitValue -- More problematic derivatives follow below. pdBuiltin "umul_hi8" [x, y] = Just [y, x] pdBuiltin "umul_hi16" [x, y] = Just [y, x] diff --git a/src/Futhark/Analysis/PrimExp.hs b/src/Futhark/Analysis/PrimExp.hs index e70de707b8..3f9a0cb8c2 100644 --- a/src/Futhark/Analysis/PrimExp.hs +++ b/src/Futhark/Analysis/PrimExp.hs @@ -59,6 +59,7 @@ module Futhark.Analysis.PrimExp fMax16, fMax32, fMax64, + condExp, -- * Untyped construction (~*~), @@ -701,6 +702,17 @@ fMax32 x y = isF32 $ BinOpExp (FMax Float32) (untyped x) (untyped y) fMax64 :: TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v fMax64 x y = isF64 $ BinOpExp (FMax Float64) (untyped x) (untyped y) +-- | Conditional expression. +condExp :: TPrimExp Bool v -> TPrimExp t v -> TPrimExp t v -> TPrimExp t v +condExp x y z = + TPrimExp $ + FunExp + (condFun t) + [untyped x, untyped y, untyped z] + t + where + t = primExpType $ untyped y + -- | Convert result of some integer expression to have the same type -- as another, using sign extension. sExtAs :: diff --git a/src/Language/Futhark/Primitive.hs b/src/Language/Futhark/Primitive.hs index 593332a98a..a6740c75f2 100644 --- a/src/Language/Futhark/Primitive.hs +++ b/src/Language/Futhark/Primitive.hs @@ -92,6 +92,8 @@ module Language.Futhark.Primitive -- * Primitive functions primFuns, + condFun, + isCondFun, -- * Utility zeroIsh, @@ -128,6 +130,7 @@ import Data.Bits ) import Data.Fixed (mod') -- Weird location. import Data.Int (Int16, Int32, Int64, Int8) +import Data.List qualified as L import Data.Map qualified as M import Data.Text qualified as T import Data.Word (Word16, Word32, Word64, Word8) @@ -1181,6 +1184,17 @@ doubleToWord = G.runGet G.getWord64le . P.runPut . P.putDoublele wordToDouble :: Word64 -> Double wordToDouble = G.runGet G.getDoublele . P.runPut . P.putWord64le +-- | @condFun t@ is the name of the ternary conditional function that +-- accepts operands of type @[Bool, t, t]@, and returns either the +-- first or second @t@ based on the truth value of the @Bool@. +condFun :: PrimType -> T.Text +condFun t = "cond_" <> prettyText t + +-- | Is this the name of a condition function as per 'condFun', and +-- for which type? +isCondFun :: T.Text -> Maybe PrimType +isCondFun v = L.find (\t -> condFun t == v) allPrimTypes + -- | A mapping from names of primitive functions to their parameter -- types, their result type, and a function for evaluating them. primFuns :: @@ -1191,7 +1205,7 @@ primFuns :: [PrimValue] -> Maybe PrimValue ) primFuns = - M.fromList + M.fromList $ [ f16 "sqrt16" sqrt, f32 "sqrt32" sqrt, f64 "sqrt64" sqrt, @@ -1529,6 +1543,17 @@ primFuns = f32_3 "fma32" (\a b c -> a * b + c), f64_3 "fma64" (\a b c -> a * b + c) ] + <> [ ( condFun t, + ( [Bool, t, t], + t, + \case + [BoolValue b, tv, fv] -> + Just $ if b then tv else fv + _ -> Nothing + ) + ) + | t <- allPrimTypes + ] where i8 s f = (s, ([IntType Int8], IntType Int32, i8PrimFun f)) i16 s f = (s, ([IntType Int16], IntType Int32, i16PrimFun f)) diff --git a/tests/ad/pow0.fut b/tests/ad/pow0.fut new file mode 100644 index 0000000000..24c170c3d7 --- /dev/null +++ b/tests/ad/pow0.fut @@ -0,0 +1,13 @@ +-- The power function has a dangerous kink for x==0. + +-- == +-- entry: fwd +-- input { 0.0 1.0 } output { 1.0 } + +-- == +-- entry: rev +-- input { 0.0 1.0 } output { 1.0 0.0 } + +entry fwd x y : f64 = jvp (\(x, y) -> x ** y) (x, y) (1, 1) + +entry rev x y = vjp (\(x, y) -> x ** y) (x, y) 1f64 diff --git a/tests/issue2216.fut b/tests/issue2216.fut new file mode 100644 index 0000000000..3e6451a475 --- /dev/null +++ b/tests/issue2216.fut @@ -0,0 +1,15 @@ +-- == +-- input { [0.0, 0.0, 0.0] } +-- output { [[0.0, 0.0, 0.0], +-- [0.0, 2.0, 0.0], +-- [0.0, 0.0, 2.0]] } + +def identity_mat n = tabulate_2d n n (\i j -> f64.bool (i == j)) + +def Jacobi [n] (f: [n]f64 -> [n]f64) (x: [n]f64) : [n][n]f64 = + map (\i -> jvp f x i) (identity_mat n) + +def Hessian [n] (f: [n]f64 -> f64) (x: [n]f64) : [n][n]f64 = + Jacobi (\x -> vjp f x 1) x + +entry main (x: [3]f64) = Hessian (\x -> x[1] ** 2 + x[2] ** 2) x