Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVPTX] deprecate nvvm.rotate.* intrinsics, cleanup funnel-shift handling #107655

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

AlexMaclean
Copy link
Member

This change deprecates the following intrinsics which can be trivially converted to llvm funnel-shift intrinsics:

  • @llvm.nvvm.rotate.b32
  • @llvm.nvvm.rotate.right.b64
  • @llvm.nvvm.rotate.b64

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 6, 2024

@llvm/pr-subscribers-backend-nvptx

@llvm/pr-subscribers-llvm-ir

Author: Alex MacLean (AlexMaclean)

Changes

This change deprecates the following intrinsics which can be trivially converted to llvm funnel-shift intrinsics:

  • @llvm.nvvm.rotate.b32
  • @llvm.nvvm.rotate.right.b64
  • @llvm.nvvm.rotate.b64

Patch is 39.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/107655.diff

6 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (-16)
  • (modified) llvm/lib/IR/AutoUpgrade.cpp (+103-75)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+218-161)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+2-127)
  • (modified) llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll (+16)
  • (modified) llvm/test/CodeGen/NVPTX/rotate.ll (+36-60)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 39685c920d948d..1b94c56ca9b828 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -4489,22 +4489,6 @@ def int_nvvm_sust_p_3d_v4i32_trap
               "llvm.nvvm.sust.p.3d.v4i32.trap">,
     ClangBuiltin<"__nvvm_sust_p_3d_v4i32_trap">;
 
-
-def int_nvvm_rotate_b32
-  : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty],
-              [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b32">,
-              ClangBuiltin<"__nvvm_rotate_b32">;
-
-def int_nvvm_rotate_b64
-  : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty],
-             [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b64">,
-             ClangBuiltin<"__nvvm_rotate_b64">;
-
-def int_nvvm_rotate_right_b64
-  : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty],
-              [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.right.b64">,
-              ClangBuiltin<"__nvvm_rotate_right_b64">;
-
 def int_nvvm_swap_lo_hi_b64
   : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty],
               [IntrNoMem, IntrSpeculatable], "llvm.nvvm.swap.lo.hi.b64">,
diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp
index 69dae5e32dbbe8..2b8e93dff46684 100644
--- a/llvm/lib/IR/AutoUpgrade.cpp
+++ b/llvm/lib/IR/AutoUpgrade.cpp
@@ -1268,6 +1268,9 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
       else if (Name.consume_front("atomic.load.add."))
         // nvvm.atomic.load.add.{f32.p,f64.p}
         Expand = Name.starts_with("f32.p") || Name.starts_with("f64.p");
+      else if (Name.consume_front("rotate."))
+        // nvvm.rotate.{b32,b64,right.b64}
+        Expand = Name == "b32" || Name == "b64" || Name == "right.b64";
       else
         Expand = false;
 
@@ -2254,6 +2257,104 @@ void llvm::UpgradeInlineAsmString(std::string *AsmStr) {
   }
 }
 
