Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Apply-History-Best Task Filtering #11692

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions include/tvm/meta_schedule/apply_history_best.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>

namespace tvm {
namespace te {
class Tensor;
} // namespace te
} // namespace tvm

namespace tvm {
namespace meta_schedule {

Expand All @@ -38,12 +44,21 @@ namespace meta_schedule {
*/
class ApplyHistoryBestNode : public runtime::Object {
public:
using FTEFilterFunc =
runtime::TypedPackedFunc<Optional<tir::PrimFunc>(const Array<te::Tensor, void>&)>;

/*! \brief The database to be queried from */
Database database{nullptr};
/*! \brief The filtering function for TE computation */
FTEFilterFunc te_filter_func{nullptr};
/*! \brief The logging function to be used */
PackedFunc logging_func;

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("database", &database); }
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("database", &database);
// `te_filter_func` is not visited
// `logging_func` is not visited
}
/*!
* \brief Query the best entry from the database
* \param task_name The name of the task to be queried
Expand All @@ -67,9 +82,11 @@ class ApplyHistoryBest : public runtime::ObjectRef {
/*!
* \brief Constructor
* \param database The database to be queried from
* \param te_filter_func The filtering function for TE computation
* \param logging_func The logging function to use
*/
explicit ApplyHistoryBest(Database database, PackedFunc logging_func);
explicit ApplyHistoryBest(Database database, ApplyHistoryBestNode::FTEFilterFunc te_filter_func,
PackedFunc logging_func);
/*!
* \brief The current ApplyHistoryBest in the context
* \return The ApplyHistoryBest in the current scope.
Expand Down
23 changes: 23 additions & 0 deletions include/tvm/meta_schedule/extracted_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@
#include <tvm/runtime/object.h>
#include <tvm/target/target.h>

namespace tvm {
namespace tir {
class PrimFunc;
} // namespace tir
namespace te {
class Tensor;
} // namespace te
} // namespace tvm

namespace tvm {
namespace meta_schedule {

Expand Down Expand Up @@ -67,6 +76,20 @@ class ExtractedTask : public runtime::ObjectRef {
ExtractedTaskNode);
};

/*!
* \brief The default TE task filter
* \param args The input/output arguments of the TE compute graph
* \return NullOpt if the task is filtered out, otherwise the task in PrimFunc
*/
Optional<tvm::tir::PrimFunc> DefaultTaskFilter(const Array<tvm::te::Tensor, void>& args);

/*!
* \brief The default TE task filter, with `te.extern` allowed
* \param args The input/output arguments of the TE compute graph
* \return NullOpt if the task is filtered out, otherwise the task in PrimFunc
*/
Optional<tir::PrimFunc> DefaultTaskFilterAllowExtern(const Array<tvm::te::Tensor, void>& args);

} // namespace meta_schedule
} // namespace tvm

Expand Down
26 changes: 22 additions & 4 deletions python/tvm/meta_schedule/apply_history_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
# under the License.
"""A context manager that injects the best tuning record in the database into compilation"""
import logging
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

from tvm._ffi import register_object
from tvm._ffi import get_global_func, register_object
from tvm.ir import IRModule
from tvm.runtime import Object
from tvm.target import Target
from tvm.te import Tensor
from tvm.tir import PrimFunc

from . import _ffi_api
from .database import Database
Expand All @@ -38,13 +40,29 @@ class ApplyHistoryBest(Object):
----------
database : Database
The database to be queried from
te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None
The filtering function for TE computation
If it's a string, it's the name of the filtering function. Built in functions are
- "meta_schedule.DefaultTaskFilter"
- "meta_schedule.DefaultTaskFilterAllowExtern"
If it's None, it's the default filtering function
If it's a callable, it's the filtering function
"""

database: Database

def __init__(self, database: Database) -> None:
def __init__(
self,
database: Database,
te_filter_func: Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None,
) -> None:
if isinstance(te_filter_func, str):
te_filter_func = get_global_func(te_filter_func)
self.__init_handle_by_constructor__(
_ffi_api.ApplyHistoryBest, database, make_logging_func(logger) # type: ignore # pylint: disable=no-member
_ffi_api.ApplyHistoryBest, # type: ignore # pylint: disable=no-member
database,
te_filter_func,
make_logging_func(logger),
)

def query(
Expand Down
16 changes: 12 additions & 4 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""MetaSchedule-Relay integration"""
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np # type: ignore
from tvm import nd
Expand All @@ -24,6 +24,7 @@
from tvm.runtime import NDArray
from tvm.target import Target
from tvm.te import Tensor
from tvm.tir import PrimFunc

