diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index d81652433bf1..31528c90a06f 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -836,10 +836,12 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { << " = " << src << "[" << src_offset << " + i];\n"; os << "}\n"; } else if (m == 16 && n == 16) { - os << "for (int i = 0; i < 8; ++i) {\n"; - os << dst << "[(i / 4 * 8 + threadIdx.x / 4) * " << stride - << " + (threadIdx.x % 4) * 4 + i % 4]" - << " = " << src << "[" << src_offset << " + i];\n"; + os << "for (int outer = 0; outer < 2; ++outer) {\n"; + os << "for (int i = 0; i < 4; ++i) {\n"; + os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride + << " + outer * 8 + (threadIdx.x % 4) * 2 + i % 2]" + << " = " << src << "[" << src_offset << " + i * outer * 4];\n"; + os << "}\n"; os << "}\n"; } } else if (op->op.same_as(builtin::mma_fill())) { @@ -848,7 +850,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string dst_offset = this->PrintExpr(op->args[2]); os << "for (int i = 0; i < " << num_elem << "; ++i) {\n"; - os << dst << "[" << dst_offset << " + i] = 0.0;" ; + os << dst << "[" << dst_offset << " + i] = 0.0;"; os << "}\n"; } else { CodeGenC::VisitExpr_(op, os);