Skip to content

Commit

Permalink
Add pragma(LDC_musttail)
Browse files Browse the repository at this point in the history
  • Loading branch information
MrSmith33 committed Apr 6, 2023
1 parent 0d4d711 commit f5ba4a4
Show file tree
Hide file tree
Showing 21 changed files with 138 additions and 68 deletions.
4 changes: 4 additions & 0 deletions dmd/expression.d
Original file line number Diff line number Diff line change
Expand Up @@ -5077,6 +5077,10 @@ extern (C++) final class CallExp : UnaExp
bool directcall; // true if a virtual call is devirtualized
bool inDebugStatement; /// true if this was in a debug statement
bool ignoreAttributes; /// don't enforce attributes (e.g. call @gc function in @nogc code)
version (IN_LLVM)
{
bool isMustTail; // If marked with pragma(musttail)
}
VarDeclaration vthis2; // container for multi-context

extern (D) this(const ref Loc loc, Expression e, Expressions* exps)
Expand Down
3 changes: 3 additions & 0 deletions dmd/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,9 @@ class CallExp final : public UnaExp
bool directcall; // true if a virtual call is devirtualized
bool inDebugStatement; // true if this was in a debug statement
bool ignoreAttributes; // don't enforce attributes (e.g. call @gc function in @nogc code)
#if IN_LLVM
bool isMustTail; // If marked with pragma(musttail)
#endif
VarDeclaration *vthis2; // container for multi-context

static CallExp *create(const Loc &loc, Expression *e, Expressions *exps);
Expand Down
1 change: 1 addition & 0 deletions dmd/id.d
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ immutable Msgtable[] msgtable =
{ "LDC_global_crt_dtor" },
{ "LDC_extern_weak" },
{ "LDC_profile_instr" },
{ "musttail" },

// IN_LLVM: LDC-specific traits
{ "targetCPU" },
Expand Down
1 change: 1 addition & 0 deletions dmd/id.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ struct Id
static Identifier *LDC_inline_ir;
static Identifier *LDC_extern_weak;
static Identifier *LDC_profile_instr;
static Identifier *musttail;
static Identifier *dcReflect;
static Identifier *opencl;
static Identifier *criticalenter;
Expand Down
35 changes: 35 additions & 0 deletions dmd/statementsem.d
Original file line number Diff line number Diff line change
Expand Up @@ -2135,6 +2135,10 @@ else
return setError();
}
}
else if (ps.ident == Id.musttail)
{
pragmaMustTailSemantic(ps);
}
else if (!global.params.ignoreUnsupportedPragmas)
{
ps.error("unrecognized `pragma(%s)`", ps.ident.toChars());
Expand All @@ -2153,6 +2157,37 @@ else
result = ps._body;
}

private void pragmaMustTailSemantic(PragmaStatement ps)
{
if (!ps._body)
{
ps.error("`pragma(musttail)` must be attached to a return statement");
return setError();
}

auto rs = ps._body.isReturnStatement();
if (!rs)
{
ps.error("`pragma(musttail)` must be attached to a return statement");
return setError();
}

if (!rs.exp)
{
ps.error("`pragma(musttail)` must be attached to a return statement returning result of a function call");
return setError();
}

auto ce = rs.exp.isCallExp();
if (!ce)
{
ps.error("`pragma(musttail)` must be attached to a return statement returning result of a function call");
return setError();
}

ce.isMustTail = true;
}

