Skip to content
This repository has been archived by the owner on Jan 20, 2024. It is now read-only.

Commit

Permalink
Merge pull request #221 from jsjodin/jsjodin/target-offload-mlir
Browse files Browse the repository at this point in the history
Enble target op offloading
  • Loading branch information
gregrodgers committed Jul 14, 2023
2 parents 6a84d33 + 88bc445 commit 6d63895
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 30 deletions.
8 changes: 6 additions & 2 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2092,6 +2092,9 @@ class OpenMPIRBuilder {
/// duplicating the body code.
enum BodyGenTy { Priv, DupNoPriv, NoPriv };

using GenMapInfoCallbackTy =
function_ref<MapInfosTy &(InsertPointTy CodeGenIP)>;

/// Generator for '#omp target data'
///
/// \param Loc The location where the target data construct was encountered.
Expand All @@ -2112,8 +2115,7 @@ class OpenMPIRBuilder {
OpenMPIRBuilder::InsertPointTy createTargetData(
const LocationDescription &Loc, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
TargetDataInfo &Info,
function_ref<MapInfosTy &(InsertPointTy CodeGenIP)> GenMapInfoCB,
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
omp::RuntimeFunction *MapperFunc = nullptr,
function_ref<InsertPointTy(InsertPointTy CodeGenIP,
BodyGenTy BodyGenType)>
Expand All @@ -2137,10 +2139,12 @@ class OpenMPIRBuilder {
/// as arguments to the outlined function.
/// \param BodyGenCB Callback that will generate the region code.
InsertPointTy createTarget(const LocationDescription &Loc,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
int32_t NumThreads,
SmallVectorImpl<Value *> &Inputs,
GenMapInfoCallbackTy GenMapInfoCB,
TargetBodyGenCallbackTy BodyGenCB);

/// Declarations for LLVM-IR types (simple, array, function and structure) are
Expand Down
166 changes: 146 additions & 20 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4161,8 +4161,7 @@ Constant *OpenMPIRBuilder::registerTargetRegionFunction(
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
const LocationDescription &Loc, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
TargetDataInfo &Info,
function_ref<MapInfosTy &(InsertPointTy CodeGenIP)> GenMapInfoCB,
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
omp::RuntimeFunction *MapperFunc,
function_ref<InsertPointTy(InsertPointTy CodeGenIP, BodyGenTy BodyGenType)>
BodyGenCB,
Expand Down Expand Up @@ -4293,19 +4292,85 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
return Builder.saveIP();
}

static Value *castInput(IRBuilderBase &Builder, unsigned AddrSpace,
Value *Input, Argument &Arg) {
assert(Input->getType()->isPointerTy() &&
"Only handling pointer parameters for now");
auto Addr =
Builder.CreateAlloca(Type::getInt64Ty(Builder.getContext()), AddrSpace);
auto AddrAscast =
Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
Builder.CreateStore(&Arg, AddrAscast);
auto CastAddr = Builder.CreateLoad(
Type::getInt32Ty(Builder.getContext())->getPointerTo(), AddrAscast);

return CastAddr;
}

static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
Type *Int8PtrTy, Module &M) {
if (List.empty())
return;

// Convert List to what ConstantArray needs.
SmallVector<Constant *, 8> UsedArray;
UsedArray.resize(List.size());
for (unsigned i = 0, e = List.size(); i != e; ++i) {
UsedArray[i] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
cast<Constant>(&*List[i]), Int8PtrTy);
}

if (UsedArray.empty())
return;
ArrayType *ATy = ArrayType::get(Int8PtrTy, UsedArray.size());

auto *GV =
new GlobalVariable(M, ATy, false, llvm::GlobalValue::AppendingLinkage,
llvm::ConstantArray::get(ATy, UsedArray), Name);

GV->setSection("llvm.metadata");
}

static void
emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
StringRef FunctionName, bool Mode,
std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
auto Int8Ty = Type::getInt8Ty(Builder.getContext());
auto *GVMode = new llvm::GlobalVariable(
OMPBuilder.M, Int8Ty, /*isConstant=*/true,
llvm::GlobalValue::WeakAnyLinkage,
llvm::ConstantInt::get(Int8Ty, Mode ? OMP_TGT_EXEC_MODE_SPMD
: OMP_TGT_EXEC_MODE_GENERIC),
Twine(FunctionName, "_exec_mode"));
GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
LLVMCompilerUsed.emplace_back(GVMode);
}

static Function *
createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
SmallVector<Type *> ParameterTypes;
for (auto &Arg : Inputs)
ParameterTypes.push_back(Arg->getType());
if (OMPBuilder.Config.isTargetDevice()) {
// All parameters are passed as i64
ParameterTypes.assign(Inputs.size(),
Type::getInt64Ty(Builder.getContext()));
} else {
for (auto &Arg : Inputs)
ParameterTypes.push_back(Arg->getType());
}

auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes,
/*isVarArg*/ false);
auto Func = Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName,
Builder.GetInsertBlock()->getModule());

