Skip to content

Commit

Permalink
[MetaSchedule] disallow_dynamic_loop (#9997)
Browse files Browse the repository at this point in the history
* [MetaSchedule] disallow_dynamic_loop

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>

* Update src/meta_schedule/postproc/disallow_dynamic_loop.cc

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
7 people committed Jan 22, 2022
1 parent 9a6423c commit 64f2939
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/postproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
# under the License.
"""The tvm.meta_schedule.postproc package."""
from .postproc import Postproc, PyPostproc
from .disallow_dynamic_loop import DisallowDynamicLoop
from .rewrite_reduction_block import RewriteReductionBlock
from .verify_gpu_code import VerifyGPUCode
31 changes: 31 additions & 0 deletions python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that checks if the IRModule has any loop with non-constant extent"""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.DisallowDynamicLoop")
class DisallowDynamicLoop(Postproc):
"""A postprocessor that checks if the IRModule has any loop with non-constant extent"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocDisallowDynamicLoop, # type: ignore # pylint: disable=no-member
)
85 changes: 85 additions & 0 deletions src/meta_schedule/postproc/disallow_dynamic_loop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace tir {

/*! \brief Check if an IRModule has any dynamic loop. */
struct DynamicExtentFinder : private StmtVisitor {
public:
static bool Find(const IRModule& mod) {
DynamicExtentFinder finder;
for (const auto& kv : mod->functions) {
const BaseFunc& func = kv.second;
if (const auto* prim_func = func.as<PrimFuncNode>()) {
finder(prim_func->body);
if (finder.found_) {
return true;
}
}
}
return false;
}

private:
void VisitStmt_(const ForNode* loop) final {
if (!loop->extent->IsInstance<IntImmNode>()) {
found_ = true;
} else {
StmtVisitor::VisitStmt_(loop);
}
}

void VisitStmt(const Stmt& stmt) final {
if (!found_) {
StmtVisitor::VisitStmt(stmt);
}
}

bool found_ = false;
};

} // namespace tir

namespace meta_schedule {

/*! \brief Check if the IRModule has any loop with non-constant extent. */
class DisallowDynamicLoopNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {}
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); }

static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop";
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode);
};

Postproc Postproc::DisallowDynamicLoop() {
ObjectPtr<DisallowDynamicLoopNode> n = make_object<DisallowDynamicLoopNode>();
return Postproc(n);
}

TVM_REGISTER_NODE_TYPE(DisallowDynamicLoopNode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop")
.set_body_typed(Postproc::DisallowDynamicLoop);

} // namespace meta_schedule
} // namespace tvm
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring

import tvm
from tvm import tir
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.postproc import DisallowDynamicLoop
from tvm.script import tir as T
from tvm.target import Target


def _target() -> Target:
return Target("cuda", host="llvm")


def _create_context(mod, target) -> TuneContext:
ctx = TuneContext(
mod=mod,
target=target,
postprocs=[
DisallowDynamicLoop(),
],
task_name="test",
)
for rule in ctx.postprocs:
rule.initialize_with_tune_context(ctx)
return ctx


# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
# fmt: off

@tvm.script.ir_module
class Matmul:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@tvm.script.ir_module
class DynamicLoop:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
for i, j in T.grid(1024, 1024):
for k in T.serial(0, i):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument


def test_postproc_disallow_dynamic_loops():
mod = Matmul
ctx = _create_context(mod, target=_target())
sch = tir.Schedule(mod, debug_mask="all")
assert ctx.postprocs[0].apply(sch)


def test_postproc_disallow_dynamic_loops_fail():
mod = DynamicLoop
ctx = _create_context(mod, target=_target())
sch = tir.Schedule(mod, debug_mask="all")
assert not ctx.postprocs[0].apply(sch)


if __name__ == "__main__":
test_postproc_disallow_dynamic_loops()
test_postproc_disallow_dynamic_loops_fail()

0 comments on commit 64f2939

Please sign in to comment.