Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make searchsorted*/findnext/findprev return values of keytype #32978

Merged
merged 13 commits into from
Apr 28, 2020
8 changes: 4 additions & 4 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,7 @@ CartesianIndex(2, 1)
"""
function findnext(A, start)
l = last(keys(A))
i = start
i = oftype(l, start)
i > l && return nothing
while true
A[i] && return i
Expand Down Expand Up @@ -1735,7 +1735,7 @@ CartesianIndex(1, 1)
"""
function findnext(testf::Function, A, start)
l = last(keys(A))
i = start
i = oftype(l, start)
i > l && return nothing
while true
testf(A[i]) && return i
Expand Down Expand Up @@ -1839,8 +1839,8 @@ CartesianIndex(2, 1)
```
"""
function findprev(A, start)
i = start
f = first(keys(A))
i = oftype(f, start)
i < f && return nothing
while true
A[i] && return i
Expand Down Expand Up @@ -1930,8 +1930,8 @@ CartesianIndex(2, 1)
```
"""
function findprev(testf::Function, A, start)
i = start
f = first(keys(A))
i = oftype(f, start)
i < f && return nothing
while true
testf(A[i]) && return i
Expand Down
14 changes: 8 additions & 6 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,7 @@ end
function findnext(B::BitArray, start::Integer)
start > 0 || throw(BoundsError(B, start))
start > length(B) && return nothing
unsafe_bitfindnext(B.chunks, start)
unsafe_bitfindnext(B.chunks, Int(start))
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsafe_bitfindnext (and prev, too) accepts a start::Integer. I was about to make a comment that we should just move this there and/or restrict its signature, but then I realized it's also used by BitSet. BitSet demands an Int64 result, whereas its use for BitArray demands an Int result. So this seems like the right way to go about it. 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it changed, but my impression is that BitSet calls unsafe_bitfindnext by feeding it an Int, not an Int64, so I would also favor restricting unsafe_bitfindnext to Int input, as this is an internal function. But the PR is also fine as is!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now restricted unsafe_bitfindnext to Int.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sorry for the nitpick, but please change also unsafe_bitfindprev accordingly too :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was fast!

end

#findfirst(B::BitArray) = findnext(B, 1) ## defined in array.jl
Expand All @@ -1411,8 +1411,9 @@ function findnextnot(B::BitArray, start::Integer)
l = length(Bc)
l == 0 && return nothing

chunk_start = _div64(start-1)+1
within_chunk_start = _mod64(start-1)
st = Int(start)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an aside, note that for this kind of thing, there is no problem writing start = Int(start), which saves you from having to find another name :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then I will change it.

chunk_start = _div64(st-1)+1
within_chunk_start = _mod64(st-1)
mask = ~(_msk64 << within_chunk_start)

@inbounds if chunk_start < l
Expand Down Expand Up @@ -1480,7 +1481,7 @@ end
function findprev(B::BitArray, start::Integer)
start > 0 || return nothing
start > length(B) && throw(BoundsError(B, start))
unsafe_bitfindprev(B.chunks, start)
unsafe_bitfindprev(B.chunks, Int(start))
end

function findprevnot(B::BitArray, start::Integer)
Expand All @@ -1489,8 +1490,9 @@ function findprevnot(B::BitArray, start::Integer)

Bc = B.chunks

chunk_start = _div64(start-1)+1
mask = ~_msk_end(start)
st = Int(start)
chunk_start = _div64(st-1)+1
mask = ~_msk_end(st)

@inbounds begin
if Bc[chunk_start] | mask != _msk64
Expand Down
25 changes: 13 additions & 12 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using .Base: copymutable, LinearIndices, length, (:),
AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline,
AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !,
extrema, sub_with_overflow, add_with_overflow, oneunit, div, getindex, setindex!,
length, resize!, fill, Missing, require_one_based_indexing
length, resize!, fill, Missing, require_one_based_indexing, keytype

using .Base: >>>, !==

Expand Down Expand Up @@ -174,7 +174,7 @@ midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...)

# index of the first value of vector a that is greater than or equal to x;
# returns length(v)+1 if x is greater than all values in v.
function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering) where T<:Integer
function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keytype(v) where T<:Integer
u = T(1)
lo = lo - u
hi = hi + u
Expand All @@ -191,7 +191,7 @@ end

# index of the last value of vector a that is less than or equal to x;
# returns 0 if x is less than all values of v.
function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering) where T<:Integer
function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keytype(v) where T<:Integer
u = T(1)
lo = lo - u
hi = hi + u
Expand All @@ -209,7 +209,7 @@ end
# returns the range of indices of v equal to x
# if v does not contain x, returns a 0-length range
# indicating the insertion point of x
function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering) where T<:Integer
function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering)::UnitRange{keytype(v)} where T<:Integer
u = T(1)
lo = ilo - u
hi = ihi + u
Expand All @@ -228,7 +228,7 @@ function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering) where T
return (lo + 1) : (hi - 1)
end

function searchsortedlast(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
function searchsortedlast(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if step(a) == 0
lt(o, x, first(a)) ? 0 : length(a)
Expand All @@ -238,7 +238,7 @@ function searchsortedlast(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
end
end

function searchsortedfirst(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
function searchsortedfirst(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if step(a) == 0
lt(o, first(a), x) ? length(a) + 1 : 1
Expand All @@ -248,7 +248,7 @@ function searchsortedfirst(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
end
end

function searchsortedlast(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)
function searchsortedlast(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
h = step(a)
if h == 0
Expand All @@ -270,7 +270,7 @@ function searchsortedlast(a::AbstractRange{<:Integer}, x::Real, o::DirectOrderin
end
end

function searchsortedfirst(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)
function searchsortedfirst(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
h = step(a)
if h == 0
Expand All @@ -285,14 +285,15 @@ function searchsortedfirst(a::AbstractRange{<:Integer}, x::Real, o::DirectOrderi
lastindex(a) + 1
else
if o isa ForwardOrdering
-fld(floor(Integer, -x) + first(a), h) + 1
y = isa(x, Unsigned) ? floor(-Signed(x)) : floor(Integer, -x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When x isa Unsigned, this method looks like it won't be called, but rather the next method below (I guess this split didn't exist when you put up this PR).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the method below is used for x::Unsigned. I have restored the original, except for using Signed(first(a)) instead of first(a) (which is needed for passing the tests added in this PR).

else
-fld(ceil(Integer, -x) + first(a), h) + 1
y = isa(x, Unsigned) ? ceil(-Signed(x)) : ceil(Integer, -x)
end
-fld(y + Signed(first(a)), h) + 1
end
end

function searchsortedfirst(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)
function searchsortedfirst(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if lt(o, first(a), x)
if step(a) == 0
Expand All @@ -305,7 +306,7 @@ function searchsortedfirst(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOr
end
end

function searchsortedlast(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)
function searchsortedlast(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if lt(o, x, first(a))
0
Expand Down
6 changes: 3 additions & 3 deletions base/strings/search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ function findnext(testf::Function, s::AbstractString, i::Integer)
@inbounds i == z || isvalid(s, i) || string_index_err(s, i)
for (j, d) in pairs(SubString(s, i))
if testf(d)
return i + j - 1
return Int(i + j - 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a conversion is needed somewhere in this function, my personal preference would favor converting at the begining (i = Int(i), which might help having to compile less function specializations, like isvalid or SubString, but this also might not matter). But keep it as you prefer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I have changed it.

end
end
return nothing
Expand Down Expand Up @@ -272,7 +272,7 @@ julia> findnext("Lang", "JuliaLang", 2)
6:9
```
"""
findnext(t::AbstractString, s::AbstractString, i::Integer) = _search(s, t, i)
findnext(t::AbstractString, s::AbstractString, i::Integer) = _search(s, t, Int(i))

