Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

callsite union splitting #17212

Merged
merged 3 commits into from
Jul 2, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 107 additions & 3 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ const MAX_TUPLETYPE_LEN = 15
const MAX_TUPLE_DEPTH = 4

const MAX_TUPLE_SPLAT = 16
const MAX_UNION_SPLITTING = 6

# alloc_elim_pass! relies on `Slot_AssignedOnce | Slot_UsedUndef` being
# SSA. This should be true now but can break if we start to track conditional
Expand Down Expand Up @@ -2359,6 +2360,16 @@ function inline_as_constant(val::ANY, argexprs, sv)
return (QuoteNode(val), stmts)
end

function countunionsplit(atypes::Vector{Any})
nu = 1
for ti in atypes
if isa(ti, Union)
nu *= length((ti::Union).types)
end
end
return nu
end

# inline functions whose bodies are "inline_worthy"
# where the function body doesn't contain any argument more than once.
# static parameters are ok if all the static parameter values are leaf types,
Expand Down Expand Up @@ -2413,13 +2424,106 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
return NF
end

atype_unlimited = argtypes_to_type(atypes)
local atype_unlimited = argtypes_to_type(atypes)
function invoke_NF()
# converts a :call to :invoke
cache_linfo = ccall(:jl_get_spec_lambda, Any, (Any,), atype_unlimited)
if cache_linfo !== nothing
local nu = countunionsplit(atypes)
nu > MAX_UNION_SPLITTING && return NF

if nu > 1
local spec_hit = nothing
local spec_miss = nothing
local error_label = nothing
local linfo_var = add_slot!(enclosing, LambdaInfo, false)
local ex = copy(e)
local stmts = []
for i = 1:length(atypes); local i
local ti = atypes[i]
if isa(ti, Union)
aei = ex.args[i]
if !effect_free(aei, sv, false)
newvar = newvar!(sv, ti)
push!(stmts, Expr(:(=), newvar, aei))
ex.args[i] = newvar
end
end
end
function splitunion(atypes::Vector{Any}, i::Int)
if i == 0
local sig = argtypes_to_type(atypes)
local li = ccall(:jl_get_spec_lambda, Any, (Any,), sig)
li === nothing && return false
local stmt = []
push!(stmt, Expr(:(=), linfo_var, li))
spec_hit === nothing && (spec_hit = genlabel(sv))
push!(stmt, GotoNode(spec_hit.label))
return stmt
else
local ti = atypes[i]
if isa(ti, Union)
local all = true
local stmts = []
local aei = ex.args[i]
for ty in (ti::Union).types; local ty
atypes[i] = ty
local match = splitunion(atypes, i - 1)
if match !== false
after = genlabel(sv)
unshift!(match, Expr(:gotoifnot, Expr(:call, GlobalRef(Core, :isa), aei, ty), after.label))
append!(stmts, match)
push!(stmts, after)
else
all = false
end
end
if all
error_label === nothing && (error_label = genlabel(sv))
push!(stmts, GotoNode(error_label.label))
else
spec_miss === nothing && (spec_miss = genlabel(sv))
push!(stmts, GotoNode(spec_miss.label))
end
atypes[i] = ti
return isempty(stmts) ? false : stmts
else
return splitunion(atypes, i - 1)
end
end
end
local match = splitunion(atypes, length(atypes))
if match !== false && spec_hit !== nothing
append!(stmts, match)
if error_label !== nothing
push!(stmts, error_label)
push!(stmts, Expr(:call, GlobalRef(_topmod(sv.mod), :error), "error in type inference due to #265"))
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"...issue #265"? (This is effectively "jargon.")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was not addressed

