Skip to content

Commit

Permalink
Test 2-arg versions
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen authored and vchuravy committed Apr 1, 2024
1 parent f075240 commit 87b217f
Showing 1 changed file with 48 additions and 2 deletions.
50 changes: 48 additions & 2 deletions test/rules/LinearAlgebra/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ using Test
elseif Tret <: DuplicatedNoNeed
@test only(ret) dexp
end

if pfun === identity && sz == n && inc == 1
@testset "consistency of 2-arg version" begin
ret2 = autodiff(Forward, fun, Tret, x_annot, y_annot)
@test ret2 == ret
end
end
end

@testset for Tret in (BatchDuplicated, BatchDuplicatedNoNeed),
Expand Down Expand Up @@ -89,6 +96,12 @@ using Test
y[1] = 0
return d
end
function fun_overwrite!(x, y)
d = fun(x, y)
x[1] = 0
y[1] = 0
return d
end

@testset for Tret in (Const, Active),
Tx in (Const, Duplicated),
Expand All @@ -106,9 +119,9 @@ using Test
xcopy, ycopy, ∂xcopy, ∂ycopy = map(copy, (x, y, ∂x, ∂y))

x_annot =
Tx <: Const ? Const(pfun(x)) : Duplicated(pfun(xcopy), pfun(∂xcopy))
Tx <: Const ? Const(pfun(xcopy)) : Duplicated(pfun(xcopy), pfun(∂xcopy))
y_annot =
Ty <: Const ? Const(pfun(y)) : Duplicated(pfun(ycopy), pfun(∂ycopy))
Ty <: Const ? Const(pfun(ycopy)) : Duplicated(pfun(ycopy), pfun(∂ycopy))
activities = (Const(n), x_annot, Const(inc), y_annot, Const(inc))

vexp = fun(n, x, inc, y, inc)
Expand Down Expand Up @@ -138,6 +151,39 @@ using Test
@test ∂ycopy
dexp[2] .* !(Ty <: Const || Tret <: Const) .+
∂y .* ((Ty <: Const) .| (y .== ycopy))

if pfun === identity && sz == n && inc == 1
@testset "consistency of 2-arg version" begin
xcopy2, ycopy2, ∂xcopy2, ∂ycopy2 = map(copy, (x, y, ∂x, ∂y))
x_annot = if Tx <: Const
Const(pfun(xcopy2))
else
Duplicated(pfun(xcopy2), pfun(∂xcopy2))
end
y_annot = if Ty <: Const
Const(pfun(ycopy2))
else
Duplicated(pfun(ycopy2), pfun(∂ycopy2))
end
activities = (x_annot, y_annot)
fwd, rev = autodiff_thunk(
ReverseSplitWithPrimal,
Const{typeof(f)},
Tret,
map(typeof, activities)...,
)
tape, val2, shadow_val = fwd(Const(f), activities...)
if Tret <: Const
dval2, = rev(Const(f), activities..., tape)
else
dval2, = rev(Const(f), activities..., dret, tape)
end
@test all(isnothing, dval2)
@test val2 == val
@test ∂xcopy2 == ∂xcopy
@test ∂ycopy2 == ∂ycopy
end
end
end
end
end
Expand Down

0 comments on commit 87b217f

Please sign in to comment.