"""
findnext(ch::AbstractChar, string::AbstractString, start::Integer)
Expand Down Expand Up @@ -484,7 +484,7 @@ julia> findprev("Julia", "JuliaLang", 6)
1:5
```
"""
findprev(t::AbstractString, s::AbstractString, i::Integer) = _rsearch(s, t, i)
findprev(t::AbstractString, s::AbstractString, i::Integer) = _rsearch(s, t, Int(i))

"""
findprev(ch::AbstractChar, string::AbstractString, start::Integer)
Expand Down
16 changes: 16 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,22 @@ end
@test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i)
@test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i)
end

# issue 32568
for T = (UInt, BigInt)
@test findnext(!iszero, x_sp, T(4)) isa keytype(x_sp)
@test findnext(!iszero, x_sp, T(5)) isa keytype(x_sp)
@test findprev(!iszero, x_sp, T(5)) isa keytype(x_sp)
@test findprev(!iszero, x_sp, T(6)) isa keytype(x_sp)
@test findnext(iseven, x_sp, T(4)) isa keytype(x_sp)
@test findnext(iseven, x_sp, T(5)) isa keytype(x_sp)
@test findprev(iseven, x_sp, T(4)) isa keytype(x_sp)
@test findprev(iseven, x_sp, T(5)) isa keytype(x_sp)
@test findnext(!iszero, z_sp, T(4)) isa keytype(z_sp)
@test findnext(!iszero, z_sp, T(5)) isa keytype(z_sp)
@test findprev(!iszero, z_sp, T(4)) isa keytype(z_sp)
@test findprev(!iszero, z_sp, T(5)) isa keytype(z_sp)
end
end

# #20711
Expand Down
12 changes: 12 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,18 @@ end
@test findlast(isequal(0x00), [0x01, 0x00]) == 2
@test findnext(isequal(0x00), [0x00, 0x01, 0x00], 2) == 3
@test findprev(isequal(0x00), [0x00, 0x01, 0x00], 2) == 1

@testset "issue 32568" for T = (UInt, BigInt)
@test findnext(!iszero, a, T(1)) isa keytype(a)
@test findnext(!iszero, a, T(2)) isa keytype(a)
@test findprev(!iszero, a, T(4)) isa keytype(a)
@test findprev(!iszero, a, T(5)) isa keytype(a)
b = [true, false, true]
@test findnext(b, T(2)) isa keytype(b)
@test findnext(b, T(3)) isa keytype(b)
@test findprev(b, T(1)) isa keytype(b)
@test findprev(b, T(2)) isa keytype(b)
end
end
@testset "find with Matrix" begin
A = [1 2 0; 3 4 0]
Expand Down
15 changes: 15 additions & 0 deletions test/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,21 @@ timesofar("find")
@test_throws BoundsError findprev(x->true, b1, 11)
@test_throws BoundsError findnext(x->true, b1, -1)

@testset "issue 32568" for T = (UInt, BigInt)
for x = (1, 2)
@test findnext(evens, T(x)) isa keytype(evens)
@test findnext(iseven, evens, T(x)) isa keytype(evens)
@test findnext(isequal(true), evens, T(x)) isa keytype(evens)
@test findnext(isequal(false), evens, T(x)) isa keytype(evens)
end
for x = (3, 4)
@test findprev(evens, T(x)) isa keytype(evens)
@test findprev(iseven, evens, T(x)) isa keytype(evens)
@test findprev(isequal(true), evens, T(x)) isa keytype(evens)
@test findprev(isequal(false), evens, T(x)) isa keytype(evens)
end
end

for l = [1, 63, 64, 65, 127, 128, 129]
f = falses(l)
t = trues(l)
Expand Down
15 changes: 14 additions & 1 deletion test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ end
@test searchsortedlast(500:1.0:600, -1.0e20) == 0
@test searchsortedlast(500:1.0:600, 1.0e20) == 101
end

@testset "issue 32568" begin
for R in numTypes, T in numTypes
for arr in [R[1:5;], R(1):R(5), R(1):2:R(5)]
@test eltype(searchsorted(arr, T(2))) == keytype(arr)
@test eltype(searchsorted(arr, T(2), big(1), big(4), Forward)) == keytype(arr)
@test searchsortedfirst(arr, T(2)) isa keytype(arr)
@test searchsortedfirst(arr, T(2), big(1), big(4), Forward) isa keytype(arr)
@test searchsortedlast(arr, T(2)) isa keytype(arr)
@test searchsortedlast(arr, T(2), big(1), big(4), Forward) isa keytype(arr)
end
end
end

@testset "issue #34157" begin
@test searchsorted(1:2.0, -Inf) === 1:0
@test searchsorted([1,2], -Inf) === 1:0
Expand Down Expand Up @@ -173,7 +187,6 @@ end
@test searchsortedlast(reverse(coll), -huge, rev=true) === lastindex(coll)
@test searchsorted(reverse(coll), huge, rev=true) === firstindex(coll):firstindex(coll) - 1
@test searchsorted(reverse(coll), -huge, rev=true) === lastindex(coll)+1:lastindex(coll)

end
end
end
Expand Down
19 changes: 19 additions & 0 deletions test/strings/search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,22 @@ s_18109 = "fooα🐨βcd3"
@test findall("aa", "aaaaaa") == [1:2, 3:4, 5:6]
@test findall("aa", "aaaaaa", overlap=true) == [1:2, 2:3, 3:4, 4:5, 5:6]
end

# issue 32568
for T = (UInt, BigInt)
for x = (4, 5)
@test eltype(findnext(r"l", astr, T(x))) == Int
@test findnext(isequal('l'), astr, T(x)) isa Int
@test findprev(isequal('l'), astr, T(x)) isa Int
@test findnext('l', astr, T(x)) isa Int
@test findprev('l', astr, T(x)) isa Int
end
for x = (5, 6)
@test eltype(findprev(",b", "foo,bar,baz", T(x))) == Int
end
for x = (7, 8)
@test eltype(findnext(",b", "foo,bar,baz", T(x))) == Int
@test findnext(isletter, astr, T(x)) isa Int
@test findprev(isletter, astr, T(x)) isa Int
end
end
7 changes: 7 additions & 0 deletions test/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,13 @@ end
@test findprev(isequal(1), (1, 1), 1) == 1
@test findnext(isequal(1), (2, 3), 1) === nothing
@test findprev(isequal(1), (2, 3), 2) === nothing

@testset "issue 32568" begin
@test findnext(isequal(1), (1, 2), big(1)) isa Int
@test findprev(isequal(1), (1, 2), big(2)) isa Int
@test findnext(isequal(1), (1, 1), UInt(2)) isa Int
@test findprev(isequal(1), (1, 1), UInt(1)) isa Int
end
end

@testset "properties" begin
Expand Down