Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid nesting testsets in test_rule #158

Merged
merged 2 commits into from
Jun 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/ChainRulesTestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ export TestIterator
export check_equal, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix
export ⊢


include("generate_tangent.jl")
include("data_generation.jl")
include("iterator.jl")

include("output_control.jl")
include("check_result.jl")

include("finite_difference_calls.jl")
Expand Down
88 changes: 53 additions & 35 deletions src/check_result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,55 @@
# Note that this must work well both on Differential types and Primal types

"""
check_equal(actual, expected; kwargs...)
check_equal(actual, expected, [msg]; kwargs...)

`@test`'s that `actual ≈ expected`, but breaks up data such that human readable results
are shown on failures.
Understands things like `unthunk`ing `ChainRuleCore.Thunk`s, etc.

If provided `msg` is printed on a failure. Often additional items are appended to `msg` to
give bread-crumbs into nested structures.

All keyword arguments are passed to `isapprox`.
"""
function check_equal(
actual::Union{AbstractArray{<:Number},Number},
expected::Union{AbstractArray{<:Number},Number};
expected::Union{AbstractArray{<:Number},Number},
msg="";
kwargs...,
)
@test isapprox(actual, expected; kwargs...)
@test_msg msg isapprox(actual, expected; kwargs...)
end

for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk))
@eval function check_equal(actual::$T1, expected::$T2; kwargs...)
return check_equal(unthunk(actual), unthunk(expected); kwargs...)
@eval function check_equal(actual::$T1, expected::$T2, msg=""; kwargs...)
return check_equal(unthunk(actual), unthunk(expected), msg; kwargs...)
end
end

check_equal(::ZeroTangent, x; kwargs...) = check_equal(zero(x), x; kwargs...)
check_equal(x, ::ZeroTangent; kwargs...) = check_equal(x, zero(x); kwargs...)
check_equal(x::ZeroTangent, y::ZeroTangent; kwargs...) = @test true
check_equal(::ZeroTangent, x, msg=""; kwargs...) = check_equal(zero(x), x, msg; kwargs...)
check_equal(x, ::ZeroTangent, msg=""; kwargs...) = check_equal(x, zero(x), msg; kwargs...)
check_equal(x::ZeroTangent, y::ZeroTangent, msg=""; kwargs...) = @test true

# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
check_equal(x::NoTangent, y::Nothing; kwargs...) = @test true
check_equal(x::Nothing, y::NoTangent; kwargs...) = @test true
check_equal(x::NoTangent, y::Nothing, msg=""; kwargs...) = @test true
check_equal(x::Nothing, y::NoTangent, msg=""; kwargs...) = @test true

# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
# not yet been implemented
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
check_equal(x::ChainRulesCore.NotImplemented, y; kwargs...) = @test_broken x == y
check_equal(x, y::ChainRulesCore.NotImplemented; kwargs...) = @test_broken x == y
check_equal(x::ChainRulesCore.NotImplemented, y, msg=""; kwargs...) = @test_broken x == y
check_equal(x, y::ChainRulesCore.NotImplemented, msg=""; kwargs...) = @test_broken x == y
# In this case we check for equality (messages etc. have to be equal)
function check_equal(
x::ChainRulesCore.NotImplemented, y::ChainRulesCore.NotImplemented; kwargs...
x::ChainRulesCore.NotImplemented, y::ChainRulesCore.NotImplemented, msg=""; kwargs...
)
return @test x == y
return @test_msg msg x == y
end