end
local ret_var, merge
if spec_miss !== nothing
ret_var = add_slot!(enclosing, ex.typ, false)
merge = genlabel(sv)
push!(stmts, spec_miss)
push!(stmts, Expr(:(=), ret_var, ex))
push!(stmts, GotoNode(merge.label))
else
ret_var = newvar!(sv, ex.typ)
end
push!(stmts, spec_hit)
ex = copy(ex)
ex.head = :invoke
unshift!(ex.args, linfo_var)
push!(stmts, Expr(:(=), ret_var, ex))
if spec_miss !== nothing
push!(stmts, merge)
end
#println(stmts)
return (ret_var, stmts)
end
else
local cache_linfo = ccall(:jl_get_spec_lambda, Any, (Any,), atype_unlimited)
cache_linfo === nothing && return NF
e.head = :invoke
unshift!(e.args, cache_linfo)
return e
end
return NF
end
Expand Down
4 changes: 2 additions & 2 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ function _methods_by_ftype(t::ANY, lim)
if 1 < nu <= 64
return _methods(Any[tp...], length(tp), lim, [])
end
# TODO: the following can return incorrect answers that the above branch would have corrected
# XXX: the following can return incorrect answers that the above branch would have corrected
return ccall(:jl_matching_methods, Any, (Any,Cint,Cint), t, lim, 0)
end
function _methods(t::Array,i,lim::Integer,matching::Array{Any,1})
Expand All @@ -206,7 +206,7 @@ function _methods(t::Array,i,lim::Integer,matching::Array{Any,1})
for ty in (ti::Union).types
t[i] = ty
if _methods(t,i-1,lim,matching) === false
t[i] = ty
t[i] = ti
return false
end
end
Expand Down
9 changes: 9 additions & 0 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,15 @@ show_unquoted(io::IO, ex::LabelNode, ::Int, ::Int) = print(io, ex.label, ":
show_unquoted(io::IO, ex::GotoNode, ::Int, ::Int) = print(io, "goto ", ex.label)
show_unquoted(io::IO, ex::GlobalRef, ::Int, ::Int) = print(io, ex.mod, '.', ex.name)

function show_unquoted(io::IO, ex::LambdaInfo, ::Int, ::Int)
if isdefined(ex, :specTypes)
print(io, "LambdaInfo for ")
show_lambda_types(io, ex.specTypes.parameters)
else
show(io, ex)
end
end

function show_unquoted(io::IO, ex::Slot, ::Int, ::Int)
typ = isa(ex,TypedSlot) ? ex.typ : Any
slotid = ex.id
Expand Down
2 changes: 1 addition & 1 deletion src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ static jl_cgval_t emit_typeof(const jl_cgval_t &p, jl_codectx_t *ctx)
{
// given p, compute its type
if (!p.constant && p.isboxed && !jl_is_leaf_type(p.typ)) {
return mark_julia_type(emit_typeof(p.V), true, jl_datatype_type, ctx);
return mark_julia_type(emit_typeof(p.V), true, jl_datatype_type, ctx, /*needsroot*/false);
}
jl_value_t *aty = p.typ;
if (jl_is_type_type(aty)) // convert Int::Type{Int} ==> typeof(Int) ==> DataType
Expand Down
57 changes: 29 additions & 28 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2700,38 +2700,39 @@ static jl_cgval_t emit_invoke(jl_expr_t *ex, jl_codectx_t *ctx)
size_t arglen = jl_array_dim0(ex->args);
size_t nargs = arglen - 1;
assert(arglen >= 2);
jl_lambda_info_t *li = (jl_lambda_info_t*)args[0];
assert(jl_is_lambda_info(li));

if (li->jlcall_api == 2) {
assert(li->constval);
return mark_julia_const(li->constval);
}
else if (li->functionObjectsDecls.functionObject == NULL) {
assert(!li->inCompile);
if (li->code == jl_nothing && !li->inInference && li->inferred) {
// XXX: it was inferred in the past, so it's almost valid to re-infer it now
jl_type_infer(li, 0);
jl_cgval_t lival = emit_expr(args[0], ctx);
if (lival.constant) {
jl_lambda_info_t *li = (jl_lambda_info_t*)lival.constant;
assert(jl_is_lambda_info(li));
if (li->jlcall_api == 2) {
assert(li->constval);
return mark_julia_const(li->constval);
}
if (li->functionObjectsDecls.functionObject == NULL) {
assert(!li->inCompile);
if (li->code == jl_nothing && !li->inInference && li->inferred) {
// XXX: it was inferred in the past, so it's almost valid to re-infer it now
jl_type_infer(li, 0);
}
if (!li->inInference && li->inferred && li->code != jl_nothing) {
jl_compile_linfo(li);
}
}
if (!li->inInference && li->inferred && li->code != jl_nothing) {
jl_compile_linfo(li);
Value *theFptr = (Value*)li->functionObjectsDecls.functionObject;
if (theFptr && li->jlcall_api == 0) {
jl_cgval_t fval = emit_expr(args[1], ctx);
jl_cgval_t result = emit_call_function_object(li, fval, theFptr, &args[1], nargs - 1, (jl_value_t*)ex, ctx);
if (result.typ == jl_bottom_type)
CreateTrap(builder);
return result;
}
}
Value *theFptr = (Value*)li->functionObjectsDecls.functionObject;
jl_cgval_t result;
if (theFptr && li->jlcall_api == 0) {
jl_cgval_t fval = emit_expr(args[1], ctx);
result = emit_call_function_object(li, fval, theFptr, &args[1], nargs - 1, (jl_value_t*)ex, ctx);
}
else {
result = mark_julia_type(emit_jlcall(prepare_call(jlinvoke_func), literal_pointer_val((jl_value_t*)li),
&args[1], nargs, ctx),
true, expr_type((jl_value_t*)ex, ctx), ctx);
}

if (result.typ == jl_bottom_type) {
jl_cgval_t result = mark_julia_type(emit_jlcall(prepare_call(jlinvoke_func), boxed(lival, ctx, false),
&args[1], nargs, ctx),
true, expr_type((jl_value_t*)ex, ctx), ctx);
if (result.typ == jl_bottom_type)
CreateTrap(builder);
}
return result;
}

Expand Down Expand Up @@ -3997,7 +3998,7 @@ static Function *gen_jlcall_wrapper(jl_lambda_info_t *lam, Function *f, bool sre
bool retboxed;
(void)julia_type_to_llvm(jlretty, &retboxed);
if (sret) { assert(!retboxed); }
jl_cgval_t retval = sret ? mark_julia_slot(result, jlretty, tbaa_stack) : mark_julia_type(call, retboxed, jlretty, &ctx);
jl_cgval_t retval = sret ? mark_julia_slot(result, jlretty, tbaa_stack) : mark_julia_type(call, retboxed, jlretty, &ctx, /*needsroot*/false);
builder.CreateRet(boxed(retval, &ctx, false)); // no gcroot needed since this on the return path

return w;
Expand Down
31 changes: 23 additions & 8 deletions src/debuginfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,8 @@ class JuliaJITEventListener: public JITEventListener
else
SectionAddrCheck = SectionLoadAddr;
create_PRUNTIME_FUNCTION(
(uint8_t*)(intptr_t)Addr, (size_t)Size, sName,
(uint8_t*)(intptr_t)SectionLoadAddr, (size_t)SectionSize, UnwindData);
(uint8_t*)(uintptr_t)Addr, (size_t)Size, sName,
(uint8_t*)(uintptr_t)SectionLoadAddr, (size_t)SectionSize, UnwindData);
#endif
StringMap<jl_lambda_info_t*>::iterator linfo_it = linfo_in_flight.find(sName);
jl_lambda_info_t *linfo = NULL;
Expand Down Expand Up @@ -559,8 +559,8 @@ class JuliaJITEventListener: public JITEventListener
else
SectionAddrCheck = SectionLoadAddr;
create_PRUNTIME_FUNCTION(
(uint8_t*)(intptr_t)Addr, (size_t)Size, sName,
(uint8_t*)(intptr_t)SectionLoadAddr, (size_t)SectionSize, UnwindData);
(uint8_t*)(uintptr_t)Addr, (size_t)Size, sName,
(uint8_t*)(uintptr_t)SectionLoadAddr, (size_t)SectionSize, UnwindData);
#endif
StringMap<jl_lambda_info_t*>::iterator linfo_it = linfo_in_flight.find(sName);
jl_lambda_info_t *linfo = NULL;
Expand Down Expand Up @@ -1256,7 +1256,7 @@ int jl_getFunctionInfo(jl_frame_t **frames_out, size_t pointer, int skipC, int n
// Without MCJIT we use the FuncInfo structure containing address maps
std::map<size_t, FuncInfo, revcomp> &info = jl_jit_events->getMap();
std::map<size_t, FuncInfo, revcomp>::iterator it = info.lower_bound(pointer);
if (it != info.end() && (intptr_t)(*it).first + (*it).second.lengthAdr >= pointer) {
if (it != info.end() && (uintptr_t)(*it).first + (*it).second.lengthAdr >= pointer) {
// We do this to hide the jlcall wrappers when getting julia backtraces,
// but it is still good to have them for regular lookup of C frames.
if (skipC && (*it).second.lines.empty()) {
Expand Down Expand Up @@ -1330,6 +1330,21 @@ int jl_getFunctionInfo(jl_frame_t **frames_out, size_t pointer, int skipC, int n
return jl_getDylibFunctionInfo(frames_out, pointer, skipC, noInline);
}

extern "C" jl_lambda_info_t *jl_gdblookuplinfo(void *p)
{
#ifndef USE_MCJIT
std::map<size_t, FuncInfo, revcomp> &info = jl_jit_events->getMap();
std::map<size_t, FuncInfo, revcomp>::iterator it = info.lower_bound((size_t)p);
jl_lambda_info_t *li = NULL;
if (it != info.end() && (uintptr_t)(*it).first + (*it).second.lengthAdr >= (uintptr_t)p)
li = (*it).second.linfo;
uv_rwlock_rdunlock(&threadsafe);
return li;
#else
return jl_jit_events->lookupLinfo((size_t)p);
#endif
}

#if defined(LLVM37) && (defined(_OS_LINUX_) || (defined(_OS_DARWIN_) && defined(LLVM_SHLIB)))
extern "C" void __register_frame(void*);
extern "C" void __deregister_frame(void*);
Expand Down Expand Up @@ -1745,7 +1760,7 @@ uint64_t jl_getUnwindInfo(uint64_t dwAddr)
std::map<size_t, ObjectInfo, revcomp>::iterator it = objmap.lower_bound(dwAddr);
uint64_t ipstart = 0; // ip of the start of the section (if found)
if (it != objmap.end() && dwAddr < it->first + it->second.SectionSize) {
ipstart = (uint64_t)(intptr_t)(*it).first;
ipstart = (uint64_t)(uintptr_t)(*it).first;
}
uv_rwlock_rdunlock(&threadsafe);
return ipstart;
Expand All @@ -1758,8 +1773,8 @@ uint64_t jl_getUnwindInfo(uint64_t dwAddr)
std::map<size_t, FuncInfo, revcomp> &info = jl_jit_events->getMap();
std::map<size_t, FuncInfo, revcomp>::iterator it = info.lower_bound(dwAddr);
uint64_t ipstart = 0; // ip of the first instruction in the function (if found)
if (it != info.end() && (intptr_t)(*it).first + (*it).second.lengthAdr > dwAddr) {
ipstart = (uint64_t)(intptr_t)(*it).first;
if (it != info.end() && (uintptr_t)(*it).first + (*it).second.lengthAdr > dwAddr) {
ipstart = (uint64_t)(uintptr_t)(*it).first;
}
uv_rwlock_rdunlock(&threadsafe);
return ipstart;
Expand Down