Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Support multiple TIR-level dynamic shared memory allocations #8571

Merged
merged 21 commits into from
Jul 31, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ TVM_DLL Pass LegalizePackedCalls();
*/
TVM_DLL Pass FlattenBuffer();

/*!
* A pass to merge multiple TIR-level dynamic shared memory allocations into one
*/
TVM_DLL Pass MergeDynamicSharedMemoryAllocations();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ def _build_for_device(input_mod, target, target_host):
mod_mixed = input_mod
mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)

opt_mixed = [tvm.tir.transform.VerifyMemory()]
opt_mixed = [
tvm.tir.transform.VerifyMemory(),
tvm.tir.transform.MergeDynamicSharedMemoryAllocations(),
]
if len(mod_mixed.functions) == 1:
opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]

Expand Down
12 changes: 12 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,3 +655,15 @@ def FlattenBuffer():
The result pass
"""
return _ffi_api.FlattenBuffer() # type: ignore


def MergeDynamicSharedMemoryAllocations():
"""This pass merges multiple TIR-level dynamic shared memory allocations
into one allocation.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target),
tir::transform::VerifyMemory()};

mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
if (pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", Bool(false)).value()) {
mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
}
Expand Down
149 changes: 149 additions & 0 deletions src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file merge_dynamic_shared_memory_allocations.cc
* \brief Each GPU kernel is allowed to have only one dynamic shared memory allocation.
* This pass merges multiple TIR-level dynamic shared memory allocations into one allocation.
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <unordered_map>
#include <unordered_set>

#include "../../runtime/thread_storage_scope.h"
#include "ir_utils.h"

namespace tvm {
namespace tir {

bool IsDynamicSharedMemory(Var buffer_var) {
auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
}

class AllocateCollector : public StmtExprVisitor {
public:
void VisitStmt_(const AllocateNode* op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
dyn_shmem_allocs_.insert(op);
}
StmtExprVisitor::VisitStmt_(op);
}

std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
};

class DynamicSharedMemoryRewriter : public StmtExprMutator {
public:
explicit DynamicSharedMemoryRewriter(
const std::unordered_set<const AllocateNode*>& dyn_shmem_allocs)
: dyn_shmem_allocs_{dyn_shmem_allocs} {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent && !allocated) {
// Allocate one dynamic shared memory allocation at the beginning of thread scope
int align = 1;
for (auto& alloc : dyn_shmem_allocs_) {
masahi marked this conversation as resolved.
Show resolved Hide resolved
ICHECK_EQ(alloc->dtype.lanes(), 1) << "vector dtype allocation not supported.";
align = std::max(align, alloc->dtype.bytes());
}
for (auto& alloc : dyn_shmem_allocs_) {
masahi marked this conversation as resolved.
Show resolved Hide resolved
ICHECK_EQ(alloc->extents.size(), 1);
buffer_byte_offsets_[alloc->buffer_var.get()] = merged_alloc_size_;
merged_alloc_size_ += alloc->extents[0] * align;
}

allocated = true;
auto new_body = Allocate(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_},
const_true(), StmtExprMutator::VisitStmt(op->body));
return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span);
}
return StmtMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const AllocateNode* op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
return StmtExprMutator::VisitStmt(op->body);
}
return StmtExprMutator::VisitStmt_(op);
}

PrimExpr VisitExpr_(const LoadNode* op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
auto offset = GetBufferOffset(op->buffer_var, op->dtype);
auto index = StmtExprMutator::VisitExpr(op->index);
return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span);
}
return StmtExprMutator::VisitExpr_(op);
}

Stmt VisitStmt_(const StoreNode* op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
auto offset = GetBufferOffset(op->buffer_var, op->value->dtype);
auto index = StmtExprMutator::VisitExpr(op->index);
auto value = StmtExprMutator::VisitExpr(op->value);
return Store(merged_buf_var_, value, offset + index, op->predicate, op->span);
}
return StmtExprMutator::VisitStmt_(op);
}

private:
PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) {
auto it = buffer_byte_offsets_.find(buffer_var.get());
ICHECK(it != buffer_byte_offsets_.end());
return indexdiv(it->second, dtype.bytes());
}

Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")};
std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
PrimExpr merged_alloc_size_{0};
std::unordered_map<const VarNode*, PrimExpr> buffer_byte_offsets_;
bool allocated{false};
};

Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) {
AllocateCollector collector;
collector(stmt);
if (collector.dyn_shmem_allocs_.size() > 1) {
return DynamicSharedMemoryRewriter(collector.dyn_shmem_allocs_)(std::move(stmt));
}
return stmt;
}

namespace transform {

Pass MergeDynamicSharedMemoryAllocations() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.MergeDynamicSharedMemoryAllocations", {});
}

TVM_REGISTER_GLOBAL("tir.transform.MergeDynamicSharedMemoryAllocations")
.set_body_typed(MergeDynamicSharedMemoryAllocations);

} // namespace transform
} // namespace tir
} // namespace tvm
143 changes: 143 additions & 0 deletions tests/python/unittest/test_tir_ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,147 @@ def check_target(target):
check_target(target)


