Skip to content

Commit

Permalink
[CUDA] Support multiple TIR-level dynamic shared memory allocations (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and ylc committed Sep 29, 2021
1 parent 0bfcf97 commit 7c677ad
Show file tree
Hide file tree
Showing 6 changed files with 430 additions and 1 deletion.
5 changes: 5 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,11 @@ TVM_DLL Pass LowerMatchBuffer();
*/
TVM_DLL Pass FlattenBuffer();

/*!
* A pass to merge multiple TIR-level dynamic shared memory allocations into one
*/
TVM_DLL Pass MergeDynamicSharedMemoryAllocations();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ def _build_for_device(input_mod, target, target_host):
mod_mixed = input_mod
mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)

opt_mixed = [tvm.tir.transform.VerifyMemory()]
opt_mixed = [
tvm.tir.transform.VerifyMemory(),
tvm.tir.transform.MergeDynamicSharedMemoryAllocations(),
]
if len(mod_mixed.functions) == 1:
opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]

Expand Down
12 changes: 12 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,15 @@ def FlattenBuffer():
The result pass
"""
return _ffi_api.FlattenBuffer() # type: ignore


def MergeDynamicSharedMemoryAllocations():
"""This pass merges multiple TIR-level dynamic shared memory allocations
into one allocation.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target),
tir::transform::VerifyMemory()};

mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
if (pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", Bool(false)).value()) {
mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
}
Expand Down
149 changes: 149 additions & 0 deletions src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file merge_dynamic_shared_memory_allocations.cc
* \brief Each GPU kernel is allowed to have only one dynamic shared memory allocation.
* This pass merges multiple TIR-level dynamic shared memory allocations into one allocation.
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <unordered_map>
#include <unordered_set>

#include "../../runtime/thread_storage_scope.h"
#include "ir_utils.h"

namespace tvm {
namespace tir {

bool IsDynamicSharedMemory(Var buffer_var) {
auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
}

class AllocateCollector : public StmtExprVisitor {
public:
void VisitStmt_(const AllocateNode* op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
dyn_shmem_allocs_.insert(op);
}
StmtExprVisitor::VisitStmt_(op);
}

std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
};

class DynamicSharedMemoryRewriter : public StmtExprMutator {
public:
explicit DynamicSharedMemoryRewriter(
const std::unordered_set<const AllocateNode*>& dyn_shmem_allocs)
: dyn_shmem_allocs_{dyn_shmem_allocs} {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent && !allocated) {
// Allocate one dynamic shared memory allocation at the beginning of thread scope
int align = 1;
for (const auto& alloc : dyn_shmem_allocs_) {
ICHECK_EQ(alloc->dtype.lanes(), 1) << "vector dtype allocation not supported.";
align = std::max(align, alloc->dtype.bytes());
}
for (const auto& alloc : dyn_shmem_allocs_) {
ICHECK_EQ(alloc->extents.size(), 1);
buffer_byte_offsets_[alloc->buffer_var.get()] = merged_alloc_size_;
merged_alloc_size_ += alloc->extents[0] * align;
}

allocated = true;
auto new_body = Allocate(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_},
const_true(), StmtExprMutator::VisitStmt(op->body));
return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span);
}
return StmtMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const AllocateNode* op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
return StmtExprMutator::VisitStmt(op->body);
}
return StmtExprMutator::VisitStmt_(op);
}

PrimExpr VisitExpr_(const LoadNode* op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
auto offset = GetBufferOffset(op->buffer_var, op->dtype);
auto index = StmtExprMutator::VisitExpr(op->index);
return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span);
}
return StmtExprMutator::VisitExpr_(op);
}

Stmt VisitStmt_(const StoreNode* op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
auto offset = GetBufferOffset(op->buffer_var, op->value->dtype);
auto index = StmtExprMutator::VisitExpr(op->index);
auto value = StmtExprMutator::VisitExpr(op->value);
return Store(merged_buf_var_, value, offset + index, op->predicate, op->span);
}
return StmtExprMutator::VisitStmt_(op);
}

private:
PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) {
auto it = buffer_byte_offsets_.find(buffer_var.get());
ICHECK(it != buffer_byte_offsets_.end());
return indexdiv(it->second, dtype.bytes());
}

Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")};
std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
PrimExpr merged_alloc_size_{0};
std::unordered_map<const VarNode*, PrimExpr> buffer_byte_offsets_;
bool allocated{false};
};

Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) {
AllocateCollector collector;
collector(stmt);
if (collector.dyn_shmem_allocs_.size() > 1) {
return DynamicSharedMemoryRewriter(collector.dyn_shmem_allocs_)(std::move(stmt));
}
return stmt;
}

namespace transform {

Pass MergeDynamicSharedMemoryAllocations() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.MergeDynamicSharedMemoryAllocations", {});
}

TVM_REGISTER_GLOBAL("tir.transform.MergeDynamicSharedMemoryAllocations")
.set_body_typed(MergeDynamicSharedMemoryAllocations);

} // namespace transform
} // namespace tir
} // namespace tvm
Loading

0 comments on commit 7c677ad

Please sign in to comment.