Skip to content

Commit

Permalink
incorporate apache#9797
Browse files Browse the repository at this point in the history
  • Loading branch information
liaopeiyuan committed Sep 15, 2022
1 parent 298e999 commit c5116fa
Show file tree
Hide file tree
Showing 5 changed files with 954 additions and 177 deletions.
143 changes: 134 additions & 9 deletions python/tvm/relay/op/contrib/tachikoma.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@
- The other way is to implement the function by themselves to
check the attributes of the op and decide if it should be offloaded to Tachikoma.
"""
import logging

import tvm.ir
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name

from ...dataflow_pattern import wildcard, is_op
from .register import register_pattern_table

logger = logging.getLogger("Tachikoma")


def _register_external_op_helper(op_name, supported=True):
"""The helper function to indicate that a given operator can be supported
Expand All @@ -49,6 +56,7 @@ def _register_external_op_helper(op_name, supported=True):
f : callable
A function that returns if the operator is supported by Tachikoma.
"""

@tvm.ir.register_op_attr(op_name, "target.tachikoma")
def _func_wrapper(expr):
return supported
Expand All @@ -60,26 +68,143 @@ def _func_wrapper(expr):
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu")
_register_external_op_helper("tanh")
_register_external_op_helper("sigmoid")
_register_external_op_helper("add")
_register_external_op_helper("subtract")
_register_external_op_helper("multiply")


def make_pattern(with_bias=True):
def make_conv_pattern(with_bias=True, with_eltwise=None):
"""Create patterns related to nn.conv2d.
Parameters
----------
with_bias : bool
Whether attach `bias_add` to `nn.conv2d`.
with_eltwise : str
The attached elementwise post-op name.
Returns
-------
conv_out : CallPattern
Call node sequence.
"""
data = wildcard()
weight = wildcard()
bias = wildcard()
conv = is_op('nn.conv2d')(data, weight)
conv = is_op("nn.conv2d")(data, weight)
if with_bias:
conv_out = is_op('add')(conv, bias)
conv_out = is_op("add")(conv, bias)
else:
conv_out = conv
return is_op('nn.relu')(conv_out)
if with_eltwise:
return is_op(with_eltwise)(conv_out)
return conv_out


def make_dense_pattern(with_bias=True, with_eltwise=None):
"""Create patterns related to nn.dense.
Parameters
----------
with_bias : bool
Whether attach `bias_add` to `nn.dense`.
with_eltwise : str
The attached elementwise post-op name.
Returns
-------
dense_out : CallPattern
Call node sequence.
"""
data = wildcard()
weight = wildcard()
bias = wildcard()
dense = is_op("nn.dense")(data, weight)
if with_bias:
dense_out = is_op("add")(dense, bias)
else:
dense_out = dense
if with_eltwise:
dense_out = is_op(with_eltwise)(dense_out)
return dense_out


def make_tachikoma_pattern(op, with_bias, with_eltwise):
"""Create tachikoma patterns.
Parameters
----------
op : str
The first call node's op name.
with_bias : bool
Whether attach `bias_add` to `nn.dense`.
with_eltwise : str
The attached elementwise post-op name.
Returns
-------
pattern : Tuple(pattern_name, CallPattern)
Created pattern name, along with its CallPattern.
"""
pat_name = "tachikoma." + op
pat_name += "_bias" if with_bias else ""
pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
if op == "conv2d":
tachikoma_pattern = (pat_name, make_conv_pattern(with_bias, with_eltwise))
elif op == "dense":
tachikoma_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise))
else:
logger.warning("Currently, only conv2d and dense op are supported, but got %s.", op)
tachikoma_pattern = ()
return tachikoma_pattern


@register_pattern_table("tachikoma")
def pattern_table():
conv2d_bias_relu_pat = ("tachikoma.conv2d_bias_relu", make_pattern(with_bias=True))
conv2d_relu_pat = ("tachikoma.conv2d_relu", make_pattern(with_bias=False))
tachikoma_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
return tachikoma_patterns
"""Create tachikoma patterns.
Returns
-------
tachikoma_patterns : List[tachikoma_pattern]
Created patterns.
"""
elt_list = ["nn.relu", "tanh", "sigmoid", None]
tachikoma_patterns = []
for with_bias in [True, False]:
for elt in elt_list:
if not with_bias and not elt:
return tachikoma_patterns
tachikoma_patterns.append(make_tachikoma_pattern("conv2d", with_bias, elt))
tachikoma_patterns.append(make_tachikoma_pattern("dense", with_bias, elt))
return tachikoma_patterns


def partition_for_tachikoma(mod, params=None):
"""Partition the graph greedily offloading supported operators to Tachikoma.
Parameters
----------
mod : Module
The module to run passes on.
params : Optional[Dict[str, NDArray]]
Constant input parameters.
Returns
-------
mod : Module
Annotated and partitioned module.
"""

if params:
mod["main"] = bind_params_by_name(mod["main"], params)
seq = tvm.transform.Sequential(
[
transform.CanonicalizeOps(),
transform.InferType(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
# fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
transform.SimplifyExpr(),
transform.FoldConstant(),
transform.MergeComposite(pattern_table()),
transform.AnnotateTarget("tachikoma"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
]
)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
return mod
Loading

0 comments on commit c5116fa

Please sign in to comment.