Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RenyiDivergence incompatible with TrackedArray #147

Open
torfjelde opened this issue Aug 24, 2019 · 6 comments
Open

RenyiDivergence incompatible with TrackedArray #147

torfjelde opened this issue Aug 24, 2019 · 6 comments

Comments

@torfjelde
Copy link
Contributor

MVE

using Tracker, Distances
renyi_divergence(param([1.0]), param([1.0]), 0.5)

results in

MethodError: no method matching Float64(::Tracker.TrackedReal{Float64})
Closest candidates are:
  Float64(::Real, !Matched::RoundingMode) where T<:AbstractFloat at rounding.jl:185
  Float64(::T<:Number) where T<:Number at boot.jl:725
  Float64(!Matched::Int8) at float.jl:60
  ...

Stacktrace:
 [1] renyi_divergence(::TrackedArray{…,Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}, ::Float64) at /home/tor/.julia/dev/Distances/src/metrics.jl:382
 [2] top-level scope at In[35]:1

A possible "solution":

@inline Base.@propagate_inbounds function eval_start(::RenyiDivergence, a::AbstractArray{T_}, b::AbstractArray{T_}) where {T_ <: Real}
    T = eltype(a)  # returns `TrackedReal` if `a isa TrackedArray`
    zero(T), zero(T), T(sum(a)), T(sum(b))
end

Thoughts?

@KristofferC
Copy link
Member

This just seems like a bug in TrackedArray. Why isn't the parameter T a TrackedReal?

@torfjelde
Copy link
Contributor Author

torfjelde commented Aug 24, 2019

Disclaimer: my understanding of Tracker.jl is limited, to say the least.

But I think a bug is a bit strong, as I'm pretty sure it's intended. By tracking the entire array, you can use matrix-operations, etc. to pull back the gradient / "backprop" rather than perform the operations elementwise on a Array{TrackedReal}.

EDIT: I agree though; it's unfortunate :/

@KristofferC
Copy link
Member

I don't really understand your comment. The docs for AbstractArray says

  AbstractArray{T,N}

  Supertype for N-dimensional arrays (or array-like types) with elements of type T

That doesn't seem to be the case for TrackedArray.

@torfjelde
Copy link
Contributor Author

Yeah, you're 100% right:) But I don't think there's a way to implement reverse-mode AD to take advantage of what I'm referring to while still satisfying the AbstractArray definition.

I'm wondering if it's okay to use eltype in this function call rather than the parametric type for the conversion call to allow this "abuse" of the AbstracArray defintion in Tracker.jl. Since Base.eltype(a::AbstractArray{T}) = T anyways, there won't be any negative side effects; the only outcome is that it just works on a TrackedArray.

@torfjelde torfjelde changed the title RenyiDivergence incompatible with Tracker.jl RenyiDivergence incompatible with TrackedArray Aug 24, 2019
@nalimilan
Copy link
Member

I'm wondering if it's okay to use eltype in this function call rather than the parametric type for the conversion call to allow this "abuse" of the AbstracArray defintion in Tracker.jl. Since Base.eltype(a::AbstractArray{T}) = T anyways, there won't be any negative side effects; the only outcome is that it just works on a TrackedArray.

If we started doing that, a lot of code would have to be updated. Given that the guaranty that the AbstractArray parameter is equivalent to eltype is relied upon in many packages (and in Base), a proper fix will have to be found instead.

@torfjelde
Copy link
Contributor Author

Understandable 👍
I guess the only proper fix would be to dispatch differently on TrackedArray for the 3 different signature combinations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants