Skip to content
This repository has been archived by the owner on Jan 20, 2024. It is now read-only.

Enble target op offloading #221

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2092,6 +2092,9 @@ class OpenMPIRBuilder {
/// duplicating the body code.
enum BodyGenTy { Priv, DupNoPriv, NoPriv };

using GenMapInfoCallbackTy =
function_ref<MapInfosTy &(InsertPointTy CodeGenIP)>;

/// Generator for '#omp target data'
///
/// \param Loc The location where the target data construct was encountered.
Expand All @@ -2112,8 +2115,7 @@ class OpenMPIRBuilder {
OpenMPIRBuilder::InsertPointTy createTargetData(
const LocationDescription &Loc, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
TargetDataInfo &Info,
function_ref<MapInfosTy &(InsertPointTy CodeGenIP)> GenMapInfoCB,
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
omp::RuntimeFunction *MapperFunc = nullptr,
function_ref<InsertPointTy(InsertPointTy CodeGenIP,
BodyGenTy BodyGenType)>
Expand All @@ -2137,10 +2139,12 @@ class OpenMPIRBuilder {
/// as arguments to the outlined function.
/// \param BodyGenCB Callback that will generate the region code.
InsertPointTy createTarget(const LocationDescription &Loc,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
int32_t NumThreads,
SmallVectorImpl<Value *> &Inputs,
GenMapInfoCallbackTy GenMapInfoCB,
TargetBodyGenCallbackTy BodyGenCB);

/// Declarations for LLVM-IR types (simple, array, function and structure) are
Expand Down
166 changes: 146 additions & 20 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4161,8 +4161,7 @@ Constant *OpenMPIRBuilder::registerTargetRegionFunction(
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
const LocationDescription &Loc, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
TargetDataInfo &Info,
function_ref<MapInfosTy &(InsertPointTy CodeGenIP)> GenMapInfoCB,
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
omp::RuntimeFunction *MapperFunc,
function_ref<InsertPointTy(InsertPointTy CodeGenIP, BodyGenTy BodyGenType)>
BodyGenCB,
Expand Down Expand Up @@ -4293,19 +4292,85 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
return Builder.saveIP();
}

static Value *castInput(IRBuilderBase &Builder, unsigned AddrSpace,
Value *Input, Argument &Arg) {
assert(Input->getType()->isPointerTy() &&
"Only handling pointer parameters for now");
auto Addr =
Builder.CreateAlloca(Type::getInt64Ty(Builder.getContext()), AddrSpace);
auto AddrAscast =
Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
Builder.CreateStore(&Arg, AddrAscast);
auto CastAddr = Builder.CreateLoad(
Type::getInt32Ty(Builder.getContext())->getPointerTo(), AddrAscast);

return CastAddr;
}

static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
Type *Int8PtrTy, Module &M) {
if (List.empty())
return;

// Convert List to what ConstantArray needs.
SmallVector<Constant *, 8> UsedArray;
UsedArray.resize(List.size());
for (unsigned i = 0, e = List.size(); i != e; ++i) {
UsedArray[i] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
cast<Constant>(&*List[i]), Int8PtrTy);
}

if (UsedArray.empty())
return;
ArrayType *ATy = ArrayType::get(Int8PtrTy, UsedArray.size());

auto *GV =
new GlobalVariable(M, ATy, false, llvm::GlobalValue::AppendingLinkage,
llvm::ConstantArray::get(ATy, UsedArray), Name);

GV->setSection("llvm.metadata");
}

static void
emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
StringRef FunctionName, bool Mode,
std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
auto Int8Ty = Type::getInt8Ty(Builder.getContext());
auto *GVMode = new llvm::GlobalVariable(
OMPBuilder.M, Int8Ty, /*isConstant=*/true,
llvm::GlobalValue::WeakAnyLinkage,
llvm::ConstantInt::get(Int8Ty, Mode ? OMP_TGT_EXEC_MODE_SPMD
: OMP_TGT_EXEC_MODE_GENERIC),
Twine(FunctionName, "_exec_mode"));
GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
LLVMCompilerUsed.emplace_back(GVMode);
}

static Function *
createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
SmallVector<Type *> ParameterTypes;
for (auto &Arg : Inputs)
ParameterTypes.push_back(Arg->getType());
if (OMPBuilder.Config.isTargetDevice()) {
// All parameters are passed as i64
ParameterTypes.assign(Inputs.size(),
Type::getInt64Ty(Builder.getContext()));
} else {
for (auto &Arg : Inputs)
ParameterTypes.push_back(Arg->getType());
}

auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes,
/*isVarArg*/ false);
auto Func = Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName,
Builder.GetInsertBlock()->getModule());