"""
_can_pass_early(actual, expected; kwargs...)
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper;
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper
and can just report `check_equal` as passing.

If either `==` or `≈` return true then so does this.
Expand All @@ -64,60 +69,71 @@ function _can_pass_early(actual, expected; kwargs...)
return false
end

function check_equal(actual::AbstractArray, expected::AbstractArray; kwargs...)
function check_equal(actual::AbstractArray, expected::AbstractArray, msg=""; kwargs...)
if _can_pass_early(actual, expected)
@test true
else
@test eachindex(actual) == eachindex(expected)
@testset "$(typeof(actual))[$ii]" for ii in eachindex(actual)
check_equal(actual[ii], expected[ii]; kwargs...)
@test_msg "$msg: indices must match" eachindex(actual) == eachindex(expected)
for ii in eachindex(actual)
new_msg = "$msg $(typeof(actual))[$ii]"
check_equal(actual[ii], expected[ii], new_msg; kwargs...)
end
end
end

function check_equal(actual::Tangent{P}, expected::Tangent{P}; kwargs...) where {P}
function check_equal(actual::Tangent{P}, expected::Tangent{P}, msg=""; kwargs...) where {P}
if _can_pass_early(actual, expected)
@test true
else
all_keys = union(keys(actual), keys(expected))
@testset "$P.$ii" for ii in all_keys
check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...)
for ii in all_keys
new_msg = "$msg $P.$ii"
check_equal(
getproperty(actual, ii), getproperty(expected, ii), new_msg; kwargs...
)
end
end
end

function check_equal(
::Tangent{ActualPrimal}, expected::Tangent{ExpectedPrimal}; kwargs...
::Tangent{ActualPrimal}, expected::Tangent{ExpectedPrimal}, msg=""; kwargs...
) where {ActualPrimal,ExpectedPrimal}
# this will certainly fail as we have another dispatch for that, but this will give as
# good error message
@test ActualPrimal === ExpectedPrimal
end

# Some structual differential and a natural differential
function check_equal(actual::Tangent{P,T}, expected; kwargs...) where {T,P}
function check_equal(actual::Tangent{P,T}, expected, msg=""; kwargs...) where {T,P}
if _can_pass_early(actual, expected)
@test true
else
@assert (T <: NamedTuple) # it should be a structual differential if we hit this

# We are only checking the properties that are in the Tangent
# the natural differential is allowed to have other properties that we ignore
@testset "$P.$ii" for ii in propertynames(actual)
check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...)
for ii in propertynames(actual)
new_msg = "$msg $P.$ii"
check_equal(
getproperty(actual, ii), getproperty(expected, ii), new_msg; kwargs...
)
end
end
end
check_equal(x, y::Tangent; kwargs...) = check_equal(y, x; kwargs...)
check_equal(x, y::Tangent, msg=""; kwargs...) = check_equal(y, x, msg; kwargs...)

# This catches comparisons of Tangents and Tuples/NamedTuple
# and gives an error message complaining about that
# and gives an error message complaining about that. the `@test` will definitely fail
const LegacyZygoteCompTypes = Union{Tuple,NamedTuple}
check_equal(::C, ::T; kwargs...) where {C<:Tangent,T<:LegacyZygoteCompTypes} = @test C === T
check_equal(::T, ::C; kwargs...) where {C<:Tangent,T<:LegacyZygoteCompTypes} = @test T === C
function check_equal(x::Tangent, y::LegacyZygoteCompTypes, msg=""; kwargs...)
@test_msg "$msg: for structural differentials use `Tangent`" typeof(x) === typeof(y)
end
function check_equal(x::LegacyZygoteCompTypes, y::Tangent, msg=""; kwargs...)
return check_equal(y, x, msg; kwargs...)
end

# Generic fallback, probably a tuple or something
function check_equal(actual::A, expected::E; kwargs...) where {A,E}
function check_equal(actual::A, expected::E, msg=""; kwargs...) where {A,E}
if _can_pass_early(actual, expected)
@test true
else
Expand All @@ -130,6 +146,8 @@ function check_equal(actual::A, expected::E; kwargs...) where {A,E}
end
end

###########################################################################################

"""
_check_add!!_behaviour(acc, val)

Expand All @@ -146,19 +164,19 @@ function _check_add!!_behaviour(acc, val; kwargs...)
# e.g. if it is immutable. We do test the `add!!` return value.
# That is what people should rely on. The mutation is just to save allocations.
acc_mutated = deepcopy(acc) # prevent this test changing others
return check_equal(add!!(acc_mutated, val), acc + val; kwargs...)
return check_equal(add!!(acc_mutated, val), acc + val, "in add!!"; kwargs...)
end

# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
# not yet been implemented
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has
# intentionally not yet been implemented
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
function _check_add!!_behaviour(acc_mutated, acc::ChainRulesCore.NotImplemented; kwargs...)
return @test_broken acc_mutated == acc
end
function _check_add!!_behaviour(acc_mutated::ChainRulesCore.NotImplemented, acc; kwargs...)
return @test_broken acc_mutated == acc
end
# In this case we check for equality (messages etc. have to be equal)
# In this case we check for equality (not implemented messages etc. have to be equal)
function _check_add!!_behaviour(
acc_mutated::ChainRulesCore.NotImplemented,
acc::ChainRulesCore.NotImplemented;
Expand Down
62 changes: 62 additions & 0 deletions src/output_control.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Test.get_test_result generates code that uses the following so we must import them
using Test: Returned, Threw, eval_test
mzgubic marked this conversation as resolved.
Show resolved Hide resolved

"A cunning hack to carry extra message along with the original expression in a test"
struct ExprAndMsg
ex
msg
end

"""
@test_msg msg condion kws...

This is per `Test.@test condion kws...` except that if it fails it also prints the `msg`.
If `msg==""` then this is just like `@test`, nothing is printed

### Examles
```julia
julia> @test_msg "It is required that the total is under 10" sum(1:1000) < 10;
Test Failed at REPL[1]:1
Expression: sum(1:1000) < 10
Problem: It is required that the total is under 10
Evaluated: 500500 < 10
ERROR: There was an error during testing


julia> @test_msg "It is required that the total is under 10" error("not working at all");
Error During Test at REPL[2]:1
Test threw exception
Expression: error("not working at all")
Problem: It is required that the total is under 10
"not working at all"
Stacktrace:

julia> a = "";

julia> @test_msg a sum(1:1000) < 10;
Test Failed at REPL[153]:1
Expression: sum(1:1000) < 10
Evaluated: 500500 < 10
ERROR: There was an error during testing
```
"""
macro test_msg(msg, ex, kws...)
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
# This code is basically a evil hack that accesses the internals of the Test stdlib.
# Code below is based on the `@test` macro definition as it was in Julia 1.6.
# https://github.com/JuliaLang/julia/blob/v1.6.1/stdlib/Test/src/Test.jl#L371-L376
Test.test_expr!("@test_msg msg", ex, kws...)

result = Test.get_test_result(ex, __source__)
return :(Test.do_test($result, $ExprAndMsg($(string(ex)), $(esc(msg)))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is using undocumented functions a code smell of some sort? I think it indicates that either:

  1. Test package should be easier to extent/modify
    or
  2. We are not achieving our goal in the right way with the tools available

I am leaning towards 1) from other experience with Test. On the other hand, I don't fully understand this PR (namely the test_msg macro). What do get_test_result and do_test actually do?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK it is 1) I will add a comment. I don't really understand what they do either, just that this is the thing that is needed.
It is kinda copied directly from the Test stdlib.
This whole thing with ExprAndMsg is a huge hack.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment added

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, fair enough. LGTM just needs a version bump

end

function Base.print(io::IO, x::ExprAndMsg)
print(io, x.ex)
!isempty(x.msg) && print(io, "\n Problem: ", x.msg)
end


### helpers for printing in log messages etc
_string_typeof(x) = string(typeof(x))
_string_typeof(xs::Tuple) = join(_string_typeof.(xs), ",")
_string_typeof(x::PrimalAndTangent) = _string_typeof(primal(x)) # only show primal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_string_typeof(x::PrimalAndTangent) = _string_typeof(primal(x)) # only show primal
_string_typeof(x::PrimalAndTangent) = "$(_string_typeof(primal(x)))$(_string_typeof(tangent(x)))"

Are we ever interested in the tangent type? Probably not?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i considered that. Note that that would display things like Vector{Float64} ⊢ NoTangent() which is a bit of a play on words since the ⊢ operator doesn't actually accept types.

I think we can leave it as is (displaying just primals) for now, and consider adding that later.

15 changes: 7 additions & 8 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ function test_frule(
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

@testset "test_frule: $f on $(join(typeof.(inputs), ","))" begin
@testset "test_frule: $f on $(_string_typeof(inputs))" begin
_ensure_not_running_on_functor(f, "test_frule")

xẋs = auto_primal_and_tangent.(inputs)
Expand Down Expand Up @@ -167,7 +167,7 @@ function test_rrule(
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

@testset "test_rrule: $f on $(join(typeof.(inputs), ","))" begin
@testset "test_rrule: $f on $(_string_typeof(inputs))" begin
_ensure_not_running_on_functor(f, "test_rrule")

# Check correctness of evaluation.
Expand Down Expand Up @@ -226,12 +226,11 @@ function test_rrule(
end

function check_thunking_is_appropriate(x̄s)
@testset "Don't thunk only non_zero argument" begin
num_zeros = count(x -> x isa AbstractZero, x̄s)
num_thunks = count(x -> x isa Thunk, x̄s)
if num_zeros + num_thunks == length(x̄s)
@test num_thunks !== 1
end
num_zeros = count(x -> x isa AbstractZero, x̄s)
num_thunks = count(x -> x isa Thunk, x̄s)
if num_zeros + num_thunks == length(x̄s)
# num_thunks can be either 0, or greater than 1.
@test_msg "Should not thunk only non_zero argument" num_thunks != 1
end
end

Expand Down
4 changes: 2 additions & 2 deletions test/testers.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# For some reason if these aren't defined here, then they are interpreted as closures
# Defining test functions here as if they are defined where used it is too easy to
# mistakenly create closures over variables that only share names by coincidence.
futestkws(x; err=true) = err ? error("futestkws_err") : x

fbtestkws(x, y; err=true) = err ? error("fbtestkws_err") : x
Expand Down Expand Up @@ -268,7 +269,6 @@ end
return first(x), first_pullback
end

#CTuple{N} = Tangent{NTuple{N, Float64}} # shorter for testing
@testset "test_frule" begin
test_frule(first, (2.0, 3.0))
test_frule(first, Tuple(randn(4)))
Expand Down