From 00ad0aa290c716cfb6b312180becae93641dcb5d Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Sun, 3 Dec 2023 14:11:58 +0000 Subject: [PATCH] Support RectDiagonal (#411) --- Project.toml | 2 +- src/BandedMatrices.jl | 2 +- src/banded/BandedMatrix.jl | 8 ++++---- src/interfaceimpl.jl | 6 +++--- test/test_interface.jl | 4 ++++ 5 files changed, 13 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 1a28e519..a29a3f75 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "BandedMatrices" uuid = "aae01518-5342-5314-be14-df237901396f" -version = "1.2.1" +version = "1.3" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/BandedMatrices.jl b/src/BandedMatrices.jl index 914b97de..5ce905d2 100644 --- a/src/BandedMatrices.jl +++ b/src/BandedMatrices.jl @@ -34,7 +34,7 @@ import ArrayLayouts: MemoryLayout, transposelayout, triangulardata, _qr!, _qr, _lu!, _lu, _factorize, AbstractTridiagonalLayout, TridiagonalLayout, BidiagonalLayout, bidiagonaluplo, diagonaldata, supdiagonaldata, subdiagonaldata, copymutable_oftype_layout, dualadjoint -import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement +import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement, RectDiagonal const libblas = LinearAlgebra.BLAS.libblas const liblapack = LinearAlgebra.BLAS.liblapack diff --git a/src/banded/BandedMatrix.jl b/src/banded/BandedMatrix.jl index 98aa5b78..ca247744 100644 --- a/src/banded/BandedMatrix.jl +++ b/src/banded/BandedMatrix.jl @@ -226,10 +226,10 @@ BandedMatrix{V,C,Base.OneTo{Int}}(Z::Zeros{T,2}, bnds::NTuple{2,Integer}) where BandedMatrix{V}(Z::Zeros{T,2}, bnds::NTuple{2,Integer}) where {T,V} = _BandedMatrix(zeros(V,max(0,sum(bnds)+1),size(Z,2)),size(Z,1),bnds...) -BandedMatrix(E::Eye{T}, bnds::NTuple{2,Integer}) where T = BandedMatrix{T}(E, bnds) -function BandedMatrix{T}(E::Eye, bnds::NTuple{2,Integer}) where T - ret=BandedMatrix(Zeros{T}(E), bnds) - ret[band(0)] .= one(T) +BandedMatrix(E::RectDiagonal{T}, bnds::NTuple{2,Integer}) where T = BandedMatrix{T}(E, bnds) +function BandedMatrix{T}(E::RectDiagonal, bnds::NTuple{2,Integer}) where T + ret = BandedMatrix(Zeros{T}(E), bnds) + ret[band(0)] .= E.diag ret end diff --git a/src/interfaceimpl.jl b/src/interfaceimpl.jl index ab791a44..ce22de65 100644 --- a/src/interfaceimpl.jl +++ b/src/interfaceimpl.jl @@ -18,9 +18,9 @@ isbanded(::Zeros) = true bandwidths(::Zeros) = (-40320,-40320) # 40320 == prod(1:8), used for special cases involving gcd inbands_getindex(::Zeros{T}, k::Integer, j::Integer) where T = zero(T) -isbanded(::Eye) = true -bandwidths(::Eye) = (0,0) -inbands_getindex(::Eye{T}, k::Integer, j::Integer) where T = one(T) +isbanded(::RectDiagonal) = true +bandwidths(::RectDiagonal) = (0,0) +inbands_getindex(E::RectDiagonal, k::Integer, j::Integer) = E.diag[k] isbanded(::Diagonal) = true bandwidths(::Diagonal) = (0,0) diff --git a/test/test_interface.jl b/test/test_interface.jl index 244734d5..7173eaca 100644 --- a/test/test_interface.jl +++ b/test/test_interface.jl @@ -59,6 +59,10 @@ LinearAlgebra.fill!(A::PseudoBandedMatrix, v) = fill!(A.data,v) @test B * Eye(5) == B @test muladd!(2.0, Eye(5), B, 0.0, zeros(5,5)) == 2B @test muladd!(2.0, B, Eye(5), 0.0, zeros(5,5)) == 2B + + @test isbanded(2Eye(5,6)) + @test bandwidths(2Eye(5,6)) == (0,0) + @test BandedMatrices.inbands_getindex(2Eye(5,6), 1,1) == 2 end @testset "Diagonal" begin