Skip to content

Commit

Permalink
[PASS] Refactor a couple of TIR passes - BindTarget, AnnotateEntryFun…
Browse files Browse the repository at this point in the history
…c, Filter, LowerInitBlock (#11628)

This PR fixes a few inconsistent pass registration and add testcases for them. 
- `LowerInitBlock` had mismatch between its pass name and ffi key.
- `BindTarget`, `AnnotateEntryFunc`, `Filter` were not following the name convention of tir passes and they were not registered in FFI registry.
  • Loading branch information
sunggg committed Jun 9, 2022
1 parent ebc9b6d commit 87502dd
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 57 deletions.
19 changes: 19 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_TIR_TRANSFORM_H_

#include <tvm/ir/transform.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>

Expand Down Expand Up @@ -625,6 +626,24 @@ TVM_DLL Pass ExtractPrimFuncConstants();
*/
TVM_DLL Pass RenormalizeSplitPattern();

/*!
* \brief Annotate a PrimFunc with a given target.
* \return The pass.
*/
TVM_DLL Pass BindTarget(Target target);

/*!
* \brief Set a PrimFunc as the entry point if it is only function in IRModule.
* \return The pass.
*/
TVM_DLL Pass AnnotateEntryFunc();

/*!
* \brief Filter PrimFuncs with a given condition.
* \return The pass.
*/
TVM_DLL Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond);

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
61 changes: 40 additions & 21 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# under the License.
"""Wrapping existing transformations."""
# pylint: disable=invalid-name
from typing import Optional
from typing import Optional, Callable

from . import _ffi_api
from . import function_pass as _fpass

Expand All @@ -43,26 +44,6 @@ def _transform(func, mod, ctx):
return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") # type: ignore


def Filter(fcond):
"""Filter functions by the calling convention attribute.
Parameters
----------
fcond : tvm.tir.PrimFunc -> bool
The condition of the filtering.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
# pylint: disable=unused-argument
def _transform(func, mod, ctx):
return func if fcond(func) else None

return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") # type: ignore


def InjectPrefetch():
"""Inject prefetch instructions into stmt.
Expand Down Expand Up @@ -806,3 +787,41 @@ def RenormalizeSplitPattern():
The result pass
"""
return _ffi_api.RenormalizeSplitPattern() # type: ignore


def BindTarget(target):
"""Annotate a PrimFunc with a given target.
Parameters
-------
target : tvm.target.Target
target
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BindTarget(target) # type: ignore


def AnnotateEntryFunc():
"""Set a PrimFunc as the entry point if it is only function in IRModule.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.AnnotateEntryFunc() # type: ignore


def Filter(fcond: Callable):
"""Filter out PrimFuncs that does not satisfy the given condition.
`fcond` should be a function that takes a primfunc and returns boolean.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.Filter(fcond) # type: ignore
45 changes: 11 additions & 34 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,32 +164,6 @@ TVM_REGISTER_GLOBAL("driver.get_binds")
return out_arr;
});

transform::Pass BindTarget(Target target) {
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
return WithAttr(std::move(f), tvm::attr::kTarget, target);
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {});
}

static transform::Pass AnnotateEntryFunc(bool b) {
auto fpass = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true));
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "AnnotateEntryFunc", {});
}

template <typename FCond>
transform::Pass Filter(FCond fcond) {
auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
if (fcond(f)) {
return f;
} else {
return tir::PrimFunc(nullptr);
}
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}

Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
transform::PassContext pass_ctx = transform::PassContext::Current();

Expand Down Expand Up @@ -564,12 +538,12 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)

Array<Pass> mixed_pass_list;

mixed_pass_list.push_back(BindTarget(target));
mixed_pass_list.push_back(tir::transform::BindTarget(target));

mixed_pass_list.push_back(tir::transform::VerifyMemory());

if (ShouldAnnotateEntryFunc(mixed_mod)) {
mixed_pass_list.push_back(AnnotateEntryFunc(true));
mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());
}

bool detect_global_barrier =
Expand Down Expand Up @@ -606,14 +580,16 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")

transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) {
Array<tvm::transform::Pass> host_pass_list;
host_pass_list.push_back(Filter([](const tir::PrimFunc& f) {

runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) !=
CallingConv::kDeviceKernelLaunch;
}));
};
host_pass_list.push_back(tir::transform::Filter(fcond));

ICHECK(mixed_mod.defined()) << "This module must be defined";

host_pass_list.push_back(BindTarget(target_host));
host_pass_list.push_back(tir::transform::BindTarget(target_host));

host_pass_list.push_back(tir::transform::LowerTVMBuiltin());
host_pass_list.push_back(tir::transform::LowerCustomDatatypes());
Expand All @@ -631,12 +607,13 @@ TVM_REGISTER_GLOBAL("driver.host_mod_passes")

transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) {
Array<Pass> device_pass_list;
device_pass_list.push_back(Filter([](const tir::PrimFunc& f) {
runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
CallingConv::kDeviceKernelLaunch;
}));
};
device_pass_list.push_back(tir::transform::Filter(fcond));

device_pass_list.push_back(BindTarget(target));
device_pass_list.push_back(tir::transform::BindTarget(target));

device_pass_list.push_back(tir::transform::LowerWarpMemory());
device_pass_list.push_back(tir::transform::Simplify());
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_init_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Pass LowerInitBlock() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
return LowerInitBlock(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerReduction", {});
return CreatePrimFuncPass(pass_func, 0, "tir.LowerInitBlock", {});
}

TVM_REGISTER_GLOBAL("tir.transform.LowerInitBlock").set_body_typed(LowerInitBlock);
Expand Down
63 changes: 63 additions & 0 deletions src/tir/transforms/primfunc_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file primfunc_utils.cc
* \brief Passes that serve as helper functions.
*/

#include <tvm/driver/driver_api.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tir {
namespace transform {
transform::Pass BindTarget(Target target) {
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
return WithAttr(std::move(f), tvm::attr::kTarget, target);
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {});
}

transform::Pass AnnotateEntryFunc() {
auto fpass = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
ICHECK(m->functions.size() == 1);
return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true));
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.AnnotateEntryFunc", {});
}

transform::Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond) {
auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
if (fcond(f)) {
return f;
} else {
return tir::PrimFunc(nullptr);
}
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.Filter", {});
}

TVM_REGISTER_GLOBAL("tir.transform.BindTarget").set_body_typed(BindTarget);
TVM_REGISTER_GLOBAL("tir.transform.AnnotateEntryFunc").set_body_typed(AnnotateEntryFunc);
TVM_REGISTER_GLOBAL("tir.transform.Filter").set_body_typed(Filter);

} // namespace transform
} // namespace tir
} // namespace tvm
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
PoolInfo pool_info = pool_allocation->pool_info;
int byte_pool_offset = pool_allocation->byte_offset->value;
int required_pool_size_for_allocation =
byte_pool_offset + CalculateExtentsSize(allocate_node.operator->());
byte_pool_offset + static_cast<int>(CalculateExtentsSize(allocate_node.operator->()));
if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) {
all_pools_sizes_[pool_info] = required_pool_size_for_allocation;
} else {
Expand Down
Loading

0 comments on commit 87502dd

Please sign in to comment.