Skip to content

Commit

Permalink
[MetaSchedule] Enable Task Filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed May 31, 2022
1 parent ac5d781 commit 5f96e8a
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 98 deletions.
8 changes: 6 additions & 2 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""MetaSchedule-Relay integration"""
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

import numpy as np # type: ignore
from tvm import nd
from tvm._ffi import get_global_func
from tvm.ir import IRModule, transform
from tvm.runtime import NDArray
from tvm.target import Target
from tvm.te import Tensor

from .extracted_task import ExtractedTask
from .utils import autotvm_silencer
Expand All @@ -36,6 +37,7 @@ def extract_task_from_relay(
opt_level: int = 3,
pass_config: Optional[Dict[str, Any]] = None,
disabled_pass: Optional[List[str]] = None,
filter_func: Callable[[List[Tensor]], bool] = None,
) -> List[ExtractedTask]:
"""Extract tuning tasks from a relay program.
Expand All @@ -53,6 +55,8 @@ def extract_task_from_relay(
The pass config of the compiler
disabled_pass : Optional[List[str]]
The list of disabled passes of the compiler
filter_func : Callable[[List[tvm.te.Tensor]], bool]
The filter function to filter out the extracted tasks
Returns
-------
Expand Down Expand Up @@ -90,4 +94,4 @@ def extract_task_from_relay(
config=pass_config,
disabled_pass=disabled_pass,
):
return list(extract_task_func(mod, target, relay_params))
return list(extract_task_func(mod, target, relay_params, filter_func))
2 changes: 1 addition & 1 deletion python/tvm/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .tag import tag_scope
from .operation import placeholder, compute, scan, extern, var, size_var, const
from .operation import thread_axis, reduce_axis
from .operation import create_prim_func, create_prim_func_from_outputs
from .operation import create_prim_func

from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp
from .autodiff import gradient
Expand Down
29 changes: 5 additions & 24 deletions python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@
# specific language governing permissions and limitations
# under the License.
""" Operation class for computation declaration."""
import inspect

# pylint: disable=invalid-name
from numbers import Integral as _Integral
from typing import List, Union
import inspect
from typing import List

import tvm._ffi
import tvm.tir
import tvm.tir._ffi_api
from tvm._ffi.base import string_types
from tvm.ir import Array
from tvm.runtime import convert
import tvm.tir
import tvm.tir._ffi_api

from . import _ffi_api
from . import tag as _tag
Expand Down Expand Up @@ -528,23 +529,3 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
if not isinstance(ops, (list, tuple, Array)):
ops = [ops]
return _ffi_api.CreatePrimFunc(ops)


def create_prim_func_from_outputs(
outputs: Union[_tensor.Tensor, List[_tensor.Tensor]],
) -> tvm.tir.PrimFunc:
"""Create a TensorIR PrimFunc from output tensor(s) in TE
Parameters
----------
outputs : Union[Tensor, List[Tensor]]
The source expression.
Returns
-------
func : tir.PrimFunc
The created function.
"""
if not isinstance(outputs, (list, tuple, Array)):
outputs = [outputs]
return _ffi_api.CreatePrimFuncFromOutputs(outputs)
80 changes: 55 additions & 25 deletions src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,58 @@ namespace tvm {
namespace relay {
namespace backend {

namespace metaschedule {

using meta_schedule::ExtractedTask;
bool DefaultTaskFilter(const Array<te::Tensor>& args) {
using namespace ::tvm::te;
std::vector<Tensor> stack;
std::unordered_set<const TensorNode*> visited;
for (const Tensor& v : args) {
for (const PrimExpr& e : v->shape) {
// Dynamic shape is not supported for now
if (!e->IsInstance<IntImmNode>()) {
return false;
}
}
if (!visited.count(v.get())) {
visited.insert(v.get());
stack.push_back(v);
}
}
while (!stack.empty()) {
Tensor tensor = stack.back();
stack.pop_back();
if (tensor->op->IsInstance<PlaceholderOpNode>()) {
// do nothing
} else if (tensor->op->IsInstance<ComputeOpNode>()) {
Array<Tensor> inputs = tensor->op->InputTensors();
for (const Tensor& v : inputs) {
if (!visited.count(v.get())) {
visited.insert(v.get());
stack.push_back(v);
}
}
} else {
return false;
}
}
return true;
}

Array<ExtractedTask> ExtractTask(IRModule mod, Target target,
Map<String, runtime::NDArray> params) {
Array<meta_schedule::ExtractedTask> ExtractTask(
IRModule mod, Target target, Map<String, runtime::NDArray> params,
runtime::TypedPackedFunc<bool(const Array<te::Tensor>&)> filter_func) {
using meta_schedule::ExtractedTask;
if (filter_func == nullptr) {
filter_func = DefaultTaskFilter;
}
backend::BindParamsInModule(mod, params);

// is_vm=true for backward compatibility
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true);
pass_seqs.push_back(transform::FuseOps());

transform::Sequential seq(pass_seqs);
auto opt_mod = seq(std::move(mod));
mod = transform::Sequential(pass_seqs)(std::move(mod));

std::vector<ExtractedTask> tasks;
std::unordered_map<tec::CCacheKey, ExtractedTask> cache;

PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache](const Expr& exp) {
PostOrderVisit(mod->Lookup("main"), [&target, &tasks, &cache, &filter_func](const Expr& exp) {
if (exp->IsInstance<FunctionNode>()) {
Function relay_func = Downcast<Function>(exp);
if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) {
Expand All @@ -61,17 +94,19 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target,
it->second->weight += 1;
return;
}
Array<te::Tensor> inputs_outputs;
Array<te::Tensor> inputs_outputs{nullptr};
std::string fused_name;
std::tie(inputs_outputs, fused_name) =
tec::LowerTECompute(relay_func, target, /*return_inputs=*/true);
auto prim_func = tir::CreatePrimFunc(inputs_outputs);
GlobalVar prim_fn_var(fused_name);
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, prim_func}});
ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1);
tasks.push_back(extracted_task);
cache.emplace(cache_key, extracted_task);
if (filter_func(inputs_outputs)) {
tir::PrimFunc prim_func = tir::CreatePrimFunc(inputs_outputs);
GlobalVar prim_fn_var(fused_name);
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, prim_func}});
ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1);
tasks.push_back(extracted_task);
cache.emplace(cache_key, extracted_task);
}
}
});
// Tasks are extracted via post order visit, return the reversed list.
Expand All @@ -83,12 +118,7 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target,
return tasks;
}

} // namespace metaschedule

TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask")
.set_body_typed([](IRModule mod, Target target, Map<String, runtime::NDArray> params) {
return metaschedule::ExtractTask(mod, target, params);
});
TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask").set_body_typed(ExtractTask);

} // namespace backend
} // namespace relay
Expand Down
33 changes: 0 additions & 33 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,40 +458,7 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
return LayoutFreePlaceholdersNormalizer().Process(std::move(func));
}

PrimFunc CreatePrimFuncFromOutputs(const Array<te::Tensor>& outputs) {
std::vector<te::Tensor> stack;
std::unordered_set<const te::TensorNode*> visited;
for (const te::Tensor& output : outputs) {
if (!visited.count(output.get())) {
visited.insert(output.get());
stack.push_back(output);
}
}

Array<te::Tensor> arg_list;
while (!stack.empty()) {
te::Tensor tensor = stack.back();
stack.pop_back();
if (tensor->op->IsInstance<te::PlaceholderOpNode>()) {
arg_list.push_back(tensor);
} else if (tensor->op->IsInstance<te::ComputeOpNode>()) {
Array<te::Tensor> inputs = tensor->op->InputTensors();
for (const te::Tensor& input : inputs) {
if (!visited.count(input.get())) {
visited.insert(input.get());
stack.push_back(input);
}
}
}
}
for (const te::Tensor& output : outputs) {
arg_list.push_back(output);
}
return CreatePrimFunc(arg_list);
}

TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc);
TVM_REGISTER_GLOBAL("te.CreatePrimFuncFromOutputs").set_body_typed(CreatePrimFuncFromOutputs);

} // namespace tir
} // namespace tvm
3 changes: 0 additions & 3 deletions src/te/operation/create_primfunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ namespace tir {
/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list);

/*! \brief Create a schedulable TensorIR func from TE compute outputs. */
PrimFunc CreatePrimFuncFromOutputs(const Array<te::Tensor>& outputs);

} // namespace tir
} // namespace tvm

Expand Down
20 changes: 10 additions & 10 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,16 @@ Schedule ConcreteScheduleNode::Copy() {
* \param level An ScheduleErrorRenderLevel enum, level of error rendering
* \sa ScheduleErrorRenderLevel
*/
#define TVM_TIR_SCHEDULE_END(primitive, level) \
} \
catch (const ScheduleError& error) { \
if ((level) == ScheduleErrorRenderLevel::kDetail) { \
throw tvm::runtime::Error(error.RenderReport(primitive)); \
} else if ((level) == ScheduleErrorRenderLevel::kFast) { \
throw tvm::runtime::Error(error.FastErrorString()); \
} else if ((level) == ScheduleErrorRenderLevel::kNone) { \
throw tvm::runtime::Error("ScheduleError: (not rendered)"); \
} \
#define TVM_TIR_SCHEDULE_END(primitive, level) \
} \
catch (const ScheduleError& error) { \
if ((level) == ScheduleErrorRenderLevel::kDetail) { \
throw tvm::runtime::Error(error.RenderReport(primitive) + "\n" + runtime::Backtrace()); \
} else if ((level) == ScheduleErrorRenderLevel::kFast) { \
throw tvm::runtime::Error(error.FastErrorString()); \
} else if ((level) == ScheduleErrorRenderLevel::kNone) { \
throw tvm::runtime::Error("ScheduleError: (not rendered)"); \
} \
}

/******** Schedule: Schedule: Sampling ********/
Expand Down
63 changes: 63 additions & 0 deletions tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,69 @@ def test_meta_schedule_integration_extract_from_bert_base():
assert expected_shape == shape, t.task_name


@requires_torch
def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
def filter_func(args) -> bool:
from tvm import te, tir

has_complex_op = False
visited = set()

def traverse(t):
nonlocal has_complex_op
assert t.handle is not None
if t.handle.value in visited:
return
if isinstance(t.op, te.PlaceholderOp):
pass
elif isinstance(t.op, te.ComputeOp):
has_complex_op = has_complex_op or any(
[isinstance(e, tir.Reduce) for e in t.op.body]
)
for x in t.op.input_tensors:
traverse(x)
visited.add(t.handle.value)

for t in args:
traverse(t)
return has_complex_op

mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
extracted_tasks = ms.extract_task_from_relay(
mod,
target="llvm",
params=params,
filter_func=filter_func,
)
expected_task_names = [
"fused_" + s
for s in [
"nn_max_pool2d",
"nn_adaptive_avg_pool2d",
"nn_dense_add",
"nn_conv2d_add",
"nn_conv2d_add_1",
"nn_conv2d_add_2",
"nn_conv2d_add_add_nn_relu",
"nn_conv2d_add_add_nn_relu_1",
"nn_conv2d_add_nn_relu",
"nn_conv2d_add_nn_relu_1",
"nn_conv2d_add_nn_relu_2",
"nn_conv2d_add_nn_relu_3",
"nn_conv2d_add_nn_relu_4",
"nn_conv2d_add_nn_relu_5",
"nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu",
"nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1",
"nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu",
"nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1",
]
]

assert len(extracted_tasks) == len(expected_task_names)
for t in extracted_tasks:
assert t.task_name in expected_task_names, t.task_name


@requires_torch
def test_meta_schedule_integration_apply_history_best():
mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
Expand Down

0 comments on commit 5f96e8a

Please sign in to comment.