Skip to content

Commit

Permalink
add (apache#4311)
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon authored and Xingyu Zhou committed Nov 13, 2019
1 parent f625f72 commit 7da13f2
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 3 deletions.
68 changes: 65 additions & 3 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

import json
import tvm
from topi.util import get_const_tuple
from .. import analysis
from .. import expr as _expr
from .. import op as _op
from .. import module as _module
from .. import scope_builder as _scope_builder
from ... import nd as _nd

from .common import StrAttrsDict
Expand Down Expand Up @@ -1037,6 +1039,47 @@ def _mx_contrib_fifo_buffer(inputs, attrs):
new_attrs['axis'] = attrs.get_int('axis')
return _op.nn.fifo_buffer(*inputs, **new_attrs)

def _mx_cond(inputs, attrs, subgraphs):
assert len(subgraphs) == 3
cond_input_locs = json.loads(attrs.get_str("cond_input_locs"))
then_input_locs = json.loads(attrs.get_str("then_input_locs"))
else_input_locs = json.loads(attrs.get_str("else_input_locs"))
num_outputs = attrs.get_int("num_outputs")

input_args = []
for i, arg in enumerate(inputs):
var = _expr.var("arg%s" % i, _infer_type(arg).checked_type)
input_args.append(var)
cond_args = [input_args[i] for i in cond_input_locs]
then_args = [input_args[i] for i in then_input_locs]
else_args = [input_args[i] for i in else_input_locs]

cond_arg_shapes = [arg.type_annotation.shape for arg in cond_args]
cond_arg_dtype_info = [arg.type_annotation.dtype for arg in cond_args]
cond_func = _from_mxnet_impl(subgraphs[0], cond_arg_shapes, cond_arg_dtype_info)
cond = _expr.Call(cond_func, cond_args).astype("bool")
cond_shape = get_const_tuple(_infer_type(cond).checked_type.shape)
if len(cond_shape) > 0:
assert len(cond_shape) == 1 and cond_shape[0] == 1, "Condition is not scalar"
cond = _op.take(cond, _expr.const(1, "int"))

sb = _scope_builder.ScopeBuilder()
with sb.if_scope(cond):
then_arg_shapes = [arg.type_annotation.shape for arg in then_args]
then_arg_dtype_info = [arg.type_annotation.dtype for arg in then_args]
then_func = _from_mxnet_impl(subgraphs[1], then_arg_shapes, then_arg_dtype_info)
sb.ret(_expr.Call(then_func, then_args))
with sb.else_scope():
else_arg_shapes = [arg.type_annotation.shape for arg in else_args]
else_arg_dtype_info = [arg.type_annotation.dtype for arg in else_args]
else_func = _from_mxnet_impl(subgraphs[2], else_arg_shapes, else_arg_dtype_info)
sb.ret(_expr.Call(else_func, else_args))
func = _expr.Function(input_args, sb.get())
ret = _expr.Call(func, inputs)
if num_outputs > 1:
ret = _expr.TupleWrapper(ret, num_outputs)
return ret


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
Expand Down Expand Up @@ -1204,6 +1247,8 @@ def _mx_contrib_fifo_buffer(inputs, attrs):
# NLP
"RNN" : _mx_rnn_layer,
"_rnn_param_concat" : _mx_rnn_param_concat,
# control flow
"_cond" : _mx_cond,
# Depricated:
"Crop" : _mx_crop_like,
# List of missing operators that are present in NNVMv1
Expand Down Expand Up @@ -1245,24 +1290,41 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None):
Converted relay Function
"""
assert symbol is not None
jgraph = json.loads(symbol.tojson())
if isinstance(symbol, dict):
jgraph = symbol
else:
jgraph = json.loads(symbol.tojson())
jnodes = jgraph["nodes"]
node_map = {}
shape_idx = 0

for nid, node in enumerate(jnodes):
children = [node_map[e[0]][e[1]] for e in node["inputs"]]
attrs = StrAttrsDict(node.get("attrs", {}))
node_name = node["name"]
op_name = node["op"]
if op_name == "null":
shape = shape_dict[node_name] if node_name in shape_dict else None
if isinstance(shape_dict, dict):
shape = shape_dict[node_name] if node_name in shape_dict else None
elif isinstance(shape_dict, (list, tuple)):
shape = shape_dict[shape_idx]
else:
raise ValueError("Unknown type of shape_dict: %s" + type(shape_dict))
if isinstance(dtype_info, dict):
dtype = dtype_info[node_name] if node_name in dtype_info else "float32"
elif isinstance(dtype_info, (list, tuple)):
dtype = dtype_info[shape_idx]
else:
dtype = dtype_info
if isinstance(shape_dict, (list, tuple)):
shape_idx += 1
node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)]
elif op_name in _convert_map:
res = _convert_map[op_name](children, attrs)
if op_name in ['_cond', '_foreach', '_while_loop']:
subgraphs = node['subgraphs']
res = _convert_map[op_name](children, attrs, subgraphs)
else:
res = _convert_map[op_name](children, attrs)
if res is None:
# defer conversion, used in RNN state initialization
res = [node]
Expand Down
26 changes: 26 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,31 @@ def verify(data_shape, kernel_size, stride, pad, num_filter):
verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)

def test_forward_cond():
def verify(a_np, b_np):
a_nd, b_nd = mx.nd.array(a_np), mx.nd.array(b_np)
pred = a_nd * b_nd < 5
then_func = lambda: (a_nd + 5) * (b_nd + 5)
else_func = lambda: (a_nd - 5) * (b_nd - 5)
ref_res = mx.nd.contrib.cond(pred, then_func, else_func)

a_sym, b_sym = mx.sym.var("a"), mx.sym.var("b")
pred = a_sym * b_sym < 5
then_func = lambda: (a_sym + 5) * (b_sym + 5)
else_func = lambda: (a_sym - 5) * (b_sym - 5)
mx_sym = mx.sym.contrib.cond(pred, then_func, else_func)

shape_dict = {"a": a_np.shape, "b": b_np.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["debug", "vm"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(a_np, b_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)

verify(np.asarray([1.0], 'float32'), np.asarray([2.0],'float32'))
verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))


if __name__ == '__main__':
test_forward_mlp()
Expand Down Expand Up @@ -963,3 +988,4 @@ def verify(data_shape, kernel_size, stride, pad, num_filter):
test_forward_one_hot()
test_forward_convolution()
test_forward_deconvolution()
test_forward_cond()

0 comments on commit 7da13f2

Please sign in to comment.