Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test all kinds of tangent types (Thunk, ZeroTangent, Tangent{T} etc) #159

Closed
oxinabox opened this issue Jun 1, 2021 · 9 comments
Closed
Milestone

Comments

@oxinabox
Copy link
Member

oxinabox commented Jun 1, 2021

This is a generalization of #98 and is a key part of JuliaDiff/ChainRules.jl#408

Lets say we have a tanget dx
we should also test that in it's place we can put @thunk(x̄), as well as ZeroTangent.

Further, following JuliaDiff/ChainRulesCore.jl#286
if dx <: AbstractArray (other than Array) we should test canonicalize(Composite{typeof(x)}, dx)
and conversely
if dx <: Composite{P} where P<:AbstractArray we should test canonicalize(P, dx)

@mzgubic
Copy link
Member

mzgubic commented Jun 8, 2021

Just to confirm: for test_rrule we just want to change the output_tangent, and not the cotangents of the args, right?
And conversely, for test_frule, we just want to change the tangents of the args, and not the tangent of the output?

One other point is that requiring new output_tangent to pass the tests is a breaking change. I was initially planning to do this in two parts, i.e. adding ZeroTangents and thunks first, and adding canonicalize once we figure out the arrays.

Do we want to do it in one go instead? I guess that means we solve arrays first?

@mzgubic
Copy link
Member

mzgubic commented Jun 8, 2021

Also, the current implementation I have in mind is something like

function test_rrule(f, args; output_tangent=Auto(), test_other_tangents=true; other_kwargs...)
    y = func(xs...; fkwargs...)
    ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent
    if test_other_tangents
        for ybar in _all_possible_tangents(ȳ)
            test_rrule(f, args; output_tangent=ybar, test_other_tangents=false, other_kwargs...)
        end
    else
        do_the_current_thing()
    end
end
    

It's kind of ugly, do we have other ideas?

@oxinabox
Copy link
Member Author

oxinabox commented Jun 8, 2021

Just to confirm: for test_rrule we just want to change the output_tangent, and not the cotangents of the args, right?
And conversely, for test_frule, we just want to change the tangents of the args, and not the tangent of the output?

This is correct

One other point is that requiring new output_tangent to pass the tests is a breaking change. I was initially planning to do this in two parts, i.e. adding ZeroTangents and thunks first, and adding canonicalize once we figure out the arrays.

Two parts sounds better.
Yes it means multiple rounds of breaking changes, but it is ChainRulesTestUtils so it is only a test dependency.

Potentially we can do something like what we have like we do for inference and have it's default controlled by a global?
That way we can start to roll it out for ChainRules.jl then once we have it all working and are happy with it we can change the default and tag a breaking release.

Possibly rather than a bool we want a typed object, for extensibility.
Probable fine in kwarg position, where it can then be used in dispatch for a function that is called that makes it positional

@mzgubic
Copy link
Member

mzgubic commented Jun 8, 2021

Thanks for reviewing

Potentially we can do something like what we have like we do for inference and have it's default controlled by a global?

Could you elaborate please?

@oxinabox
Copy link
Member Author

oxinabox commented Jun 8, 2021

For example in ChainRulesTestUtils we have:

const DEFAULT_ALT_TANGENTS_TO_TEST = Ref{Type}(Union{})

function test_rrule(f, args; output_tangent=Auto(), test_other_tangents=DEFAULT_ALT_TANGENTS_TO_TEST; other_kwargs...)
     ...
end

then In ChainRules.jl's tests we write:

ChainRulesTestUtils.DEFAULT_ALT_TANGENTS_TO_TEST[] = AbstractZeroTangent

then we make thjose tests pass and we add thunks

ChainRulesTestUtils.DEFAULT_ALT_TANGENTS_TO_TEST[] = Union{AbstractZeroTangent, AbstractThunk}

But the whole time, packages that haven't opted into the new default behavour by changing this global are unchanged.

This is normally a pretty evil antipattern of globals, since it doesn't compose at all, you can't have your dependencies setting it to different values.
but I think for a test only dependency it is fine. (I may eat these words)
Noone uses ChainRulesTestUtils as a runtime dependency.

@mzgubic
Copy link
Member

mzgubic commented Jun 10, 2021

This is somewhat more involved than I hoped. The easiest thing to do would be to:

const DEFAULT_ALT_TANGENTS = Ref{Any}([])

function test_rrule(f, args; output_tangent=Auto(), other_tangents=DEFAULT_ALT_TANGENTS; other_kwargs...)
    for ybar in other_tangents
        test_rrule(f, args; output_tangent=ybar, test_other_tangents=[], other_kwargs...)
    end
    do_the_current_thing()
end

But that would prevent us from doing canonicalize on the output_tangent later. I.e. if the output_tangent is a BlockDiagonal, we want to test the Tangent{BlockDiagonal} tangent as well.

To do that, we need something like passing functions:

const DEFAULT_ALT_TANGENTS = Ref{Function}([x -> canonicalize(x), x -> @thunk(x), x -> ZeroTangent()])

function test_rrule(f, args; output_tangent=Auto(), new_tangents=DEFAULT_ALT_TANGENTS; other_kwargs...)
    y = func(xs...; fkwargs...)
    ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent
    for new_tangent in new_tangents
        test_rrule(f, args; output_tangent=new_tangent(ȳ), new_tangents=[], other_kwargs...)
    end
    do_the_current_thing()
end

Also it looks like we will need project before we can implement ZeroTangent() as an other differential (for the purpose of finite differencing).

@oxinabox
Copy link
Member Author

oxinabox commented Jun 10, 2021

I am not sure how much we need to actually finite difference test all these.
Or if some other tests would be fine.
Like checking if all inputs are ZeroTangent we get something that iszero
Potentiall even just checking things don't error, might be a useful first step.

@mzgubic mzgubic mentioned this issue Jun 10, 2021
2 tasks
@mzgubic
Copy link
Member

mzgubic commented Jun 10, 2021

Yeah, that sounds good. I've added a PR that does FD because I already had it. But I worry it will take 4 times as long to test, which would be 2hrs for ChainRules

We still need to pass the functions though, because of canonicalize

@oxinabox
Copy link
Member Author

We are now happy enough with this, just testing Thunk.
Since pullback only has 1 input, we don't actually need that to be handled by the pullback -- it can be handled by the AD.
If pulling back a Tangent, I we could be doing transforms of that in interesting ways, but it seels a lot for little gain.

For trying structural vs natural differentials, we now have a fairly strong push towards never do structural if there is a good natural.
So testing both seems like it is not needful at this time. I suspect we we will not be seeing structural differentials for things with natural differentials very often.

We might want more for forwards mode later but that will be nonbreaking, since we will just be asserting things that had to be true in well-behaved code.

So I am going to call this done.
We can open new issues in the future

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants