Skip to content

Commit

Permalink
update c++ codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 4, 2022
1 parent 6206e38 commit 4a383e2
Showing 1 changed file with 15 additions and 23 deletions.
38 changes: 15 additions & 23 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,14 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, " problem_size,\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a), layout_A},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), layout_B},\n");
CutlassPrint(conv2d_decl, " tensor_c,\n");
CutlassPrint(conv2d_decl, " tensor_d,\n");

if (use_split_k) {
CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n");
CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n");
} else {
CutlassPrint(conv2d_decl, " tensor_c,\n");
CutlassPrint(conv2d_decl, " tensor_d,\n");
}

if (has_residual_block) {
ICHECK(use_split_k == false) << "Split-k not supported for residual block fusion";
Expand All @@ -426,13 +432,18 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
// Check the problem size is supported or not
CutlassPrint(conv2d_decl, "cutlass::Status status = conv2d_op.can_implement(arguments);\n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");

if (use_split_k) {
CutlassPrint(
conv2d_decl,
"arguments.ref_D.reset(reinterpret_cast<ElementCompute*>(workspace.get()), layout_D);\n");
}

// Initialize CUTLASS kernel with arguments and workspace pointer
CutlassPrint(conv2d_decl, "status = conv2d_op.initialize(arguments, workspace.get());\n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");

if (use_split_k) {
CutlassPrint(conv2d_decl,
"\narguments.ref_D.reset(reinterpret_cast<ElementOutput*>(workspace.get())); \n");
CutlassPrint(
conv2d_decl,
"arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}; \n");
Expand All @@ -445,25 +456,6 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");

if (use_split_k) {
CutlassPrint(conv2d_decl, "\nusing EpilogueOutputOp = Conv2d::EpilogueOutputOp;\n");
CutlassPrint(conv2d_decl, "using ReductionOp = cutlass::reduction::thread::ReduceAdd<\n");
CutlassPrint(conv2d_decl, " Conv2d::ElementAccumulator,\n");
CutlassPrint(conv2d_decl, " typename EpilogueOutputOp::ElementAccumulator,\n");
CutlassPrint(conv2d_decl, " EpilogueOutputOp::kCount\n");
CutlassPrint(conv2d_decl, " >;\n");
CutlassPrint(conv2d_decl, "\n");
CutlassPrint(conv2d_decl,
"using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<\n");
CutlassPrint(conv2d_decl, " cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,\n");
CutlassPrint(conv2d_decl, " EpilogueOutputOp,\n");
CutlassPrint(conv2d_decl, " ReductionOp\n");
CutlassPrint(conv2d_decl, " >;\n");
CutlassPrint(conv2d_decl, "\n");
CutlassPrint(
conv2d_decl,
"using ReductionDevice = cutlass::reduction::device::ReduceSplitK<ReductionKernel>;\n");
CutlassPrint(conv2d_decl,
"using ReductionStrideIndex = typename ReductionDevice::StrideIndex;\n");
CutlassPrint(conv2d_decl, " ReductionDevice reduction_op;\n");
CutlassPrint(conv2d_decl,
" const static cutlass::conv::Operator kConvolutionalOperator = "
Expand Down

0 comments on commit 4a383e2

Please sign in to comment.