Skip to content

Commit

Permalink
Add various variants of genRegNumFromMask() and use them
Browse files Browse the repository at this point in the history
  • Loading branch information
kunalspathak committed May 30, 2024
1 parent 8b7084a commit cedb079
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 8 deletions.
75 changes: 75 additions & 0 deletions src/coreclr/jit/compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,8 +933,40 @@ inline unsigned Compiler::funGetFuncIdx(BasicBlock* block)
// Assumptions:
// The mask contains one and only one register.

inline regNumber genRegNumFromMask(const SingleTypeRegSet& mask)
{
assert(mask != RBM_NONE); // Must have one bit set, so can't have a mask of zero

/* Convert the mask to a register number */

regNumber regNum = (regNumber)genLog2(mask);

/* Make sure we got it right */
assert(genRegMask(regNum) == mask);

return regNum;
}

//------------------------------------------------------------------------------
// genRegNumFromMask : Maps a single register mask having gpr/float to a register number.
// If the mask can contain predicate register, use genRegNumFromMask(reg, type)
//
// Arguments:
// mask - the register mask
//
// Return Value:
// The number of the register contained in the mask.
//
// Assumptions:
// The mask contains one and only one register.

inline regNumber genRegNumFromMask(const regMaskTP& mask)
{
#ifdef HAS_MORE_THAN_64_REGISTERS
// This method is only used for gpr/float
assert(mask.getHigh() == RBM_NONE);
#endif

assert(mask.IsNonEmpty()); // Must have one bit set, so can't have a mask of zero

/* Convert the mask to a register number */
Expand All @@ -947,6 +979,49 @@ inline regNumber genRegNumFromMask(const regMaskTP& mask)
return regNum;
}

//------------------------------------------------------------------------------
// genRegNumFromMask : Maps a single register mask to a register number.
//
// Arguments:
// mask - the register mask
// type - The
//
// Return Value:
// The number of the register contained in the mask.
//
// Assumptions:
// The mask contains one and only one register.

inline regNumber genRegNumFromMask(const regMaskTP& mask, var_types type)
{
#ifdef HAS_MORE_THAN_64_REGISTERS
// Must have exactly one bit set
assert(PopCount(mask) == 1);

int index = regMaskTP::mapTypeToRegTypeIndex(type);

#ifdef DEBUG
// Make sure the bit number of right `type` is set in the mask
// If typeIndex == 2, then it better be the bit from high mask
// No need to check for typeIndex == 0/1 because above, we already
// verified that PopCount() == 1
assert(index <= 2);
assert((index != 2) || (PopCount(mask.getHigh()) == 1));
#endif // DEBUG

// If this is mask type, add `64` to the regNumber
regNumber regNum = (regNumber)genLog2(mask[index]);
return (regNumber)(((index == 2) << 6) + regNum);

/* Make sure we got it right */
assert(genRegMask(regNum) == mask);

return regNum;
#else
return genRegNumFromMask(mask.getLow());
#endif
}

//------------------------------------------------------------------------------
// genFirstRegNumFromMask : Maps first bit set in the register mask to a register number.
//
Expand Down
12 changes: 6 additions & 6 deletions src/coreclr/jit/lsra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3017,7 +3017,7 @@ regNumber LinearScan::allocateRegMinimal(Interval* currentInterva
return REG_NA;
}

foundReg = genRegNumFromMask(foundRegBit);
foundReg = genRegNumFromMask(foundRegBit, currentInterval->registerType);
availablePhysRegRecord = getRegisterRecord(foundReg);
Interval* assignedInterval = availablePhysRegRecord->assignedInterval;
if ((assignedInterval != currentInterval) &&
Expand Down Expand Up @@ -8382,10 +8382,10 @@ void LinearScan::resolveRegisters()

if (varDsc->lvIsParam)
{
regMaskTP initialRegMask = interval->firstRefPosition->registerAssignment;
SingleTypeRegSet initialRegMask = interval->firstRefPosition->registerAssignment;
regNumber initialReg = (initialRegMask == RBM_NONE || interval->firstRefPosition->spillAfter)
? REG_STK
: genRegNumFromMask(initialRegMask);
: genRegNumFromMask(initialRegMask, interval->registerType);

#ifdef TARGET_ARM
if (varTypeIsMultiReg(varDsc))
Expand Down Expand Up @@ -8809,7 +8809,7 @@ regNumber LinearScan::getTempRegForResolution(BasicBlock* fromBlock,
#endif
}

regNumber tempReg = genRegNumFromMask(genFindLowestBit(freeRegs));
regNumber tempReg = genRegNumFromMask(genFindLowestBit(freeRegs), type);
return tempReg;
}
}
Expand Down Expand Up @@ -13449,7 +13449,7 @@ SingleTypeRegSet LinearScan::RegisterSelection::select(Interval*
//
bool thisIsSingleReg = isSingleRegister(newRelatedPreferences);
if (!thisIsSingleReg ||
linearScan->isFree(linearScan->getRegisterRecord(genRegNumFromMask(newRelatedPreferences))))
linearScan->isFree(linearScan->getRegisterRecord(genRegNumFromMask(newRelatedPreferences, regType))))
{
relatedPreferences = newRelatedPreferences;
// If this Interval has a downstream def without a single-register preference, continue to iterate.
Expand Down Expand Up @@ -13544,7 +13544,7 @@ SingleTypeRegSet LinearScan::RegisterSelection::select(Interval*
if (candidates == refPosition->registerAssignment)
{
found = true;
if (linearScan->nextIntervalRef[genRegNumFromMask(candidates)] > lastLocation)
if (linearScan->nextIntervalRef[genRegNumFromMask(candidates, regType)] > lastLocation)
{
unassignedSet = candidates;
}
Expand Down
3 changes: 2 additions & 1 deletion src/coreclr/jit/lsra.h
Original file line number Diff line number Diff line change
Expand Up @@ -2639,7 +2639,8 @@ class RefPosition
return REG_NA;
}

return genRegNumFromMask(registerAssignment);
return genRegNumFromMask(registerAssignment, getRegisterType());
}

RegisterType getRegisterType()
{
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/lsrabuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ RefPosition* LinearScan::newRefPosition(Interval* theInterval,

if (insertFixedRef)
{
regNumber physicalReg = genRegNumFromMask(mask);
regNumber physicalReg = genRegNumFromMask(mask, theInterval->registerType);
RefPosition* pos = newRefPosition(physicalReg, theLocation, RefTypeFixedReg, nullptr, mask);
assert(theInterval != nullptr);
#if defined(TARGET_LOONGARCH64) || defined(TARGET_RISCV64)
Expand Down

0 comments on commit cedb079

Please sign in to comment.