Skip to content

Commit

Permalink
adapt to delayed MulAddMul
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Dec 26, 2021
1 parent 12a9922 commit 6878a14
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 19 deletions.
1 change: 1 addition & 0 deletions src/OffsetArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ if isdefined(Base, :IdentityUnitRange)
no_offset_view(a::Base.Slice) = Base.Slice(UnitRange(a))
no_offset_view(S::SubArray) = view(parent(S), map(no_offset_view, parentindices(S))...)
end
no_offset_view(A::PermutedDimsArray{T,N,perm,iperm,P}) where {T,N,perm,iperm,P} = PermutedDimsArray(no_offset_view(parent(A)), perm)
no_offset_view(a::Array) = a
no_offset_view(i::Number) = i
no_offset_view(A::AbstractArray) = _no_offset_view(axes(A), A)
Expand Down
57 changes: 38 additions & 19 deletions src/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@ using LinearAlgebra
using LinearAlgebra: MulAddMul, mul!
lapack_axes(t::AbstractChar, M::AbstractVecOrMat) = (axes(M, t=='N' ? 1 : 2), axes(M, t=='N' ? 2 : 1))

# The signature of this differs from LinearAlgebra's only on C
function LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul = MulAddMul())
# The signatures of these differs from LinearAlgebra's *only* on C.
LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul) = unwrap_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector,
alpha, beta) = unwrap_matvecmul!(C, tA, A, B, alpha, beta)

function unwrap_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrMat, B::AbstractVector,
alpha, beta)

mB_axis = Base.axes1(B)
mA_axis, nA_axis = lapack_axes(tA, A)
Expand All @@ -21,25 +26,34 @@ function LinearAlgebra.generic_matvecmul!(C::OffsetVector, tA, A::AbstractVecOrM
B1 = no_offset_view(B)

if tA == 'T'
mul!(C1, transpose(A1), B1, _add.alpha, _add.beta)
mul!(C1, transpose(A1), B1, alpha, beta)
elseif tA == 'C'
mul!(C1, adjoint(A1), B1, _add.alpha, _add.beta)
mul!(C1, adjoint(A1), B1, alpha, beta)
elseif tA == 'N'
mul!(C1, A1, B1, _add.alpha, _add.beta)
mul!(C1, A1, B1, alpha, beta)
else
error("illegal char")
end

C
end

# The signatures of these differs from LinearAlgebra's *only* on C:
# Old path
LinearAlgebra.generic_matmatmul!(C::OffsetMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat,
_add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)

# New path
LinearAlgebra.generic_matmatmul!(C::OffsetMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add)
alpha, beta) = unwrap_matmatmul!(C, tA, tB, A, B, alpha, beta)
LinearAlgebra.generic_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat,
_add::MulAddMul) = unwrap_matmatmul!(C, tA, tB, A, B, _add)
alpha, beta) = unwrap_matmatmul!(C, tA, tB, A, B, alpha, beta)

function unwrap_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat,
_add::MulAddMul)
# Worker
@inline function unwrap_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat,
alpha, beta)

mA_axis, nA_axis = lapack_axes(tA, A)
mB_axis, nB_axis = lapack_axes(tB, B)
Expand All @@ -58,31 +72,31 @@ function unwrap_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::Abst

if tA == 'N'
if tB == 'N'
mul!(C1, A1, B1, _add.alpha, _add.beta)
mul!(C1, A1, B1, alpha, beta)
elseif tB == 'T'
mul!(C1, A1, transpose(B1), _add.alpha, _add.beta)
mul!(C1, A1, transpose(B1), alpha, beta)
elseif tB == 'C'
mul!(C1, A1, adjoint(B1), _add.alpha, _add.beta)
mul!(C1, A1, adjoint(B1), alpha, beta)
else
error("illegal char")
end
elseif tA == 'T'
if tB == 'N'
mul!(C1, transpose(A1), B1, _add.alpha, _add.beta)
mul!(C1, transpose(A1), B1, alpha, beta)
elseif tB == 'T'
mul!(C1, transpose(A1), transpose(B1), _add.alpha, _add.beta)
mul!(C1, transpose(A1), transpose(B1), alpha, beta)
elseif tB == 'C'
mul!(C1, transpose(A1), adjoint(B1), _add.alpha, _add.beta)
mul!(C1, transpose(A1), adjoint(B1), alpha, beta)
else
error("illegal char")
end
elseif tA == 'C'
if tB == 'N'
mul!(C1, adjoint(A1), B1, _add.alpha, _add.beta)
mul!(C1, adjoint(A1), B1, alpha, beta)
elseif tB == 'T'
mul!(C1, adjoint(A1), transpose(B1), _add.alpha, _add.beta)
mul!(C1, adjoint(A1), transpose(B1), alpha, beta)
elseif tB == 'C'
mul!(C1, adjoint(A1), adjoint(B1), _add.alpha, _add.beta)
mul!(C1, adjoint(A1), adjoint(B1), alpha, beta)
else
error("illegal char")
end
Expand All @@ -92,3 +106,8 @@ function unwrap_matmatmul!(C::Union{OffsetMatrix, OffsetVector}, tA, tB, A::Abst

C
end

no_offset_view(A::Adjoint) = Adjoint(no_offset_view(parent(A)))
no_offset_view(A::Transpose) = Transpose(no_offset_view(parent(A)))
no_offset_view(D::Diagonal) = Diagonal(no_offset_view(parent(D)))

0 comments on commit 6878a14

Please sign in to comment.