Skip to content

Commit

Permalink
use transpose not T etc for generic_matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Dec 28, 2021
1 parent a4fe82f commit aad0522
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ for elty in (Float32,Float64)
end
@inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
alpha::Number, beta::Number) =
generic_matvecmul!(y, 'N', A, x, alpha, beta)
generic_matvecmul!(y, identity, A, x, alpha, beta)

function *(tA::Transpose{<:Any,<:StridedMatrix{T}}, x::StridedVector{S}) where {T<:BlasFloat,S}
TS = promote_op(matprod, T, S)
Expand All @@ -88,7 +88,7 @@ end
gemv!(y, 'T', tA.parent, x, alpha, beta)
@inline mul!(y::AbstractVector, tA::Transpose{<:Any,<:AbstractVecOrMat}, x::AbstractVector,
alpha::Number, beta::Number) =
generic_matvecmul!(y, 'T', tA.parent, x, alpha, beta)
generic_matvecmul!(y, transpose, tA.parent, x, alpha, beta)

function *(adjA::Adjoint{<:Any,<:StridedMatrix{T}}, x::StridedVector{S}) where {T<:BlasFloat,S}
TS = promote_op(matprod, T, S)
Expand All @@ -107,7 +107,7 @@ end
gemv!(y, 'C', adjA.parent, x, alpha, beta)
@inline mul!(y::AbstractVector, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, x::AbstractVector,
alpha::Number, beta::Number) =
generic_matvecmul!(y, 'C', adjA.parent, x, alpha, beta)
generic_matvecmul!(y, adjoint, adjA.parent, x, alpha, beta)

# Vector-Matrix multiplication
(*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')'
Expand Down Expand Up @@ -280,7 +280,7 @@ julia> C
"""
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat,
alpha::Number, beta::Number) =
generic_matmatmul!(C, 'N', 'N', A, B, alpha, beta)
generic_matmatmul!(C, identity, identity, A, B, alpha, beta)

"""
rmul!(A, B)
Expand Down Expand Up @@ -359,7 +359,7 @@ lmul!(A, B)
end
@inline mul!(C::AbstractMatrix, tA::Transpose{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat,
alpha::Number, beta::Number) =
generic_matmatmul!(C, 'T', 'N', tA.parent, B, alpha, beta)
generic_matmatmul!(C, transpose, identity, tA.parent, B, alpha, beta)

@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, tB::Transpose{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasFloat}
Expand All @@ -385,24 +385,24 @@ end
# collapsing the following two defs with C::AbstractVecOrMat yields ambiguities
@inline mul!(C::AbstractVector, A::AbstractVecOrMat, tB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
generic_matmatmul!(C, 'N', 'T', A, tB.parent, alpha, beta)
generic_matmatmul!(C, identity, transpose, A, tB.parent, alpha, beta)
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, tB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
generic_matmatmul!(C, 'N', 'T', A, tB.parent, alpha, beta)
generic_matmatmul!(C, identity, transpose, A, tB.parent, alpha, beta)

@inline mul!(C::StridedMatrix{T}, tA::Transpose{<:Any,<:StridedVecOrMat{T}}, tB::Transpose{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemm_wrapper!(C, 'T', 'T', tA.parent, tB.parent, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, tA::Transpose{<:Any,<:AbstractVecOrMat}, tB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
generic_matmatmul!(C, 'T', 'T', tA.parent, tB.parent, alpha, beta)
generic_matmatmul!(C, transpose, transpose, tA.parent, tB.parent, alpha, beta)

@inline mul!(C::StridedMatrix{T}, tA::Transpose{<:Any,<:StridedVecOrMat{T}}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemm_wrapper!(C, 'T', 'C', tA.parent, adjB.parent, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, tA::Transpose{<:Any,<:AbstractVecOrMat}, tB::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
generic_matmatmul!(C, 'T', 'C', tA.parent, tB.parent, alpha, beta)
generic_matmatmul!(C, transpose, adjoint, tA.parent, tB.parent, alpha, beta)

@inline mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Real, beta::Real) where {T<:BlasReal} =
Expand All @@ -418,7 +418,7 @@ end
end
@inline mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat,
alpha::Number, beta::Number) =
generic_matmatmul!(C, 'C', 'N', adjA.parent, B, alpha, beta)
generic_matmatmul!(C, adjoint, identity, adjA.parent, B, alpha, beta)

@inline mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{<:BlasReal}},
alpha::Number, beta::Number) where {T<:BlasFloat} =
Expand All @@ -434,21 +434,21 @@ end
end
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, adjB::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
generic_matmatmul!(C, 'N', 'C', A, adjB.parent, alpha, beta)
generic_matmatmul!(C, identity, adjoint, A, adjB.parent, alpha, beta)

@inline mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemm_wrapper!(C, 'C', 'C', adjA.parent, adjB.parent, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, adjB::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
generic_matmatmul!(C, 'C', 'C', adjA.parent, adjB.parent, alpha, beta)
generic_matmatmul!(C, adjoint, adjoint, adjA.parent, adjB.parent, alpha, beta)

@inline mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, tB::Transpose{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemm_wrapper!(C, 'C', 'T', adjA.parent, tB.parent, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, tB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
generic_matmatmul!(C, 'C', 'T', adjA.parent, tB.parent, alpha, beta)
generic_matmatmul!(C, adjoint, transpose, adjA.parent, tB.parent, alpha, beta)

# Supporting functions for matrix multiplication

Expand Down Expand Up @@ -625,6 +625,9 @@ end

lapack_size(t::AbstractChar, M::AbstractVecOrMat) = (size(M, t=='N' ? 1 : 2), size(M, t=='N' ? 2 : 1))
lapack_axes(t::AbstractChar, M::AbstractVecOrMat) = (axes(M, t=='N' ? 1 : 2), axes(M, t=='N' ? 2 : 1))
lapack_char(::typeof(identity)) = 'N'
lapack_char(::typeof(transpose)) = 'T'
lapack_char(::typeof(adjoint)) = 'C'

function copyto!(B::AbstractVecOrMat, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int})
if tM == 'N'
Expand All @@ -646,17 +649,18 @@ function copy_transpose!(B::AbstractMatrix, ir_dest::AbstractUnitRange{Int}, jr_
B
end

# This method helps e.g. OffsetArrays to dispatch on C only, unwrap & call mul! again:
generic_matvecmul!(C::AbstractVector, fA::Function, A::AbstractVecOrMat, B::AbstractVector, alpha=true, beta=false) =
generic_matvecmul!(C, lapack_char(fA), A, B, MulAddMul(alpha, beta))

# TODO: It will be faster for large matrices to convert to float,
# call BLAS, and convert back to required type.

# NOTE: the generic version is also called as fallback for
# strides != 1 cases

generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
alpha=true, beta=false) = generic_matvecmul!(C, tA, A, B, MulAddMul(alpha, beta))

function generic_matvecmul!(C::AbstractVector{R}, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul) where R
_add::MulAddMul = MulAddMul()) where R
if has_offset_axes(C, A, B)
return generic_offset_matvecmul!(C, tA, A, B, _add) # avoids linear indexing
end
Expand Down Expand Up @@ -795,10 +799,10 @@ end

const tilebufsize = 10800 # Approximately 32k/3

# This method without MulAddMul helps e.g. OffsetArrays to dispatch on C only, unwrap & call mul! again
generic_matmatmul!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
# This method helps e.g. OffsetArrays to dispatch on C only, unwrap & call mul! again
generic_matmatmul!(C::AbstractMatrix, fA::Function, fB::Function, A::AbstractMatrix, B::AbstractMatrix,
alpha=true, beta=false) =
generic_matmatmul!(C, tA, tB, A, B, MulAddMul(alpha, beta))
generic_matmatmul!(C, lapack_char(fA), lapack_char(fB), A, B, MulAddMul(alpha, beta))
function generic_matmatmul!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul)
has_offset_axes(A,B,C) && @info "generic_matmatmul! with offsets" axes(C)
Expand All @@ -819,9 +823,9 @@ function generic_matmatmul!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::Abs
end

# Nx1 matrices may be mixed up with with vectors, and cannot be 3x3
generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, alpha=true, beta=false) =
_generic_matmatmul!(C, tA, tB, A, B, MulAddMul(alpha, beta))
generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) =
generic_matmatmul!(C::AbstractVecOrMat, fA::Function, fB::Function, A::AbstractVecOrMat, B::AbstractVecOrMat, alpha=true, beta=false) =
_generic_matmatmul!(C, lapack_char(fA), lapack_char(fB), A, B, MulAddMul(alpha, beta))
generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul = MulAddMul()) =
_generic_matmatmul!(C, tA, tB, A, B, _add)

function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
Expand Down

0 comments on commit aad0522

Please sign in to comment.