+static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
+                                       Function *F, IRBuilder<> &Builder) {
+  Value *Rep = nullptr;
+
+  if (Name == "abs.i" || Name == "abs.ll") {
+    Value *Arg = CI->getArgOperand(0);
+    Value *Neg = Builder.CreateNeg(Arg, "neg");
+    Value *Cmp = Builder.CreateICmpSGE(
+        Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond");
+    Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs");
+  } else if (Name.starts_with("atomic.load.add.f32.p") ||
+             Name.starts_with("atomic.load.add.f64.p")) {
+    Value *Ptr = CI->getArgOperand(0);
+    Value *Val = CI->getArgOperand(1);
+    Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(),
+                                  AtomicOrdering::SequentiallyConsistent);
+  } else if (Name.consume_front("max.") &&
+             (Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
+              Name == "ui" || Name == "ull")) {
+    Value *Arg0 = CI->getArgOperand(0);
+    Value *Arg1 = CI->getArgOperand(1);
+    Value *Cmp = Name.starts_with("u")
+                     ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond")
+                     : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond");
+    Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max");
+  } else if (Name.consume_front("min.") &&
+             (Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
+              Name == "ui" || Name == "ull")) {
+    Value *Arg0 = CI->getArgOperand(0);
+    Value *Arg1 = CI->getArgOperand(1);
+    Value *Cmp = Name.starts_with("u")
+                     ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond")
+                     : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond");
+    Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min");
+  } else if (Name == "clz.ll") {
+    // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64.
+    Value *Arg = CI->getArgOperand(0);
+    Value *Ctlz = Builder.CreateCall(
+        Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz,
+                                  {Arg->getType()}),
+        {Arg, Builder.getFalse()}, "ctlz");
+    Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc");
+  } else if (Name == "popc.ll") {
+    // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an
+    // i64.
+    Value *Arg = CI->getArgOperand(0);
+    Value *Popc = Builder.CreateCall(
+        Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop,
+                                  {Arg->getType()}),
+        Arg, "ctpop");
+    Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc");
+  } else if (Name == "h2f") {
+    Rep = Builder.CreateCall(
+        Intrinsic::getDeclaration(F->getParent(), Intrinsic::convert_from_fp16,
+                                  {Builder.getFloatTy()}),
+        CI->getArgOperand(0), "h2f");
+  } else if (Name == "rotate.b32") {
+    Value *Arg = CI->getOperand(0);
+    Value *ShiftAmt = CI->getOperand(1);
+    Rep = Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::fshl,
+                                  {Arg, Arg, ShiftAmt});
+  } else if (Name == "rotate.b64") {
+    Type *Int64Ty = Builder.getInt64Ty();
+    Value *Arg = CI->getOperand(0);
+    Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty);
+    Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshl,
+                                  {Arg, Arg, ZExtShiftAmt});
+  } else if (Name == "rotate.right.b64") {
+    Type *Int64Ty = Builder.getInt64Ty();
+    Value *Arg = CI->getOperand(0);
+    Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty);
+    Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr,
+                                  {Arg, Arg, ZExtShiftAmt});
+  } else {
+    Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name);
+    if (IID != Intrinsic::not_intrinsic &&
+        !F->getReturnType()->getScalarType()->isBFloatTy()) {
+      rename(F);
+      Function *NewFn = Intrinsic::getDeclaration(F->getParent(), IID);
+      SmallVector<Value *, 2> Args;
+      for (size_t I = 0; I < NewFn->arg_size(); ++I) {
+        Value *Arg = CI->getArgOperand(I);
+        Type *OldType = Arg->getType();
+        Type *NewType = NewFn->getArg(I)->getType();
+        Args.push_back(
+            (OldType->isIntegerTy() && NewType->getScalarType()->isBFloatTy())
+                ? Builder.CreateBitCast(Arg, NewType)
+                : Arg);
+      }
+      Rep = Builder.CreateCall(NewFn, Args);
+      if (F->getReturnType()->isIntegerTy())
+        Rep = Builder.CreateBitCast(Rep, F->getReturnType());
+    }
+  }
+
+  return Rep;
+}
+
 static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F,
                                       IRBuilder<> &Builder) {
   LLVMContext &C = F->getContext();
@@ -4204,81 +4305,8 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
 
     if (!IsX86 && Name == "stackprotectorcheck") {
       Rep = nullptr;
-    } else if (IsNVVM && (Name == "abs.i" || Name == "abs.ll")) {
-      Value *Arg = CI->getArgOperand(0);
-      Value *Neg = Builder.CreateNeg(Arg, "neg");
-      Value *Cmp = Builder.CreateICmpSGE(
-          Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond");
-      Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs");
-    } else if (IsNVVM && (Name.starts_with("atomic.load.add.f32.p") ||
-                          Name.starts_with("atomic.load.add.f64.p"))) {
-      Value *Ptr = CI->getArgOperand(0);
-      Value *Val = CI->getArgOperand(1);
-      Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(),
-                                    AtomicOrdering::SequentiallyConsistent);
-    } else if (IsNVVM && Name.consume_front("max.") &&
-               (Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
-                Name == "ui" || Name == "ull")) {
-      Value *Arg0 = CI->getArgOperand(0);
-      Value *Arg1 = CI->getArgOperand(1);
-      Value *Cmp = Name.starts_with("u")
-                       ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond")
-                       : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond");
-      Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max");
-    } else if (IsNVVM && Name.consume_front("min.") &&
-               (Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
-                Name == "ui" || Name == "ull")) {
-      Value *Arg0 = CI->getArgOperand(0);
-      Value *Arg1 = CI->getArgOperand(1);
-      Value *Cmp = Name.starts_with("u")
-                       ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond")
-                       : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond");
-      Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min");
-    } else if (IsNVVM && Name == "clz.ll") {
-      // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64.
-      Value *Arg = CI->getArgOperand(0);
-      Value *Ctlz = Builder.CreateCall(
-          Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz,
-                                    {Arg->getType()}),
-          {Arg, Builder.getFalse()}, "ctlz");
-      Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc");
-    } else if (IsNVVM && Name == "popc.ll") {
-      // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an
-      // i64.
-      Value *Arg = CI->getArgOperand(0);
-      Value *Popc = Builder.CreateCall(
-          Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop,
-                                    {Arg->getType()}),
-          Arg, "ctpop");
-      Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc");
-    } else if (IsNVVM) {
-      if (Name == "h2f") {
-        Rep =
-            Builder.CreateCall(Intrinsic::getDeclaration(
-                                   F->getParent(), Intrinsic::convert_from_fp16,
-                                   {Builder.getFloatTy()}),
-                               CI->getArgOperand(0), "h2f");
-      } else {
-        Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name);
-        if (IID != Intrinsic::not_intrinsic &&
-            !F->getReturnType()->getScalarType()->isBFloatTy()) {
-          rename(F);
-          NewFn = Intrinsic::getDeclaration(F->getParent(), IID);
-          SmallVector<Value *, 2> Args;
-          for (size_t I = 0; I < NewFn->arg_size(); ++I) {
-            Value *Arg = CI->getArgOperand(I);
-            Type *OldType = Arg->getType();
-            Type *NewType = NewFn->getArg(I)->getType();
-            Args.push_back((OldType->isIntegerTy() &&
-                            NewType->getScalarType()->isBFloatTy())
-                               ? Builder.CreateBitCast(Arg, NewType)
-                               : Arg);
-          }
-          Rep = Builder.CreateCall(NewFn, Args);
-          if (F->getReturnType()->isIntegerTy())
-            Rep = Builder.CreateBitCast(Rep, F->getReturnType());
-        }
-      }
+    } else if (IsNVVM){
+      Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder);
     } else if (IsX86) {
       Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder);
     } else if (IsARM) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index b7e210805db904..0d7e2d4a98d88b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1665,167 +1665,6 @@ def BREV64 :
              "brev.b64 \t$dst, $a;",
              [(set Int64Regs:$dst, (bitreverse Int64Regs:$a))]>;
 
