Skip to content

Commit

Permalink
[BUG] ToBasicBlockNormalForm immutability (apache#8778)
Browse files Browse the repository at this point in the history
* ToBasicBlockNormalForm immutability

* better comment on ToBasicBlock

* refine comment of ToBasicBlockForm
  • Loading branch information
ganler authored and ylc committed Sep 29, 2021
1 parent 58a59b0 commit 721994e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
11 changes: 7 additions & 4 deletions src/relay/transforms/to_basic_block_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ Expr ToBasicBlockNormalFormAux(const Expr& e) {
IRModule ToBasicBlockNormalForm(const IRModule& mod) {
DLOG(INFO) << "ToBBlock:" << std::endl << mod;

// Create a new module by shallow copy.
auto mod_ = IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map);

tvm::Map<GlobalVar, Function> updates;
auto funcs = mod->functions;
auto funcs = mod_->functions;
for (const auto& it : funcs) {
ICHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables";
if (const auto* n = it.second.as<FunctionNode>()) {
Expand All @@ -63,12 +66,12 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) {
}

for (auto pair : updates) {
mod->Add(pair.first, pair.second, true);
mod_->Add(pair.first, pair.second, true);
}

DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod;
DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod_;

return mod;
return mod_;
}

bool BasicBlockNormalFormCheck(const Expr& e) {
Expand Down
24 changes: 23 additions & 1 deletion tests/python/relay/test_pass_to_basic_block_normal_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.relay.analysis import detect_feature
from tvm.relay import op, create_executor, transform
from tvm.relay.prelude import Prelude
from tvm.relay.testing import count
from tvm.relay.testing import count, create_workload
from tvm.relay.analysis import Feature
from tvm.relay.analysis import check_basic_block_normal_form

Expand Down Expand Up @@ -491,5 +491,27 @@ def test_higher_order_nested():
check_basic_block_normal_form(bblock)


def test_immutability():
simple_net = relay.nn.conv2d(
data=relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")),
weight=relay.var("weight"),
kernel_size=(5, 5),
channels=3,
padding=(1, 1),
)
simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net)
mod, _ = create_workload(simple_net)

old_mod = mod

with tvm.transform.PassContext(opt_level=4):
with tvm.target.Target("llvm"):
seq = tvm.transform.Sequential(passes=[transform.ToBasicBlockNormalForm()], opt_level=4)
new_mod = seq(mod)

assert old_mod.astext() == mod.astext()
assert old_mod.astext() != new_mod.astext()


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 721994e

Please sign in to comment.