Skip to content

Commit

Permalink
LinAlg: fzeropreserving unit triangular broadcast preserves structure (
Browse files Browse the repository at this point in the history
…#53648)

On master
```julia
julia> UU = UnitUpperTriangular(reshape([1:9;],3,3))
3×3 UnitUpperTriangular{Int64, Matrix{Int64}}:
 1  4  7
 ⋅  1  8
 ⋅  ⋅  1

julia> UU .* 2
3×3 Matrix{Int64}:
 2  8  14
 0  2  16
 0  0   2
```
This PR
```julia
julia> UU .* 2
3×3 UpperTriangular{Int64, Matrix{Int64}}:
 2  8  14
 ⋅  2  16
 ⋅  ⋅   2
```
This also improves performance, as the `materialize` skips the
structured zeros.
```julia
julia> UU = UnitUpperTriangular(rand(100, 100));

julia> @Btime $UU .* 2;
  12.788 μs (3 allocations: 78.20 KiB) # master
  7.821 μs (3 allocations: 78.20 KiB) # PR
```
  • Loading branch information
jishnub committed Mar 11, 2024
1 parent 6e3044d commit 1ba83f0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
7 changes: 6 additions & 1 deletion stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,13 @@ end

function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType}) where {T,ElType}
inds = axes(bc)
if isstructurepreserving(bc) || (fzeropreserving(bc) && !(T <: Union{SymTridiagonal,UnitLowerTriangular,UnitUpperTriangular}))
fzerobc = fzeropreserving(bc)
if isstructurepreserving(bc) || (fzerobc && !(T <: Union{SymTridiagonal,UnitLowerTriangular,UnitUpperTriangular}))
return structured_broadcast_alloc(bc, T, ElType, length(inds[1]))
elseif fzerobc && T <: UnitLowerTriangular
return similar(convert(Broadcasted{StructuredMatrixStyle{LowerTriangular}}, bc), ElType)
elseif fzerobc && T <: UnitUpperTriangular
return similar(convert(Broadcasted{StructuredMatrixStyle{UpperTriangular}}, bc), ElType)
end
return similar(convert(Broadcasted{DefaultArrayStyle{ndims(bc)}}, bc), ElType)
end
Expand Down
37 changes: 37 additions & 0 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,43 @@ using Test, LinearAlgebra
@test broadcast!(*, Z, X, Y) == broadcast(*, fX, fY)
end
end
UU = UnitUpperTriangular(rand(N,N))
UL = UnitLowerTriangular(rand(N,N))
unittriangulars = (UU, UL)
Ttris = typeof.((UpperTriangular(parent(UU)), LowerTriangular(parent(UU))))
funittriangulars = map(Array, unittriangulars)
for (X, fX, Ttri) in zip(unittriangulars, funittriangulars, Ttris)
@test (Q = broadcast(sin, X); typeof(Q) == Ttri && Q == broadcast(sin, fX))
@test broadcast!(sin, Z, X) == broadcast(sin, fX)
@test (Q = broadcast(cos, X); Q isa Matrix && Q == broadcast(cos, fX))
@test broadcast!(cos, Z, X) == broadcast(cos, fX)
@test (Q = broadcast(*, s, X); typeof(Q) == Ttri && Q == broadcast(*, s, fX))
@test broadcast!(*, Z, s, X) == broadcast(*, s, fX)
@test (Q = broadcast(+, fV, fA, X); Q isa Matrix && Q == broadcast(+, fV, fA, fX))
@test broadcast!(+, Z, fV, fA, X) == broadcast(+, fV, fA, fX)
@test (Q = broadcast(*, s, fV, fA, X); Q isa Matrix && Q == broadcast(*, s, fV, fA, fX))
@test broadcast!(*, Z, s, fV, fA, X) == broadcast(*, s, fV, fA, fX)

@test X .* 2.0 == X .* (2.0,) == fX .* 2.0
@test X .* 2.0 isa Ttri
@test X .* (2.0,) isa Ttri
@test isequal(X .* Inf, fX .* Inf)

two = 2
@test X .^ 2 == X .^ (2,) == fX .^ 2 == X .^ two
@test X .^ 2 isa typeof(X) # special cased, as isstructurepreserving
@test X .^ (2,) isa Ttri
@test X .^ two isa Ttri
@test X .^ 0 == fX .^ 0
@test X .^ -1 == fX .^ -1

for (Y, fY) in zip(unittriangulars, funittriangulars)
@test broadcast(+, X, Y) == broadcast(+, fX, fY)
@test broadcast!(+, Z, X, Y) == broadcast(+, fX, fY)
@test broadcast(*, X, Y) == broadcast(*, fX, fY)
@test broadcast!(*, Z, X, Y) == broadcast(*, fX, fY)
end
end
end

@testset "broadcast! where the destination is a structured matrix" begin
Expand Down

0 comments on commit 1ba83f0

Please sign in to comment.