if (OMPBuilder.Config.isTargetDevice()) {
std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
emitExecutionMode(OMPBuilder, Builder, FuncName, false, LLVMCompilerUsed);
Type *Int8PtrTy = Type::getInt8Ty(Builder.getContext())->getPointerTo();
emitUsed("llvm.compiler.used", LLVMCompilerUsed, Int8PtrTy, OMPBuilder.M);
}
// Save insert point.
auto OldInsertPoint = Builder.saveIP();

Expand All @@ -4326,16 +4391,27 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
// Insert return instruction.
Builder.CreateRetVoid();

Builder.SetInsertPoint(&Func->getEntryBlock(),
Func->getEntryBlock().getFirstNonPHIOrDbgOrAlloca());

// Rewrite uses of input valus to parameters.
for (auto InArg : zip(Inputs, Func->args())) {
Value *Input = std::get<0>(InArg);
Argument &Arg = std::get<1>(InArg);

Value *CastInput =
OMPBuilder.Config.isTargetDevice()
? castInput(Builder,
OMPBuilder.M.getDataLayout().getAllocaAddrSpace(),
Input, Arg)
: &Arg;

// Collect all the instructions
assert(CastInput->getType()->isPointerTy() && "Not Pointer Type");
for (User *User : make_early_inc_range(Input->users()))
if (auto Instr = dyn_cast<Instruction>(User))
if (Instr->getFunction() == Func)
Instr->replaceUsesOfWith(Input, &Arg);
Instr->replaceUsesOfWith(Input, CastInput);
}

// Restore insert point.
Expand All @@ -4347,42 +4423,92 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
static void
emitTargetOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
TargetRegionEntryInfo &EntryInfo,
Function *&OutlinedFn, int32_t NumTeams,
int32_t NumThreads, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
Function *&OutlinedFn, Constant *&OutlinedFnID,
int32_t NumTeams, int32_t NumThreads,
SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
OpenMPIRBuilder::InsertPointTy AllocaIP) {

OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
[&OMPBuilder, &Builder, &Inputs, &CBFunc](StringRef EntryFnName) {
[&OMPBuilder, &Builder, &Inputs, &CBFunc,
&AllocaIP](StringRef EntryFnName) {
return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
CBFunc);
};

Constant *OutlinedFnID;
OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction,
NumTeams, NumThreads, true, OutlinedFn,
OutlinedFnID);
}

static void emitTargetCall(IRBuilderBase &Builder, Function *OutlinedFn,
SmallVectorImpl<Value *> &Args) {
// TODO: Add kernel launch call
Builder.CreateCall(OutlinedFn, Args);
static void emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
Function *OutlinedFn, Constant *OutlinedFnID,
int32_t NumTeams, int32_t NumThreads,
SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB) {

llvm::OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);

auto MapInfo = GenMapInfoCB(Builder.saveIP());
OMPBuilder.emitOffloadingArrays(AllocaIP, Builder.saveIP(), MapInfo, Info,
/*IsNonContiguous=*/true);

OpenMPIRBuilder::TargetDataRTArgs RTArgs;
OMPBuilder.emitOffloadingArraysArgument(Builder, RTArgs, Info);

// emitKernelLaunch
auto &&emitTargetCallFallbackCB =
[&](OpenMPIRBuilder::InsertPointTy IP) -> OpenMPIRBuilder::InsertPointTy {
Builder.restoreIP(IP);
Builder.CreateCall(OutlinedFn, Args);
return Builder.saveIP();
};

unsigned NumTargetItems = MapInfo.BasePointers.size();
llvm::Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
llvm::Value *NumTeamsVal = Builder.getInt32(NumTeams);
llvm::Value *NumThreadsVal = Builder.getInt32(NumThreads);
uint32_t SrcLocStrSize;
llvm::Constant *SrcLocStr =
OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
llvm::Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);
llvm::Value *NumIterations = Builder.getInt64(0);
llvm::Value *DynCGGroupMem = Builder.getInt32(0);

bool HasNoWait = false;

OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
NumTeamsVal, NumThreadsVal,
DynCGGroupMem, HasNoWait);

Builder.restoreIP(OMPBuilder.emitKernelLaunch(
Builder, OutlinedFn, OutlinedFnID, emitTargetCallFallbackCB, KArgs,
DeviceID, RTLoc, AllocaIP));
}

OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo, int32_t NumTeams, int32_t NumThreads,
SmallVectorImpl<Value *> &Args, TargetBodyGenCallbackTy CBFunc) {
const LocationDescription &Loc, OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
int32_t NumTeams, int32_t NumThreads, SmallVectorImpl<Value *> &Args,
GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy CBFunc) {
if (!updateToLocation(Loc))
return InsertPointTy();

Builder.restoreIP(CodeGenIP);

Function *OutlinedFn;
emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn, NumTeams,
NumThreads, Args, CBFunc);
Constant *OutlinedFnID;
emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn,
OutlinedFnID, NumTeams, NumThreads, Args, CBFunc,
AllocaIP);
if (!Config.isTargetDevice())
emitTargetCall(Builder, OutlinedFn, Args);
emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
NumThreads, Args, GenMapInfoCB);

return Builder.saveIP();
}

Expand Down
57 changes: 52 additions & 5 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5068,6 +5068,36 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) {
EXPECT_FALSE(verifyModule(*M, &errs()));
}

namespace {

// Some basic handling of argument mapping for the moment
void CreateDefaultMapInfos(llvm::OpenMPIRBuilder &ompBuilder,
llvm::SmallVectorImpl<llvm::Value *> &args,
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo) {
for (auto arg : args) {
if (!arg->getType()->isPointerTy()) {
combinedInfo.BasePointers.clear();
combinedInfo.Pointers.clear();
combinedInfo.Sizes.clear();
combinedInfo.Types.clear();
combinedInfo.Names.clear();
return;
}
combinedInfo.BasePointers.emplace_back(arg);
combinedInfo.Pointers.emplace_back(arg);
uint32_t SrcLocStrSize;
combinedInfo.Names.emplace_back(ompBuilder.getOrCreateSrcLocStr(
"Unknown loc - stub implementation", SrcLocStrSize));
combinedInfo.Types.emplace_back(llvm::omp::OpenMPOffloadMappingFlags(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM));
combinedInfo.Sizes.emplace_back(ompBuilder.Builder.getInt64(
ompBuilder.M.getDataLayout().getTypeAllocSize(arg->getType())));
}
}

} // namespace

TEST_F(OpenMPIRBuilderTest, TargetRegion) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
Expand Down Expand Up @@ -5099,10 +5129,19 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
Inputs.push_back(BPtr);
Inputs.push_back(CPtr);

llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
-> llvm::OpenMPIRBuilder::MapInfosTy & {
CreateDefaultMapInfos(OMPBuilder, Inputs, CombinedInfos);
return CombinedInfos;
};

TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
Builder.restoreIP(OMPBuilder.createTarget(OmpLoc, Builder.saveIP(), EntryInfo,
-1, -1, Inputs, BodyGenCB));

Builder.restoreIP(OMPBuilder.createTarget(OmpLoc, Builder.saveIP(),
Builder.saveIP(), EntryInfo, -1, -1,
Inputs, GenMapInfoCB, BodyGenCB));
OMPBuilder.finalize();
Builder.CreateRetVoid();

Expand Down Expand Up @@ -5138,6 +5177,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
Constant::getIntegerValue(Type::getInt32Ty(Ctx), APInt(32, 0)),
Constant::getNullValue(Type::getInt32PtrTy(Ctx))};

llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
-> llvm::OpenMPIRBuilder::MapInfosTy & {
CreateDefaultMapInfos(OMPBuilder, CapturedArgs, CombinedInfos);
return CombinedInfos;
};

auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP)
-> OpenMPIRBuilder::InsertPointTy {
Expand All @@ -5151,9 +5197,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);

Builder.restoreIP(
OMPBuilder.createTarget(Loc, EntryIP, EntryInfo, /*NumTeams=*/-1,
/*NumThreads=*/-1, CapturedArgs, BodyGenCB));
Builder.restoreIP(OMPBuilder.createTarget(
Loc, EntryIP, EntryIP, EntryInfo, /*NumTeams=*/-1,
/*NumThreads=*/-1, CapturedArgs, GenMapInfoCB, BodyGenCB));

Builder.CreateRetVoid();
OMPBuilder.finalize();

Expand Down
Loading

0 comments on commit 6d63895

Please sign in to comment.