if (OMPBuilder.Config.isTargetDevice()) {
std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
emitExecutionMode(OMPBuilder, Builder, FuncName, false, LLVMCompilerUsed);
Type *Int8PtrTy = Type::getInt8Ty(Builder.getContext())->getPointerTo();
emitUsed("llvm.compiler.used", LLVMCompilerUsed, Int8PtrTy, OMPBuilder.M);
}
// Save insert point.
auto OldInsertPoint = Builder.saveIP();

Expand All @@ -4326,16 +4391,27 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
// Insert return instruction.
Builder.CreateRetVoid();

Builder.SetInsertPoint(&Func->getEntryBlock(),
Func->getEntryBlock().getFirstNonPHIOrDbgOrAlloca());

// Rewrite uses of input valus to parameters.
for (auto InArg : zip(Inputs, Func->args())) {
Value *Input = std::get<0>(InArg);
Argument &Arg = std::get<1>(InArg);

Value *CastInput =
OMPBuilder.Config.isTargetDevice()
? castInput(Builder,
OMPBuilder.M.getDataLayout().getAllocaAddrSpace(),
Input, Arg)
: &Arg;

// Collect all the instructions
assert(CastInput->getType()->isPointerTy() && "Not Pointer Type");
for (User *User : make_early_inc_range(Input->users()))
if (auto Instr = dyn_cast<Instruction>(User))
if (Instr->getFunction() == Func)
Instr->replaceUsesOfWith(Input, &Arg);
Instr->replaceUsesOfWith(Input, CastInput);
}

// Restore insert point.
Expand All @@ -4347,42 +4423,92 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
static void
emitTargetOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
TargetRegionEntryInfo &EntryInfo,
Function *&OutlinedFn, int32_t NumTeams,
int32_t NumThreads, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
Function *&OutlinedFn, Constant *&OutlinedFnID,
int32_t NumTeams, int32_t NumThreads,
SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
OpenMPIRBuilder::InsertPointTy AllocaIP) {

OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
[&OMPBuilder, &Builder, &Inputs, &CBFunc](StringRef EntryFnName) {
[&OMPBuilder, &Builder, &Inputs, &CBFunc,
&AllocaIP](StringRef EntryFnName) {
return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
CBFunc);
};

Constant *OutlinedFnID;
OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction,
NumTeams, NumThreads, true, OutlinedFn,
OutlinedFnID);
}

static void emitTargetCall(IRBuilderBase &Builder, Function *OutlinedFn,
SmallVectorImpl<Value *> &Args) {
// TODO: Add kernel launch call
Builder.CreateCall(OutlinedFn, Args);
static void emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
Function *OutlinedFn, Constant *OutlinedFnID,
int32_t NumTeams, int32_t NumThreads,
SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB) {

llvm::OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);

auto MapInfo = GenMapInfoCB(Builder.saveIP());
OMPBuilder.emitOffloadingArrays(AllocaIP, Builder.saveIP(), MapInfo, Info,
/*IsNonContiguous=*/true);

OpenMPIRBuilder::TargetDataRTArgs RTArgs;
OMPBuilder.emitOffloadingArraysArgument(Builder, RTArgs, Info);

// emitKernelLaunch
auto &&emitTargetCallFallbackCB =
[&](OpenMPIRBuilder::InsertPointTy IP) -> OpenMPIRBuilder::InsertPointTy {
Builder.restoreIP(IP);
Builder.CreateCall(OutlinedFn, Args);
return Builder.saveIP();
};

unsigned NumTargetItems = MapInfo.BasePointers.size();
llvm::Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
llvm::Value *NumTeamsVal = Builder.getInt32(NumTeams);
llvm::Value *NumThreadsVal = Builder.getInt32(NumThreads);
uint32_t SrcLocStrSize;
llvm::Constant *SrcLocStr =
OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
llvm::Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);
llvm::Value *NumIterations = Builder.getInt64(0);
llvm::Value *DynCGGroupMem = Builder.getInt32(0);

bool HasNoWait = false;

OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
NumTeamsVal, NumThreadsVal,
DynCGGroupMem, HasNoWait);

Builder.restoreIP(OMPBuilder.emitKernelLaunch(
Builder, OutlinedFn, OutlinedFnID, emitTargetCallFallbackCB, KArgs,
DeviceID, RTLoc, AllocaIP));
}

OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo, int32_t NumTeams, int32_t NumThreads,
SmallVectorImpl<Value *> &Args, TargetBodyGenCallbackTy CBFunc) {
const LocationDescription &Loc, OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
int32_t NumTeams, int32_t NumThreads, SmallVectorImpl<Value *> &Args,
GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy CBFunc) {
if (!updateToLocation(Loc))
return InsertPointTy();

Builder.restoreIP(CodeGenIP);

Function *OutlinedFn;
emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn, NumTeams,
NumThreads, Args, CBFunc);
Constant *OutlinedFnID;
emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn,
OutlinedFnID, NumTeams, NumThreads, Args, CBFunc,
AllocaIP);
if (!Config.isTargetDevice())
emitTargetCall(Builder, OutlinedFn, Args);
emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
NumThreads, Args, GenMapInfoCB);

return Builder.saveIP();
}

Expand Down
57 changes: 52 additions & 5 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5068,6 +5068,36 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) {
EXPECT_FALSE(verifyModule(*M, &errs()));
}

namespace {

// Some basic handling of argument mapping for the moment
void CreateDefaultMapInfos(llvm::OpenMPIRBuilder &ompBuilder,
llvm::SmallVectorImpl<llvm::Value *> &args,
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo) {
for (auto arg : args) {
if (!arg->getType()->isPointerTy()) {
combinedInfo.BasePointers.clear();
combinedInfo.Pointers.clear();
combinedInfo.Sizes.clear();
combinedInfo.Types.clear();
combinedInfo.Names.clear();
return;
}
combinedInfo.BasePointers.emplace_back(arg);
combinedInfo.Pointers.emplace_back(arg);
uint32_t SrcLocStrSize;
combinedInfo.Names.emplace_back(ompBuilder.getOrCreateSrcLocStr(
"Unknown loc - stub implementation", SrcLocStrSize));
combinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM));
combinedInfo.Sizes.emplace_back(ompBuilder.Builder.getInt64(
ompBuilder.M.getDataLayout().getTypeAllocSize(arg->getType())));
}
}

} // namespace

TEST_F(OpenMPIRBuilderTest, TargetRegion) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
Expand Down Expand Up @@ -5099,10 +5129,19 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
Inputs.push_back(BPtr);
Inputs.push_back(CPtr);

llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
-> llvm::OpenMPIRBuilder::MapInfosTy & {
CreateDefaultMapInfos(OMPBuilder, Inputs, CombinedInfos);
return CombinedInfos;
};

TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
Builder.restoreIP(OMPBuilder.createTarget(OmpLoc, Builder.saveIP(), EntryInfo,
-1, -1, Inputs, BodyGenCB));

Builder.restoreIP(OMPBuilder.createTarget(OmpLoc, Builder.saveIP(),
Builder.saveIP(), EntryInfo, -1, -1,
Inputs, GenMapInfoCB, BodyGenCB));
OMPBuilder.finalize();
Builder.CreateRetVoid();

Expand Down Expand Up @@ -5138,6 +5177,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
Constant::getIntegerValue(Type::getInt32Ty(Ctx), APInt(32, 0)),
Constant::getNullValue(Type::getInt32PtrTy(Ctx))};

llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
-> llvm::OpenMPIRBuilder::MapInfosTy & {
CreateDefaultMapInfos(OMPBuilder, CapturedArgs, CombinedInfos);
return CombinedInfos;
};

auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP)
-> OpenMPIRBuilder::InsertPointTy {
Expand All @@ -5151,9 +5197,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);

Builder.restoreIP(
OMPBuilder.createTarget(Loc, EntryIP, EntryInfo, /*NumTeams=*/-1,
/*NumThreads=*/-1, CapturedArgs, BodyGenCB));
Builder.restoreIP(OMPBuilder.createTarget(
Loc, EntryIP, EntryIP, EntryInfo, /*NumTeams=*/-1,
/*NumThreads=*/-1, CapturedArgs, GenMapInfoCB, BodyGenCB));

Builder.CreateRetVoid();
OMPBuilder.finalize();

Expand Down
Loading