from .extracted_task import ExtractedTask
from .utils import autotvm_silencer
Expand All @@ -37,7 +38,7 @@ def extract_task_from_relay(
opt_level: int = 3,
pass_config: Optional[Dict[str, Any]] = None,
disabled_pass: Optional[List[str]] = None,
filter_func: Callable[[List[Tensor]], bool] = None,
te_filter_func: Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None,
) -> List[ExtractedTask]:
"""Extract tuning tasks from a relay program.

Expand All @@ -55,8 +56,13 @@ def extract_task_from_relay(
The pass config of the compiler
disabled_pass : Optional[List[str]]
The list of disabled passes of the compiler
filter_func : Callable[[List[tvm.te.Tensor]], bool]
te_filter_func : Callable[[List[tvm.te.Tensor]], bool]
The filter function to filter out the extracted tasks
If it's a string, it's the name of the filtering function. Built in functions are
- "meta_schedule.DefaultTaskFilter"
- "meta_schedule.DefaultTaskFilterAllowExtern"
If it's None, it's the default filtering function
If it's a callable, it's the filtering function

Returns
-------
Expand All @@ -68,6 +74,8 @@ def extract_task_from_relay(

# pylint: enable=import-outside-toplevel

if isinstance(te_filter_func, str):
te_filter_func = get_global_func(te_filter_func)
extract_task_func = get_global_func(
"relay.backend.MetaScheduleExtractTask",
allow_missing=False,
Expand All @@ -94,4 +102,4 @@ def extract_task_from_relay(
config=pass_config,
disabled_pass=disabled_pass,
):
return list(extract_task_func(mod, target, relay_params, filter_func))
return list(extract_task_func(mod, target, relay_params, te_filter_func))
15 changes: 14 additions & 1 deletion python/tvm/meta_schedule/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def apply_fixed_schedules(
target: Union[str, Target],
params: Optional[Dict[str, NDArray]],
schedule_fn: Callable[[ms.ExtractedTask, Schedule], bool],
te_filter_func=None,
):
"""Apply fixed schedules (manually written, without any tunable knobs) as specified by
schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest.
Expand All @@ -45,14 +46,26 @@ def apply_fixed_schedules(
schedule_fn : Callable[[ExtractedTask, Schedule], bool]
A callable that is applied for each extracted task and the corresponding default schedule.
Returns True if the given schedule should be committed to the database, False otherwise.
te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None
The filtering function for TE computation
If it's a string, it's the name of the filtering function. Built in functions are
- "meta_schedule.DefaultTaskFilter"
- "meta_schedule.DefaultTaskFilterAllowExtern"
If it's None, it's the default filtering function
If it's a callable, it's the filtering function

Returns
-------
database : Database
The database containing dummy tuning records for manually scheduled traces.
"""
target = Target(target) if isinstance(target, str) else target
extracted_tasks = ms.extract_task_from_relay(relay_mod, target, params)
extracted_tasks = ms.extract_task_from_relay(
relay_mod,
target,
params,
te_filter_func=te_filter_func,
)
database = ms.database.MemoryDatabase()
for task in extracted_tasks:
mod = ms.default_config.mod(task.dispatched[0])
Expand Down
15 changes: 12 additions & 3 deletions src/meta_schedule/apply_history_best.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/te/tensor.h>

#include "./utils.h"

namespace tvm {
Expand Down Expand Up @@ -87,10 +89,16 @@ void ApplyHistoryBest::ExitWithScope() {

/**************** ApplyHistoryBest ****************/

ApplyHistoryBest::ApplyHistoryBest(Database database, PackedFunc logging_func) {
ApplyHistoryBest::ApplyHistoryBest(Database database,
ApplyHistoryBestNode::FTEFilterFunc te_filter_func,
PackedFunc logging_func) {
ObjectPtr<ApplyHistoryBestNode> n = make_object<ApplyHistoryBestNode>();
n->database = database;
n->te_filter_func = te_filter_func;
n->logging_func = logging_func;
if (te_filter_func == nullptr) {
n->te_filter_func = DefaultTaskFilter;
}
data_ = n;
}

Expand Down Expand Up @@ -129,8 +137,9 @@ Optional<IRModule> ApplyHistoryBestNode::Query(runtime::String task_name, IRModu

TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode);
TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest")
.set_body_typed([](Database database, PackedFunc logging_func) -> ApplyHistoryBest {
return ApplyHistoryBest(database, logging_func);
.set_body_typed([](Database database, ApplyHistoryBestNode::FTEFilterFunc te_filter_func,
PackedFunc logging_func) -> ApplyHistoryBest {
return ApplyHistoryBest(database, te_filter_func, logging_func);
});
TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestEnterScope")
.set_body_typed(ApplyHistoryBestInternal::EnterScope);
Expand Down
55 changes: 54 additions & 1 deletion src/meta_schedule/extracted_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
* under the License.
*/
#include <tvm/meta_schedule/extracted_task.h>
#include <tvm/te/operation.h>
#include <tvm/te/tensor.h>
#include <tvm/tir/function.h>

#include "../te/operation/create_primfunc.h"
#include "./utils.h"

namespace tvm {
namespace meta_schedule {
Expand All @@ -32,12 +38,59 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target,
data_ = n;
}

Optional<tir::PrimFunc> DefaultTaskFilterImpl(const Array<te::Tensor>& args, bool allow_extern_op) {
using namespace ::tvm::te;
std::vector<Tensor> stack;
std::unordered_set<const TensorNode*> visited;
for (const Tensor& v : args) {
for (const PrimExpr& e : v->shape) {
// Dynamic shape is not supported for now
if (!e->IsInstance<IntImmNode>()) {
return NullOpt;
}
}
if (!visited.count(v.get())) {
visited.insert(v.get());
stack.push_back(v);
}
}
while (!stack.empty()) {
Tensor tensor = stack.back();
stack.pop_back();
if (tensor->op->IsInstance<PlaceholderOpNode>()) {
// do nothing
} else if (tensor->op->IsInstance<ComputeOpNode>() ||
(allow_extern_op && tensor->op->IsInstance<ExternOpNode>())) {
Array<Tensor> inputs = tensor->op->InputTensors();
for (const Tensor& v : inputs) {
if (!visited.count(v.get())) {
visited.insert(v.get());
stack.push_back(v);
}
}
} else {
return NullOpt;
}
}
return te::CreatePrimFunc(args);
}

Optional<tir::PrimFunc> DefaultTaskFilter(const Array<te::Tensor>& args) {
return DefaultTaskFilterImpl(args, false);
}

Optional<tir::PrimFunc> DefaultTaskFilterAllowExtern(const Array<te::Tensor>& args) {
return DefaultTaskFilterImpl(args, true);
}

TVM_REGISTER_NODE_TYPE(ExtractedTaskNode);
TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask")
.set_body_typed([](String task_name, IRModule mod, Target target, Array<IRModule> dispatched,
int weight) -> ExtractedTask {
return ExtractedTask(task_name, mod, target, dispatched, weight);
});

TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilter").set_body_typed(DefaultTaskFilter);
TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilterAllowExtern")
.set_body_typed(DefaultTaskFilterAllowExtern);
} // namespace meta_schedule
} // namespace tvm
1 change: 1 addition & 0 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/cost_model.h>
#include <tvm/meta_schedule/database.h>
#include <tvm/meta_schedule/extracted_task.h>
#include <tvm/meta_schedule/feature_extractor.h>
#include <tvm/meta_schedule/measure_callback.h>
#include <tvm/meta_schedule/profiler.h>
Expand Down
Loading