Skip to content

Commit

Permalink
add 16x8x16 mma store codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent e08df2a commit ec81250
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -822,16 +822,26 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
smem_ptr, smem_elem_offset);
} else if (op->op.same_as(builtin::mma_store())) {
int m = Downcast<Integer>(op->args[1])->value;
int n = Downcast<Integer>(op->args[1])->value;
std::string dst = this->PrintExpr(op->args[2]);
std::string src = this->PrintExpr(op->args[3]);
std::string src_offset = this->PrintExpr(op->args[4]);
std::string stride = this->PrintExpr(op->args[5]);

os << "for (int i = 0; i < 4; ++i) {\n";
os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride
<< " + (threadIdx.x % 4) * 2 + i % 2]"
<< " = " << src << "[" << src_offset << " + i];\n";
os << "}\n";
if (m == 16 && n == 8) {
os << "for (int i = 0; i < 4; ++i) {\n";
os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride
<< " + (threadIdx.x % 4) * 2 + i % 2]"
<< " = " << src << "[" << src_offset << " + i];\n";
os << "}\n";
} else if (m == 16 && n == 16) {
os << "for (int i = 0; i < 8; ++i) {\n";
os << dst << "[(i / 4 * 8 + threadIdx.x / 4) * " << stride
<< " + (threadIdx.x % 4) * 4 + i % 4]"
<< " = " << src << "[" << src_offset << " + i];\n";
os << "}\n";
}
} else if (op->op.same_as(builtin::mma_fill())) {
std::string num_elem = this->PrintExpr(op->args[0]);
std::string dst = this->PrintExpr(op->args[1]);
Expand Down

0 comments on commit ec81250

Please sign in to comment.