Skip to content

Commit

Permalink
[TOPI] add 3D upsampling Op. (#4584)
Browse files Browse the repository at this point in the history
* [TOPI] add 3D upsampling Op.

* fix lint issues

* change align_corners to coordinate_transformation_mode

* fix resize3d half_pixel

* make a simple function and clean up trilinear_resize3d_python

* fix doc
  • Loading branch information
optima2005 authored and masahi committed Dec 27, 2019
1 parent 1071e24 commit c3deec1
Show file tree
Hide file tree
Showing 14 changed files with 763 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ This level enables typical convnet models.
tvm.relay.nn.global_max_pool2d
tvm.relay.nn.global_avg_pool2d
tvm.relay.nn.upsampling
tvm.relay.nn.upsampling3d
tvm.relay.nn.batch_flatten
tvm.relay.nn.pad
tvm.relay.nn.lrn
Expand Down Expand Up @@ -254,6 +255,7 @@ Level 2 Definitions
.. autofunction:: tvm.relay.nn.global_max_pool2d
.. autofunction:: tvm.relay.nn.global_avg_pool2d
.. autofunction:: tvm.relay.nn.upsampling
.. autofunction:: tvm.relay.nn.upsampling3d
.. autofunction:: tvm.relay.nn.batch_flatten
.. autofunction:: tvm.relay.nn.pad
.. autofunction:: tvm.relay.nn.lrn
Expand Down
33 changes: 33 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,39 @@ struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
}
};

/*! \brief Attributes for upsampling3d operator */
struct UpSampling3DAttrs : public tvm::AttrsNode<UpSampling3DAttrs> {
double scale_d;
double scale_h;
double scale_w;
std::string layout;
std::string method;
std::string coordinate_transformation_mode;

TVM_DECLARE_ATTRS(UpSampling3DAttrs, "relay.attrs.UpSampling3DAttrs") {
TVM_ATTR_FIELD(scale_d)
.describe("The upsampling factor for depth");
TVM_ATTR_FIELD(scale_h)
.describe("The upsampling factor for height");
TVM_ATTR_FIELD(scale_w)
.describe("The upsampling factor for width");
TVM_ATTR_FIELD(layout).set_default("NCDHW")
.describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Upsampling is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(method).set_default("nearest_neighbor")
.describe("Specify the mode to use for scaling."
"nearest_neighbor - Nearest Neighbor"
"trilinear - Trilinear Interpolation");
TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel")
.describe("Describes how to transform the coordinate in the resized tensor"
"to the coordinate in the original tensor."
"Refer to the ONNX Resize operator specification for details"
"Available options are half_pixel, align_corners and asymmetric");
}
};

/*! \brief Attributes used for the padding operator */
struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
double pad_value;
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,25 @@ def compute_upsampling(attrs, inputs, out_dtype, target):
align_corners = attrs.align_corners
return [topi.nn.upsampling(inputs[0], scale_h, scale_w, layout, method, align_corners)]

# upsampling3d
reg.register_schedule("nn.upsampling3d", reg.schedule_injective)

def schedule_upsampling3d(_, outs, target):
"""Schedule definition of upsampling3d"""
with target:
return topi.generic.schedule_injective(outs)

@reg.register_compute("nn.upsampling3d")
def compute_upsampling3d(attrs, inputs, out_dtype, target):
scale_d = attrs.scale_d
scale_h = attrs.scale_h
scale_w = attrs.scale_w
layout = attrs.layout
method = attrs.method
coordinate_transformation_mode = attrs.coordinate_transformation_mode
return [topi.nn.upsampling3d(inputs[0], scale_d, scale_h, scale_w, layout, method,\
coordinate_transformation_mode)]

# pad
reg.register_schedule("nn.pad", schedule_broadcast)

Expand Down
52 changes: 52 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,58 @@ def upsampling(data,
return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners)


def upsampling3d(data,
scale_d=1,
scale_h=1,
scale_w=1,
layout="NCDHW",
method="nearest_neighbor",
coordinate_transformation_mode="half_pixel"):
"""3D Upsampling.
This operator takes data as input and does 3D scaling to the given scale factor.
In the default case, where the data_layout is `NCDHW`
with data of shape (n, c, d, h, w)
out will have a shape (n, c, d*scale_d, h*scale_h, w*scale_w)
method indicates the algorithm to be used while calculating the out value
and method can be one of ("trilinear", "nearest_neighbor")
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
scale_d : tvm.relay.Expr
The scale factor for depth upsampling.
scale_h : tvm.relay.Expr
The scale factor for height upsampling.
scale_w : tvm.relay.Expr
The scale factor for width upsampling.
layout : str, optional
Layout of the input.
method : str, optional
Scale method to used [nearest_neighbor, trilinear].
coordinate_transformation_mode: string, optional
Describes how to transform the coordinate in the resized tensor
to the coordinate in the original tensor.
Refer to the ONNX Resize operator specification for details.
Available options are "half_pixel", "align_corners" and "asymmetric".
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.upsampling3d(data, scale_d, scale_h, scale_w, layout, method,
coordinate_transformation_mode)


def batch_flatten(data):
"""BatchFlatten.
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ class FIFOBufferAttrs(Attrs):
class UpSamplingAttrs(Attrs):
"""Attributes for nn.upsampling"""

@register_relay_attr_node
class UpSampling3DAttrs(Attrs):
"""Attributes for nn.upsampling3d"""

@register_relay_attr_node
class PadAttrs(Attrs):
"""Attributes for nn.pad"""
Expand Down
90 changes: 87 additions & 3 deletions src/relay/op/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace tvm {
namespace relay {

TVM_REGISTER_NODE_TYPE(UpSamplingAttrs);
TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs);

template <typename T>
Array<Array<Layout> > UpsamplingInferCorrectLayout(
Expand All @@ -50,8 +51,11 @@ Array<Array<Layout> > UpsamplingInferCorrectLayout(
Layout input = new_in_layouts[0];
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
params->layout = input.name(); // modify self to follow the input layout
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))&&
(input.IndexOf(LayoutAxis::Get('D')) == -1 ||
(input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) &&
!input.Contains(LayoutAxis::Get('d'))))) {
params->layout = input.name(); // modify self to follow the input layout
}
}

Expand Down Expand Up @@ -108,7 +112,6 @@ Expr MakeUpSampling(Expr data,
return CallNode::make(op, {data}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.upsampling")
.set_body_typed(MakeUpSampling);

Expand Down Expand Up @@ -138,5 +141,86 @@ RELAY_REGISTER_OP("nn.upsampling")
.set_attr<TOpPattern>("TOpPattern", kInjective);


// UpSampling3D
bool UpSampling3DRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;

static const Layout kNCDHW("NCDHW");

const UpSampling3DAttrs* param = attrs.as<UpSampling3DAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->layout);

auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCDHW);
CHECK(layout_converter.defined())
<< "UpSampling3D only support input layouts that are convertible from NCDHW."
<< " But got " << in_layout;

auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(2, ir::Cast::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d)));
oshape.Set(3, ir::Cast::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h)));
oshape.Set(4, ir::Cast::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w)));

// assign output type
reporter->Assign(types[1],
TensorTypeNode::make(layout_converter.BackwardShape(oshape),
data->dtype));
return true;
}

