Skip to content

Commit

Permalink
Accomodate for rectangular matrices in copytrito! (#54587)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed May 29, 2024
1 parent dc63ab2 commit fc54be6
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 19 deletions.
25 changes: 15 additions & 10 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2014,19 +2014,24 @@ function copytrito!(B::AbstractMatrix, A::AbstractMatrix, uplo::AbstractChar)
BLAS.chkuplo(uplo)
m,n = size(A)
m1,n1 = size(B)
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)"))
A = Base.unalias(B, A)
if uplo == 'U'
for j=1:n
for i=1:min(j,m)
@inbounds B[i,j] = A[i,j]
end
if n < m
(m1 < n || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($n,$n)"))
else
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
end
else # uplo == 'L'
for j=1:n
for i=j:m
@inbounds B[i,j] = A[i,j]
end
for j in 1:n, i in 1:min(j,m)
@inbounds B[i,j] = A[i,j]
end
else # uplo == 'L'
if m < n
(m1 < m || n1 < m) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$m)"))
else
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
end
for j in 1:n, i in j:m
@inbounds B[i,j] = A[i,j]
end
end
return B
Expand Down
20 changes: 17 additions & 3 deletions stdlib/LinearAlgebra/src/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7157,9 +7157,23 @@ for (fn, elty) in ((:dlacpy_, :Float64),
function lacpy!(B::AbstractMatrix{$elty}, A::AbstractMatrix{$elty}, uplo::AbstractChar)
require_one_based_indexing(A, B)
chkstride1(A, B)
m,n = size(A)
m1,n1 = size(B)
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)"))
m, n = size(A)
m1, n1 = size(B)
if uplo == 'U'
if n < m
(m1 < n || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($n,$n)"))
else
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
end
elseif uplo == 'L'
if m < n
(m1 < m || n1 < m) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$m)"))
else
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
end
else
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
end
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
ccall((@blasfunc($fn), libblastrampoline), Cvoid,
Expand Down
52 changes: 47 additions & 5 deletions stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -654,12 +654,54 @@ end

@testset "copytrito!" begin
n = 10
for A in (rand(n, n), rand(Int8, n, n)), uplo in ('L', 'U')
for AA in (A, view(A, reverse.(axes(A))...))
for B in (zeros(n, n), zeros(n+1, n+2))
copytrito!(B, AA, uplo)
@testset "square" begin
for A in (rand(n, n), rand(Int8, n, n)), uplo in ('L', 'U')
for AA in (A, view(A, reverse.(axes(A))...))
C = uplo == 'L' ? tril(AA) : triu(AA)
@test view(B, 1:n, 1:n) == C
for B in (zeros(n, n), zeros(n+1, n+2))
copytrito!(B, AA, uplo)
@test view(B, 1:n, 1:n) == C
end
end
end
end
@testset "wide" begin
for A in (rand(n, 2n), rand(Int8, n, 2n))
for AA in (A, view(A, reverse.(axes(A))...))
C = tril(AA)
for (M, N) in ((n, n), (n+1, n), (n, n+1), (n+1, n+1))
B = zeros(M, N)
copytrito!(B, AA, 'L')
@test view(B, 1:n, 1:n) == view(C, 1:n, 1:n)
end
@test_throws DimensionMismatch copytrito!(zeros(n-1, 2n), AA, 'L')
C = triu(AA)
for (M, N) in ((n, 2n), (n+1, 2n), (n, 2n+1), (n+1, 2n+1))
B = zeros(M, N)
copytrito!(B, AA, 'U')
@test view(B, 1:n, 1:2n) == view(C, 1:n, 1:2n)
end
@test_throws DimensionMismatch copytrito!(zeros(n+1, 2n-1), AA, 'U')
end
end
end
@testset "tall" begin
for A in (rand(2n, n), rand(Int8, 2n, n))
for AA in (A, view(A, reverse.(axes(A))...))
C = triu(AA)
for (M, N) in ((n, n), (n+1, n), (n, n+1), (n+1, n+1))
B = zeros(M, N)
copytrito!(B, AA, 'U')
@test view(B, 1:n, 1:n) == view(C, 1:n, 1:n)
end
@test_throws DimensionMismatch copytrito!(zeros(n-1, n+1), AA, 'U')
C = tril(AA)
for (M, N) in ((2n, n), (2n, n+1), (2n+1, n), (2n+1, n+1))
B = zeros(M, N)
copytrito!(B, AA, 'L')
@test view(B, 1:2n, 1:n) == view(C, 1:2n, 1:n)
end
@test_throws DimensionMismatch copytrito!(zeros(n-1, n+1), AA, 'L')
end
end
end
Expand Down
20 changes: 19 additions & 1 deletion stdlib/LinearAlgebra/test/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -805,8 +805,26 @@ end
B = zeros(elty, n, n)
LinearAlgebra.LAPACK.lacpy!(B, A, uplo)
C = uplo == 'L' ? tril(A) : (uplo == 'U' ? triu(A) : A)
@test B C
@test B == C
B = zeros(elty, n+1, n+1)
LinearAlgebra.LAPACK.lacpy!(B, A, uplo)
C = uplo == 'L' ? tril(A) : (uplo == 'U' ? triu(A) : A)
@test view(B, 1:n, 1:n) == C
end
A = rand(elty, n, n+1)
B = zeros(elty, n, n)
LinearAlgebra.LAPACK.lacpy!(B, A, 'L')
@test B == view(tril(A), 1:n, 1:n)
B = zeros(elty, n, n+1)
LinearAlgebra.LAPACK.lacpy!(B, A, 'U')
@test B == triu(A)
A = rand(elty, n+1, n)
B = zeros(elty, n, n)
LinearAlgebra.LAPACK.lacpy!(B, A, 'U')
@test B == view(triu(A), 1:n, 1:n)
B = zeros(elty, n+1, n)
LinearAlgebra.LAPACK.lacpy!(B, A, 'L')
@test B == tril(A)
end
end

Expand Down

0 comments on commit fc54be6

Please sign in to comment.