diff --git a/Project.toml b/Project.toml index 595c7b5d..9886a19e 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Aqua = "0.6" ArrayLayouts = "1" Documenter = "0.27" -FillArrays = "1.3" +FillArrays = "1.0.1" PrecompileTools = "1" julia = "1.6" diff --git a/src/interfaceimpl.jl b/src/interfaceimpl.jl index bd90628d..c93aab6a 100644 --- a/src/interfaceimpl.jl +++ b/src/interfaceimpl.jl @@ -67,6 +67,27 @@ function rot180(A::AbstractBandedMatrix) _BandedMatrix(bandeddata(A)[end:-1:1,end:-1:1], m, u+sh,l-sh) end -for MT in (:Diagonal, :SymTridiagonal, :Tridiagonal, :Bidiagonal) - @eval getindex(D::$MT, b::Band) = diag(D, b.i) +function getindex(D::Diagonal{T,V}, b::Band) where {T,V} + iszero(b.i) && return copy(D.diag) + convert(V, Zeros{T}(size(D,1)-abs(b.i))) +end + +function getindex(D::Tridiagonal{T,V}, b::Band) where {T,V} + b.i == -1 && return copy(D.dl) + iszero(b.i) && return copy(D.d) + b.i == 1 && return copy(D.du) + convert(V, Zeros{T}(size(D,1)-abs(b.i))) +end + +function getindex(D::SymTridiagonal{T,V}, b::Band) where {T,V} + iszero(b.i) && return copy(D.dv) + abs(b.i) == 1 && return copy(D.ev) + convert(V, Zeros{T}(size(D,1)-abs(b.i))) +end + +function getindex(D::Bidiagonal{T,V}, b::Band) where {T,V} + iszero(b.i) && return copy(D.dv) + D.uplo == 'L' && b.i == -1 && return copy(D.ev) + D.uplo == 'U' && b.i == 1 && return copy(D.ev) + convert(V, Zeros{T}(size(D,1)-abs(b.i))) end diff --git a/test/test_interface.jl b/test/test_interface.jl index 5b747626..47c9211e 100644 --- a/test/test_interface.jl +++ b/test/test_interface.jl @@ -58,11 +58,6 @@ LinearAlgebra.fill!(A::PseudoBandedMatrix, v) = fill!(A.data,v) @test B * Eye(5) == B @test muladd!(2.0, Eye(5), B, 0.0, zeros(5,5)) == 2B @test muladd!(2.0, B, Eye(5), 0.0, zeros(5,5)) == 2B - - E = Eye(4) - @test (@inferred E[band(0)]) == Ones(4) - @test (@inferred E[band(1)]) == Zeros(3) - @test (@inferred E[band(-1)]) == Zeros(3) end @testset "Diagonal" begin @@ -84,9 +79,9 @@ LinearAlgebra.fill!(A::PseudoBandedMatrix, v) = fill!(A.data,v) @test A[band(0)] == [2; ones(4)] B = Diagonal(Fill(1,5)) - @test (@inferred B[band(0)]) == Fill(1,5) - @test (@inferred B[band(1)]) == B[band(-1)] == Fill(0,4) - @test (@inferred B[band(2)]) == B[band(-2)] == Fill(0,3) + @test B[band(0)] ≡ Fill(1,5) + @test B[band(1)] ≡ B[band(-1)] ≡ Fill(0,4) + @test B[band(2)] ≡ B[band(-2)] ≡ Fill(0,3) end @testset "SymTridiagonal" begin @@ -98,32 +93,32 @@ LinearAlgebra.fill!(A::PseudoBandedMatrix, v) = fill!(A.data,v) @test A[1,1] == 2 B = SymTridiagonal(Fill(1,5), Fill(2,4)) - @test (@inferred B[band(0)]) == Fill(1,5) - @test (@inferred B[band(1)]) == B[band(-1)] == Fill(2,4) - @test (@inferred B[band(2)]) == B[band(-2)] == Fill(0,3) + @test B[band(0)] ≡ Fill(1,5) + @test B[band(1)] ≡ B[band(-1)] ≡ Fill(2,4) + @test B[band(2)] ≡ B[band(-2)] ≡ Fill(0,3) end @testset "Tridiagonal" begin B = Tridiagonal(Fill(1,4), Fill(2,5), Fill(3,4)) - @test (@inferred B[band(0)]) == Fill(2,5) - @test (@inferred B[band(1)]) == Fill(3,4) - @test (@inferred B[band(-1)]) == Fill(1,4) - @test B[band(2)] == B[band(-2)] == Fill(0,3) + @test B[band(0)] ≡ Fill(2,5) + @test B[band(1)] ≡ Fill(3,4) + @test B[band(-1)] ≡ Fill(1,4) + @test B[band(2)] ≡ B[band(-2)] ≡ Fill(0,3) end @testset "Bidiagonal" begin L = Bidiagonal(Fill(2,5), Fill(1,4), :L) - @test (@inferred L[band(0)]) == Fill(2,5) - @test (@inferred L[band(1)]) == Fill(0,4) - @test (@inferred L[band(-1)]) == Fill(1,4) - @test (@inferred L[band(2)]) == L[band(-2)] == Fill(0,3) + @test L[band(0)] ≡ Fill(2,5) + @test L[band(1)] ≡ Fill(0,4) + @test L[band(-1)] ≡ Fill(1,4) + @test L[band(2)] ≡ L[band(-2)] ≡ Fill(0,3) @test BandedMatrix(L) == L U = Bidiagonal(Fill(2,5), Fill(1,4), :U) - @test (@inferred U[band(0)]) == Fill(2,5) - @test (@inferred U[band(1)]) == Fill(1,4) - @test (@inferred U[band(-1)]) == Fill(0,4) - @test (@inferred U[band(2)]) == U[band(-2)] == Fill(0,3) + @test U[band(0)] ≡ Fill(2,5) + @test U[band(1)] ≡ Fill(1,4) + @test U[band(-1)] ≡ Fill(0,4) + @test U[band(2)] ≡ U[band(-2)] ≡ Fill(0,3) @test BandedMatrix(U) == U end