diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 519618e4be..15fb997f75 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -162,9 +162,13 @@ function (c::Conv)(x::AbstractArray) b = reshape(c.bias, map(_->1, c.stride)..., :, 1) σ = NNlib.fast_act(c.σ, x) cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups) - σ.(conv(x, c.weight, cdims) .+ b) + _conv_bias_act(x, c.weight, cdims, b, σ) end +_conv_bias_act(x, w, cdims, b, σ) = NNlib.conv_bias_act(x, w, cdims, b, σ) +_conv_bias_act(x::CuArray, w::CuArray, cdims, b::Zeros, σ) = + _conv_bias_act(x, w, cdims, CUDA.zeros(size(b)...), σ) + _channels_in(l ::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups _channels_out(l::Conv) = size(l.weight, ndims(l.weight))