Skip to content

Commit

Permalink
fixed mma store codegen for 16x8x16
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent ec81250 commit 18e8d73
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -836,10 +836,12 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
<< " = " << src << "[" << src_offset << " + i];\n";
os << "}\n";
} else if (m == 16 && n == 16) {
os << "for (int i = 0; i < 8; ++i) {\n";
os << dst << "[(i / 4 * 8 + threadIdx.x / 4) * " << stride
<< " + (threadIdx.x % 4) * 4 + i % 4]"
<< " = " << src << "[" << src_offset << " + i];\n";
os << "for (int outer = 0; outer < 2; ++outer) {\n";
os << "for (int i = 0; i < 4; ++i) {\n";
os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride
<< " + outer * 8 + (threadIdx.x % 4) * 2 + i % 2]"
<< " = " << src << "[" << src_offset << " + i * outer * 4];\n";
os << "}\n";
os << "}\n";
}
} else if (op->op.same_as(builtin::mma_fill())) {
Expand All @@ -848,7 +850,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string dst_offset = this->PrintExpr(op->args[2]);

os << "for (int i = 0; i < " << num_elem << "; ++i) {\n";
os << dst << "[" << dst_offset << " + i] = 0.0;" ;
os << dst << "[" << dst_offset << " + i] = 0.0;";
os << "}\n";
} else {
CodeGenC::VisitExpr_(op, os);
Expand Down

0 comments on commit 18e8d73

Please sign in to comment.