-//
-// Rotate: Use ptx shf instruction if available.
-//
-
-// 32 bit r2 = rotl r1, n
-//    =>
-//        r2 = shf.l r1, r1, n
-def ROTL32imm_hw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt),
-            "shf.l.wrap.b32 \t$dst, $src, $src, $amt;",
-            [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 imm:$amt)))]>,
-           Requires<[hasHWROT32]>;
-
-def ROTL32reg_hw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
-            "shf.l.wrap.b32 \t$dst, $src, $src, $amt;",
-            [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
-           Requires<[hasHWROT32]>;
-
-// 32 bit r2 = rotr r1, n
-//    =>
-//        r2 = shf.r r1, r1, n
-def ROTR32imm_hw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt),
-            "shf.r.wrap.b32 \t$dst, $src, $src, $amt;",
-            [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 imm:$amt)))]>,
-           Requires<[hasHWROT32]>;
-
-def ROTR32reg_hw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
-            "shf.r.wrap.b32 \t$dst, $src, $src, $amt;",
-            [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
-           Requires<[hasHWROT32]>;
-
-// 32-bit software rotate by immediate.  $amt2 should equal 32 - $amt1.
-def ROT32imm_sw :
-  NVPTXInst<(outs Int32Regs:$dst),
-            (ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2),
-            "{{\n\t"
-            ".reg .b32 %lhs;\n\t"
-            ".reg .b32 %rhs;\n\t"
-            "shl.b32 \t%lhs, $src, $amt1;\n\t"
-            "shr.b32 \t%rhs, $src, $amt2;\n\t"
-            "add.u32 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            []>;
-
-def SUB_FRM_32 : SDNodeXForm<imm, [{
-  return CurDAG->getTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32);
-}]>;
-
-def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)),
-          (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>,
-      Requires<[noHWROT32]>;
-def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)),
-          (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>,
-      Requires<[noHWROT32]>;
-
-// 32-bit software rotate left by register.
-def ROTL32reg_sw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
-            "{{\n\t"
-            ".reg .b32 %lhs;\n\t"
-            ".reg .b32 %rhs;\n\t"
-            ".reg .b32 %amt2;\n\t"
-            "shl.b32 \t%lhs, $src, $amt;\n\t"
-            "sub.s32 \t%amt2, 32, $amt;\n\t"
-            "shr.b32 \t%rhs, $src, %amt2;\n\t"
-            "add.u32 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
-           Requires<[noHWROT32]>;
-
-// 32-bit software rotate right by register.
-def ROTR32reg_sw :
-  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt),
-            "{{\n\t"
-            ".reg .b32 %lhs;\n\t"
-            ".reg .b32 %rhs;\n\t"
-            ".reg .b32 %amt2;\n\t"
-            "shr.b32 \t%lhs, $src, $amt;\n\t"
-            "sub.s32 \t%amt2, 32, $amt;\n\t"
-            "shl.b32 \t%rhs, $src, %amt2;\n\t"
-            "add.u32 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>,
-           Requires<[noHWROT32]>;
-
-// 64-bit software rotate by immediate.  $amt2 should equal 64 - $amt1.
-def ROT64imm_sw :
-  NVPTXInst<(outs Int64Regs:$dst),
-            (ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2),
-            "{{\n\t"
-            ".reg .b64 %lhs;\n\t"
-            ".reg .b64 %rhs;\n\t"
-            "shl.b64 \t%lhs, $src, $amt1;\n\t"
-            "shr.b64 \t%rhs, $src, $amt2;\n\t"
-            "add.u64 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            []>;
-
-def SUB_FRM_64 : SDNodeXForm<imm, [{
-    return CurDAG->getTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32);
-}]>;
-
-def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)),
-          (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>;
-def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)),
-          (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>;
-
-// 64-bit software rotate left by register.
-def ROTL64reg_sw :
-  NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt),
-            "{{\n\t"
-            ".reg .b64 %lhs;\n\t"
-            ".reg .b64 %rhs;\n\t"
-            ".reg .u32 %amt2;\n\t"
-            "and.b32 \t%amt2, $amt, 63;\n\t"
-            "shl.b64 \t%lhs, $src, %amt2;\n\t"
-            "sub.u32 \t%amt2, 64, %amt2;\n\t"
-            "shr.b64 \t%rhs, $src, %amt2;\n\t"
-            "add.u64 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            [(set Int64Regs:$dst, (rotl Int64Regs:$src, (i32 Int32Regs:$amt)))]>;
-
-def ROTR64reg_sw :
-  NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt),
-            "{{\n\t"
-            ".reg .b64 %lhs;\n\t"
-            ".reg .b64 %rhs;\n\t"
-            ".reg .u32 %amt2;\n\t"
-            "and.b32 \t%amt2, $amt, 63;\n\t"
-            "shr.b64 \t%lhs, $src, %amt2;\n\t"
-            "sub.u32 \t%amt2, 64, %amt2;\n\t"
-            "shl.b64 \t%rhs, $src, %amt2;\n\t"
-            "add.u64 \t$dst, %lhs, %rhs;\n\t"
-            "}}",
-            [(set Int64Regs:$dst, (rotr Int64Regs:$src, (i32 Int32Regs:$amt)))]>;
-
-//
-// Funnnel shift in clamp mode
-//
-
-// Create SDNodes so they can be used in the DAG code, e.g.
-// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts)
-def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>;
-def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>;
-
-def FUNSHFLCLAMP :
-  NVPTXInst<(outs Int32Regs:$dst),
-            (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
-            "shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;",
-            [(set Int32Regs:$dst,
-              (FUN_SHFL_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>;
-
-def FUNSHFRCLAMP :
-  NVPTXInst<(outs Int32Regs:$dst),
-            (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
-            "shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;",
-            [(set Int32Regs:$dst,
-             (FUN_SHFR_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>;
 
 //
 // BFE - bit-field extract
@@ -3651,6 +3490,224 @@ def : Pat<(v2i16 (build_vector (i16 Int16Regs:$a), (i16 Int16Regs:$b))),
 def: Pat<(v2i16 (scalar_to_vector (i16 Int16Regs:$a))),
          (CVT_u32_u16 Int16Regs:$a, CvtNONE)>;
 
+//
+// Rotate: Use ptx shf instruction if available.
+//
+
+// Create SDNodes so they can be used in the DAG code, e.g.
+// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts)
+def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>;
+def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>;
+
+// Funnel shift, requires >= sm_32.  Does not trap if amt is out of range, so
+// no side effects.
+let hasSideEffects = false in {
+
+  def SHF_L_CLAMP_B32_REG :
+    NVPTXInst<(outs Int32Regs:$dst),
+              (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
+              "shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;",
+              [(set Int32Regs:$dst,
+                (FUN_SHFL_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>,
+    Requires<[hasHWROT32]>;
+
+  def SHF_R_CLAMP_B32_REG :
+    NVPTXInst<(outs Int32Regs:$dst),
+              (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
+              "shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;",
+              [(set Int32Regs:$dst,
+              (FUN_SHFR_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>,
+    Requires<[hasHWROT32]>;
+
+  def SHF_L_WRAP_B32_IMM
+    : NVPTXInst<(outs Int32Regs:$dst),
+                (ins  Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt),
+                "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
+      Requires<[hasHWROT32]>;
+
+  def SHF_L_WRAP_B32_REG
+    : NVPTXInst<(outs Int32Regs:$dst),
+                (ins  Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
+                "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
+      Requires<[hasHWROT32]>;
+
+  def SHF_R_WRAP_B32_IMM
+    : NVPTXInst<(outs Int32Regs:$dst),
+                (ins  Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt),
+                "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
+      Requires<[hasHWROT32]>;
+
+  def SHF_R_WRAP_B32_REG
+    : NVPTXInst<(outs Int32Regs:$dst),
+                (ins  Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
+                "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>,
+      Requires<[hasHWROT32]>;
+}
+
+// 32 bit r2 = rotl r1, n
+//    =>
+//        r2 = shf.l r1, r1, n
+def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)),
+          (SHF_L_WRAP_B32_IMM Int32Regs:$src, Int32Regs:$src, imm:$amt)>,
+   ...
[truncated]

Copy link

github-actions bot commented Sep 6, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream-remove-rotate branch 2 times, most recently from c91d6ff to 88bc7ef Compare September 12, 2024 16:58
@AlexMaclean
Copy link
Member Author

@Artem-B / @jlebar ping

@jlebar
Copy link
Member

jlebar commented Sep 12, 2024 via email

@AlexMaclean
Copy link
Member Author

@Artem-B ping

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in the patch title: handeling -> handling.

LGTM otherwise.

@AlexMaclean AlexMaclean changed the title [NVPTX] deprecate nvvm.rotate.* intrinsics, cleanup funnel-shift handeling [NVPTX] deprecate nvvm.rotate.* intrinsics, cleanup funnel-shift handling Sep 20, 2024
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream-remove-rotate branch from 88bc7ef to 9a0201c Compare September 21, 2024 19:22
@AlexMaclean
Copy link
Member Author

@Artem-B, I took a deeper look at this and realize the way we were lowering 64 bit rotation isn't actually correct.

Here's an alive proof in llvm that shows the expansion is incorrect https://alive2.llvm.org/ce/z/RQVDDG

I've move to just using the SelectionDAG expansion for all rotate instructions, this will properly convert of a 32-bit funnel shift when appropriate and expand to a valid multi-instruction expression otherwise.

Could you please take another look?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants