diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index ed6865300a58..d81652433bf1 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -822,16 +822,26 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset, smem_ptr, smem_elem_offset); } else if (op->op.same_as(builtin::mma_store())) { + int m = Downcast(op->args[1])->value; + int n = Downcast(op->args[1])->value; std::string dst = this->PrintExpr(op->args[2]); std::string src = this->PrintExpr(op->args[3]); std::string src_offset = this->PrintExpr(op->args[4]); std::string stride = this->PrintExpr(op->args[5]); - os << "for (int i = 0; i < 4; ++i) {\n"; - os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride - << " + (threadIdx.x % 4) * 2 + i % 2]" - << " = " << src << "[" << src_offset << " + i];\n"; - os << "}\n"; + if (m == 16 && n == 8) { + os << "for (int i = 0; i < 4; ++i) {\n"; + os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride + << " + (threadIdx.x % 4) * 2 + i % 2]" + << " = " << src << "[" << src_offset << " + i];\n"; + os << "}\n"; + } else if (m == 16 && n == 16) { + os << "for (int i = 0; i < 8; ++i) {\n"; + os << dst << "[(i / 4 * 8 + threadIdx.x / 4) * " << stride + << " + (threadIdx.x % 4) * 4 + i % 4]" + << " = " << src << "[" << src_offset << " + i];\n"; + os << "}\n"; + } } else if (op->op.same_as(builtin::mma_fill())) { std::string num_elem = this->PrintExpr(op->args[0]); std::string dst = this->PrintExpr(op->args[1]);