diff --git a/nnvm/example/mobilenet_inference_gpu.py b/nnvm/example/mobilenet_inference_gpu.py new file mode 100644 index 0000000000000..4331ec4c0f022 --- /dev/null +++ b/nnvm/example/mobilenet_inference_gpu.py @@ -0,0 +1,117 @@ +"""Forward propagation of MobileNet on GPU.""" +import numpy as np +import time +import os + +import tvm +import topi +import nnvm.symbol as sym +import nnvm.compiler +import nnvm.runtime +from tvm.contrib import nvcc + +TASK="mobilenet" + +target = 'cuda' +ctx = tvm.gpu(0) + +@tvm.register_func +def tvm_callback_cuda_compile(code): + ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_60"]) + return ptx + +def write_code(code, fname): + with open(fname, "w") as f: + f.write(code) + +@tvm.register_func +def tvm_callback_cuda_postproc(code): + if not os.path.exists("perf"): + os.mkdir("perf") + write_code(code, "perf/%s_generated.cu" % TASK) + return code + +dtype = 'float32' +epsilon = 1e-10 + 1e-5 + +def conv_block(data, name, channels, kernel_size=(3,3), strides=(1,1), padding=(1,1)): + # convolution + bn + relu + conv = sym.conv2d(data=data, channels=channels, kernel_size=kernel_size, strides=strides, + padding=padding, use_bias=False, layout='NCHW', name=name + '_conv') + bn = sym.batch_norm(data=conv, epsilon=epsilon, name=name + '_bn') + act = sym.relu(data=bn, name=name + '_relu') + return act + +def separable_conv_block(data, name, depthwise_channels, pointwise_channels, kernel_size=(3,3), downsample=False, padding=(1,1)): + if downsample: + strides = (2,2) + else: + strides = (1,1) + # depthwise convolution + bn + relu + conv1 = sym.conv2d(data=data, channels=depthwise_channels, groups=depthwise_channels, kernel_size=kernel_size, strides=strides, + padding=padding, use_bias=False, layout='NCHW', name=name + '_conv1') + bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + '_bn1') + act1 = sym.relu(data=bn1, name=name + '_relu1') + # pointwise convolution + bn + relu + conv2 = sym.conv2d(data=act1, channels=pointwise_channels, kernel_size=(1,1), strides=(1,1), + padding=(0,0), use_bias=False, layout='NCHW', name=name + '_conv2') + bn2 = sym.batch_norm(data=conv2, epsilon=epsilon, name=name + '_bn2') + act2 = sym.relu(data=bn2, name=name + '_relu2') + return act2 + +def mobile_net(num_classes=1000, alpha=1.0, is_shallow=False): + data = sym.Variable("data") + body = conv_block(data, 'conv_block_1', int(32*alpha), strides=(2,2)) + body = separable_conv_block(body, 'separable_conv_block_1', int(32*alpha), int(64*alpha)) + body = separable_conv_block(body, 'separable_conv_block_2', int(64*alpha), int(128*alpha), downsample=True) + body = separable_conv_block(body, 'separable_conv_block_3', int(128*alpha), int(128*alpha)) + body = separable_conv_block(body, 'separable_conv_block_4', int(128*alpha), int(256*alpha), downsample=True) + body = separable_conv_block(body, 'separable_conv_block_5', int(256*alpha), int(256*alpha)) + body = separable_conv_block(body, 'separable_conv_block_6', int(256*alpha), int(512*alpha), downsample=True) + if is_shallow: + body = separable_conv_block(body, 'separable_conv_block_7', int(512*alpha), int(1024*alpha), downsample=True) + body = separable_conv_block(body, 'separable_conv_block_8', int(1024*alpha), int(1024*alpha)) + else: + for i in range(7, 12): + body = separable_conv_block(body, 'separable_conv_block_%d' % i, int(512*alpha), int(512*alpha)) + body = separable_conv_block(body, 'separable_conv_block_12', int(512*alpha), int(1024*alpha), downsample=True) + body = separable_conv_block(body, 'separable_conv_block_13', int(1024*alpha), int(1024*alpha)) + pool = sym.global_avg_pool2d(data=body, name='pool') + flatten = sym.flatten(data=pool, name='flatten') + fc = sym.dense(data=flatten, units=num_classes, use_bias=False, name='fc') + softmax = sym.softmax(data=fc, name='softmax') + return softmax + + +batch_size = 1 +num_classes = 1000 +image_shape = (3,224,224) +data_shape = (batch_size,) + image_shape +out_shape = (batch_size, num_classes) + +net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False) + +# build graph +with nnvm.compiler.build_config(opt_level=2): + graph, lib, _ = nnvm.compiler.build(net, target, {'data': data_shape}) +# prepare params +params = {} +names = graph.index.input_names +shapes = [graph.json_attr("shape")[graph.index.entry_id(x)] for x in names] +for i in range(len(names)): + params[names[i]] = tvm.nd.array(np.random.uniform(-0.1, 0.1, size=shapes[i]).astype(dtype), ctx=ctx) +# create runtime module +module = nnvm.runtime.create(graph, lib, ctx) +# set input +module.set_input(**params) +# run +print("run") +module.run() +ctx.sync() +start = time.time() +for i in range(1000): + module.run() + ctx.sync() +print("average time cost of 1000 runs = %g ms" % ((time.time() - start))) +# get output +out = module.get_output(0, tvm.nd.empty(out_shape, dtype)) diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index a48a1287fa71b..672709a683944 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -202,7 +202,6 @@ struct Pool2DParam : public dmlc::Parameter { TShape pool_size; TShape strides; TShape padding; - int groups; int layout; bool ceil_mode; @@ -214,12 +213,6 @@ struct Pool2DParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) .describe("If padding is non-zero, then the input is implicitly zero-padded" "on both sides for padding number of points"); - DMLC_DECLARE_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); DMLC_DECLARE_FIELD(layout) .add_enum("NCHW", kNCHW) .add_enum("NHWC", kNHWC) diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index c10830b795567..c3b32d01a2537 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -18,6 +18,7 @@ def compute_relu(attrs, inputs, _): reg.register_schedule("relu", _fschedule_broadcast) reg.register_pattern("relu", OpPattern.ELEMWISE) + # leaky_relu @reg.register_compute("leaky_relu") def compute_leaky_relu(attrs, inputs, _): @@ -27,6 +28,7 @@ def compute_leaky_relu(attrs, inputs, _): reg.register_schedule("leaky_relu", _fschedule_broadcast) reg.register_pattern("leaky_relu", OpPattern.ELEMWISE) + # flatten @reg.register_compute("flatten") def compute_flatten(attrs, inputs, _): @@ -73,11 +75,10 @@ def schedule_dense(_, outs, target): # naive schedule return tvm.create_schedule([x.op for x in outs]) -# register extern for now, change me when fusion is enabled. reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE) -# conv +# conv2d @reg.register_compute("conv2d") def compute_conv2d(attrs, inputs, _): """Compute definition of conv2d""" @@ -113,3 +114,89 @@ def schedule_conv2d(attrs, outs, target): return tvm.create_schedule([x.op for x in outs]) reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# max_pool2d +@reg.register_compute("max_pool2d") +def compute_max_pool2d(attrs, inputs, _): + """Compute definition of max_pool2d""" + pool_size = attrs.get_int_tuple("pool_size") + strides = attrs.get_int_tuple("strides") + padding = attrs.get_int_tuple("padding") + layout = attrs["layout"] + ceil_mode = attrs["ceil_mode"] + assert layout == "NCHW", "only support nchw for now" + assert ceil_mode == "False", "not support ceil_mode now" + return topi.nn.pool(inputs[0], pool_size, strides, padding, pool_type='max') + +@reg.register_schedule("max_pool2d") +def schedule_max_pool2d(_, outs, target): + """Schedule definition of max_pool2d""" + if target == "cuda": + return topi.cuda.schedule_pool(outs) + # naive schedule + return tvm.create_schedule([x.op for x in outs]) + +reg.register_pattern("max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# avg_pool2d +@reg.register_compute("avg_pool2d") +def compute_avg_pool2d(attrs, inputs, _): + """Compute definition of avg_pool2d""" + pool_size = attrs.get_int_tuple("pool_size") + strides = attrs.get_int_tuple("strides") + padding = attrs.get_int_tuple("padding") + layout = attrs["layout"] + ceil_mode = attrs["ceil_mode"] + assert layout == "NCHW", "only support nchw for now" + assert ceil_mode == "False", "not support ceil_mode now" + return topi.nn.pool(inputs[0], pool_size, strides, padding, pool_type='avg') + +@reg.register_schedule("avg_pool2d") +def schedule_avg_pool2d(_, outs, target): + """Schedule definition of avg_pool2d""" + if target == "cuda": + return topi.cuda.schedule_pool(outs) + # naive schedule + return tvm.create_schedule([x.op for x in outs]) + +reg.register_pattern("avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# global_max_pool2d +@reg.register_compute("global_max_pool2d") +def compute_global_max_pool2d(attrs, inputs, _): + """Compute definition of global_max_pool2d""" + layout = attrs["layout"] + assert layout == "NCHW", "only support nchw for now" + return topi.nn.global_pool(inputs[0], pool_type='max') + +@reg.register_schedule("global_max_pool2d") +def schedule_global_max_pool2d(_, outs, target): + """Schedule definition of global_max_pool2d""" + if target == "cuda": + return topi.cuda.schedule_global_pool(outs) + # naive schedule + return tvm.create_schedule([x.op for x in outs]) + +reg.register_pattern("global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# global_avg_pool2d +@reg.register_compute("global_avg_pool2d") +def compute_global_avg_pool2d(attrs, inputs, _): + """Compute definition of global_avg_pool2d""" + layout = attrs["layout"] + assert layout == "NCHW", "only support nchw for now" + return topi.nn.global_pool(inputs[0], pool_type='avg') + +@reg.register_schedule("global_avg_pool2d") +def schedule_global_avg_pool2d(_, outs, target): + """Schedule definition of global_avg_pool2d""" + if target == "cuda": + return topi.cuda.schedule_global_pool(outs) + # naive schedule + return tvm.create_schedule([x.op for x in outs]) + +reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py index baeaaf86040a0..7443449e4096a 100644 --- a/nnvm/tests/python/compiler/test_top_level1.py +++ b/nnvm/tests/python/compiler/test_top_level1.py @@ -16,7 +16,6 @@ def test_relu(): for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) - # get member functions data = np.random.uniform(size=dshape).astype(dtype) m.run(x=data) data = (data < 0) * data * 0.3 + (data>0) * data - 0.2 @@ -34,17 +33,10 @@ def test_exp(): for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input - data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) - set_input("x", data) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) - y_np = np.exp(data.asnumpy()) + data = np.random.uniform(size=dshape).astype(dtype) + m.run(x=data) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + y_np = np.exp(data) np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) @@ -58,17 +50,10 @@ def test_log(): with nnvm.compiler.build_config(opt_level=1): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input - data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) - set_input("x", data) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) - y_np = np.log(data.asnumpy()) + data = np.random.uniform(size=dshape).astype(dtype) + m.run(x=data) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + y_np = np.log(data) np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) @@ -82,17 +67,10 @@ def test_tanh(): with nnvm.compiler.build_config(opt_level=1): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input - data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) - set_input("x", data) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) - y_np = np.sinh(data.asnumpy()) / np.cosh(data.asnumpy()) + data = np.random.uniform(size=dshape).astype(dtype) + m.run(x=data) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + y_np = np.sinh(data) / np.cosh(data) np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) @@ -105,17 +83,10 @@ def test_sigmoid(): for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input - data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) - set_input("x", data) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) - y_np = 1.0 / (1.0 + np.exp(-data.asnumpy())) + data = np.random.uniform(size=dshape).astype(dtype) + m.run(x=data) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + y_np = 1.0 / (1.0 + np.exp(-data)) np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) @@ -129,17 +100,10 @@ def test_softmax(): with nnvm.compiler.build_config(opt_level=1): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input - data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) - set_input("x", data) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) - y_np = topi.testing.softmax_python(data.asnumpy()) + data = np.random.uniform(size=dshape).astype(dtype) + m.run(x=data) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + y_np = topi.testing.softmax_python(data) np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py index 79f29f40ce5c1..c6f38e06fb789 100644 --- a/nnvm/tests/python/compiler/test_top_level2.py +++ b/nnvm/tests/python/compiler/test_top_level2.py @@ -10,8 +10,8 @@ def test_conv2d(): x = sym.Variable("x") - y = sym.conv2d(x, channels=10, kernel_size=(3, 3), - name="y", use_bias=False, padding=(1,1)) + y = sym.conv2d(x, channels=10, kernel_size=(3,3), + name="y", padding=(1,1)) dtype = "float32" dshape = (1, 3, 18, 18) kshape = (10, 3, 3, 3) @@ -20,26 +20,20 @@ def test_conv2d(): for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) m = nnvm.runtime.create(graph, lib, ctx) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - # set input data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) - set_input("x", data) - set_input("y_weight", kernel) - # execute - run() - # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) + bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype)) + m.run(x=data, y_weight=kernel, y_bias=bias) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) c_np = topi.testing.conv2d_nchw_python( data.asnumpy(), kernel.asnumpy(), 1, 1) + c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1) np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) def test_grouped_conv2d(): x = sym.Variable("x") - y = sym.conv2d(x, channels=32, kernel_size=(3, 3), groups=32, + y = sym.conv2d(x, channels=32, kernel_size=(3,3), groups=32, name="y", padding=(1,1)) dtype = "float32" dshape = (1, 32, 18, 18) @@ -49,12 +43,10 @@ def test_grouped_conv2d(): for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) m = nnvm.runtime.create(graph, lib, ctx) - # set input data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype)) m.run(x=data, y_weight=kernel, y_bias=bias) - # get output out = m.get_output(0, tvm.nd.empty(oshape, dtype)) c_np = topi.testing.depthwise_conv2d_python_nchw( data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME') @@ -62,6 +54,78 @@ def test_grouped_conv2d(): np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) +def test_max_pool2d(): + x = sym.Variable("x") + y = sym.max_pool2d(x, pool_size=(2,2), strides=(2,2), padding=(0,0), name="y") + dtype = "float32" + dshape = (1, 3, 28, 28) + oshape = (1, 3, 14, 14) + shape_dict = {"x": dshape} + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) + m = nnvm.runtime.create(graph, lib, ctx) + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + m.run(x=data) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + b_np = np.max(data.asnumpy().reshape(1,3,14,2,14,2), axis=(3,5)) + np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5) + + +def test_avg_pool2d(): + x = sym.Variable("x") + y = sym.avg_pool2d(x, pool_size=(2,2), strides=(2,2), padding=(0,0), name="y") + dtype = "float32" + dshape = (1, 3, 28, 28) + oshape = (1, 3, 14, 14) + shape_dict = {"x": dshape} + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) + m = nnvm.runtime.create(graph, lib, ctx) + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + m.run(x=data) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + b_np = np.mean(data.asnumpy().reshape(1,3,14,2,14,2), axis=(3,5)) + np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5) + + +def test_global_max_pool2d(): + x = sym.Variable("x") + y = sym.global_max_pool2d(x, name="y") + dtype = "float32" + dshape = (1, 1024, 7, 7) + oshape = (1, 1024, 1, 1) + shape_dict = {"x": dshape} + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) + m = nnvm.runtime.create(graph, lib, ctx) + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + m.run(x=data) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + b_np = np.max(data.asnumpy(), axis=(2,3), keepdims=True) + np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5) + + +def test_global_avg_pool2d(): + x = sym.Variable("x") + y = sym.global_avg_pool2d(x, name="y") + dtype = "float32" + dshape = (1, 1024, 7, 7) + oshape = (1, 1024, 1, 1) + shape_dict = {"x": dshape} + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) + m = nnvm.runtime.create(graph, lib, ctx) + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + m.run(x=data) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + b_np = np.mean(data.asnumpy(), axis=(2,3), keepdims=True) + np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5) + + if __name__ == "__main__": test_conv2d() test_grouped_conv2d() + test_max_pool2d() + test_avg_pool2d() + test_global_max_pool2d() + test_global_avg_pool2d() diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py index 8fcc106e4db21..807ab6eab375b 100644 --- a/nnvm/tests/python/frontend/mxnet/test_forward.py +++ b/nnvm/tests/python/frontend/mxnet/test_forward.py @@ -23,7 +23,7 @@ def default_ctx(): else: return tvm.cpu(0) -def test_mxnet_frontend_impl(mx_symbol, data_shape=(2, 3, 224, 224), out_shape=(2, 1000)): +def test_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000)): def get_mxnet_output(symbol, x, dtype='float32'): from collections import namedtuple Batch = namedtuple('Batch', ['data']) @@ -83,6 +83,5 @@ def test_forward_resnet(): if __name__ == '__main__': test_forward_mlp() - # waiting for max_pool2d - # test_forward_vgg() - # test_forward_resnet() + test_forward_vgg() + test_forward_resnet()