Skip to content

Commit

Permalink
adjust to the upstream irinterp refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Apr 5, 2023
1 parent 411a5f9 commit af1b39a
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 51 deletions.
18 changes: 12 additions & 6 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ Internal method which generates the code for forward mode diffentiation
- `ir` the IR being differnetation
- `to_diff`: collection of all SSA values for which the derivative is to be taken,
- `to_diff`: collection of all SSA values for which the derivative is to be taken,
paired with the order (first deriviative, second derivative etc)
- `visit_custom!(ir, stmt, order::Int, recurse::Bool)`:
- `visit_custom!(ir, stmt, order::Int, recurse::Bool)`:
decides if the custom `transform!` should be applied to a `stmt` or not
Default: `false` for all statements
- `transform!(ir, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
Expand Down Expand Up @@ -289,10 +289,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue, Int}};
end


function forward_diff!(ir::IRCode, interp, mi::MethodInstance, world, to_diff::Vector{Pair{SSAValue, Int}}; kwargs...)
function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::MethodInstance,
to_diff::Vector{Pair{SSAValue, Int}}; kwargs...)
forward_diff_no_inf!(ir, to_diff; kwargs...)

# Step 3: Re-inference

ir = compact!(ir)

extra_reprocess = CC.BitSet()
Expand All @@ -302,9 +304,13 @@ function forward_diff!(ir::IRCode, interp, mi::MethodInstance, world, to_diff::V
end
end

interp′ = enable_reinference(interp)
irsv = IRInterpretationState(interp′, ir, mi, world, ir.argtypes[1:mi.def.nargs])
rt = CC._ir_abstract_constant_propagation(interp′, irsv; extra_reprocess)
method_info = CC.MethodInfo(src)
argtypes = ir.argtypes[1:mi.def.nargs]
world = CC.get_world_counter(interp)
irsv = IRInterpretationState(interp, method_info, ir, mi, argtypes, world, src.min_world, src.max_world)
rt = CC._ir_abstract_constant_propagation(enable_reinference(interp), irsv; extra_reprocess)

ir = compact!(ir)

return ir
end
9 changes: 4 additions & 5 deletions src/stage2/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
interp = ADInterpreter(; forward=true, backward=false)
match = Base._which(tt)
frame = Core.Compiler.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true)
mi = frame.linfo

ir = copy((interp.opt[0][frame.linfo].inferred).ir::IRCode)
src = CC.copy(interp.unopt[0][mi].src)
ir = CC.copy((@atomic :monotonic interp.opt[0][mi].inferred).ir::IRCode)

# Find all Return Nodes
vals = Pair{SSAValue, Int}[]
Expand Down Expand Up @@ -43,10 +45,7 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
return insert_node!(ir, SSAValue(1), NewInstruction(Expr(:call, ∂xⁿ{order}(), arg), typeof(∂xⁿ{order}()(1.0))))
end

ir = forward_diff!(interp, ir, src, mi, vals; visit_custom!, transform!)

irsv = CC.IRInterpretationState(interp, ir, frame.linfo, CC.get_world_counter(interp), ir.argtypes[1:frame.linfo.def.nargs])
ir = forward_diff!(ir, interp, frame.linfo, CC.get_world_counter(interp), vals; visit_custom!, transform!)

ir = compact!(ir)
return OpaqueClosure(ir)
end
79 changes: 39 additions & 40 deletions src/stage2/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ function Compiler3.get_codeinstance(graph::ADGraph, cursor::ADCursor)
end
=#

using Core.Compiler: AbstractInterpreter, NativeInterpreter, InferenceState,
InferenceResult, CodeInstance, WorldRange, ArgInfo, StmtInfo
using Core: MethodInstance, CodeInstance
using .CC: AbstractInterpreter, ArgInfo, Effects, InferenceResult, InferenceState,
IRInterpretationState, NativeInterpreter, OptimizationState, StmtInfo, WorldRange

const OptCache = Dict{MethodInstance, CodeInstance}
const UnoptCache = Dict{Union{MethodInstance, InferenceResult}, Cthulhu.InferredSource}
Expand Down Expand Up @@ -120,7 +121,7 @@ function Cthulhu.lookup(interp::ADInterpreter, curs::ADCursor, optimize::Bool; a
opt = codeinst.inferred
if opt !== nothing
opt = opt::Cthulhu.OptimizedSource
src = Core.Compiler.copy(opt.ir)
src = CC.copy(opt.ir)
codeinf = opt.src
infos = src.stmts.info
slottypes = src.argtypes
Expand Down Expand Up @@ -162,7 +163,6 @@ function Cthulhu.custom_toggles(interp::ADInterpreter)
end

# TODO: Something is going very wrong here
using Core.Compiler: Effects, OptimizationState
function Cthulhu.get_effects(interp::ADInterpreter, mi::MethodInstance, opt::Bool)
if haskey(interp.unopt[0], mi)
return interp.unopt[0][mi].effects
Expand All @@ -171,7 +171,7 @@ function Cthulhu.get_effects(interp::ADInterpreter, mi::MethodInstance, opt::Boo
end
end

function Core.Compiler.is_same_frame(interp::ADInterpreter, linfo::MethodInstance, frame::InferenceState)
function CC.is_same_frame(interp::ADInterpreter, linfo::MethodInstance, frame::InferenceState)
linfo === frame.linfo || return false
return interp.current_level === frame.interp.current_level
end
Expand Down Expand Up @@ -224,7 +224,7 @@ function Cthulhu.navigate(curs::ADCursor, callsite::Cthulhu.Callsite)
return ADCursor(curs.level, Cthulhu.get_mi(callsite))
end

function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::Core.Compiler.CallInfo), argtypes::Cthulhu.ArgTypes, @nospecialize(rt), optimize::Bool)
function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::CC.CallInfo), argtypes::Cthulhu.ArgTypes, @nospecialize(rt), optimize::Bool)
if isa(info, RecurseInfo)
newargtypes = argtypes[2:end]
callinfos = Cthulhu.process_info(interp, info.info, newargtypes, Cthulhu.unwrapType(widenconst(rt)), optimize)
Expand Down Expand Up @@ -252,33 +252,33 @@ function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::Core.Co
elseif isa(info, CompClosInfo)
return Any[CompClosCallInfo(rt)]
end
return invoke(Cthulhu.process_info, Tuple{AbstractInterpreter, Core.Compiler.CallInfo, Cthulhu.ArgTypes, Any, Bool},
return invoke(Cthulhu.process_info, Tuple{AbstractInterpreter, CC.CallInfo, Cthulhu.ArgTypes, Any, Bool},
interp, info, argtypes, rt, optimize)
end

Core.Compiler.InferenceParams(ei::ADInterpreter) = InferenceParams(ei.native_interpreter)
Core.Compiler.OptimizationParams(ei::ADInterpreter) = OptimizationParams(ei.native_interpreter)
Core.Compiler.get_world_counter(ei::ADInterpreter) = get_world_counter(ei.native_interpreter)
Core.Compiler.get_inference_cache(ei::ADInterpreter) = get_inference_cache(ei.native_interpreter)
CC.InferenceParams(ei::ADInterpreter) = InferenceParams(ei.native_interpreter)
CC.OptimizationParams(ei::ADInterpreter) = OptimizationParams(ei.native_interpreter)
CC.get_world_counter(ei::ADInterpreter) = get_world_counter(ei.native_interpreter)
CC.get_inference_cache(ei::ADInterpreter) = get_inference_cache(ei.native_interpreter)

# No need to do any locking since we're not putting our results into the runtime cache
Core.Compiler.lock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing
Core.Compiler.unlock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing
CC.lock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing
CC.unlock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing

struct CodeInfoView
d::Dict{MethodInstance, Any}
end

function Core.Compiler.code_cache(ei::ADInterpreter)
function CC.code_cache(ei::ADInterpreter)
while ei.current_level > lastindex(ei.opt)
push!(ei.opt, Dict{MethodInstance, Any}())
end
ei.opt[ei.current_level]
end
Core.Compiler.may_optimize(ei::ADInterpreter) = true
Core.Compiler.may_compress(ei::ADInterpreter) = false
Core.Compiler.may_discard_trees(ei::ADInterpreter) = false
function Core.Compiler.get(view::CodeInfoView, mi::MethodInstance, default)
CC.may_optimize(ei::ADInterpreter) = true
CC.may_compress(ei::ADInterpreter) = false
CC.may_discard_trees(ei::ADInterpreter) = false
function CC.get(view::CodeInfoView, mi::MethodInstance, default)
r = get(view.d, mi, nothing)
if r === nothing
return default
Expand All @@ -298,23 +298,23 @@ end
Cthulhu.get_remarks(interp::ADInterpreter, key::Union{MethodInstance,InferenceResult}) = get(interp.remarks[interp.current_level], key, nothing)

#=
function Core.Compiler.const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
function CC.const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
return true
end
=#

function Core.Compiler.finish(state::InferenceState, interp::ADInterpreter)
res = @invoke Core.Compiler.finish(state::InferenceState, interp::AbstractInterpreter)
key = Core.Compiler.any(state.result.overridden_by_const) ? state.result : state.linfo
function CC.finish(state::InferenceState, interp::ADInterpreter)
res = @invoke CC.finish(state::InferenceState, interp::AbstractInterpreter)
key = CC.any(state.result.overridden_by_const) ? state.result : state.linfo
interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(
copy(state.src),
copy(state.stmt_info),
isdefined(Core.Compiler, :Effects) ? state.ipo_effects : nothing,
state.ipo_effects,
state.result.result)
return res
end

function Core.Compiler.transform_result_for_cache(interp::ADInterpreter,
function CC.transform_result_for_cache(interp::ADInterpreter,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects)
end
Expand All @@ -325,28 +325,27 @@ function CC.inlining_policy(interp::ADInterpreter,
if isa(info, FRuleCallInfo)
return nothing
end
if isdefined(CC, :SemiConcreteResult) && isa(src, CC.SemiConcreteResult)
if isa(src, CC.SemiConcreteResult)
return src
end
@assert isa(src, Cthulhu.OptimizedSource) || isnothing(src)
if isa(src, Cthulhu.OptimizedSource)
if CC.is_stmt_inline(stmt_flag) || src.isinlineable
return src.ir
end
else
# the default inlining policy may try additional effor to find the source in a local cache
return @invoke CC.inlining_policy(interp::AbstractInterpreter,
nothing, info::CC.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any})
return nothing
end
return nothing
# the default inlining policy may try additional effor to find the source in a local cache
return @invoke CC.inlining_policy(interp::AbstractInterpreter,
nothing, info::CC.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any})
end

function dummy() end
const dummym = first(methods(dummy))

function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype),
sv::IRCode, max_methods::Int)
sv::IRInterpretationState, max_methods::Int)

if interp.reinference
# Create a dummy inference state to serve as the root
Expand All @@ -359,41 +358,41 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
return r
end

return CallMeta(Any, CC.Effects(), CC.NoCallInfo())
return CallMeta(Any, Effects(), CC.NoCallInfo())
end

#=
function Core.Compiler.optimize(interp::ADInterpreter, opt::OptimizationState,
function CC.optimize(interp::ADInterpreter, opt::OptimizationState,
params::OptimizationParams, caller::InferenceResult)
# TODO: Enable some amount of inlining
#@timeit "optimizer" ir = run_passes(opt.src, opt, caller)
sv = opt
ci = opt.src
ir = Core.Compiler.convert_to_ircode(ci, sv)
ir = Core.Compiler.slot2reg(ir, ci, sv)
ir = CC.convert_to_ircode(ci, sv)
ir = CC.slot2reg(ir, ci, sv)
# TODO: Domsorting can produce an updated domtree - no need to recompute here
ir = Core.Compiler.compact!(ir)
return Core.Compiler.finish(interp, opt, params, ir, caller)
ir = CC.compact!(ir)
return CC.finish(interp, opt, params, ir, caller)
end
=#

function Core.Compiler.finish!(interp::ADInterpreter, caller::InferenceResult)
function CC.finish!(interp::ADInterpreter, caller::InferenceResult)
effects = caller.ipo_effects
caller.src = Cthulhu.create_cthulhu_source(caller.src, effects)
end

function ir2codeinst(ir::IRCode, inst::CodeInstance, ci::CodeInfo)
CodeInstance(inst.def, inst.rettype, isdefined(inst, :rettype_const) ? inst.rettype_const : nothing,
Cthulhu.OptimizedSource(Core.Compiler.copy(ir), ci, inst.inferred.isinlineable, Core.Compiler.decode_effects(inst.purity_bits)),
Cthulhu.OptimizedSource(CC.copy(ir), ci, inst.inferred.isinlineable, CC.decode_effects(inst.purity_bits)),
Int32(0), inst.min_world, inst.max_world, inst.ipo_purity_bits, inst.purity_bits,
inst.argescapes, inst.relocatability)
end

using Core: OpaqueClosure
function codegen(interp::ADInterpreter, curs::ADCursor, cache=Dict{ADCursor, OpaqueClosure}())
ir = Core.Compiler.copy(Cthulhu.get_optimized_codeinst(interp, curs).inferred.ir)
ir = CC.copy(Cthulhu.get_optimized_codeinst(interp, curs).inferred.ir)
codeinst = interp.opt[curs.level][curs.mi]
ci = codeinst.inferred.src
if curs.level >= 1
Expand Down

0 comments on commit af1b39a

Please sign in to comment.