Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offset matrix multiplication via generic_matmatmul! #270

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "1.10.8"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
Adapt = "2, 3"
Expand All @@ -24,9 +25,8 @@ DistributedArrays = "aaf54ef3-cdf8-58ed-94cc-d582ad619b94"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "CatIndices", "DistributedArrays", "DelimitedFiles", "Documenter", "Test", "LinearAlgebra", "EllipsisNotation", "StaticArrays", "FillArrays"]
test = ["Aqua", "CatIndices", "DistributedArrays", "DelimitedFiles", "Documenter", "Test", "EllipsisNotation", "StaticArrays", "FillArrays"]
3 changes: 3 additions & 0 deletions src/OffsetArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ if isdefined(Base, :IdentityUnitRange)
no_offset_view(a::Base.Slice) = Base.Slice(UnitRange(a))
no_offset_view(S::SubArray) = view(parent(S), map(no_offset_view, parentindices(S))...)
end
no_offset_view(A::PermutedDimsArray{T,N,perm,iperm,P}) where {T,N,perm,iperm,P} = PermutedDimsArray(no_offset_view(parent(A)), perm)
no_offset_view(a::Array) = a
no_offset_view(i::Number) = i
no_offset_view(A::AbstractArray) = _no_offset_view(axes(A), A)
Expand Down Expand Up @@ -853,6 +854,8 @@ end
import Adapt
Adapt.adapt_structure(to, O::OffsetArray) = parent_call(x -> Adapt.adapt(to, x), O)

include("linearalgebra.jl")

if Base.VERSION >= v"1.4.2"
include("precompile.jl")
_precompile_()
Expand Down
113 changes: 113 additions & 0 deletions src/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
using LinearAlgebra
using LinearAlgebra: MulAddMul, mul!
lapack_axes(t::AbstractChar, M::AbstractVecOrMat) = (axes(M, t=='N' ? 1 : 2), axes(M, t=='N' ? 2 : 1))

# The signatures of these differs from LinearAlgebra's *only* on C.
LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul) = unwrap_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector,
alpha, beta) = unwrap_matvecmul!(C, tA, A, B, alpha, beta)

function unwrap_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector,
alpha, beta)

mB_axis = Base.axes1(B)
mA_axis, nA_axis = lapack_axes(tA, A)

if mB_axis != nA_axis
throw(DimensionMismatch("mul! can't contract axis $(UnitRange(nA_axis)) from A with axes(B) == ($(UnitRange(mB_axis)),)"))
end
if mA_axis != Base.axes1(C)
throw(DimensionMismatch("mul! got axes(C) == ($(UnitRange(Base.axes1(C))),), expected $(UnitRange(mA_axis))"))
end

C1 = no_offset_view(C)
A1 = no_offset_view(A)
B1 = no_offset_view(B)

if tA == 'T'
mul!(C1, transpose(A1), B1, alpha, beta)
elseif tA == 'C'
mul!(C1, adjoint(A1), B1, alpha, beta)
elseif tA == 'N'
mul!(C1, A1, B1, alpha, beta)
else
error("illegal char")
end

C
end

# The signatures of these differs from LinearAlgebra's *only* on C:
# Old path
LinearAlgebra.generic_matmatmul!(C::OffsetMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat,
_add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)

# New path
LinearAlgebra.generic_matmatmul!(C::OffsetMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
alpha, beta) = unwrap_matmatmul!(C, tA, tB, A, B, alpha, beta)
LinearAlgebra.generic_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat,
alpha, beta) = unwrap_matmatmul!(C, tA, tB, A, B, alpha, beta)

# Worker
@inline function unwrap_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat,
alpha, beta)

mA_axis, nA_axis = lapack_axes(tA, A)
mB_axis, nB_axis = lapack_axes(tB, B)

if nA_axis != mB_axis
throw(DimensionMismatch("mul! can't contract axis $(UnitRange(nA_axis)) from A with $(UnitRange(mB_axis)) from B"))
elseif mA_axis != axes(C,1)
throw(DimensionMismatch("mul! got axes(C,1) == $(UnitRange(axes(C,1))), expected $(UnitRange(mA_axis)) from A"))
elseif nB_axis != axes(C,2)
throw(DimensionMismatch("mul! got axes(C,2) == $(UnitRange(axes(C,2))), expected $(UnitRange(nB_axis)) from B"))
end

C1 = no_offset_view(C)
A1 = no_offset_view(A)
B1 = no_offset_view(B)

if tA == 'N'
if tB == 'N'
mul!(C1, A1, B1, alpha, beta)
elseif tB == 'T'
mul!(C1, A1, transpose(B1), alpha, beta)
elseif tB == 'C'
mul!(C1, A1, adjoint(B1), alpha, beta)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bigger change to LinearAlgebra.generic_matmatmul! would be to make it keep adjoint longer, before introducing 'C', etc. Then this nest of conditions could be removed.

It seems an odd design that MulAddMul pushes α,β into the type domain (partly) at the same time that it moves transpose/adjoint to values from types. Perhaps JuliaLang/julia#43552 could fix both at the same time.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

816b928 (and JuliaLang/julia@aad0522) changes this. It removes the extra allocation above.

The downside is that these methods will not be called for mul! with a version of Julia older than JuliaLang/julia#43552 .

else
error("illegal char")
end
elseif tA == 'T'
if tB == 'N'
mul!(C1, transpose(A1), B1, alpha, beta)
elseif tB == 'T'
mul!(C1, transpose(A1), transpose(B1), alpha, beta)
elseif tB == 'C'
mul!(C1, transpose(A1), adjoint(B1), alpha, beta)
else
error("illegal char")
end
elseif tA == 'C'
if tB == 'N'
mul!(C1, adjoint(A1), B1, alpha, beta)
elseif tB == 'T'
mul!(C1, adjoint(A1), transpose(B1), alpha, beta)
elseif tB == 'C'
mul!(C1, adjoint(A1), adjoint(B1), alpha, beta)
else
error("illegal char")
end
else
error("illegal char")
end

C
end

no_offset_view(A::Adjoint) = Adjoint(no_offset_view(parent(A)))
no_offset_view(A::Transpose) = Transpose(no_offset_view(parent(A)))
no_offset_view(D::Diagonal) = Diagonal(no_offset_view(parent(D)))