Skip to content

Commit

Permalink
Add handling of an empty iterator for mean and var (#29033)
Browse files Browse the repository at this point in the history
  • Loading branch information
bkamins authored and nalimilan committed Oct 11, 2018
1 parent e13b285 commit b519147
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
10 changes: 6 additions & 4 deletions stdlib/Statistics/src/Statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ julia> mean([√1, √2, √3])
function mean(f::Base.Callable, itr)
y = iterate(itr)
if y === nothing
throw(ArgumentError("mean of empty collection undefined: $(repr(itr))"))
return Base.mapreduce_empty_iter(f, Base.add_sum, itr,
Base.IteratorEltype(itr)) / 0
end
count = 1
value, state = y
Expand Down Expand Up @@ -131,7 +132,7 @@ _mean(A::AbstractArray{T}, region) where {T} = mean!(Base.reducedim_init(t -> t/
_mean(A::AbstractArray, ::Colon) = sum(A) / length(A)

function mean(r::AbstractRange{<:Real})
isempty(r) && throw(ArgumentError("mean of an empty range is undefined"))
isempty(r) && return oftype((first(r) + last(r)) / 2, NaN)
(first(r) + last(r)) / 2
end

Expand All @@ -148,7 +149,8 @@ var(iterable; corrected::Bool=true, mean=nothing) = _var(iterable, corrected, me
function _var(iterable, corrected::Bool, mean)
y = iterate(iterable)
if y === nothing
throw(ArgumentError("variance of empty collection undefined: $(repr(iterable))"))
T = eltype(iterable)
return oftype((abs2(zero(T)) + abs2(zero(T)))/2, NaN)
end
count = 1
value, state = y
Expand Down Expand Up @@ -265,7 +267,7 @@ varm(A::AbstractArray, m; corrected::Bool=true) = _varm(A, m, corrected, :)

function _varm(A::AbstractArray{T}, m, corrected::Bool, ::Colon) where T
n = length(A)
n == 0 && return typeof((abs2(zero(T)) + abs2(zero(T)))/2)(NaN)
n == 0 && return oftype((abs2(zero(T)) + abs2(zero(T)))/2, NaN)
return centralize_sumabs2(A, m) / (n - Int(corrected))
end

Expand Down
39 changes: 34 additions & 5 deletions stdlib/Statistics/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ end
end

@testset "mean" begin
@test_throws ArgumentError mean(())
@test_throws MethodError mean(())
@test mean((1,2,3)) === 2.
@test mean([0]) === 0.
@test mean([1.]) === 1.
Expand All @@ -86,6 +86,21 @@ end
@test ismissing(mean([missing, NaN]))
@test isequal(mean([missing 1.0; 2.0 3.0], dims=1), [missing 2.0])
@test mean(skipmissing([1, missing, 2])) === 1.5
@test isequal(mean(Complex{Float64}[]), NaN+NaN*im)
@test mean(Complex{Float64}[]) isa Complex{Float64}
@test isequal(mean(skipmissing(Complex{Float64}[])), NaN+NaN*im)
@test mean(skipmissing(Complex{Float64}[])) isa Complex{Float64}
@test isequal(mean(abs, Complex{Float64}[]), NaN)
@test mean(abs, Complex{Float64}[]) isa Float64
@test isequal(mean(abs, skipmissing(Complex{Float64}[])), NaN)
@test mean(abs, skipmissing(Complex{Float64}[])) isa Float64
@test isequal(mean(Int[]), NaN)
@test mean(Int[]) isa Float64
@test isequal(mean(skipmissing(Int[])), NaN)
@test mean(skipmissing(Int[])) isa Float64
@test_throws MethodError mean([])
@test_throws MethodError mean(skipmissing([]))
@test_throws ArgumentError mean((1 for i in 2:1))

# Check that small types are accumulated using wider type
for T in (Int8, UInt8)
Expand All @@ -104,15 +119,17 @@ end
@test f(2:0.1:n) f([2:0.1:n;])
end
end
@test mean(2:1) === NaN
@test mean(big(2):1) isa BigFloat
end

@testset "var & std" begin
# edge case: empty vector
# iterable; this has to throw for type stability
@test_throws ArgumentError var(())
@test_throws ArgumentError var((); corrected=false)
@test_throws ArgumentError var((); mean=2)
@test_throws ArgumentError var((); mean=2, corrected=false)
@test_throws MethodError var(())
@test_throws MethodError var((); corrected=false)
@test_throws MethodError var((); mean=2)
@test_throws MethodError var((); mean=2, corrected=false)
# reduction
@test isnan(var(Int[]))
@test isnan(var(Int[]; corrected=false))
Expand Down Expand Up @@ -245,6 +262,18 @@ end
@test ismissing(f([missing, NaN], missing))
@test f(skipmissing([1, missing, 2]), 0) === f([1, 2], 0)
end

@test isequal(var(Complex{Float64}[]), NaN)
@test var(Complex{Float64}[]) isa Float64
@test isequal(var(skipmissing(Complex{Float64}[])), NaN)
@test var(skipmissing(Complex{Float64}[])) isa Float64
@test_throws MethodError var([])
@test_throws MethodError var(skipmissing([]))
@test_throws MethodError var((1 for i in 2:1))
@test isequal(var(Int[]), NaN)
@test var(Int[]) isa Float64
@test isequal(var(skipmissing(Int[])), NaN)
@test var(skipmissing(Int[])) isa Float64
end

function safe_cov(x, y, zm::Bool, cr::Bool)
Expand Down

0 comments on commit b519147

Please sign in to comment.