Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_rrule trying to find Tangent on axes #268

Closed
theogf opened this issue Jan 31, 2023 · 1 comment
Closed

test_rrule trying to find Tangent on axes #268

theogf opened this issue Jan 31, 2023 · 1 comment

Comments

@theogf
Copy link
Contributor

theogf commented Jan 31, 2023

This is more a question than a bug report.

I defined the following rrule for Fill with _map (Zygote does not allow playing with map :) )

_map(f, args...) = map(f, args...)
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f::Tf, x::F) where {Tf,F<:Fill}
    y_el, back = ChainRulesCore.rrule_via_ad(config, f, x.value)
    function _map_Fill_rrule(Δ)
        Δf, Δx_el = back.value)
        return NoTangent(), Δf, Tangent{F}(value = Δx_el, axes = NoTangent())
    end
    return Fill(y_el, axes(x)), _map_Fill_rrule
end

The result seems correct but I cannot call test_rrule on it:

test_rrule(_map, sum, Fill(randn(3, 4), 4))

The error narrows down to the jacobian function from FiniteDifferences trying to differentiate through the axes field of Fill.

I tried to pass a Tangent to Fill via ⊢ Tangent{typeof(x)}(value=randn(3, 4), axes=NoTangent()) but without success...

Could you help me figure out what I need to do?

@theogf theogf changed the title test_rrule trying to find tangent on axes test_rrule trying to find Tangent on axes Jan 31, 2023
@theogf
Copy link
Contributor Author

theogf commented Jan 31, 2023

Nevermind dispatching on to_vec solves the issue as proposed in #258

@theogf theogf closed this as completed Jan 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant