Skip to content

Commit

Permalink
[TOP][Example] register pool, global_pool; add mobilenet example (apa…
Browse files Browse the repository at this point in the history
…che#32)

* register pool, global_pool; add mobilenet example

* tests of pool and global_pool

* use new API of runtime module

* small fix
  • Loading branch information
Huyuwei authored and sergei-mironov committed Aug 8, 2018
1 parent e4cb94f commit 0432407
Show file tree
Hide file tree
Showing 6 changed files with 308 additions and 84 deletions.
117 changes: 117 additions & 0 deletions nnvm/example/mobilenet_inference_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Forward propagation of MobileNet on GPU."""
import numpy as np
import time
import os

import tvm
import topi
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
from tvm.contrib import nvcc

TASK="mobilenet"

target = 'cuda'
ctx = tvm.gpu(0)

@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_60"])
return ptx

def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)

@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
os.mkdir("perf")
write_code(code, "perf/%s_generated.cu" % TASK)
return code

dtype = 'float32'
epsilon = 1e-10 + 1e-5

def conv_block(data, name, channels, kernel_size=(3,3), strides=(1,1), padding=(1,1)):
# convolution + bn + relu
conv = sym.conv2d(data=data, channels=channels, kernel_size=kernel_size, strides=strides,
padding=padding, use_bias=False, layout='NCHW', name=name + '_conv')
bn = sym.batch_norm(data=conv, epsilon=epsilon, name=name + '_bn')
act = sym.relu(data=bn, name=name + '_relu')
return act

def separable_conv_block(data, name, depthwise_channels, pointwise_channels, kernel_size=(3,3), downsample=False, padding=(1,1)):
if downsample:
strides = (2,2)
else:
strides = (1,1)
# depthwise convolution + bn + relu
conv1 = sym.conv2d(data=data, channels=depthwise_channels, groups=depthwise_channels, kernel_size=kernel_size, strides=strides,
padding=padding, use_bias=False, layout='NCHW', name=name + '_conv1')
bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + '_bn1')
act1 = sym.relu(data=bn1, name=name + '_relu1')
# pointwise convolution + bn + relu
conv2 = sym.conv2d(data=act1, channels=pointwise_channels, kernel_size=(1,1), strides=(1,1),
padding=(0,0), use_bias=False, layout='NCHW', name=name + '_conv2')
bn2 = sym.batch_norm(data=conv2, epsilon=epsilon, name=name + '_bn2')
act2 = sym.relu(data=bn2, name=name + '_relu2')
return act2

def mobile_net(num_classes=1000, alpha=1.0, is_shallow=False):
data = sym.Variable("data")
body = conv_block(data, 'conv_block_1', int(32*alpha), strides=(2,2))
body = separable_conv_block(body, 'separable_conv_block_1', int(32*alpha), int(64*alpha))
body = separable_conv_block(body, 'separable_conv_block_2', int(64*alpha), int(128*alpha), downsample=True)
body = separable_conv_block(body, 'separable_conv_block_3', int(128*alpha), int(128*alpha))
body = separable_conv_block(body, 'separable_conv_block_4', int(128*alpha), int(256*alpha), downsample=True)
body = separable_conv_block(body, 'separable_conv_block_5', int(256*alpha), int(256*alpha))
body = separable_conv_block(body, 'separable_conv_block_6', int(256*alpha), int(512*alpha), downsample=True)
if is_shallow:
body = separable_conv_block(body, 'separable_conv_block_7', int(512*alpha), int(1024*alpha), downsample=True)
body = separable_conv_block(body, 'separable_conv_block_8', int(1024*alpha), int(1024*alpha))
else:
for i in range(7, 12):
body = separable_conv_block(body, 'separable_conv_block_%d' % i, int(512*alpha), int(512*alpha))
body = separable_conv_block(body, 'separable_conv_block_12', int(512*alpha), int(1024*alpha), downsample=True)
body = separable_conv_block(body, 'separable_conv_block_13', int(1024*alpha), int(1024*alpha))
pool = sym.global_avg_pool2d(data=body, name='pool')
flatten = sym.flatten(data=pool, name='flatten')
fc = sym.dense(data=flatten, units=num_classes, use_bias=False, name='fc')
softmax = sym.softmax(data=fc, name='softmax')
return softmax


batch_size = 1
num_classes = 1000
image_shape = (3,224,224)
data_shape = (batch_size,) + image_shape
out_shape = (batch_size, num_classes)

net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False)

# build graph
with nnvm.compiler.build_config(opt_level=2):
graph, lib, _ = nnvm.compiler.build(net, target, {'data': data_shape})
# prepare params
params = {}
names = graph.index.input_names
shapes = [graph.json_attr("shape")[graph.index.entry_id(x)] for x in names]
for i in range(len(names)):
params[names[i]] = tvm.nd.array(np.random.uniform(-0.1, 0.1, size=shapes[i]).astype(dtype), ctx=ctx)
# create runtime module
module = nnvm.runtime.create(graph, lib, ctx)
# set input
module.set_input(**params)
# run
print("run")
module.run()
ctx.sync()
start = time.time()
for i in range(1000):
module.run()
ctx.sync()
print("average time cost of 1000 runs = %g ms" % ((time.time() - start)))
# get output
out = module.get_output(0, tvm.nd.empty(out_shape, dtype))
7 changes: 0 additions & 7 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> {
TShape pool_size;
TShape strides;
TShape padding;
int groups;
int layout;
bool ceil_mode;

Expand All @@ -214,12 +213,6 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> {
DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
DMLC_DECLARE_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
DMLC_DECLARE_FIELD(layout)
.add_enum("NCHW", kNCHW)
.add_enum("NHWC", kNHWC)
Expand Down
91 changes: 89 additions & 2 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def compute_relu(attrs, inputs, _):
reg.register_schedule("relu", _fschedule_broadcast)
reg.register_pattern("relu", OpPattern.ELEMWISE)


# leaky_relu
@reg.register_compute("leaky_relu")
def compute_leaky_relu(attrs, inputs, _):
Expand All @@ -27,6 +28,7 @@ def compute_leaky_relu(attrs, inputs, _):
reg.register_schedule("leaky_relu", _fschedule_broadcast)
reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)


# flatten
@reg.register_compute("flatten")
def compute_flatten(attrs, inputs, _):
Expand Down Expand Up @@ -73,11 +75,10 @@ def schedule_dense(_, outs, target):
# naive schedule
return tvm.create_schedule([x.op for x in outs])

# register extern for now, change me when fusion is enabled.
reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE)


# conv
# conv2d
@reg.register_compute("conv2d")
def compute_conv2d(attrs, inputs, _):
"""Compute definition of conv2d"""
Expand Down Expand Up @@ -113,3 +114,89 @@ def schedule_conv2d(attrs, outs, target):
return tvm.create_schedule([x.op for x in outs])

reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# max_pool2d
@reg.register_compute("max_pool2d")
def compute_max_pool2d(attrs, inputs, _):
"""Compute definition of max_pool2d"""
pool_size = attrs.get_int_tuple("pool_size")
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs["ceil_mode"]
assert layout == "NCHW", "only support nchw for now"
assert ceil_mode == "False", "not support ceil_mode now"
return topi.nn.pool(inputs[0], pool_size, strides, padding, pool_type='max')

@reg.register_schedule("max_pool2d")
def schedule_max_pool2d(_, outs, target):
"""Schedule definition of max_pool2d"""
if target == "cuda":
return topi.cuda.schedule_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])

reg.register_pattern("max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# avg_pool2d
@reg.register_compute("avg_pool2d")
def compute_avg_pool2d(attrs, inputs, _):
"""Compute definition of avg_pool2d"""
pool_size = attrs.get_int_tuple("pool_size")
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs["ceil_mode"]
assert layout == "NCHW", "only support nchw for now"
assert ceil_mode == "False", "not support ceil_mode now"
return topi.nn.pool(inputs[0], pool_size, strides, padding, pool_type='avg')

@reg.register_schedule("avg_pool2d")
def schedule_avg_pool2d(_, outs, target):
"""Schedule definition of avg_pool2d"""
if target == "cuda":
return topi.cuda.schedule_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])

reg.register_pattern("avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# global_max_pool2d
@reg.register_compute("global_max_pool2d")
def compute_global_max_pool2d(attrs, inputs, _):
"""Compute definition of global_max_pool2d"""
layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now"
return topi.nn.global_pool(inputs[0], pool_type='max')

@reg.register_schedule("global_max_pool2d")
def schedule_global_max_pool2d(_, outs, target):
"""Schedule definition of global_max_pool2d"""
if target == "cuda":
return topi.cuda.schedule_global_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])

reg.register_pattern("global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# global_avg_pool2d
@reg.register_compute("global_avg_pool2d")
def compute_global_avg_pool2d(attrs, inputs, _):
"""Compute definition of global_avg_pool2d"""
layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now"
return topi.nn.global_pool(inputs[0], pool_type='avg')

@reg.register_schedule("global_avg_pool2d")
def schedule_global_avg_pool2d(_, outs, target):
"""Schedule definition of global_avg_pool2d"""
if target == "cuda":
return topi.cuda.schedule_global_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])

reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
76 changes: 20 additions & 56 deletions nnvm/tests/python/compiler/test_top_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def test_relu():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
data = (data < 0) * data * 0.3 + (data>0) * data - 0.2
Expand All @@ -34,17 +33,10 @@ def test_exp():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = np.exp(data.asnumpy())
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
y_np = np.exp(data)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)


Expand All @@ -58,17 +50,10 @@ def test_log():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = np.log(data.asnumpy())
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
y_np = np.log(data)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)


Expand All @@ -82,17 +67,10 @@ def test_tanh():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = np.sinh(data.asnumpy()) / np.cosh(data.asnumpy())
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
y_np = np.sinh(data) / np.cosh(data)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)


Expand All @@ -105,17 +83,10 @@ def test_sigmoid():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = 1.0 / (1.0 + np.exp(-data.asnumpy()))
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
y_np = 1.0 / (1.0 + np.exp(-data))
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)


Expand All @@ -129,17 +100,10 @@ def test_softmax():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get output
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = topi.testing.softmax_python(data.asnumpy())
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
y_np = topi.testing.softmax_python(data)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)


Expand Down
Loading

0 comments on commit 0432407

Please sign in to comment.