-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rule for cumprod
#420
Add rule for cumprod
#420
Conversation
Timing these today, there isn't in fact a clear winner, it depends on the size, and how many zeros are encountered. This approach uses less memory. Some numbers (on two computers): julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1];
6.625 μs (10 allocations: 16.01 KiB) # m1 mac + rosetta, Julia 1.6
20.720 μs (21 allocations: 16.49 KiB) # xeon
julia> x=rand(10,100); x[1:21:end].=0; # half the columns have a zero
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $x)[1];
6.883 μs (8 allocations: 15.98 KiB) # m1
18.045 μs (21 allocations: 16.49 KiB) # xeon
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(3,100)))[1]; # fewer rows
1.625 μs (8 allocations: 5.10 KiB) # m1
7.330 μs (21 allocations: 5.62 KiB) # xeon
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(30,100)))[1]; # more rows
36.750 μs (10 allocations: 47.13 KiB) # m1
138.382 μs (23 allocations: 47.65 KiB) # xeon Compare to FluxML/Zygote.jl#294 (which could probably be optimised a bit)
The last case, Edit, after 70334fc, which adds a simpler path for the case of no zeros (like the Zygote PR): julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1];
4.863 μs (6 allocations: 23.86 KiB) # m1
15.388 μs (19 allocations: 24.38 KiB) # xeon
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(3,100)))[1]; # fewer rows
2.000 μs (6 allocations: 7.55 KiB)
8.765 μs (19 allocations: 8.06 KiB)
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(30,100)))[1]; # more rows
12.125 μs (9 allocations: 70.59 KiB)
37.687 μs (22 allocations: 71.11 KiB) Edit', after b75f94f, the original path is now faster than the fast path was: julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1];
3.036 μs (8 allocations: 15.98 KiB)
11.527 μs (21 allocations: 16.49 KiB)
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $x)[1]; # half with a zero
2.866 μs (8 allocations: 15.98 KiB)
10.990 μs (21 allocations: 16.49 KiB)
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(3,100)))[1]; # fewer rows
1.471 μs (8 allocations: 5.10 KiB)
6.963 μs (21 allocations: 5.62 KiB)
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(30,100)))[1]; # more rows
7.903 μs (10 allocations: 47.13 KiB)
25.996 μs (23 allocations: 47.65 KiB) |
Codecov Report
@@ Coverage Diff @@
## master #420 +/- ##
==========================================
- Coverage 98.51% 98.46% -0.06%
==========================================
Files 21 21
Lines 2094 2148 +54
==========================================
+ Hits 2063 2115 +52
- Misses 31 33 +2
Continue to review full report at Codecov.
|
|
||
@testset "types" begin | ||
back = unthunk(rrule(cumprod, [1, 2, 3])[2]) # allow integer input | ||
@test unthunk(back(fill(0.5, 3))[2]) == [9/2, 2, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we testing values here?
It would be good to add comments explaining what we are particularly checking for that test_rrule
will not catch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_rrule
doesn't seem to accept integer input, this tests that the rule still does.
julia> test_rrule(cumprod, [1,2,3])
test_rrule: cumprod on Vector{Int64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/6oOem/src/testers.jl:227
Got exception outside of a @test
InexactError: Int64(0.99)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comment to that effect?
Yeah can't test methods that require integers with finite differencing.
Since once you apply a finite difference you get a Float64 instead.
Which means you don't hit the method for integers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment now says "# rule allows integer input, but test_rrule does not"
Good to go? No rush except that something may change under it... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry i lost track of this one.
LGTM
Closes #254.
The approach here much like that in #335 for
prod
, and quite different to FluxML/Zygote.jl#294 .