diff --git a/src/lib/array.jl b/src/lib/array.jl index 8577852ad..489ee2fab 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -329,19 +329,6 @@ end end # Reductions -#= -@adjoint function sum(xs::AbstractArray; dims = :) - if dims === (:) - sum(xs), Δ -> (Fill(Δ, size(xs)),) - else - sum(xs, dims = dims), Δ -> (similar(xs) .= Δ,) - end -end -=# - -@adjoint function sum(xs::AbstractArray{Bool}; dims = :) - sum(xs, dims = dims), Δ -> (nothing,) -end function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray) return _pullback(cx, (f, xs) -> prod(f.(xs)), f, xs) diff --git a/test/lib/array.jl b/test/lib/array.jl index a3b73aff9..9afe43673 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -50,8 +50,8 @@ end @testset "dictionary comprehension" begin d = Dict(1 => 5, 2 => 6) g = gradient(d -> sum([v^2 for (_,v) in d]), d)[1] - @test g isa Dict{Int, Int} - @test g == Dict(1 => 10, 2 => 12) + @test g isa Dict{Int, Float64} + @test g == Dict(1 => 10.0, 2 => 12.0) w = randn(5)