Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unthunk getindex and iterate on Composite objects #237

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

gxyd
Copy link
Contributor

@gxyd gxyd commented Oct 17, 2020

Fixes #233

@gxyd
Copy link
Contributor Author

gxyd commented Oct 18, 2020

Currently something like:

julia> r = Composite{Tuple{Float64,}}(a=(@thunk 2.0^2), b=(@thunk 2.0^3))
julia> collect(r)
MethodError: Cannot `convert` an object of type 
  Float64 to an object of type 
  Thunk
Closest candidates are:
  convert(::Type{T}, !Matched::T) where T at essentials.jl:171
  Thunk(::F) where F at /Users/gaurav/.julia/dev/ChainRulesCore/src/differentials/thunks.jl:95

Stacktrace:
 [1] setindex!(::Array{Thunk,1}, ::Float64, ::Int64) at ./array.jl:826
 [2] copyto!(::Array{Thunk,1}, ::Composite{Tuple{Float64},NamedTuple{(:a, :b),Tuple{Thunk{var"#3#5"},Thunk{var"#4#6"}}}}) at ./abstractarray.jl:724
 [3] _collect(::UnitRange{Int64}, ::Composite{Tuple{Float64},NamedTuple{(:a, :b),Tuple{Thunk{var"#3#5"},Thunk{var"#4#6"}}}}, ::Base.HasEltype, ::Base.HasLength) at ./array.jl:609
 [4] collect(::Composite{Tuple{Float64},NamedTuple{(:a, :b),Tuple{Thunk{var"#3#5"},Thunk{var"#4#6"}}}}) at ./array.jl:603
 [5] top-level scope at In[12]:1

raises error. Would that need re-defining Base.collect as well?

@gxyd
Copy link
Contributor Author

gxyd commented Oct 18, 2020

Though as expected collect(Float64, r) works just fine:

julia> collect(Float64, r)
2-element Array{Float64,1}:
 4.0
 8.0

I'll add the test case for it.

return nothing
end
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or I can replace the entire code block with:

function Base.iterate(comp::Composite, args...)
    out = iterate(backing(comp), args...)
    if out isa Nothing
        return out
    else
        return (unthunk(out[1]), out[2])
    end
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I look at the benchmarks, the current code in the pull request gives me:

julia> @btime iterate(r, 2)
  47.447 ns (1 allocation: 32 bytes)

while the suggested code gives me:

julia> @btime iterate(r, 2)
  224.165 ns (2 allocations: 48 bytes)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my instinct is the suggested code is better. I will take a look into why it is slower

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Poking at this, it looks like something funky is happening where it is deciding not to constant fold the state in the iterate.
I get on julia-master no allocations and instant on both.

but when timing: [x for x in $r] the getindex code has 1 allocation and the iterate backing code has 2 allocations.

Weird.
But i guess this is fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I've used the iterate approach then.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the getindex appraoch is fine

@gxyd
Copy link
Contributor Author

gxyd commented Oct 20, 2020

Hi @nickrobinson251 can you please review this :)

@oxinabox
Copy link
Member

julia> r = Composite{Tuple{Float64,}}(a=(@thunk 2.0^2), b=(@thunk 2.0^3))
julia> collect(r)
MethodError: Cannot `convert` an object of type 
  Float64 to an object of type 
  Thunk

I think this is because we are now defining the eltype wrong.
Rather than special casing collect, we might need to special case HasEltype so that if the composite backing type has thunks, it says that it is EltypeUnknown.
This is more involved than i thought.

We should probably also have tests of something that uses iterate more directly than collect.
Maybe a little function for a for loop?

@@ -65,6 +67,8 @@ end
# Testing iterate via collect
@test collect(Composite{Foo}(x=2.5)) == [2.5]
@test collect(Composite{Tuple{Float64,}}(2.0)) == [2.0]
@test collect(Float64, Composite{Tuple{Float64,}}(@thunk 2.0^2)) == [4.0]
@test collect(Float64, Composite{Tuple{Float64,}}(a=(@thunk 2.0^2),)) == [4.0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have tests for dictionary backed composites.
That is probably ok to continue to not have them here
for now.

THough I suspect iterating them will be weird, since they will iterate pairs and only the value, not the whole pair will need to be unthunked.
So perhaps it is worth making sure we do them right in this PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THough I suspect iterating them will be weird, since they will iterate pairs and only the value, not the whole pair will need to be unthunked.

Agreed.

We don't have tests for dictionary backed composites.
That is probably ok to continue to not have them here
for now.

I rather think that having the implementation of Composite{<:Dict, <:Dict} would be better here, since I think that approach of iterate might work better in general for Dict and Tuple's, as the next state for a Composite of Tuple's is state + 1 but may not so for Dict.

WDYT, shall I try that in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually get element, state = iterate(collection) so, we would get the state for free by using the iterate approach, but not from the getindex approach.

@gxyd
Copy link
Contributor Author

gxyd commented Oct 25, 2020

What about something like:

julia> a = Composite{Tuple{Thunk}}((@thunk 2.0^3), (@thunk 2.0^2))

Is it supposed to give unthunked values on iteration?

@oxinabox
Copy link
Member

What about something like:

julia> a = Composite{Tuple{Thunk}}((@thunk 2.0^3), (@thunk 2.0^2))

Is it supposed to give unthunked values on iteration?

Yes. (Assuming you mean Tuple{Thunk, Thunk} since two arguments)
Though Tuple{Thunk, Thunk} is a super weird primal type.
I'm note sure if it should even occur during second order differentiation.

But since a Thunk at the end of the day actually represents the value it returns on being unthunked, and should be operated on as such, it is always correct and desirable to unthunk it before manipulating it

@gxyd
Copy link
Contributor Author

gxyd commented Oct 25, 2020

Rather than special casing collect, we might need to special case HasEltype so that if the composite backing type has thunks, it says that it is EltypeUnknown

Did you mean to say, special casing Base.IteratorEltype? Even Base.HasEltype don't accept any arguments to it.

@oxinabox
Copy link
Member

Did you mean to say, special casing Base.IteratorEltype?

Yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

getindex and iterate on Composite should unthunk
2 participants