@tvm.testing.requires_gpu
masahi marked this conversation as resolved.
Show resolved Hide resolved
def test_matmul_dyn_shared():
n = 1024
A = te.placeholder((n, n), name="A", dtype="float16")
B = te.placeholder((n, n), name="B", dtype="float16")

def syncthread():
return tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))

def test_matmul_ir(A, B, C):
ib = tvm.tir.ir_builder.create()
block = 16

tx = te.thread_axis("threadIdx.x")
ty = te.thread_axis("threadIdx.y")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", block)
ib.scope_attr(ty, "thread_extent", block)
ib.scope_attr(bx, "thread_extent", n / block)
masahi marked this conversation as resolved.
Show resolved Hide resolved
ib.scope_attr(by, "thread_extent", n / block)

A_sh = ib.allocate(A.dtype, (block, block), scope="shared.dyn", name="A_sh") # fp16
B_sh = ib.allocate(B.dtype, (block, block), scope="shared.dyn", name="B_sh") # fp16
# Create a dynamic shared memory for the accumulation.
# This is for testing merging dynamic shared memory alloctions with different data type.
# In practice, there is no need to allocate a shared memory for C.
C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32

A_ptr = ib.buffer_ptr(A)
B_ptr = ib.buffer_ptr(B)
C_ptr = ib.buffer_ptr(C)

C_sh[ty, tx] = 0.0

with ib.for_range(0, n // block, name="i") as i:
A_sh[ty, tx] = A_ptr[by * block + ty, i * block + tx]
B_sh[ty, tx] = B_ptr[i * block + ty, bx * block + tx]
ib.emit(syncthread())

with ib.for_range(0, block, name="k") as k:
C_sh[ty, tx] += cast(A_sh[ty, k] * B_sh[k, tx], "float32")

ib.emit(syncthread())

C_ptr[by * block + ty, bx * block + tx] = C_sh[ty, tx]

return ib.get()

C = te.extern(
A.shape,
[A, B],
lambda ins, outs: test_matmul_ir(ins[0], ins[1], outs[0]),
name="reduce",
dtype="float32",
)
s = te.create_schedule(C.op)

def check_target(target):
if not tvm.testing.device_enabled(target):
return

fmatmul = tvm.build(s, [A, B, C], target)
dev = tvm.device(target, 0)

size = (n, n)
a_np = np.random.uniform(size=size).astype(A.dtype)
b_np = np.random.uniform(size=size).astype(B.dtype)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros(size, dtype=C.dtype), dev)
fmatmul(a, b, c)
np_ref = np.dot(a_np.astype("float32"), b_np.astype("float32"))
tvm.testing.assert_allclose(c.numpy(), np_ref, 1e-4, 1e-4)

for target in ["cuda", "nvptx"]:
check_target(target)


@tvm.testing.requires_gpu
def test_dyn_shared_vectorized_store():
"""Test vectorized store into dynamic shared memory"""
n = te.size_var("n")
A = te.placeholder((n,), name="A", dtype="float16")
B = te.placeholder((n,), name="B", dtype="float32")

def test_device_ir(A, B, C):
n = A.shape[0]
ib = tvm.tir.ir_builder.create()

values_per_thread = 4
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", tvm.tir.indexdiv(n, values_per_thread))

A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn") # fp16
B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn") # fp32

Aptr = ib.buffer_ptr(A)
Bptr = ib.buffer_ptr(B)
Cptr = ib.buffer_ptr(C)

with ib.for_range(0, values_per_thread, kind="vectorize") as i:
A_sh[tx * values_per_thread + i] = Aptr[tx * values_per_thread + i]
B_sh[tx * values_per_thread + i] = Bptr[tx * values_per_thread + i]

with ib.for_range(0, values_per_thread) as i:
Cptr[tx * values_per_thread + i] = (
cast(A_sh[tx * values_per_thread + i], "float32") + B_sh[tx * values_per_thread + i]
)

return ib.get()

C = te.extern(
(n,),
[A, B],
lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]),
name="vadd",
dtype="float32",
)
s = te.create_schedule(C.op)

def check_target(target):
if not tvm.testing.device_enabled(target):
return

fadd = tvm.build(s, [A, B, C], target)
dev = tvm.device(target, 0)

for n in [512, 1024]:
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)
c = tvm.nd.array(np.zeros((n,), dtype=C.dtype), dev)
fadd(a, b, c)
tvm.testing.assert_allclose(
c.numpy(), a.numpy().astype("float32") + b.numpy(), 1e-4, 1e-4
)

for target in ["cuda", "nvptx"]:
check_target(target)


if __name__ == "__main__":
test_prefetch()
test_if()
Expand All @@ -565,3 +706,5 @@ def check_target(target):
test_while_mandel()
test_while_binary_search()
test_dyn_shared()
test_matmul_dyn_shared()
test_dyn_shared_vectorized_store()