From b519147d60683c3a0ed587925bb7221153b64b65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bogumi=C5=82=20Kami=C5=84ski?= Date: Thu, 11 Oct 2018 21:02:23 +0200 Subject: [PATCH] Add handling of an empty iterator for mean and var (#29033) --- stdlib/Statistics/src/Statistics.jl | 10 +++++--- stdlib/Statistics/test/runtests.jl | 39 +++++++++++++++++++++++++---- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/stdlib/Statistics/src/Statistics.jl b/stdlib/Statistics/src/Statistics.jl index bacfac70ee770..74b1fb7b572e5 100644 --- a/stdlib/Statistics/src/Statistics.jl +++ b/stdlib/Statistics/src/Statistics.jl @@ -57,7 +57,8 @@ julia> mean([√1, √2, √3]) function mean(f::Base.Callable, itr) y = iterate(itr) if y === nothing - throw(ArgumentError("mean of empty collection undefined: $(repr(itr))")) + return Base.mapreduce_empty_iter(f, Base.add_sum, itr, + Base.IteratorEltype(itr)) / 0 end count = 1 value, state = y @@ -131,7 +132,7 @@ _mean(A::AbstractArray{T}, region) where {T} = mean!(Base.reducedim_init(t -> t/ _mean(A::AbstractArray, ::Colon) = sum(A) / length(A) function mean(r::AbstractRange{<:Real}) - isempty(r) && throw(ArgumentError("mean of an empty range is undefined")) + isempty(r) && return oftype((first(r) + last(r)) / 2, NaN) (first(r) + last(r)) / 2 end @@ -148,7 +149,8 @@ var(iterable; corrected::Bool=true, mean=nothing) = _var(iterable, corrected, me function _var(iterable, corrected::Bool, mean) y = iterate(iterable) if y === nothing - throw(ArgumentError("variance of empty collection undefined: $(repr(iterable))")) + T = eltype(iterable) + return oftype((abs2(zero(T)) + abs2(zero(T)))/2, NaN) end count = 1 value, state = y @@ -265,7 +267,7 @@ varm(A::AbstractArray, m; corrected::Bool=true) = _varm(A, m, corrected, :) function _varm(A::AbstractArray{T}, m, corrected::Bool, ::Colon) where T n = length(A) - n == 0 && return typeof((abs2(zero(T)) + abs2(zero(T)))/2)(NaN) + n == 0 && return oftype((abs2(zero(T)) + abs2(zero(T)))/2, NaN) return centralize_sumabs2(A, m) / (n - Int(corrected)) end diff --git a/stdlib/Statistics/test/runtests.jl b/stdlib/Statistics/test/runtests.jl index 8cd4129e9bbbd..6c26efd12925f 100644 --- a/stdlib/Statistics/test/runtests.jl +++ b/stdlib/Statistics/test/runtests.jl @@ -60,7 +60,7 @@ end end @testset "mean" begin - @test_throws ArgumentError mean(()) + @test_throws MethodError mean(()) @test mean((1,2,3)) === 2. @test mean([0]) === 0. @test mean([1.]) === 1. @@ -86,6 +86,21 @@ end @test ismissing(mean([missing, NaN])) @test isequal(mean([missing 1.0; 2.0 3.0], dims=1), [missing 2.0]) @test mean(skipmissing([1, missing, 2])) === 1.5 + @test isequal(mean(Complex{Float64}[]), NaN+NaN*im) + @test mean(Complex{Float64}[]) isa Complex{Float64} + @test isequal(mean(skipmissing(Complex{Float64}[])), NaN+NaN*im) + @test mean(skipmissing(Complex{Float64}[])) isa Complex{Float64} + @test isequal(mean(abs, Complex{Float64}[]), NaN) + @test mean(abs, Complex{Float64}[]) isa Float64 + @test isequal(mean(abs, skipmissing(Complex{Float64}[])), NaN) + @test mean(abs, skipmissing(Complex{Float64}[])) isa Float64 + @test isequal(mean(Int[]), NaN) + @test mean(Int[]) isa Float64 + @test isequal(mean(skipmissing(Int[])), NaN) + @test mean(skipmissing(Int[])) isa Float64 + @test_throws MethodError mean([]) + @test_throws MethodError mean(skipmissing([])) + @test_throws ArgumentError mean((1 for i in 2:1)) # Check that small types are accumulated using wider type for T in (Int8, UInt8) @@ -104,15 +119,17 @@ end @test f(2:0.1:n) ≈ f([2:0.1:n;]) end end + @test mean(2:1) === NaN + @test mean(big(2):1) isa BigFloat end @testset "var & std" begin # edge case: empty vector # iterable; this has to throw for type stability - @test_throws ArgumentError var(()) - @test_throws ArgumentError var((); corrected=false) - @test_throws ArgumentError var((); mean=2) - @test_throws ArgumentError var((); mean=2, corrected=false) + @test_throws MethodError var(()) + @test_throws MethodError var((); corrected=false) + @test_throws MethodError var((); mean=2) + @test_throws MethodError var((); mean=2, corrected=false) # reduction @test isnan(var(Int[])) @test isnan(var(Int[]; corrected=false)) @@ -245,6 +262,18 @@ end @test ismissing(f([missing, NaN], missing)) @test f(skipmissing([1, missing, 2]), 0) === f([1, 2], 0) end + + @test isequal(var(Complex{Float64}[]), NaN) + @test var(Complex{Float64}[]) isa Float64 + @test isequal(var(skipmissing(Complex{Float64}[])), NaN) + @test var(skipmissing(Complex{Float64}[])) isa Float64 + @test_throws MethodError var([]) + @test_throws MethodError var(skipmissing([])) + @test_throws MethodError var((1 for i in 2:1)) + @test isequal(var(Int[]), NaN) + @test var(Int[]) isa Float64 + @test isequal(var(skipmissing(Int[])), NaN) + @test var(skipmissing(Int[])) isa Float64 end function safe_cov(x, y, zm::Bool, cr::Bool)