override void visit(StaticAssertStatement s)
{
s.sa.semantic2(sc);
Expand Down
10 changes: 5 additions & 5 deletions gen/aa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ DLValue *DtoAAIndex(const Loc &loc, Type *type, DValue *aa, DValue *key,
DtoTypeInfoOf(loc, aa->type->unSharedOf()->mutableOf(), /*base=*/false);
LLValue *castedAATI = DtoBitCast(rawAATI, funcTy->getParamType(1));
LLValue *valsize = DtoConstSize_t(getTypeAllocSize(DtoType(type)));
ret = gIR->CreateCallOrInvoke(func, aaval, castedAATI, valsize, pkey,
ret = gIR->CreateCallOrInvoke(loc, func, aaval, castedAATI, valsize, pkey,
"aa.index");
} else {
LLValue *keyti = to_keyti(loc, aa, funcTy->getParamType(1));
ret = gIR->CreateCallOrInvoke(func, aaval, keyti, pkey, "aa.index");
ret = gIR->CreateCallOrInvoke(loc, func, aaval, keyti, pkey, "aa.index");
}

// cast return value
Expand Down Expand Up @@ -130,7 +130,7 @@ DValue *DtoAAIn(const Loc &loc, Type *type, DValue *aa, DValue *key) {
pkey = DtoBitCast(pkey, getVoidPtrType());

// call runtime
LLValue *ret = gIR->CreateCallOrInvoke(func, aaval, keyti, pkey, "aa.in");
LLValue *ret = gIR->CreateCallOrInvoke(loc, func, aaval, keyti, pkey, "aa.in");

// cast return value
LLType *targettype = DtoType(type);
Expand Down Expand Up @@ -174,7 +174,7 @@ DValue *DtoAARemove(const Loc &loc, DValue *aa, DValue *key) {
pkey = DtoBitCast(pkey, funcTy->getParamType(2));

// call runtime
LLValue *res = gIR->CreateCallOrInvoke(func, aaval, keyti, pkey);
LLValue *res = gIR->CreateCallOrInvoke(loc, func, aaval, keyti, pkey);

return new DImValue(Type::tbool, res);
}
Expand All @@ -192,7 +192,7 @@ LLValue *DtoAAEquals(const Loc &loc, EXP op, DValue *l, DValue *r) {
LLValue *abval = DtoBitCast(DtoRVal(r), funcTy->getParamType(2));
LLValue *aaTypeInfo = DtoTypeInfoOf(loc, t);
LLValue *res =
gIR->CreateCallOrInvoke(func, aaTypeInfo, aaval, abval, "aaEqRes");
gIR->CreateCallOrInvoke(loc, func, aaTypeInfo, aaval, abval, "aaEqRes");

const auto predicate = eqTokToICmpPred(op, /* invert = */ true);
res = gIR->ir->CreateICmp(predicate, res, DtoConstInt(0));
Expand Down
20 changes: 10 additions & 10 deletions gen/arrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ static void copySlice(const Loc &loc, LLValue *dstarr, LLValue *dstlen,
if (checksEnabled && !knownInBounds) {
LLFunction *fn = getRuntimeFunction(loc, gIR->module, "_d_array_slice_copy");
gIR->CreateCallOrInvoke(
fn, {dstarr, dstlen, srcarr, srclen, DtoConstSize_t(elementSize)}, "",
loc, fn, {dstarr, dstlen, srcarr, srclen, DtoConstSize_t(elementSize)}, "",
/*isNothrow=*/true);
} else {
// We might have dstarr == srcarr at compile time, but as long as
Expand Down Expand Up @@ -271,7 +271,7 @@ void DtoArrayAssign(const Loc &loc, DValue *lhs, DValue *rhs, EXP op,
loc, gIR->module,
!canSkipPostblit ? "_d_arrayassign_l" : "_d_arrayassign_r");
gIR->CreateCallOrInvoke(
fn, DtoTypeInfoOf(loc, elemType), DtoSlice(rhsPtr, rhsLength, getI8Type()),
loc, fn, DtoTypeInfoOf(loc, elemType), DtoSlice(rhsPtr, rhsLength, getI8Type()),
DtoSlice(lhsPtr, lhsLength, getI8Type()), DtoBitCast(tmpSwap, getVoidPtrType()));
}
} else {
Expand Down Expand Up @@ -305,7 +305,7 @@ void DtoArrayAssign(const Loc &loc, DValue *lhs, DValue *rhs, EXP op,
LLFunction *fn =
getRuntimeFunction(loc, gIR->module, "_d_arraysetassign");
gIR->CreateCallOrInvoke(
fn, lhsPtr, DtoBitCast(makeLValue(loc, rhs), getVoidPtrType()),
loc, fn, lhsPtr, DtoBitCast(makeLValue(loc, rhs), getVoidPtrType()),
gIR->ir->CreateTruncOrBitCast(lhsLength,
LLType::getInt32Ty(gIR->context())),
DtoTypeInfoOf(loc, stripModifiers(t2)));
Expand Down Expand Up @@ -672,7 +672,7 @@ DSliceValue *DtoNewDynArray(const Loc &loc, Type *arrayType, DValue *dim,

// call allocator
LLValue *newArray =
gIR->CreateCallOrInvoke(fn, arrayTypeInfo, arrayLen, ".gc_mem");
gIR->CreateCallOrInvoke(loc, fn, arrayTypeInfo, arrayLen, ".gc_mem");

// return a DSliceValue with the well-known length for better optimizability
auto ptr =
Expand Down Expand Up @@ -741,7 +741,7 @@ DSliceValue *DtoNewMulDimDynArray(const Loc &loc, Type *arrayType,

// call allocator
LLValue *newptr =
gIR->CreateCallOrInvoke(fn, arrayTypeInfo, DtoLoad(dtype, darray), ".gc_mem");
gIR->CreateCallOrInvoke(loc, fn, arrayTypeInfo, DtoLoad(dtype, darray), ".gc_mem");

IF_LOG Logger::cout() << "final ptr = " << *newptr << '\n';

Expand Down Expand Up @@ -769,7 +769,7 @@ DSliceValue *DtoResizeDynArray(const Loc &loc, Type *arrayType, DValue *array,
: "_d_arraysetlengthiT");

LLValue *newArray = gIR->CreateCallOrInvoke(
fn, DtoTypeInfoOf(loc, arrayType), newdim,
loc, fn, DtoTypeInfoOf(loc, arrayType), newdim,
DtoBitCast(DtoLVal(array), fn->getFunctionType()->getParamType(2)),
".gc_mem");

Expand Down Expand Up @@ -871,7 +871,7 @@ DSliceValue *DtoCatArrays(const Loc &loc, Type *arrayType, Expression *exp1,
args.push_back(loadArray(exp2,2));
}

auto newArray = gIR->CreateCallOrInvoke(fn, args, ".appendedArray");
auto newArray = gIR->CreateCallOrInvoke(loc, fn, args, ".appendedArray");
return getSlice(arrayType, newArray);
}

Expand All @@ -886,7 +886,7 @@ DSliceValue *DtoAppendDChar(const Loc &loc, DValue *arr, Expression *exp,

// Call function (ref string x, dchar c)
LLValue *newArray = gIR->CreateCallOrInvoke(
fn, DtoBitCast(DtoLVal(arr), fn->getFunctionType()->getParamType(0)),
loc, fn, DtoBitCast(DtoLVal(arr), fn->getFunctionType()->getParamType(0)),
DtoBitCast(valueToAppend, fn->getFunctionType()->getParamType(1)),
".appendedArray");

Expand Down Expand Up @@ -942,7 +942,7 @@ LLValue *DtoArrayEqCmp_impl(const Loc &loc, const char *func, DValue *l,
args.push_back(DtoBitCast(tival, fn->getFunctionType()->getParamType(2)));
}

return gIR->CreateCallOrInvoke(fn, args);
return gIR->CreateCallOrInvoke(loc, fn, args);
}

/// When `true` is returned, the type can be compared using `memcmp`.
Expand Down Expand Up @@ -1324,7 +1324,7 @@ static void emitRangeErrorImpl(IRState *irs, const Loc &loc,
args.push_back(DtoModuleFileName(module, loc));
args.push_back(DtoConstUint(loc.linnum));
args.insert(args.end(), extraArgs.begin(), extraArgs.end());
irs->CreateCallOrInvoke(fn, args);
irs->CreateCallOrInvoke(loc, fn, args);
irs->ir->CreateUnreachable();
break;
}
Expand Down
8 changes: 4 additions & 4 deletions gen/classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ DValue *DtoNewClass(const Loc &loc, TypeClass *tc, NewExp *newexp) {
LLConstant *ci =
DtoBitCast(irClass->getClassInfoSymbol(), DtoType(getClassInfoType()));
mem = gIR->CreateCallOrInvoke(
fn, ci, useEHAlloc ? ".newthrowable_alloc" : ".newclass_gc_alloc");
loc, fn, ci, useEHAlloc ? ".newthrowable_alloc" : ".newclass_gc_alloc");
mem = DtoBitCast(mem, DtoType(tc),
useEHAlloc ? ".newthrowable" : ".newclass_gc");
doInit = !useEHAlloc;
Expand Down Expand Up @@ -183,7 +183,7 @@ void DtoFinalizeClass(const Loc &loc, LLValue *inst) {
getRuntimeFunction(loc, gIR->module, "_d_callfinalizer");

gIR->CreateCallOrInvoke(
fn, DtoBitCast(inst, fn->getFunctionType()->getParamType(0)), "");
loc, fn, DtoBitCast(inst, fn->getFunctionType()->getParamType(0)), "");
}

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -378,7 +378,7 @@ DValue *DtoDynamicCastObject(const Loc &loc, DValue *val, Type *_to) {
assert(funcTy->getParamType(1) == cinfo->getType());

// call it
LLValue *ret = gIR->CreateCallOrInvoke(func, obj, cinfo);
LLValue *ret = gIR->CreateCallOrInvoke(loc, func, obj, cinfo);

// cast return value
ret = DtoBitCast(ret, DtoType(_to));
Expand Down Expand Up @@ -412,7 +412,7 @@ DValue *DtoDynamicCastInterface(const Loc &loc, DValue *val, Type *_to) {
cinfo = DtoBitCast(cinfo, funcTy->getParamType(1));

// call it
LLValue *ret = gIR->CreateCallOrInvoke(func, ptr, cinfo);
LLValue *ret = gIR->CreateCallOrInvoke(loc, func, ptr, cinfo);

// cast return value
ret = DtoBitCast(ret, DtoType(_to));
Expand Down
3 changes: 2 additions & 1 deletion gen/dpragma.d
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ extern (C++) enum LDCPragma : int {
LLVMbitop_bts,
LLVMbitop_vld,
LLVMbitop_vst,
LLVMextern_weak
LLVMextern_weak,
LLVMmusttail,
};

extern (C++) LDCPragma DtoGetPragma(Scope* sc, PragmaDeclaration decl, ref const(char)* arg1str);
Expand Down
13 changes: 11 additions & 2 deletions gen/funcgenstate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "gen/funcgenstate.h"

#include "dmd/errors.h"
#include "dmd/identifier.h"
#include "gen/llvm.h"
#include "gen/llvmhelpers.h"
Expand Down Expand Up @@ -103,10 +104,10 @@ FuncGenState::FuncGenState(IrFunction &irFunc, IRState &irs)
: irFunc(irFunc), scopes(irs), jumpTargets(scopes), switchTargets(),
irs(irs) {}

LLCallBasePtr FuncGenState::callOrInvoke(llvm::Value *callee,
LLCallBasePtr FuncGenState::callOrInvoke(const Loc &loc, llvm::Value *callee,
llvm::FunctionType *calleeType,
llvm::ArrayRef<llvm::Value *> args,
const char *name, bool isNothrow) {
const char *name, bool isNothrow, bool isMustTail) {
// If this is a direct call, we might be able to use the callee attributes
// to our advantage.
llvm::Function *calleeFn = llvm::dyn_cast<llvm::Function>(callee);
Expand Down Expand Up @@ -135,9 +136,17 @@ LLCallBasePtr FuncGenState::callOrInvoke(llvm::Value *callee,
if (calleeFn) {
call->setAttributes(calleeFn->getAttributes());
}
if (isMustTail) {
call->setTailCallKind(llvm::CallInst::TCK_MustTail);
}
return call;
}

if (isMustTail) {
error(loc, "cannot perform tail-call, there is code after call");
fatal();
}

llvm::BasicBlock *landingPad = scopes.getLandingPad();

llvm::BasicBlock *postinvoke = irs.insertBB("postinvoke");
Expand Down
5 changes: 3 additions & 2 deletions gen/funcgenstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,11 @@ class FuncGenState {

/// Emits a call or invoke to the given callee, depending on whether there
/// are catches/cleanups active or not.
LLCallBasePtr callOrInvoke(llvm::Value *callee,
LLCallBasePtr callOrInvoke(const Loc &loc, llvm::Value *callee,
llvm::FunctionType *calleeType,
llvm::ArrayRef<llvm::Value *> args,
const char *name = "", bool isNothrow = false);
const char *name = "", bool isNothrow = false,
bool isMustTail = false);

private:
IRState &irs;
Expand Down
27 changes: 14 additions & 13 deletions gen/irstate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,43 +79,44 @@ llvm::BasicBlock *IRState::insertBB(const llvm::Twine &name) {
return insertBBAfter(scopebb(), name);
}

llvm::Instruction *IRState::CreateCallOrInvoke(LLFunction *Callee,
llvm::Instruction *IRState::CreateCallOrInvoke(const Loc &loc, LLFunction *Callee,
const char *Name) {
return CreateCallOrInvoke(Callee, {}, Name);
return CreateCallOrInvoke(loc, Callee, {}, Name);
}

llvm::Instruction *IRState::CreateCallOrInvoke(LLFunction *Callee,
llvm::Instruction *IRState::CreateCallOrInvoke(const Loc &loc, LLFunction *Callee,
llvm::ArrayRef<LLValue *> Args,
const char *Name,
bool isNothrow) {
return funcGen().callOrInvoke(Callee, Callee->getFunctionType(), Args, Name,
isNothrow);
return funcGen().callOrInvoke(loc, Callee, Callee->getFunctionType(), Args,
Name, isNothrow);
}

llvm::Instruction *IRState::CreateCallOrInvoke(LLFunction *Callee,
llvm::Instruction *IRState::CreateCallOrInvoke(const Loc &loc,
LLFunction *Callee,
LLValue *Arg1,
const char *Name) {
return CreateCallOrInvoke(Callee, llvm::ArrayRef<LLValue *>(Arg1), Name);
return CreateCallOrInvoke(loc, Callee, llvm::ArrayRef<LLValue *>(Arg1), Name);
}

llvm::Instruction *IRState::CreateCallOrInvoke(LLFunction *Callee,
llvm::Instruction *IRState::CreateCallOrInvoke(const Loc &loc, LLFunction *Callee,
LLValue *Arg1, LLValue *Arg2,
const char *Name) {
return CreateCallOrInvoke(Callee, {Arg1, Arg2}, Name);
return CreateCallOrInvoke(loc, Callee, {Arg1, Arg2}, Name);
}

llvm::Instruction *IRState::CreateCallOrInvoke(LLFunction *Callee,
llvm::Instruction *IRState::CreateCallOrInvoke(const Loc &loc, LLFunction *Callee,
LLValue *Arg1, LLValue *Arg2,
LLValue *Arg3,
const char *Name) {
return CreateCallOrInvoke(Callee, {Arg1, Arg2, Arg3}, Name);
return CreateCallOrInvoke(loc, Callee, {Arg1, Arg2, Arg3}, Name);
}

llvm::Instruction *IRState::CreateCallOrInvoke(LLFunction *Callee,
llvm::Instruction *IRState::CreateCallOrInvoke(const Loc &loc, LLFunction *Callee,
LLValue *Arg1, LLValue *Arg2,
LLValue *Arg3, LLValue *Arg4,
const char *Name) {
return CreateCallOrInvoke(Callee, {Arg1, Arg2, Arg3, Arg4}, Name);
return CreateCallOrInvoke(loc, Callee, {Arg1, Arg2, Arg3, Arg4}, Name);
}

bool IRState::emitArrayBoundsChecks() {
Expand Down
Loading

0 comments on commit f5ba4a4

Please sign in to comment.