Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SimplifyCFG] Supporting hoisting/sinking callbases with differing attrs #109472

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion llvm/include/llvm/IR/Attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,15 @@ class Attribute {
/// Return true if the target-dependent attribute is present.
bool hasAttribute(StringRef Val) const;

/// Returns true if the attribute's kind can be represented as an enum (Enum,
/// Integer, Type, ConstantRange, or ConstantRangeList attribute).
bool hasKindAsEnum() const {
return isEnumAttribute() || isIntAttribute() || isTypeAttribute() ||
isConstantRangeAttribute() || isConstantRangeListAttribute();
}

/// Return the attribute's kind as an enum (Attribute::AttrKind). This
/// requires the attribute to be an enum, integer, or type attribute.
/// requires the attribute be representable as an enum (see: `hasKindAsEnum`).
Attribute::AttrKind getKindAsEnum() const;

/// Return the attribute's value as an integer. This requires that the
Expand Down
13 changes: 8 additions & 5 deletions llvm/include/llvm/IR/Instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -881,16 +881,19 @@ class Instruction : public User,
/// This is like isIdenticalTo, except that it ignores the
/// SubclassOptionalData flags, which may specify conditions under which the
/// instruction's result is undefined.
bool isIdenticalToWhenDefined(const Instruction *I) const LLVM_READONLY;
bool isIdenticalToWhenDefined(const Instruction *I,
bool IgnoreAttrs = false) const LLVM_READONLY;

/// When checking for operation equivalence (using isSameOperationAs) it is
/// sometimes useful to ignore certain attributes.
enum OperationEquivalenceFlags {
/// Check for equivalence ignoring load/store alignment.
CompareIgnoringAlignment = 1<<0,
CompareIgnoringAlignment = 1 << 0,
/// Check for equivalence treating a type and a vector of that type
/// as equivalent.
CompareUsingScalarTypes = 1<<1
CompareUsingScalarTypes = 1 << 1,
/// Check for equivalence ignoring callbase attrs.
CompareIgnoringAttrs = 1 << 2,
};

/// This function determines if the specified instruction executes the same
Expand All @@ -911,8 +914,8 @@ class Instruction : public User,
/// @returns true if the specific instruction has the same opcde specific
/// characteristics as the current one. Determine if one instruction has the
/// same state as another.
bool hasSameSpecialState(const Instruction *I2,
bool IgnoreAlignment = false) const LLVM_READONLY;
bool hasSameSpecialState(const Instruction *I2, bool IgnoreAlignment = false,
bool IgnoreAttrs = false) const LLVM_READONLY;

/// Return true if there are any uses of this instruction in blocks other than
/// the specified block. Note that PHI nodes are considered to evaluate their
Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/IR/Attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,7 @@ bool Attribute::isConstantRangeListAttribute() const {

Attribute::AttrKind Attribute::getKindAsEnum() const {
if (!pImpl) return None;
assert((isEnumAttribute() || isIntAttribute() || isTypeAttribute() ||
isConstantRangeAttribute() || isConstantRangeListAttribute()) &&
assert(hasKindAsEnum() &&
"Invalid attribute type to get the kind as an enum!");
return pImpl->getKindAsEnum();
}
Expand Down
25 changes: 15 additions & 10 deletions llvm/lib/IR/Instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,8 @@ const char *Instruction::getOpcodeName(unsigned OpCode) {
/// This must be kept in sync with FunctionComparator::cmpOperations in
/// lib/Transforms/IPO/MergeFunctions.cpp.
bool Instruction::hasSameSpecialState(const Instruction *I2,
bool IgnoreAlignment) const {
bool IgnoreAlignment,
bool IgnoreAttrs) const {
auto I1 = this;
assert(I1->getOpcode() == I2->getOpcode() &&
"Can not compare special state of different instructions");
Expand All @@ -811,15 +812,18 @@ bool Instruction::hasSameSpecialState(const Instruction *I2,
if (const CallInst *CI = dyn_cast<CallInst>(I1))
return CI->isTailCall() == cast<CallInst>(I2)->isTailCall() &&
CI->getCallingConv() == cast<CallInst>(I2)->getCallingConv() &&
CI->getAttributes() == cast<CallInst>(I2)->getAttributes() &&
(IgnoreAttrs ||
CI->getAttributes() == cast<CallInst>(I2)->getAttributes()) &&
CI->hasIdenticalOperandBundleSchema(*cast<CallInst>(I2));
if (const InvokeInst *CI = dyn_cast<InvokeInst>(I1))
return CI->getCallingConv() == cast<InvokeInst>(I2)->getCallingConv() &&
CI->getAttributes() == cast<InvokeInst>(I2)->getAttributes() &&
(IgnoreAttrs ||
CI->getAttributes() == cast<InvokeInst>(I2)->getAttributes()) &&
CI->hasIdenticalOperandBundleSchema(*cast<InvokeInst>(I2));
if (const CallBrInst *CI = dyn_cast<CallBrInst>(I1))
return CI->getCallingConv() == cast<CallBrInst>(I2)->getCallingConv() &&
CI->getAttributes() == cast<CallBrInst>(I2)->getAttributes() &&
(IgnoreAttrs ||
CI->getAttributes() == cast<CallBrInst>(I2)->getAttributes()) &&
CI->hasIdenticalOperandBundleSchema(*cast<CallBrInst>(I2));
if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(I1))
return IVI->getIndices() == cast<InsertValueInst>(I2)->getIndices();
Expand Down Expand Up @@ -857,10 +861,10 @@ bool Instruction::isIdenticalTo(const Instruction *I) const {
SubclassOptionalData == I->SubclassOptionalData;
}

bool Instruction::isIdenticalToWhenDefined(const Instruction *I) const {
bool Instruction::isIdenticalToWhenDefined(const Instruction *I,
bool IgnoreAttrs) const {
if (getOpcode() != I->getOpcode() ||
getNumOperands() != I->getNumOperands() ||
getType() != I->getType())
getNumOperands() != I->getNumOperands() || getType() != I->getType())
return false;

// If both instructions have no operands, they are identical.
Expand All @@ -879,15 +883,16 @@ bool Instruction::isIdenticalToWhenDefined(const Instruction *I) const {
otherPHI->block_begin());
}

return this->hasSameSpecialState(I);
return this->hasSameSpecialState(I, /*IgnoreAlignment=*/false, IgnoreAttrs);
}

// Keep this in sync with FunctionComparator::cmpOperations in
// lib/Transforms/IPO/MergeFunctions.cpp.
bool Instruction::isSameOperationAs(const Instruction *I,
unsigned flags) const {
bool IgnoreAlignment = flags & CompareIgnoringAlignment;
bool UseScalarTypes = flags & CompareUsingScalarTypes;
bool UseScalarTypes = flags & CompareUsingScalarTypes;
bool IgnoreAttrs = flags & CompareIgnoringAttrs;

if (getOpcode() != I->getOpcode() ||
getNumOperands() != I->getNumOperands() ||
Expand All @@ -905,7 +910,7 @@ bool Instruction::isSameOperationAs(const Instruction *I,
getOperand(i)->getType() != I->getOperand(i)->getType())
return false;

return this->hasSameSpecialState(I, IgnoreAlignment);
return this->hasSameSpecialState(I, IgnoreAlignment, IgnoreAttrs);
}

bool Instruction::isUsedOutsideOfBlock(const BasicBlock *BB) const {
Expand Down
165 changes: 162 additions & 3 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1591,10 +1591,150 @@ static void hoistLockstepIdenticalDbgVariableRecords(
}
}

// See if we can intersect the attributes for two callbases (used for
// hoisting/sinking).
static std::optional<AttributeList> tryIntersectAttrs(const CallBase *CB0,
const CallBase *CB1) {
assert(CB0->getCalledFunction() == CB1->getCalledFunction() &&
"Merging attrs for different functions!");

AttributeList AL0 = CB0->getAttributes();
AttributeList AL1 = CB1->getAttributes();

// Trivial case if attributes match
if (AL0 == AL1)
return AL0;

// Otherwise go through all attributes present and make sure they either match
// or that dropping them is okay.
// Note: At the moment the logic is only concerned with correctness (i.e we
// can't sink a callbase with a `ByVal` attr on a param with one that doesn't
// have the attr). But there may be some attributes that are not preferable to
// drop i.e a certain Range attr might trivialize inlining so intersecting it
// with a callbase without the attr might not be profitable.
LLVMContext &Ctx = CB0->getContext();
auto IntersectAttrs = [&Ctx](AttributeSet AS0,
AttributeSet AS1) -> std::optional<AttrBuilder> {
AttrBuilder Intersected(Ctx);

AttributeSet AllAttrs = AS0.addAttributes(Ctx, AS1);
for (Attribute Attr : AllAttrs) {
if (!Attr.isValid())
return std::nullopt;

// Only supporting enum attrs for now.
if (!Attr.hasKindAsEnum())
return std::nullopt;

Attribute::AttrKind Kind = Attr.getKindAsEnum();
bool BothContain = AS0.hasAttribute(Kind) && AS1.hasAttribute(Kind);
switch (Kind) {
default:
// Except for the below attrs we know we can intersect safely, fail if
// the attributes don't match.
if (!BothContain)
return std::nullopt;
if (AS0.getAttribute(Kind) != AS1.getAttribute(Kind))
return std::nullopt;
Intersected.addAttribute(Attr);
break;
// Attributes that can safely be intersected and can safely be thrown
// away.
case Attribute::Cold:
case Attribute::Hot:
case Attribute::MustProgress:
case Attribute::NoAlias:
case Attribute::NoCallback:
case Attribute::NoCapture:
case Attribute::NoFree:
case Attribute::NoRecurse:
case Attribute::NoReturn:
case Attribute::NoSync:
case Attribute::NoUndef:
case Attribute::NoUnwind:
case Attribute::NonNull:
case Attribute::OptimizeForSize:
// TODO: We could merge ReadNone + Readonly -> ReadOnly
case Attribute::ReadNone:
case Attribute::ReadOnly:
case Attribute::Returned:
case Attribute::Speculatable:
case Attribute::WillReturn:
case Attribute::Writable:
case Attribute::WriteOnly:
if (BothContain)
Intersected.addAttribute(Attr);
break;
// Alignment/Dereferenceable/DereferenceableOrNull/Memory/Range we can
// safely throw out, but intersection requires us to compare the values
// at hand.
case Attribute::Alignment:
if (BothContain)
Intersected.addAlignmentAttr(
std::min(AS0.getAlignment().valueOrOne(),
AS1.getAlignment().valueOrOne()));
break;
case Attribute::Dereferenceable:
if (BothContain)
Intersected.addDereferenceableAttr(std::min(
AS0.getDereferenceableBytes(), AS1.getDereferenceableBytes()));
break;
case Attribute::DereferenceableOrNull:
if (BothContain)
Intersected.addDereferenceableOrNullAttr(
std::min(AS0.getDereferenceableOrNullBytes(),
AS1.getDereferenceableOrNullBytes()));
break;
case Attribute::Memory:
if (BothContain)
Intersected.addMemoryAttr(AS0.getMemoryEffects() |
AS1.getMemoryEffects());
break;
case Attribute::Range:
if (BothContain) {
ConstantRange Range0 = AS0.getAttribute(Attribute::Range).getRange();
ConstantRange Range1 = AS1.getAttribute(Attribute::Range).getRange();
ConstantRange NewRange = Range0.unionWith(Range1);
if (!NewRange.isFullSet())
Intersected.addRangeAttr(NewRange);
}
}
}
return Intersected;
};

// Intersect all attribute types (ret/fn/param).
AttributeList IntersectedAttrs{};
auto IntersectedRetAttrs =
IntersectAttrs(AL0.getRetAttrs(), AL1.getRetAttrs());
if (!IntersectedRetAttrs)
return std::nullopt;
IntersectedAttrs =
IntersectedAttrs.addRetAttributes(Ctx, *IntersectedRetAttrs);

auto IntersectedFnAttrs = IntersectAttrs(AL0.getFnAttrs(), AL1.getFnAttrs());
if (!IntersectedFnAttrs)
return std::nullopt;
IntersectedAttrs = IntersectedAttrs.addFnAttributes(Ctx, *IntersectedFnAttrs);

for (unsigned ParamIdx = 0; ParamIdx < CB0->arg_size(); ++ParamIdx) {
auto IntersectedParamAttrs = IntersectAttrs(AL0.getParamAttrs(ParamIdx),
AL1.getParamAttrs(ParamIdx));
if (!IntersectedParamAttrs)
return std::nullopt;
IntersectedAttrs = IntersectedAttrs.addParamAttributes(
Ctx, ParamIdx, *IntersectedParamAttrs);
}
return IntersectedAttrs;
}

static bool areIdenticalUpToCommutativity(const Instruction *I1,
const Instruction *I2) {
if (I1->isIdenticalToWhenDefined(I2))
if (I1->isIdenticalToWhenDefined(I2, /*IgnoreAttrs=*/true)) {
if (auto *CB1 = dyn_cast<CallBase>(I1))
return tryIntersectAttrs(CB1, cast<CallBase>(I2)).has_value();
return true;
}

if (auto *Cmp1 = dyn_cast<CmpInst>(I1))
if (auto *Cmp2 = dyn_cast<CmpInst>(I2))
Expand Down Expand Up @@ -1775,6 +1915,14 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(BasicBlock *BB,
if (!I2->use_empty())
I2->replaceAllUsesWith(I1);
I1->andIRFlags(I2);
if (auto *CB = dyn_cast<CallBase>(I1)) {
auto IntersectedAttrs = tryIntersectAttrs(CB, cast<CallBase>(I2));
assert(IntersectedAttrs &&
"We should not be trying to hoist callbases "
"with non-intersectable attributes");
CB->setAttributes(*IntersectedAttrs);
}

combineMetadataForCSE(I1, I2, true);
// I1 and I2 are being combined into a single instruction. Its debug
// location is the merged locations of the original instructions.
Expand Down Expand Up @@ -1995,7 +2143,7 @@ static bool canSinkInstructions(
const Instruction *I0 = Insts.front();
const auto I0MMRA = MMRAMetadata(*I0);
for (auto *I : Insts) {
if (!I->isSameOperationAs(I0))
if (!I->isSameOperationAs(I0, Instruction::CompareIgnoringAttrs))
return false;

// swifterror pointers can only be used by a load or store; sinking a load
Expand Down Expand Up @@ -2029,7 +2177,7 @@ static bool canSinkInstructions(
// I.e. if we have two direct calls to different callees, we don't want to
// turn that into an indirect call. Likewise, if we have an indirect call,
// and a direct call, we don't actually want to have a single indirect call.
if (isa<CallBase>(I0)) {
if (auto *CB = dyn_cast<CallBase>(I0)) {
auto IsIndirectCall = [](const Instruction *I) {
return cast<CallBase>(I)->isIndirectCall();
};
Expand All @@ -2048,6 +2196,11 @@ static bool canSinkInstructions(
else if (Callee != CurrCallee)
return false;
}
}
// Check that we can intersect the attributes if we sink.
for (const Instruction *I : Insts) {
if (I != I0 && !tryIntersectAttrs(CB, cast<CallBase>(I)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we cache the result of intersection?

return false;
}
}

Expand Down Expand Up @@ -2152,6 +2305,12 @@ static void sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) {
I0->applyMergedLocation(I0->getDebugLoc(), I->getDebugLoc());
combineMetadataForCSE(I0, I, true);
I0->andIRFlags(I);
if (auto *CB = dyn_cast<CallBase>(I0)) {
auto IntersectedAttrs = tryIntersectAttrs(CB, cast<CallBase>(I));
assert(IntersectedAttrs && "We should not be trying to sink callbases "
"with non-intersectable attributes");
CB->setAttributes(*IntersectedAttrs);
}
}

for (User *U : make_early_inc_range(I0->users())) {
Expand Down
Loading
Loading