Skip to content

Commit

Permalink
Allow reinterpreting singleton types (#43500)
Browse files Browse the repository at this point in the history
  • Loading branch information
Liozou committed Jan 14, 2022
1 parent 14154fc commit 7b1cc4b
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 9 deletions.
42 changes: 37 additions & 5 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S},IsReshaped} <: AbstractArray{T
@noinline
throw(ArgumentError("cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a $msg size"))
end
function throwsingleton(S::Type, T::Type, kind)
@noinline
throw(ArgumentError("cannot reinterpret $kind `$(S)` array to `$(T)` which is a singleton type"))
end

global reinterpret
function reinterpret(::Type{T}, a::A) where {T,N,S,A<:AbstractArray{S, N}}
Expand All @@ -39,7 +43,11 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S},IsReshaped} <: AbstractArray{T
if N != 0 && sizeof(S) != sizeof(T)
ax1 = axes(a)[1]
dim = length(ax1)
rem(dim*sizeof(S),sizeof(T)) == 0 || thrownonint(S, T, dim)
if Base.issingletontype(T)
dim == 0 || throwsingleton(S, T, "a non-empty")
else
rem(dim*sizeof(S),sizeof(T)) == 0 || thrownonint(S, T, dim)
end
first(ax1) == 1 || throwaxes1(S, T, ax1)
end
readable = array_subpadding(T, S)
Expand All @@ -58,14 +66,20 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S},IsReshaped} <: AbstractArray{T
@noinline
throw(ArgumentError("`reinterpret(reshape, $T, a)` where `eltype(a)` is $(eltype(a)) requires that `axes(a, 1)` (got $(axes(a, 1))) be equal to 1:$(sizeof(T) ÷ sizeof(eltype(a))) (from the ratio of element sizes)"))
end
function throwfromsingleton(S, T)
@noinline
throw(ArgumentError("`reinterpret(reshape, $T, a)` where `eltype(a)` is $S requires that $T be a singleton type, since $S is one"))
end
isbitstype(T) || throwbits(S, T, T)
isbitstype(S) || throwbits(S, T, S)
if sizeof(S) == sizeof(T)
N = ndims(a)
elseif sizeof(S) > sizeof(T)
Base.issingletontype(T) && throwsingleton(S, T, "with reshape a")
rem(sizeof(S), sizeof(T)) == 0 || throwintmult(S, T)
N = ndims(a) + 1
else
Base.issingletontype(S) && throwfromsingleton(S, T)
rem(sizeof(T), sizeof(S)) == 0 || throwintmult(S, T)
N = ndims(a) - 1
N > -1 || throwsize0(S, T, "larger")
Expand Down Expand Up @@ -286,7 +300,7 @@ unaliascopy(a::ReshapedReinterpretArray{T}) where {T} = reinterpret(reshape, T,

function size(a::NonReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
psize = size(a.parent)
size1 = div(psize[1]*sizeof(S), sizeof(T))
size1 = Base.issingletontype(T) ? psize[1] : div(psize[1]*sizeof(S), sizeof(T))
tuple(size1, tail(psize)...)
end
function size(a::ReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
Expand All @@ -300,7 +314,7 @@ size(a::NonReshapedReinterpretArray{T,0}) where {T} = ()
function axes(a::NonReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
paxs = axes(a.parent)
f, l = first(paxs[1]), length(paxs[1])
size1 = div(l*sizeof(S), sizeof(T))
size1 = Base.issingletontype(T) ? l : div(l*sizeof(S), sizeof(T))
tuple(oftype(paxs[1], f:f+size1-1), tail(paxs)...)
end
function axes(a::ReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
Expand Down Expand Up @@ -351,6 +365,10 @@ end
@inline @propagate_inbounds function _getindex_ra(a::NonReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
if Base.issingletontype(T) # singleton types
@boundscheck checkbounds(a, i1, tailinds...)
return T.instance
end
return reinterpret(T, a.parent[i1, tailinds...])
else
@boundscheck checkbounds(a, i1, tailinds...)
Expand Down Expand Up @@ -395,6 +413,10 @@ end
@inline @propagate_inbounds function _getindex_ra(a::ReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
if Base.issingletontype(T) # singleton types
@boundscheck checkbounds(a, i1, tailinds...)
return T.instance
end
return reinterpret(T, a.parent[i1, tailinds...])
end
@boundscheck checkbounds(a, i1, tailinds...)
Expand Down Expand Up @@ -475,7 +497,12 @@ end
v = convert(T, v)::T
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
return setindex!(a.parent, reinterpret(S, v), i1, tailinds...)
if Base.issingletontype(T) # singleton types
@boundscheck checkbounds(a, i1, tailinds...)
# setindex! is a noop except for the index check
else
setindex!(a.parent, reinterpret(S, v), i1, tailinds...)
end
else
@boundscheck checkbounds(a, i1, tailinds...)
ind_start, sidx = divrem((i1-1)*sizeof(T), sizeof(S))
Expand Down Expand Up @@ -536,7 +563,12 @@ end
v = convert(T, v)::T
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
return setindex!(a.parent, reinterpret(S, v), i1, tailinds...)
if Base.issingletontype(T) # singleton types
@boundscheck checkbounds(a, i1, tailinds...)
# setindex! is a noop except for the index check
else
setindex!(a.parent, reinterpret(S, v), i1, tailinds...)
end
end
@boundscheck checkbounds(a, i1, tailinds...)
t = Ref{T}(v)
Expand Down
77 changes: 73 additions & 4 deletions test/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,19 @@ for (_A, Ar, _B) in ((A, Ars, B), (As, Arss, Bs))
@test Arsc == [1 -1; 2 -2]
reinterpret(NTuple{3, Int64}, Bc)[2] = (4,5,6)
@test Bc == Complex{Int64}[5+6im, 7+4im, 5+6im]
reinterpret(NTuple{3, Int64}, Bc)[1] = (1,2,3)
B2 = reinterpret(NTuple{3, Int64}, Bc)
@test setindex!(B2, (1,2,3), 1) == B2
@test Bc == Complex{Int64}[1+2im, 3+4im, 5+6im]
Bc = copy(_B)
Brrs = reinterpret(reshape, Int64, Bc)
Brrs[2, 3] = -5
@test setindex!(Brrs, -5, 2, 3) == Brrs
@test Bc == Complex{Int64}[5+6im, 7+8im, 9-5im]
Brrs[last(eachindex(Brrs))] = 22
@test Bc == Complex{Int64}[5+6im, 7+8im, 9+22im]

A1 = reinterpret(Float64, _A)
A2 = reinterpret(ComplexF64, _A)
A1[1] = 1.0
@test setindex!(A1, 1.0, 1) == A1
@test real(A2[1]) == 1.0
A1 = reinterpret(reshape, Float64, _A)
A1[1] = 2.5
Expand All @@ -88,7 +89,7 @@ for (_A, Ar, _B) in ((A, Ars, B), (As, Arss, Bs))
@test real(A2rs[1]) == 1.0
A1rs = reinterpret(reshape, Float64, Ar)
A2rs = reinterpret(reshape, ComplexF64, Ar)
A1rs[1, 1] = 2.5
@test setindex!(A1rs, 2.5, 1, 1) == A1rs
@test real(A2rs[1]) == 2.5
end
end
Expand Down Expand Up @@ -376,3 +377,71 @@ end
a = reinterpret(reshape, NTuple{4,Float64}, rand(Float64, 4, 4))
@test typeof(Base.unaliascopy(a)) === typeof(a)
end


@testset "singleton types" begin
mutable struct NotASingleton end # not a singleton because it is mutable
struct SomeSingleton
# A singleton type that does not have the internal constructor SomeSingleton()
SomeSingleton(x) = new()
end

@test_throws ErrorException reinterpret(Int, nothing)
@test_throws ErrorException reinterpret(Missing, 3)
@test_throws ErrorException reinterpret(Missing, NotASingleton())
@test_throws ErrorException reinterpret(NotASingleton, ())

@test_throws ArgumentError reinterpret(NotASingleton, fill(nothing, ()))
@test_throws ArgumentError reinterpret(reshape, NotASingleton, fill(missing, 3))
@test_throws ArgumentError reinterpret(Tuple{}, fill(NotASingleton(), 2))
@test_throws ArgumentError reinterpret(reshape, Nothing, fill(NotASingleton(), ()))

t = fill(nothing, 3, 5)
@test reinterpret(SomeSingleton, t) == reinterpret(reshape, SomeSingleton, t)
@test reinterpret(SomeSingleton, t) == [SomeSingleton(i*j) for i in 1:3, j in 1:5]
@test reinterpret(Int, t) == fill(17, 0, 5)
@test_throws ArgumentError reinterpret(reshape, Float64, t)
@test_throws ArgumentError reinterpret(Nothing, 1:6)
@test_throws ArgumentError reinterpret(reshape, Missing, [0.0])

# reintepret of empty array with reshape
@test reinterpret(reshape, Nothing, fill(missing, (0,0,0))) == fill(nothing, (0,0,0))
@test_throws ArgumentError reinterpret(reshape, Nothing, fill(3.2, (0,0)))
@test_throws ArgumentError reinterpret(reshape, Float64, fill(nothing, 0))

# reinterpret of 0-dimensional array
z = reinterpret(Tuple{}, fill(missing, ()))
@test z == fill((), ())
@test z == reinterpret(reshape, Tuple{}, fill(nothing, ()))
@test_throws BoundsError z[2]
@test_throws BoundsError z[3] = ()
@test_throws ArgumentError reinterpret(UInt8, fill(nothing, ()))
@test_throws ArgumentError reinterpret(Missing, fill(1f0, ()))
@test_throws ArgumentError reinterpret(reshape, Float64, fill(nothing, ()))
@test_throws ArgumentError reinterpret(reshape, Nothing, fill(17, ()))


@test @inferred(ndims(reinterpret(reshape, SomeSingleton, t))) == 2
@test @inferred(axes(reinterpret(reshape, Tuple{}, t))) == (Base.OneTo(3),Base.OneTo(5))
@test @inferred(size(reinterpret(reshape, Missing, t))) == (3,5)

x = reinterpret(Tuple{}, t)
@test x == reinterpret(reshape, Tuple{}, t)
@test x[3,5] === ()
x1 = fill((), 3, 5)
@test setindex!(x, (), 1, 1) == x1
@test_throws BoundsError x[17]
@test_throws BoundsError x[4,2]
@test_throws BoundsError x[1,2,3]
@test_throws BoundsError x[18] = ()
@test_throws MethodError x[1,3] = missing
@test x == fill((), (3, 5))
x = reinterpret(reshape, SomeSingleton, t)
@test_throws BoundsError x[19]
@test_throws BoundsError x[2,6] = SomeSingleton(0xa)
@test x[2,3] === SomeSingleton(:x)
x2 = fill(SomeSingleton(0.7), 3, 5)
@test x == x2
@test setindex!(x, SomeSingleton(:), 3, 5) == x2
@test_throws MethodError x[2,4] = nothing
end

8 comments on commit 7b1cc4b

@vtjnash
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nanosoldier runbenchmarks(ALL, vs="#v1.7.1" )

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something went wrong when running your job:

NanosoldierError: failed to run benchmarks against primary commit: IOError: chown("/nanosoldier/workdir/jl_dBfwUr/environment", -1, 26084): no such file or directory (ENOENT)

Logs and partial data can be found here

@vtjnash
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nanosoldier runbenchmarks("scalar", vs="#v1.7.1" )

@vtjnash
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nanosoldier runbenchmarks("scalar", vs="#v1.7.1" )

@vtjnash
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nanosoldier runbenchmarks("scalar", vs="#v1.7.1" )

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something went wrong when running your job:

CompositeException(Any[TaskFailedException(Task (failed) @0x00007fff94d9b340)])

@vtjnash
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nanosoldier runbenchmarks("scalar", vs="#v1.7.1" )

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something went wrong when running your job:

CompositeException(Any[TaskFailedException(Task (failed) @0x00007fff96ed60e0)])

Please sign in to comment.