From 1a9925d0c23f137340b553bce6921bc2681c1f42 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sat, 25 Apr 2020 03:48:47 -0600 Subject: [PATCH] Special case empty covec-diagonal-vec product (#35557) Co-Authored-By: Takafumi Arakaki --- stdlib/LinearAlgebra/src/diagonal.jl | 17 +++++++++++------ stdlib/LinearAlgebra/test/diagonal.jl | 6 ++++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index d67c0caab0c0f..f3b4ac17eec78 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -659,14 +659,19 @@ end # disambiguation methods: * of Diagonal and Adj/Trans AbsVec *(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x))) *(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x))) -*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = - mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y)) -*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = - mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y)) -function dot(x::AbstractVector, D::Diagonal, y::AbstractVector) - mapreduce(t -> dot(t[1], t[2], t[3]), +, zip(x, D.diag, y)) +*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y) +*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y) +dot(x::AbstractVector, D::Diagonal, y::AbstractVector) = _mapreduce_prod(dot, x, D, y) + +function _mapreduce_prod(f, x, D::Diagonal, y) + if isempty(x) && isempty(D) && isempty(y) + return zero(Base.promote_op(f, eltype(x), eltype(D), eltype(y))) + else + return mapreduce(t -> f(t[1], t[2], t[3]), +, zip(x, D.diag, y)) + end end + function cholesky!(A::Diagonal, ::Val{false} = Val(false); check::Bool = true) info = 0 for (i, di) in enumerate(A.diag) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index fdb070dd70aab..98b2fca354fd6 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -711,4 +711,10 @@ end @test s1 == prod(sign, d) end +@testset "Empty (#35424)" begin + @test zeros(0)'*Diagonal(zeros(0))*zeros(0) === 0.0 + @test transpose(zeros(0))*Diagonal(zeros(Complex{Int}, 0))*zeros(0) === 0.0 + 0.0im + @test dot(zeros(Int32, 0), Diagonal(zeros(Int, 0)), zeros(Int16, 0)) === 0 +end + end # module TestDiagonal