// Positional relay function to create upsampling3d operator
// used by frontend FFI.
Expr MakeUpSampling3D(Expr data,
double scale_d,
double scale_h,
double scale_w,
std::string layout,
std::string method,
std::string coordinate_transformation_mode) {
auto attrs = make_node<UpSampling3DAttrs>();
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->scale_d = scale_d;
attrs->scale_h = scale_h;
attrs->scale_w = scale_w;
attrs->coordinate_transformation_mode = coordinate_transformation_mode;
static const Op& op = Op::Get("nn.upsampling3d");
return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.upsampling3d")
.set_body_typed(MakeUpSampling3D);


RELAY_REGISTER_OP("nn.upsampling3d")
.describe(R"code(Perform upsampling on input array with nearest neighbour or
bilinear interpolation.
- **data**: data is 5D array of shape
(batch_size, channels, in_depth, in_height, in_width) for NCDHW
(batch_size, in_depth, in_height, in_width, channels) for NDHWC
- **out**: Output is 5D array of shape
for layout NCDHW
(batch_size, channels, in_depth*scale, in_height*scale, in_width*scale)
for layout NDHWC
(batch_size, in_depth*scale, in_height*scale, in_width*scale, channels)
)code" TVM_ADD_FILELINE)
.set_attrs_type<UpSampling3DAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("UpSampling3D", UpSampling3DRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
UpsamplingInferCorrectLayout<UpSampling3DAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace relay
} // namespace tvm
62 changes: 62 additions & 0 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,22 @@ def test_upsampling_infer_type():
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32")

def test_upsampling3d_infer_type():
n, c, d, h, w = tvm.var("n"), tvm.var("c"), tvm.var("d"), tvm.var("h"), tvm.var("w")
scale = tvm.const(2.0, "float64")
x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32"))
y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear")

yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(d*scale)),
tvm.expr.Cast("int32", tvm.round(h*scale)),
tvm.expr.Cast("int32", tvm.round(w*scale))),
"float32")
n, c = tvm.var("n"), tvm.var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 100, 200), "float32"))
y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 200, 200, 400), "float32")

def _test_pool2d(opfunc, reffunc):
n, c, h, w = tvm.var("n"), 10, 224, 224
Expand Down Expand Up @@ -782,6 +798,50 @@ def test_upsampling():
_test_upsampling("NHWC", "nearest_neighbor")
_test_upsampling("NHWC", "bilinear", True)

def _test_upsampling3d(layout, method, coordinate_transformation_mode="half_pixel"):
n, c, d, h, w = tvm.var("n"), 8, 16, 16, 16
scale_d = 2.0
scale_h = 2.0
scale_w = 2.0
dtype = "float32"
def get_shape():
if layout == "NCDHW":
return (c, d, h, w), (c, int(round(d*scale_d)), int(round(h*scale_h)),\
int(round(w*scale_w)))
else:
return (d, h, w, c), (int(round(d*scale_d)), int(round(h*scale_h)),\
int(round(w*scale_w)), c)
ishape, oshape = get_shape()
x = relay.var("x", relay.TensorType((n,) + ishape, dtype))
y = relay.nn.upsampling3d(x, scale_d=scale_d, scale_h=scale_h, scale_w=scale_w,\
layout=layout, method=method,\
coordinate_transformation_mode=coordinate_transformation_mode)

yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n,) + oshape, dtype)
dshape = (1,) + ishape
x = relay.var("x", shape=dshape)
y = relay.nn.upsampling3d(x, scale_d=scale_d, scale_h=scale_h, scale_w=scale_w,\
layout=layout, method=method,\
coordinate_transformation_mode=coordinate_transformation_mode)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
if method == "nearest_neighbor":
ref = topi.testing.upsampling3d_python(data, (scale_d, scale_h, scale_w), layout)
else:
ref = topi.testing.trilinear_resize3d_python(data, (int(round(d*scale_d)),\
int(round(h*scale_h)),\
int(round(w*scale_w))), layout)
for target, ctx in ctx_list():
executor = relay.create_executor("graph", ctx=ctx, target=target)
out = executor.evaluate(func)(data)
tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=1e-5)

def test_upsampling3d():
_test_upsampling3d("NCDHW", "nearest_neighbor")
_test_upsampling3d("NCDHW", "trilinear", "align_corners")
_test_upsampling3d("NDHWC", "nearest_neighbor")
_test_upsampling3d("NDHWC", "trilinear", "align_corners")

def test_conv2d_int8_intrinsics():
def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
Expand Down Expand Up @@ -935,6 +995,7 @@ def test_bitpack_infer_type():
test_conv2d_infer_type()
test_bitpack_infer_type()
test_upsampling_infer_type()
test_upsampling3d_infer_type()
test_flatten_infer_type()
test_pad_infer_type()
test_pad_run()
Expand All @@ -948,4 +1009,5 @@ def test_bitpack_infer_type():
test_bitserial_conv2d_infer_type()
test_batch_flatten()
test_upsampling()
test_upsampling3d()
test_conv2d_int8_intrinsics()
Loading

0 comments on commit c3deec1

Please sign in to comment.