Skip to content

Commit

Permalink
Rollup merge of rust-lang#55073 - alexcrichton:demote-simd, r=nagisa
Browse files Browse the repository at this point in the history
The issue of passing around SIMD types as values between functions has
seen [quite a lot] of [discussion], and although we thought [we fixed
it][quite a lot] it [wasn't]! This PR is a change to rustc to, again,
try to fix this issue.

The fundamental problem here remains the same, if a SIMD vector argument
is passed by-value in LLVM's function type, then if the caller and
callee disagree on target features a miscompile happens. We solve this
by never passing SIMD vectors by-value, but LLVM will still thwart us
with its argument promotion pass to promote by-ref SIMD arguments to
by-val SIMD arguments.

This commit is an attempt to thwart LLVM thwarting us. We, just before
codegen, will take yet another look at the LLVM module and demote any
by-value SIMD arguments we see. This is a very manual attempt by us to
ensure the codegen for a module keeps working, and it unfortunately is
likely producing suboptimal code, even in release mode. The saving grace
for this, in theory, is that if SIMD types are passed by-value across
a boundary in release mode it's pretty unlikely to be performance
sensitive (as it's already doing a load/store, and otherwise
perf-sensitive bits should be inlined).

The implementation here is basically a big wad of C++. It was largely
copied from LLVM's own argument promotion pass, only doing the reverse.
In local testing this...

Closes rust-lang#50154
Closes rust-lang#52636
Closes rust-lang#54583
Closes rust-lang#55059

[quite a lot]: rust-lang#47743
[discussion]: rust-lang#44367
[wasn't]: rust-lang#50154
  • Loading branch information
Manishearth committed Oct 20, 2018
2 parents 22cc2ae + 3cc8f73 commit b860765
Show file tree
Hide file tree
Showing 9 changed files with 332 additions and 9 deletions.
12 changes: 5 additions & 7 deletions src/librustc_codegen_llvm/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ impl LtoModuleCodegen {
let module = module.take().unwrap();
{
let config = cgcx.config(module.kind);
let llmod = module.module_llvm.llmod();
let tm = &*module.module_llvm.tm;
run_pass_manager(cgcx, tm, llmod, config, false);
run_pass_manager(cgcx, &module, config, false);
timeline.record("fat-done");
}
Ok(module)
Expand Down Expand Up @@ -557,8 +555,7 @@ fn thin_lto(cgcx: &CodegenContext,
}

fn run_pass_manager(cgcx: &CodegenContext,
tm: &llvm::TargetMachine,
llmod: &llvm::Module,
module: &ModuleCodegen,
config: &ModuleConfig,
thin: bool) {
// Now we have one massive module inside of llmod. Time to run the
Expand All @@ -569,7 +566,8 @@ fn run_pass_manager(cgcx: &CodegenContext,
debug!("running the pass manager");
unsafe {
let pm = llvm::LLVMCreatePassManager();
llvm::LLVMRustAddAnalysisPasses(tm, pm, llmod);
let llmod = module.module_llvm.llmod();
llvm::LLVMRustAddAnalysisPasses(module.module_llvm.tm, pm, llmod);

if config.verify_llvm_ir {
let pass = llvm::LLVMRustFindAndCreatePass("verify\0".as_ptr() as *const _);
Expand Down Expand Up @@ -864,7 +862,7 @@ impl ThinModule {
// little differently.
info!("running thin lto passes over {}", module.name);
let config = cgcx.config(module.kind);
run_pass_manager(cgcx, module.module_llvm.tm, llmod, config, true);
run_pass_manager(cgcx, &module, config, true);
cgcx.save_temp_bitcode(&module, "thin-lto-after-pm");
timeline.record("thin-done");
}
Expand Down
34 changes: 33 additions & 1 deletion src/librustc_codegen_llvm/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ unsafe fn optimize(cgcx: &CodegenContext,
None,
&format!("llvm module passes [{}]", module_name.unwrap()),
|| {
llvm::LLVMRunPassManager(mpm, llmod)
llvm::LLVMRunPassManager(mpm, llmod);
});

// Deallocate managers that we're now done with
Expand Down Expand Up @@ -691,6 +691,38 @@ unsafe fn codegen(cgcx: &CodegenContext,
create_msvc_imps(cgcx, llcx, llmod);
}

// Ok now this one's a super interesting invocations. SIMD in rustc is
// difficult where we want some parts of the program to be able to use
// some SIMD features while other parts of the program don't. The real
// tough part is that we want this to actually work correctly!
//
// We go to great lengths to make sure this works, and one crucial
// aspect is that vector arguments (simd types) are never passed by
// value in the ABI of functions. It turns out, however, that LLVM will
// undo our "clever work" of passing vector types by reference. Its
// argument promotion pass will promote these by-ref arguments to
// by-val. That, however, introduces codegen errors!
//
// The upstream LLVM bug [1] has unfortunatey not really seen a lot of
// activity. The Rust bug [2], however, has seen quite a lot of reports
// of this in the wild. As a result, this is worked around locally here.
// We have a custom transformation, `LLVMRustDemoteSimdArguments`, which
// does the opposite of argument promotion by demoting any by-value SIMD
// arguments in function signatures to pointers intead of being
// by-value.
//
// This operates at the LLVM IR layer because LLVM is thwarting our
// codegen and this is the only chance we get to make sure it's correct
// before we hit codegen.
//
// Hopefully one day the upstream LLVM bug will be fixed and we'll no
// longer need this!
//
// [1]: https://bugs.llvm.org/show_bug.cgi?id=37358
// [2]: https://github.com/rust-lang/rust/issues/50154
llvm::LLVMRustDemoteSimdArguments(llmod);
cgcx.save_temp_bitcode(&module, "simd-demoted");

// A codegen-specific pass manager is used to generate object
// files for an LLVM module.
//
Expand Down
2 changes: 2 additions & 0 deletions src/librustc_codegen_llvm/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,8 @@ extern "C" {
/// Runs a pass manager on a module.
pub fn LLVMRunPassManager(PM: &PassManager<'a>, M: &'a Module) -> Bool;

pub fn LLVMRustDemoteSimdArguments(M: &'a Module);

pub fn LLVMInitializePasses();

pub fn LLVMPassManagerBuilderCreate() -> &'static mut PassManagerBuilder;
Expand Down
4 changes: 3 additions & 1 deletion src/librustc_llvm/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ fn main() {
}

build_helper::rerun_if_changed_anything_in_dir(Path::new("../rustllvm"));
cfg.file("../rustllvm/PassWrapper.cpp")
cfg
.file("../rustllvm/DemoteSimd.cpp")
.file("../rustllvm/PassWrapper.cpp")
.file("../rustllvm/RustWrapper.cpp")
.file("../rustllvm/ArchiveWrapper.cpp")
.file("../rustllvm/Linker.cpp")
Expand Down
189 changes: 189 additions & 0 deletions src/rustllvm/DemoteSimd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

#include <vector>
#include <set>

#include "rustllvm.h"

#if LLVM_VERSION_GE(5, 0)

#include "llvm/IR/CallSite.h"
#include "llvm/IR/Module.h"
#include "llvm/ADT/STLExtras.h"

using namespace llvm;

static std::vector<Function*>
GetFunctionsWithSimdArgs(Module *M) {
std::vector<Function*> Ret;

for (auto &F : M->functions()) {
// Skip all intrinsic calls as these are always tightly controlled to "work
// correctly", so no need to fixup any of these.
if (F.isIntrinsic())
continue;

// We're only interested in rustc-defined functions, not unstably-defined
// imported SIMD ffi functions.
if (F.isDeclaration())
continue;

// Argument promotion only happens on internal functions, so skip demoting
// arguments in external functions like FFI shims and such.
if (!F.hasLocalLinkage())
continue;

// If any argument to this function is a by-value vector type, then that's
// bad! The compiler didn't generate any functions that looked like this,
// and we try to rely on LLVM to not do this! Argument promotion may,
// however, promote arguments from behind references. In any case, figure
// out if we're interested in demoting this argument.
if (any_of(F.args(), [](Argument &arg) { return arg.getType()->isVectorTy(); }))
Ret.push_back(&F);
}

return Ret;
}

extern "C" void
LLVMRustDemoteSimdArguments(LLVMModuleRef Mod) {
Module *M = unwrap(Mod);

auto Functions = GetFunctionsWithSimdArgs(M);

for (auto F : Functions) {
// Build up our list of new parameters and new argument attributes.
// We're only changing those arguments which are vector types.
SmallVector<Type*, 8> Params;
SmallVector<AttributeSet, 8> ArgAttrVec;
auto PAL = F->getAttributes();
for (auto &Arg : F->args()) {
auto *Ty = Arg.getType();
if (Ty->isVectorTy()) {
Params.push_back(PointerType::get(Ty, 0));
ArgAttrVec.push_back(AttributeSet());
} else {
Params.push_back(Ty);
ArgAttrVec.push_back(PAL.getParamAttributes(Arg.getArgNo()));
}
}

// Replace `F` with a new function with our new signature. I'm... not really
// sure how this works, but this is all the steps `ArgumentPromotion` does
// to replace a signature as well.
assert(!F->isVarArg()); // ArgumentPromotion should skip these fns
FunctionType *NFTy = FunctionType::get(F->getReturnType(), Params, false);
Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName());
NF->copyAttributesFrom(F);
NF->setSubprogram(F->getSubprogram());
F->setSubprogram(nullptr);
NF->setAttributes(AttributeList::get(F->getContext(),
PAL.getFnAttributes(),
PAL.getRetAttributes(),
ArgAttrVec));
ArgAttrVec.clear();
F->getParent()->getFunctionList().insert(F->getIterator(), NF);
NF->takeName(F);

// Iterate over all invocations of `F`, updating all `call` instructions to
// store immediate vector types in a local `alloc` instead of a by-value
// vector.
//
// Like before, much of this is copied from the `ArgumentPromotion` pass in
// LLVM.
SmallVector<Value*, 16> Args;
while (!F->use_empty()) {
CallSite CS(F->user_back());
assert(CS.getCalledFunction() == F);
Instruction *Call = CS.getInstruction();
const AttributeList &CallPAL = CS.getAttributes();

// Loop over the operands, inserting an `alloca` and a store for any
// argument we're demoting to be by reference
//
// FIXME: we probably want to figure out an LLVM pass to run and clean up
// this function and instructions we're generating, we should in theory
// only generate a maximum number of `alloca` instructions rather than
// one-per-variable unconditionally.
CallSite::arg_iterator AI = CS.arg_begin();
size_t ArgNo = 0;
for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
++I, ++AI, ++ArgNo) {
if (I->getType()->isVectorTy()) {
AllocaInst *AllocA = new AllocaInst(I->getType(), 0, nullptr, "", Call);
new StoreInst(*AI, AllocA, Call);
Args.push_back(AllocA);
ArgAttrVec.push_back(AttributeSet());
} else {
Args.push_back(*AI);
ArgAttrVec.push_back(CallPAL.getParamAttributes(ArgNo));
}
}
assert(AI == CS.arg_end());

// Create a new call instructions which we'll use to replace the old call
// instruction, copying over as many attributes and such as possible.
SmallVector<OperandBundleDef, 1> OpBundles;
CS.getOperandBundlesAsDefs(OpBundles);

CallSite NewCS;
if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) {
InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
Args, OpBundles, "", Call);
} else {
auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", Call);
NewCall->setTailCallKind(cast<CallInst>(Call)->getTailCallKind());
NewCS = NewCall;
}
NewCS.setCallingConv(CS.getCallingConv());
NewCS.setAttributes(
AttributeList::get(F->getContext(), CallPAL.getFnAttributes(),
CallPAL.getRetAttributes(), ArgAttrVec));
NewCS->setDebugLoc(Call->getDebugLoc());
Args.clear();
ArgAttrVec.clear();
Call->replaceAllUsesWith(NewCS.getInstruction());
NewCS->takeName(Call);
Call->eraseFromParent();
}

// Splice the body of the old function right into the new function.
NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());

// Update our new function to replace all uses of the by-value argument with
// loads of the pointer argument we've generated.
//
// FIXME: we probably want to only generate one load instruction per
// function? Or maybe run an LLVM pass to clean up this function?
for (Function::arg_iterator I = F->arg_begin(),
E = F->arg_end(),
I2 = NF->arg_begin();
I != E;
++I, ++I2) {
if (I->getType()->isVectorTy()) {
I->replaceAllUsesWith(new LoadInst(&*I2, "", &NF->begin()->front()));
} else {
I->replaceAllUsesWith(&*I2);
}
I2->takeName(&*I);
}

// Delete all references to the old function, it should be entirely dead
// now.
M->getFunctionList().remove(F);
}
}

#else // LLVM_VERSION_GE(8, 0)
extern "C" void
LLVMRustDemoteSimdArguments(LLVMModuleRef Mod) {
}
#endif // LLVM_VERSION_GE(8, 0)
13 changes: 13 additions & 0 deletions src/test/run-make/simd-argument-promotion-thwarted/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-include ../../run-make-fulldeps/tools.mk

ifeq ($(TARGET),x86_64-unknown-linux-gnu)
all:
$(RUSTC) t1.rs -C opt-level=3
$(TMPDIR)/t1
$(RUSTC) t2.rs -C opt-level=3
$(TMPDIR)/t2
$(RUSTC) t3.rs -C opt-level=3
$(TMPDIR)/t3
else
all:
endif
21 changes: 21 additions & 0 deletions src/test/run-make/simd-argument-promotion-thwarted/t1.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use std::arch::x86_64;

fn main() {
if !is_x86_feature_detected!("avx2") {
return println!("AVX2 is not supported on this machine/build.");
}
let load_bytes: [u8; 32] = [0x0f; 32];
let lb_ptr = load_bytes.as_ptr();
let reg_load = unsafe {
x86_64::_mm256_loadu_si256(
lb_ptr as *const x86_64::__m256i
)
};
println!("{:?}", reg_load);
let mut store_bytes: [u8; 32] = [0; 32];
let sb_ptr = store_bytes.as_mut_ptr();
unsafe {
x86_64::_mm256_storeu_si256(sb_ptr as *mut x86_64::__m256i, reg_load);
}
assert_eq!(load_bytes, store_bytes);
}
14 changes: 14 additions & 0 deletions src/test/run-make/simd-argument-promotion-thwarted/t2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use std::arch::x86_64::*;

fn main() {
if !is_x86_feature_detected!("avx") {
return println!("AVX is not supported on this machine/build.");
}
unsafe {
let f = _mm256_set_pd(2.0, 2.0, 2.0, 2.0);
let r = _mm256_mul_pd(f, f);

union A { a: __m256d, b: [f64; 4] }
assert_eq!(A { a: r }.b, [4.0, 4.0, 4.0, 4.0]);
}
}
Loading

0 comments on commit b860765

Please sign in to comment.