Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add high-level support for architecture-agnostic automatic differentiation #101

Merged
merged 4 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 48 additions & 6 deletions src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,17 @@ parallel_async(caller::Module, args::Union{Symbol,Expr}...; package::Symbol=get_

function parallel(caller::Module, args::Union{Symbol,Expr}...; package::Symbol=get_package(), async::Bool=false)
posargs, kwargs_expr, kernelarg = split_parallel_args(args)
kwargs, backend_kwargs_expr = extract_kwargs(caller, kwargs_expr, (:stream, :shmem, :launch, :configcall), "@parallel <kernelcall>", true; eval_args=(:launch,))
launch = haskey(kwargs, :launch) ? kwargs.launch : true
configcall = haskey(kwargs, :configcall) ? kwargs.configcall : kernelarg
if isgpu(package) parallel_call_gpu(posargs..., kernelarg, backend_kwargs_expr, async, package; kwargs...)
elseif (package == PKG_THREADS) parallel_call_threads(posargs..., kernelarg, async; launch=launch, configcall=configcall) # Ignore keyword args as they are not for the threads case (noted in doc).
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
kwargs, backend_kwargs_expr = extract_kwargs(caller, kwargs_expr, (:stream, :shmem, :launch, :configcall, :∇, :ad_mode, :ad_annotations), "@parallel <kernelcall>", true; eval_args=(:launch,))
is_ad_highlevel = haskey(kwargs, :∇)
launch = haskey(kwargs, :launch) ? kwargs.launch : true
configcall = haskey(kwargs, :configcall) ? kwargs.configcall : kernelarg
if is_ad_highlevel
parallel_call_ad(caller, kernelarg, backend_kwargs_expr, async, package, posargs, kwargs)
else
if isgpu(package) parallel_call_gpu(posargs..., kernelarg, backend_kwargs_expr, async, package; kwargs...)
elseif (package == PKG_THREADS) parallel_call_threads(posargs..., kernelarg, async; launch=launch, configcall=configcall) # Ignore keyword args as they are not for the threads case (noted in doc).
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
end
end

Expand Down Expand Up @@ -176,6 +181,43 @@ end

## @PARALLEL CALL FUNCTIONS

function parallel_call_ad(caller::Module, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool, package::Symbol, posargs, kwargs)
ad_mode = haskey(kwargs, :ad_mode) ? kwargs.ad_mode : AD_MODE_DEFAULT
ad_annotations_expr = haskey(kwargs, :ad_annotations) ? extract_tuple(kwargs.ad_annotations; nested=true) : []
ad_vars_expr = extract_tuple(kwargs.∇; nested=true)
~, ~, ad_vars = extract_kwargs(caller, ad_vars_expr, (), "", true; separator=:->)
~, ~, ad_annotations = extract_kwargs(caller, ad_annotations_expr, (), "", true)
ad_vars = map(x->unblock(x), ad_vars)
ad_annotations = map(x->extract_tuple(x), ad_annotations)
f_name = extract_kernelcall_name(kernelcall)
f_posargs, ~ = extract_kernelcall_args(kernelcall)
ad_annotations_byvar = Dict(a => [] for a in f_posargs)
for (keyword, vars) in zip(keys(ad_annotations), values(ad_annotations))
if (keyword ∉ keys(AD_SUPPORTED_ANNOTATIONS)) @KeywordArgumentError("annotation $keyword is not (yet) supported with high-level syntax; use the generic syntax calling directly `autodiff_deferred!`.") end
for var in vars
if (ad_annotations_byvar[var] != []) @KeywordArgumentError("variable $var has more than one annotation. Nested annotations are not (yet) supported with high-level syntax; use the generic syntax calling directly `autodiff_deferred!`.") end
push!(ad_annotations_byvar[var], AD_SUPPORTED_ANNOTATIONS[keyword])
end
end
for var in keys(ad_vars)
if ad_annotations_byvar[var] == []
push!(ad_annotations_byvar[var], AD_DUPLICATE_DEFAULT)
end
end
for var in f_posargs
if ad_annotations_byvar[var] == []
push!(ad_annotations_byvar[var], AD_ANNOTATION_DEFAULT)
end
end
annotated_args = (:($(ad_annotations_byvar[var][1])($((var ∈ keys(ad_vars) ? (var, ad_vars[var]) : (var,))...))) for var in f_posargs)
ad_call = :(autodiff_deferred!($ad_mode, $f_name, $(annotated_args...)))
kwargs_remaining = filter(x->!(x in (:∇, :ad_mode, :ad_annotations)), keys(kwargs))
kwargs_remaining_expr = [:($key=$val) for (key,val) in kwargs_remaining]
if (async) return :( @parallel $(posargs...) $(backend_kwargs_expr...) $(kwargs_remaining_expr...) configcall=$kernelcall $ad_call ) #TODO: the package needs to be passed further here later.
else return :( @parallel_async $(posargs...) $(backend_kwargs_expr...) $(kwargs_remaining_expr...) configcall=$kernelcall $ad_call ) #...
end
end

function parallel_call_gpu(ranges::Union{Symbol,Expr}, nblocks::Union{Symbol,Expr}, nthreads::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool, package::Symbol; stream::Union{Symbol,Expr}=default_stream(package), shmem::Union{Symbol,Expr,Nothing}=nothing, launch::Bool=true, configcall::Expr=kernelcall)
ranges = :(ParallelStencil.ParallelKernel.promote_ranges($ranges))
if (package == PKG_CUDA) int_type = INT_CUDA
Expand Down
29 changes: 18 additions & 11 deletions src/ParallelKernel/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ else
end
import Enzyme
using CellArrays, StaticArrays, MacroTools
import MacroTools: postwalk, splitdef, combinedef, isexpr # NOTE: inexpr_walk used instead of MacroTools.inexpr
import MacroTools: postwalk, splitdef, combinedef, isexpr, unblock # NOTE: inexpr_walk used instead of MacroTools.inexpr


## CONSTANTS AND TYPES (and the macros wrapping them)
Expand Down Expand Up @@ -58,6 +58,10 @@ const SUPPORTED_LITERALTYPES = [Float16, Float32, Float64, Complex{Fl
const SUPPORTED_NUMBERTYPES = [Float16, Float32, Float64, Complex{Float16}, Complex{Float32}, Complex{Float64}]
const PKNumber = Union{Float16, Float32, Float64, Complex{Float16}, Complex{Float32}, Complex{Float64}} # NOTE: this always needs to correspond to SUPPORTED_NUMBERTYPES!
const NUMBERTYPE_NONE = DataType
const AD_MODE_DEFAULT = :(Enzyme.Reverse)
const AD_DUPLICATE_DEFAULT = Enzyme.DuplicatedNoNeed
const AD_ANNOTATION_DEFAULT = Enzyme.Const
const AD_SUPPORTED_ANNOTATIONS = (Const=Enzyme.Const, Active=Enzyme.Active, Duplicated=Enzyme.Duplicated, DuplicatedNoNeed=Enzyme.DuplicatedNoNeed)
const ERRMSG_UNSUPPORTED_PACKAGE = "unsupported package for parallelization"
const ERRMSG_CHECK_PACKAGE = "package has to be functional and one of the following: $(join(SUPPORTED_PACKAGES,", "))"
const ERRMSG_CHECK_NUMBERTYPE = "numbertype has to be one of the following: $(join(SUPPORTED_NUMBERTYPES,", "))"
Expand Down Expand Up @@ -200,10 +204,11 @@ function extract_args(call::Expr, macroname::Symbol)
end

extract_kernelcall_args(call::Expr) = split_args(call.args[2:end]; in_kernelcall=true)
extract_kernelcall_name(call::Expr) = call.args[1]

function is_kwarg(arg; in_kernelcall=false)
function is_kwarg(arg; in_kernelcall=false, separator=:(=))
if in_kernelcall return ( isa(arg, Expr) && inexpr_walk(arg, :kw; match_only_head=true) )
else return ( isa(arg, Expr) && (arg.head == :(=)) && isa(arg.args[1], Symbol))
else return ( isa(arg, Expr) && (arg.head == separator) && isa(arg.args[1], Symbol))
end
end

Expand All @@ -220,8 +225,8 @@ function split_args(args; in_kernelcall=false)
return posargs, kwargs
end

function split_kwargs(kwargs)
if !all(is_kwarg.(kwargs)) @ModuleInternalError("not all of kwargs are keyword arguments.") end
function split_kwargs(kwargs; separator=:(=))
if !all(is_kwarg.(kwargs; separator=separator)) @ModuleInternalError("not all of kwargs are keyword arguments.") end
return Dict(x.args[1] => x.args[2] for x in kwargs)
end

Expand All @@ -241,16 +246,16 @@ function extract_kwargvalues(kwargs_expr, valid_kwargs, macroname)
return extract_values(kwargs, valid_kwargs)
end

function extract_kwargs(caller::Module, kwargs_expr, valid_kwargs, macroname, has_unknown_kwargs; eval_args=())
kwargs = split_kwargs(kwargs_expr)
function extract_kwargs(caller::Module, kwargs_expr, valid_kwargs, macroname, has_unknown_kwargs; eval_args=(), separator=:(=))
kwargs = split_kwargs(kwargs_expr, separator=separator)
if (!has_unknown_kwargs) validate_kwargkeys(kwargs, valid_kwargs, macroname) end
for k in keys(kwargs)
if (k in eval_args) kwargs[k] = eval_arg(caller, kwargs[k]) end
end
kwargs_known = NamedTuple(filter(x -> x.first ∈ valid_kwargs, kwargs))
kwargs_unknown = NamedTuple(filter(x -> x.first ∉ valid_kwargs, kwargs))
kwargs_unknown_expr = [:($k = $(kwargs_unknown[k])) for k in keys(kwargs_unknown)]
return kwargs_known, kwargs_unknown_expr
return kwargs_known, kwargs_unknown_expr, kwargs_unknown
end

function extract_kwargs(caller::Module, kwargs_expr, valid_kwargs, macroname; eval_args=())
Expand Down Expand Up @@ -314,9 +319,11 @@ inexpr_walk(expr, s::Symbol; match_only_head=false) = false

Base.unquoted(s::Symbol) = s

function extract_tuple(t::Union{Expr,Symbol}) # NOTE: this could return a tuple, but would require to change all small arrays to tuples...
if isa(t, Expr)
return Base.unquoted.(t.args)
function extract_tuple(t::Union{Expr,Symbol}; nested=false) # NOTE: this could return a tuple, but would require to change all small arrays to tuples...
if isa(t, Expr) && t.head == :tuple
if (nested) return t.args
else return Base.unquoted.(t.args)
end
else
return [t]
end
Expand Down
7 changes: 5 additions & 2 deletions src/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,13 @@ function parallel(source::LineNumberNode, caller::Module, args::Union{Symbol,Exp
elseif is_call(args[end])
posargs, kwargs_expr, kernelarg = split_parallel_args(args)
kwargs, backend_kwargs_expr = extract_kwargs(caller, kwargs_expr, (:memopt, :configcall), "@parallel <kernelcall>", true; eval_args=(:memopt,))
memopt = haskey(kwargs, :memopt) ? kwargs.memopt : get_memopt()
memopt = haskey(kwargs, :memopt) ? kwargs.memopt : get_memopt()
configcall = haskey(kwargs, :configcall) ? kwargs.configcall : kernelarg
configcall_kwarg_expr = :(configcall=$configcall)
if memopt
is_ad_highlevel = haskey(kwargs, :∇)
if is_ad_highlevel
ParallelKernel.parallel_call_ad(caller, kernelarg, backend_kwargs_expr, async, package, posargs, kwargs)
elseif memopt
if (length(posargs) > 1) @ArgumentError("maximum one positional argument (ranges) is allowed in a @parallel memopt=true call.") end
parallel_call_memopt(caller, posargs..., kernelarg, backend_kwargs_expr, async; kwargs...)
else
Expand Down
7 changes: 7 additions & 0 deletions test/ParallelKernel/test_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ end
@test @prettystring(1, @parallel stream=mystream f(A)) == "f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))"
end;
end;
@testset "@parallel ∇" begin
@test @prettystring(1, @parallel ∇=B->B̄ f!(A, B, a)) == "@parallel_async configcall = f!(A, B, a) autodiff_deferred!(Enzyme.Reverse, f!, (EnzymeCore.Const)(A), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel ∇=(A->Ā, B->B̄) f!(A, B, a)) == "@parallel_async configcall = f!(A, B, a) autodiff_deferred!(Enzyme.Reverse, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel ∇=(A->Ā, B->B̄) ad_mode=Enzyme.Forward f!(A, B, a)) == "@parallel_async configcall = f!(A, B, a) autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel ∇=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=B) f!(A, B, a)) == "@parallel_async configcall = f!(A, B, a) autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.Duplicated)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel ∇=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=(B,A), Active=b) f!(A, B, a, b)) == "@parallel_async configcall = f!(A, B, a, b) autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.Duplicated)(A, Ā), (EnzymeCore.Duplicated)(B, B̄), (EnzymeCore.Const)(a), (EnzymeCore.Active)(b))"
end;
@testset "@parallel_indices" begin
@testset "addition of range arguments" begin
expansion = @gorgeousstring(1, @parallel_indices (ix,iy) f(a::T, b::T) where T <: Union{Array{Float32}, Array{Float64}} = (println("a=$a, b=$b)"); return))
Expand Down
Loading