Skip to content

Commit

Permalink
[RELAY] Remove re-exports of tvm.transform
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Apr 14, 2020
1 parent f08d5d7 commit 2febc97
Show file tree
Hide file tree
Showing 38 changed files with 168 additions and 228 deletions.
8 changes: 8 additions & 0 deletions docs/api/python/ir.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@ tvm.ir
:members:
:imported-members:
:autosummary:


tvm.transform
-------------
.. automodule:: tvm.transform
:members:
:imported-members:
:autosummary:
2 changes: 1 addition & 1 deletion docs/dev/convert_layout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r
# Convert the layout to NCHW
# RemoveUnunsedFunctions is used to clean up the graph.
seq = relay.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout('NCHW')])
with relay.transform.PassContext(opt_level=3):
mod = seq(mod)
Expand Down
4 changes: 2 additions & 2 deletions docs/dev/relay_pass_infra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ using ``Sequential`` associated with other types of passes.
func = relay.Function([x], z2)
# Customize the optimization pipeline.
seq = _transform.Sequential([
seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
Expand All @@ -609,7 +609,7 @@ sequential pass example could be like the following to enable IR dumping for

.. code:: python
seq = _transform.Sequential([
seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.PrintIR(),
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,11 @@ TVM_DLL Pass CreateModulePass(

/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param header The header to be attached to the output.
* \param show_meta_data Whether should we show meta data.
* \return The pass.
*/
TVM_DLL Pass PrintIR(std::string header);
TVM_DLL Pass PrintIR(std::string header="", bool show_meta_data=false);

} // namespace transform
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _update_global_key(item, _):
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequantial": _rename("transform.Sequantial"),
"relay.Sequential": _rename("transform.Sequential"),
# TIR
"Variable": _update_tir_var("tir.Var"),
"SizeVar": _update_tir_var("tir.SizeVar"),
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/ir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,19 @@ def create_module_pass(pass_arg):
return create_module_pass


def PrintIR(header):
def PrintIR(header="", show_meta_data=False):
"""A special trace pass that prints the header and IR.
Parameters
----------
header : str
The header to be displayed along with the dump.
show_meta_data : bool
A boolean flag to indicate if meta data should be printed.
Returns
--------
The pass
"""
return _ffi_transform_api.PrintIR(header)
return _ffi_transform_api.PrintIR(header, show_meta_data)
11 changes: 0 additions & 11 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,9 @@
# Scope builder
ScopeBuilder = scope_builder.ScopeBuilder

module_pass = transform.module_pass
function_pass = transform.function_pass

# Parser
fromtext = parser.fromtext

# Param Serialization
save_param_dict = param_dict.save_param_dict
load_param_dict = param_dict.load_param_dict

# Pass manager
PassInfo = transform.PassInfo
PassContext = transform.PassContext
Pass = transform.Pass
ModulePass = transform.ModulePass
FunctionPass = transform.FunctionPass
Sequential = transform.Sequential
8 changes: 4 additions & 4 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,10 @@ def optimize(self):
opt_mod : tvm.IRModule
The optimized module.
"""
seq = transform.Sequential([transform.SimplifyInference(),
transform.FuseOps(0),
transform.ToANormalForm(),
transform.InferType()])
seq = tvm.transform.Sequential([transform.SimplifyInference(),
transform.FuseOps(0),
transform.ToANormalForm(),
transform.InferType()])
return seq(self.mod)

def _make_executor(self, expr=None):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/qnn/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def @main(%quantized_data: Tensor[(200), int32]) -> Tensor[(200), int8] {
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that canonicalizes QNN ops to Relay ops.
"""

Expand Down Expand Up @@ -108,7 +108,7 @@ def Legalize():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that legalizes QNN ops.
"""

Expand Down
33 changes: 18 additions & 15 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pylint: disable=unused-argument, not-context-manager
"""Automatic quantization toolkit."""
import tvm.ir
import tvm
from tvm.runtime import Object

from . import _quantize
Expand Down Expand Up @@ -240,7 +241,7 @@ def partition():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass for VTA rewrite.
"""
return _quantize.QuantizePartition()
Expand All @@ -253,7 +254,7 @@ def annotate():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass for quantization annotation.
"""
return _quantize.QuantizeAnnotate()
Expand All @@ -267,7 +268,7 @@ def realize():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass for quantization realization.
"""
return _quantize.QuantizeRealize()
Expand Down Expand Up @@ -298,11 +299,12 @@ def prerequisite_optimize(mod, params=None):
""" Prerequisite optimization passes for quantization. Perform
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization. """
optimize = _transform.Sequential([_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
optimize = tvm.transform.Sequential(
[_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])

if params:
mod['main'] = _bind_params(mod['main'], params)
Expand Down Expand Up @@ -336,19 +338,20 @@ def quantize(mod, params=None, dataset=None):
"""
mod = prerequisite_optimize(mod, params)

calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1,
name="QuantizeCalibrate")
calibrate_pass = tvm.transform.module_pass(
calibrate(dataset), opt_level=1,
name="QuantizeCalibrate")
quant_passes = [partition(),
annotate(),
calibrate_pass]
if not current_qconfig().do_simulation:
quant_passes.append(realize())
quant_passes.append(_transform.FoldConstant())
quantize_seq = _transform.Sequential(quant_passes)
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
quantize_seq = tvm.transform.Sequential(quant_passes)
with tvm.transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
with quantize_context():
mod = quantize_seq(mod)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ..transform import gradient

def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def optimize(self, prog: Expr):

# necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
# and fusion (to get primitive functions)
opts = relay.transform.Sequential([relay.transform.SimplifyInference(),
opts = tvm.transform.Sequential([relay.transform.SimplifyInference(),
relay.transform.FuseOps(fuse_opt_level=0)])
mod = opts(mod)
optimized = mod['main']
Expand Down
Loading

0 comments on commit 2febc97

Please sign in to comment.