From 5f62154ec0d256bfc7c481ac2decd046888c9cec Mon Sep 17 00:00:00 2001 From: JiaKui Hu Date: Fri, 15 Apr 2022 13:14:35 +0800 Subject: [PATCH] [RELAY][FRONTEND] Initial OneFlow frontend support. (#8790) * add relay.f.frontend.fm_oneflow support cnns * support cuda * fix mobilenetv2 and reviews * fix: model without meta info * support eager and yolo, add test * fix: license * add: tutorials * fix: support new graph * fix some comments * refine * fix concat op convert bug * refine * refine * change cuda to cpu * fix bug * fix ci error in tvm * fix pylint check * delete useless file * add skimage package in docker * fix ci error * fix bug * add oneflow fronted test in ci * merge conflict * fix tutorial * try to find error in ci * revert * merge conflict * black oneflow * Delete from_oneflow.py Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Co-authored-by: BBuf <1182563586@qq.com> --- python/tvm/relay/frontend/__init__.py | 1 + python/tvm/relay/frontend/oneflow.py | 1821 +++++++++++++++++ tests/python/frontend/oneflow/test_forward.py | 723 +++++++ tests/scripts/task_python_frontend.sh | 3 + 4 files changed, 2548 insertions(+) create mode 100644 python/tvm/relay/frontend/oneflow.py create mode 100644 tests/python/frontend/oneflow/test_forward.py diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index aa49b63203f2c..fbbd4f99212d9 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -23,6 +23,7 @@ from .mxnet import from_mxnet from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var from .keras import from_keras +from .oneflow import from_oneflow from .onnx import from_onnx from .tflite import from_tflite from .coreml import from_coreml diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py new file mode 100644 index 0000000000000..c15b7b3c249cf --- /dev/null +++ b/python/tvm/relay/frontend/oneflow.py @@ -0,0 +1,1821 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines +# pylint: disable=import-outside-toplevel +"""OneFlow: OneFlow is a performance-centered and open-source deep learning framework.""" + +import os +import re +import copy +import warnings + +import numpy as np +import tvm +from tvm.ir import IRModule +from tvm.topi.utils import get_const_tuple + +from .. import analysis +from .. import expr as _expr +from .. import function as _function +from .. import op as _op +from .. import ty as _ty +from .common import ( + AttrCvt, + Renamer, + fold_constant, + get_relay_op, + infer_channels, + infer_shape, + infer_type, + new_var, +) + +__all__ = ["from_oneflow"] + +FLOW_2_STR_DTYPE = { + 2: "float32", + 3: "float64", + 6: "int64", + 5: "int32", + 4: "int8", + 7: "uint8", + 9: "float16", +} + + +def is_input_op(node): + """Return true when the node is the input of the graph.""" + return node.WhichOneof("op_type") == "input_conf" + + +def is_user_op(node): + """Return true when the node is the intermediate variables of graph.""" + return node.WhichOneof("op_type") == "user_conf" + + +def is_output_op(node): + """Return true when the node is the output of the graph.""" + return node.WhichOneof("op_type") == "output_conf" + + +def is_param_op(node): + """Return true when the node is the intermediate variables of model(saved).""" + return node.WhichOneof("op_type") == "variable_conf" + + +def get_node_info(node): + """ + Get basic information about nodes: shape, data_type + """ + # list->tuple + shape = tuple(node.input_conf.blob_conf.shape.dim) + # get data type + dtype = node.input_conf.blob_conf.data_type + if dtype in list(FLOW_2_NP_DTYPE.keys()): + data_type = FLOW_2_NP_DTYPE[dtype] + else: + raise IndexError("Please check the data type of your node: %s" % node.name) + + return shape, data_type + + +def _dtype_shape_promotion(inputs): + """Promote data type and shape for list of tensors.""" + + dtype_order = ["bool", "int8", "int16", "int32", "int64", "float32", "float64"] + + ranks = [len(infer_shape(x)) for x in inputs] + if set(ranks) == set([1, 0]): + for i, r in enumerate(ranks): + if r == 0: + inputs[i] = _op.expand_dims(inputs[i], axis=0) + + dtypes = set(dtype_order.index(infer_type(x).checked_type.dtype) for x in inputs) + if len(dtypes) == 1: + return inputs + max_dtype = dtype_order[max(dtypes)] + for i, input_op in enumerate(inputs): + if infer_type(input_op).checked_type.dtype != max_dtype: + inputs[i] = input_op.astype(max_dtype) + return inputs + + +def parse_attr(attr): + """Parse attribute of user op in oneflow.""" + attrs = {} + for a in attr: + attr_str = str(attr[a]) + + if attr_str[0:7] == "at_list": + attr_str_ = attr_str.split(" ")[0] + + if attr_str_ == "at_list_float": + attrs[a] = tuple(attr[a].at_list_float.val) + elif attr_str_ == "at_list_int32": + attrs[a] = tuple(attr[a].at_list_int32.val) + elif attr_str_ == "at_list_int64": + attrs[a] = tuple(attr[a].at_list_int64.val) + + elif attr_str.split(":")[0] == "at_string": + attrs[a] = attr[a].at_string + + elif attr_str.split(" ")[0] == "at_shape": + attrs[a] = tuple(list(attr[a].at_shape.dim)) + + else: + attr_str_ = attr_str.split(":")[0] + if attr_str_ == "at_bool": + attrs[a] = attr[a].at_bool + elif attr_str_ == "at_double": + attrs[a] = attr[a].at_double + elif attr_str_ == "at_float": + attrs[a] = attr[a].at_float + elif attr_str_ == "at_int32": + attrs[a] = attr[a].at_int32 + elif attr_str_ == "at_int64": + attrs[a] = attr[a].at_int64 + + return attrs + + +def shape_of(x, dtype="int64"): + ttype = infer_type(x).checked_type + if not _ty.is_dynamic(ttype): + shape = list(ttype.shape) + return _expr.const(shape, dtype) + + return _op.shape_of(x, dtype) + + +def dimension_constraint(): + def _dim_check(attrs): + if len(attrs["kernel_size"]) in [1, 2, 3]: + return True + return False + + return _dim_check, "Only 1d, 2d and 3d kernel supported." + + +class OneFlowOpConverter(object): + """A helper class for holding oneflow op converters.""" + + @classmethod + def get_converter(cls): + """ + Get converter matches given opset. + Parameters + ---------- + None + + Returns + ------- + converter, which should be `_impl_vx`. + """ + version = 1 + if hasattr(cls, "_impl_v{}".format(version)): + return getattr(cls, "_impl_v{}".format(version)) + raise NotImplementedError("version {} of {} not implemented".format(version, cls.__name__)) + + +class Pool(OneFlowOpConverter): + """A helper class for pool op converters.""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + data = inputs[0] + attrs.pop("data_format") + out = AttrCvt( + op_name=cls.name, + transforms={ + "kernel_size": "pool_size", + "stride": "strides", + "dilations": ("dilation", 1), + }, + ignores=["return_indices", "divisor_override"], + custom_check=dimension_constraint(), + )([data], attrs, params) + + return out + + +class AdaptiveAvgPool2d(OneFlowOpConverter): + """Operator converter for AdaptiveAvgPool2d""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.nn.adaptive_avg_pool2d(inputs[0], output_size=attrs["output_size"]) + + +class AdaptiveMaxPool2d(OneFlowOpConverter): + """Operator converter for AdaptiveMaxPool2d""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.nn.adaptive_max_pool2d(inputs[0], output_size=attrs["output_size"]) + + +class GlobalAveragePool(OneFlowOpConverter): + """Operator converter for GlobalAveragePool""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + rank = len(infer_shape(inputs[0])) + if rank == 3: + return _op.nn.global_avg_pool1d(inputs[0]) + if rank == 4: + return _op.nn.global_avg_pool2d(inputs[0]) + if rank == 5: + return _op.nn.global_avg_pool3d(inputs[0]) + raise NotImplementedError( + "Global average pooling is only implemented for 1D, 2D, and 3D kernels, got %dD." + % (rank - 2), + ) + + +class GlobalMaxPool(OneFlowOpConverter): + """Operator converter for GlobalMaxPool""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + rank = len(infer_shape(inputs[0])) + if rank == 3: + return _op.nn.global_max_pool1d(inputs[0]) + if rank == 4: + return _op.nn.global_max_pool2d(inputs[0]) + if rank == 5: + return _op.nn.global_max_pool3d(inputs[0]) + raise NotImplementedError( + "Global max pooling is only implemented for 1D, 2D, and 3D kernels, got %dD." + % (rank - 2), + ) + + +class Conv(OneFlowOpConverter): + """A helper class for conv op converters.""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + # The kernel is imported from model_dir_path, without the ".weight" logo, etc. + # The data is obtained through the graph, its op contains "-input_" + in_names = ["-input_"] + kernel_names = [".weight"] + for i in inputs: + IN_NAMES = any(x in str(i) for x in in_names) + KERNEL_NAMES = any(x in str(i) for x in kernel_names) + if IN_NAMES: + data = i + elif KERNEL_NAMES: + kernel = i + else: + data = i + + # Use shape of input to determine convolution type. + kernel_type = infer_type(kernel) + kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] + + if "kernel_size" not in attrs: + attrs["kernel_size"] = kernel_shapes[0][2:] + if "dilation_rate" in attrs: + attrs["dilation"] = list(attrs["dilation_rate"]) + attrs.pop("dilation_rate") + + pad_v = attrs.get("padding_before", [0, 0]) + attrs["padding"] = [pad_v[0], pad_v[1], pad_v[0], pad_v[1]] + + group_conv1d = False + if cls.name == "conv1d" and attrs.get("groups") != 1: + group_conv1d = True + # Expand input from NCW to NCHW + data = _op.expand_dims(data, axis=2) + # Expand kernel from OIW to OIHW + kernel = _op.expand_dims(kernel, axis=2) + # Add new value to kernel_shape, strices, dilation, pads, if needed + attrs["kernel_size"] = [1] + list(attrs["kernel_size"]) + if "strides" in attrs: + attrs["strides"] = [1] + list(attrs["strides"]) + if "dilations" in attrs: + attrs["dilation"] = [1] + list(attrs["dilations"]) + + out = AttrCvt( + op_name=cls.name, + transforms={ + "group": ("groups", 1), + }, + ignores=["data_format", "filters", "padding_after", "padding_before"], + custom_check=dimension_constraint(), + )([data, kernel], attrs, params) + + # If this was a group_conv1d, squish output back to NCW. + if group_conv1d: + out = _op.squeeze(out, axis=[2]) + + return out + + +class ConvTranspose(OneFlowOpConverter): + """Operator converter for ConvTranspose.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + in_names = ["-input_"] + kernel_names = [".weight"] + for i in inputs: + IN_NAMES = any(x in str(i) for x in in_names) + KERNEL_NAMES = any(x in str(i) for x in kernel_names) + if IN_NAMES: + data = i + elif KERNEL_NAMES: + kernel = i + else: + data = i + + # get number of channels + attrs["channels"] = attrs.get("filters", 1) + attrs["groups"] = attrs.get("group", 1) + + kernel_type = infer_type(kernel) + kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] + + if "kernel_size" not in attrs: + attrs["kernel_size"] = kernel_shapes[0][2:] + + if "dilation_rate" in attrs: + attrs["dilation"] = list(attrs["dilation_rate"]) + attrs.pop("dilation_rate") + + pad_v = attrs.get("padding_before", [0, 0]) + attrs["padding"] = [pad_v[0], pad_v[1], pad_v[0], pad_v[1]] + + out = AttrCvt( + op_name=cls.name, + transforms={ + "group": ("groups", 1), + }, + disables=["filters", "data_format", "padding_before"], + custom_check=dimension_constraint(), + )([data, kernel], attrs, params) + + return out + + +class Upsample(OneFlowOpConverter): + """A helper class for upsample op converters""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + data = inputs[0] + input_shape = infer_shape(data) + dims = len(input_shape) + + width_scale = attrs.get("width_scale", 1.0) + height_scale = attrs.get("height_scale", 1.0) + align_corners = attrs.get("align_corners", False) + + if "nearest" in cls.name: + method = "nearest_neighbor" + elif "trilinear" in cls.name: + method = "trilinear" + elif "bilinear" in cls.name: + method = "bilinear" + + # in 3d case, we use the purely static op + if dims == 5: + if isinstance(scales, _expr.Expr): + scale_h = _op.take(scales, _op.const(3)) + scale_w = _op.take(scales, _op.const(4)) + scale_d = _op.take(scales, _op.const(1)) + else: + assert len(scales) == 5 + scale_h = scales[-2] + scale_w = scales[-1] + scale_d = scales[-3] + + layout = "NCDHW" + out = _op.nn.upsampling3d( + data, + scale_d, + scale_h, + scale_w, + layout=layout, + method=method, + coordinate_transformation_mode="asymmetric", + ) + # in 2d case, use dynamic op + else: + if isinstance(height_scale, _expr.Expr): + height_scale = _op.take(height_scale, _op.const(3)) + width_scale = _op.take(width_scale, _op.const(4)) + layout = "NCHW" + + out = _op.nn.upsampling( + inputs[0], + height_scale, + width_scale, + layout=layout, + method=method, + align_corners=align_corners, + ) + return out + + +class UpsampleNearest(Upsample): + """Operator converter for Upsample Nearest""" + + name = "upsample_nearest" + + +class UpsampleBiLinear(Upsample): + """Operator converter for Upsample Bilinear""" + + name = "upsample_bilinear" + + +class Conv2d(Conv): + """Operator converter for Conv2d""" + + name = "conv2d" + + +class ConvTranspose2d(ConvTranspose): + """Operator converter for ConvTranspose2d""" + + name = "conv2d_transpose" + + +class BatchNorm(OneFlowOpConverter): + """Operator converter for BatchNorm""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + # sort the inputs + sorted_inputs = copy.deepcopy(inputs) + for i in inputs: + IN_NAMES = "-input_" in str(i) + if IN_NAMES: + sorted_inputs[0] = i + elif "weight" in str(i) and not IN_NAMES: + sorted_inputs[1] = i + elif "bias" in str(i) and not IN_NAMES: + sorted_inputs[2] = i + elif "mean" in str(i) and not IN_NAMES: + sorted_inputs[3] = i + elif "var" in str(i) and not IN_NAMES: + sorted_inputs[4] = i + + if "data_format" in attrs: + if attrs["data_format"] == "channel_first": + attrs["axis"] = 1 + + out = AttrCvt(op_name="batch_norm", ignores=["training"], disables=["momentum"])( + sorted_inputs, attrs, params + ) + return out[0] + + +class Flatten(OneFlowOpConverter): + """Operator converter for Flatten""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axis = attrs.get("axis", 1) + ishape = _op.shape_of(inputs[0]) + ndim = infer_shape(ishape)[0] + if axis < 0: + axis = axis + ndim + + if axis == 1: + out = _op.nn.batch_flatten(inputs[0]) + else: + pre_shape = _op.prod(_op.strided_slice(ishape, [0], [axis], [1]), keepdims=True) + post_shape = _op.prod(_op.strided_slice(ishape, [axis], [ndim], [1]), keepdims=True) + newshape = _op.concatenate([pre_shape, post_shape], axis=0) + out = _op.reshape(inputs[0], newshape) + return out + + +class MatMul(OneFlowOpConverter): + """Operator converter for MatMul""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 2, "Gemm op take 2 inputs, {} given".format(len(inputs)) + # Similar to 'class Conv' + true_names = ["weight"] + false_names = ["-input_"] + for i in inputs: + T_NAMES = any(x in str(i) for x in true_names) + F_NAMES = any(x in str(i) for x in false_names) + if T_NAMES and not F_NAMES: + matmul_b = i + else: + matmul_a = i + + dtype = infer_type(matmul_a).checked_type.dtype + + # Y = alpha * A * B + alpha = float(attrs.get("alpha", 1.0)) + transA = bool(attrs.get("transpose_a", False)) + transB = bool(attrs.get("transpose_b", False)) + + # get number of channels + channels = infer_channels(matmul_b, not transB) + if transA: + matmul_a = _op.transpose(matmul_a, axes=(1, 0)) + if not transB: + matmul_b = _op.transpose(matmul_b, axes=(1, 0)) + matmul_a = _op.nn.batch_flatten(matmul_a) + if alpha != 1.0: + matmul_a *= _expr.const(alpha, dtype=dtype) + + return _op.nn.dense(matmul_a, matmul_b, units=channels) + + +class Reduce(OneFlowOpConverter): + """Operator converter for reduce ops""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + attr = {"axis": attrs.get("axis", 0), "keepdims": attrs.get("keepdims", True)} + return AttrCvt(cls.name)(inputs, attr) + + +class ReduceMax(Reduce): + """Operator converter for ReduceMax""" + + name = "max" + + +class ReduceMin(Reduce): + """Operator converter for ReduceMin""" + + name = "min" + + +class ReduceSum(Reduce): + """Operator converter for ReduceSum""" + + name = "sum" + + +class ReduceMean(Reduce): + """Operator converter for ReduceMean""" + + name = "mean" + + +class Square(OneFlowOpConverter): + """Operator converter for square""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 1, "Square op {} take 1 inputs, {} given".format( + cls.name, len(inputs) + ) + return _op.multiply(inputs[0], inputs[0]) + + +class Add(OneFlowOpConverter): + """Operator converter for Add""" + + name = "add" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(cls.name, len(inputs)) + axis = int(attrs.get("axis", 0)) + + true_names = ["weight", "bias"] + false_names = ["-input_"] + + for i in inputs: + T_NAMES = any(x in str(i) for x in true_names) + F_NAMES = any(x in str(i) for x in false_names) + if T_NAMES and not F_NAMES: + add_b = i + else: + add_a = i + + # fix the shape + add_shape = infer_shape(add_a) + if len(add_shape) > 2: + add_b = _op.expand_dims(add_b, axis=axis, num_newaxis=len(add_shape) - 2) + add_b_shape = list(infer_shape(add_b)) + add_b_shape.insert(0, add_shape[0]) + + add_b = _op.reshape(add_b, tuple(add_b_shape)) + out = get_relay_op(cls.name)(add_a, add_b) + + return out + + +class Expand(OneFlowOpConverter): + """Operator converter for Expand""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + input_shape = infer_shape(inputs[0]) + assert input_shape == attrs["in_shape"], "shape wrong" + + new_shape = attrs["out_shape"] + out = _op.broadcast_to(inputs[0], shape=new_shape) + + return out + + +class ExpandDim(OneFlowOpConverter): + """Operator converter for ExpandDim""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + + return _op.expand_dims(inputs[0], axis=attrs.get("axis", 0)) + + +class BroadcastMath(OneFlowOpConverter): + """Operator converter for broadcast math ops""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(cls.name, len(inputs)) + beta_names = ["weight", "bias", "mean", "var", "Constant"] + + for i in inputs: + T_NAMES = any([x in str(i) for x in beta_names]) + if T_NAMES and "-input_" not in str(i): + input_b = i + else: + input_a = i + + if cls.name == "divide": + length = [] + for i in inputs: + length.append(len(str(i))) + for i in inputs: + if len(str(i)) == max(length): + input_a = i + else: + input_b = i + if cls.name == "subtract": + length = [] + for i in inputs: + length.append(len(str(i))) + for i in inputs: + if len(str(i)) == max(length): + input_b = i + else: + input_a = i + try: + return get_relay_op(cls.name)(input_a, input_b) + except UnboundLocalError: + return get_relay_op(cls.name)(*inputs) + + +class BroadcastMul(BroadcastMath): + """Operator converter for Mul broadcast""" + + name = "multiply" + + +class BroadcastAdd(BroadcastMath): + """Operator converter for Add broadcast""" + + name = "add" + + +class BroadcastSub(BroadcastMath): + """Operator converter for Sub broadcast""" + + name = "subtract" + + +class BroadcastDiv(BroadcastMath): + """Operator converter for Div broadcast""" + + name = "divide" + + +class Greater(OneFlowOpConverter): + """Operator converter for greater""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.greater(inputs[0], inputs[1]) + + +class Log1p(OneFlowOpConverter): + """Operator converter for Log1p""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.log(inputs[0] + _expr.const(1.0)) + + +class Expm1(OneFlowOpConverter): + """Operator converter for Expm1""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.exp(inputs[0]) - _expr.const(1.0) + + +class Unary(OneFlowOpConverter): + """A helper class for unary op converters""" + + name = "" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 1, "Unary math op {} takes 1 input, {} given".format( + cls.name, len(inputs) + ) + return get_relay_op(cls.name)(*inputs) + + +class Absolute(Unary): + """Operator converter for Absolute.""" + + name = "abs" + + +class AddN(OneFlowOpConverter): + """Operator converter for Add_n""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) > 0, "add_n take >=1 inputs, but 0 given." + + res = inputs[0] + for each in inputs[1:]: + res = _op.add(res, each) + return res + + +class ScalarAdd(OneFlowOpConverter): + """Operator convert for Add_scalar""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 1, "add_scalar take == 1 inputs, but {} given.".format(len(inputs)) + + if attrs.get("has_int_operand", True): + res = inputs[0] + _expr.const(attrs["int_operand"]) + elif attrs.get("has_float_operand", True): + res = inputs[0] + _expr.const(attrs["float_operand"]) + else: + raise AttributeError( + "please check if has_int_operand or has_float_operand in your attrs" + ) + + return res + + +class ScalarMul(OneFlowOpConverter): + """Operator convert for Mul_scalar""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 1, "add_scalar take == 1 inputs, but {} given.".format(len(inputs)) + + if attrs.get("has_int_operand", True): + res = inputs[0] * _expr.const(attrs["int_operand"], dtype="float32") + elif attrs.get("has_float_operand", True): + res = inputs[0] * _expr.const(attrs["float_operand"]) + else: + raise AttributeError( + "please check if has_int_operand or has_float_operand in your attrs" + ) + + return res + + +class ScalarPow(OneFlowOpConverter): + """Operator convert for Pow_scalar""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + exponent = attrs.get("exponent", 1.0) + exponent = _expr.const(exponent, dtype="float32") + return _op.power(inputs[0], exponent) + + +class MaxPool2d(Pool): + """Operator converter for MaxPool""" + + name = "max_pool2d" + + +class AveragePool2d(Pool): + """Operator converter for AveragePool.""" + + name = "avg_pool2d" + + +class Affine(OneFlowOpConverter): + """Operator converter for Affine transformation.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = _expr.const(attrs.get("alpha", 1.0)) + beta = _expr.const(attrs.get("beta", 0.0)) + return (alpha * inputs[0]) + beta + + +class Reshape(OneFlowOpConverter): + """Operator converter for Reshape.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.reshape(inputs[0], attrs["shape"]) + + +class Softmax(OneFlowOpConverter): + """Operator converter for Softmax.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axis = attrs.get("axis", 1) + ndim = len(infer_shape(inputs[0])) + if axis < 0: + axis += ndim + axes = list(range(axis, ndim)) + x = inputs[0] + m = _op.max(x, axes, keepdims=True) + e = _op.exp(x - m) + return e / _op.sum(e, axes, keepdims=True) + + +class LogSoftmax(OneFlowOpConverter): + """Operator converter for LogSoftmax.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axis = attrs.get("axis", 1) + ndim = len(infer_shape(inputs[0])) + if axis < 0: + axis += ndim + axes = list(range(axis, ndim)) + x = inputs[0] + m = _op.max(x, axes, keepdims=True) + e = _op.exp(x - m) + s = _op.sum(e, axes, keepdims=True) + return x - m - _op.log(s) + + +class Dropout(OneFlowOpConverter): + """Operator converter for Dropout.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + out = AttrCvt("dropout", {"ratio": "rate"}, ignores=["is_test"]) + return out + + +class ThresholdedRelu(OneFlowOpConverter): + """Operator converter for ThresholdedRelu.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = float(attrs.get("alpha", 1.0)) + alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha)) + mask = _op.greater(inputs[0], alpha_tensor).astype("float32") + return inputs[0] * mask + + +class Elu(OneFlowOpConverter): + """Operator converter for Elu""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = float(attrs.get("alpha", 1.0)) + return _expr.const(-alpha) * _op.nn.relu( + _expr.const(1.0) - _op.exp(inputs[0]) + ) + _op.nn.relu(inputs[0]) + + +class PReLU(OneFlowOpConverter): + """Operator converter for PReLU""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 2, "PReLU need 2 inputs, but {} given".format(len(inputs)) + for i in inputs: + if "-input_" in str(i): + prelu_a = i + else: + prelu_b = i + + input_shape = shape_of(prelu_a) + alpha = _op.broadcast_to_like(prelu_b, prelu_a) + alpha = _op.reshape(alpha, [-1]) + + output = _op.nn.prelu(_op.reshape(prelu_a, [-1]), alpha, axis=0) + out = _op.reshape(output, input_shape) + return out + + +class Selu(OneFlowOpConverter): + """Operator converter for Selu""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = float(attrs.get("alpha", 1.67326319217681884765625)) + gamma = float(attrs.get("gamma", 1.05070102214813232421875)) + return _expr.const(gamma) * ( + _expr.const(-alpha) * _op.nn.relu(_expr.const(1.0) - _op.exp(inputs[0])) + + _op.nn.relu(inputs[0]) + ) + + +class Silu(OneFlowOpConverter): + """Operator converter for Silu""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + a = inputs[0] + b = _op.sigmoid(inputs[0]) + return _op.multiply(a, b) + + +class Gelu(OneFlowOpConverter): + """Operator converter for Gelu""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + data = inputs[0] + return data * ( + _expr.const(0.5) + _op.erf(data * _expr.const(0.5**0.5)) * _expr.const(0.5) + ) + + +class HardTanh(OneFlowOpConverter): + """Operator converter for HardTanh""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + tanh_min = attrs.get("min_val", 0.0) + tanh_max = attrs.get("max_val", 0.0) + return _op.tensor.clip(inputs[0], tanh_min, tanh_max) + + +class Softplus(OneFlowOpConverter): + """Operator converter for Softplus""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + data = inputs[0] + data_dtype = infer_type(data).checked_type.dtype + data = _op.exp(data) + _expr.const(1, dtype=data_dtype) + return _op.log(data) + + +class Softsign(OneFlowOpConverter): + """Operator converter for Softsign""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return inputs[0] / (_expr.const(1.0) + Absolute.get_converter()(inputs, attrs, params)) + + +class Concat(OneFlowOpConverter): + """Operator converter for Concat""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + attrs.pop("max_dim_size") + inputs = _dtype_shape_promotion(inputs) + return _op.concatenate(inputs, axis=attrs["axis"]) + + +class Clip(OneFlowOpConverter): + """Operator converter for Clip""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + attr = {} + dtype = infer_type(inputs[0]) + + if "float" in str(dtype): + attr["a_min"] = attrs["floating_min"] + attr["a_max"] = attrs["floating_max"] + elif "int" in str(dtype): + attr["a_min"] = attrs["integral_min"] + attr["a_max"] = attrs["integral_max"] + else: + attr["a_min"] = -np.inf + attr["a_max"] = np.inf + + out = AttrCvt("clip")(inputs, attr, params) + return out + + +class Slice(OneFlowOpConverter): + """Operator converter for Slice""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + starts = list(attrs["start"]) + ends = list(attrs["stop"]) + steps = list(attrs["step"]) + return _op.strided_slice(inputs[0], starts, ends, steps) + + +class Split(OneFlowOpConverter): + """Operator converter for Split""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + splits = attrs.get("split", None) + if splits is not None: + indices = [] + attrs["indices_or_sections"] = [] + index = 0 + for i in splits[:-1]: + index += i + indices.append(index) + output = _op.split(inputs[0], indices, attrs.get("axis", 0)) + # If the output of split is a single value, unpack if from the TupleWrapper + if len(output) == 1: + output = output[0] + return output + + +class Scatter(OneFlowOpConverter): + """Operator converter for Scatter""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axis = attrs.get("axis", 0) + return _op.scatter(inputs[0], inputs[1], inputs[2], axis) + + +class Unsqueeze(OneFlowOpConverter): + """Operator converter for Unsqueeze""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axes = sorted(attrs["axes"]) + for axis in axes: + inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1) + return inputs[0] + + +class Sign(OneFlowOpConverter): + """Operator converter for Sign""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.sign(inputs[0]) + + +class Reciprocal(OneFlowOpConverter): + """Operator converter for Reciprocal""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + dtype = infer_type(inputs[0]).checked_type.dtype + return _expr.const(1.0, dtype=dtype) / inputs[0] + + +class Erf(OneFlowOpConverter): + """Operator converter for Erf""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _op.erf(inputs[0]) + + +class Erfc(OneFlowOpConverter): + """Operator converter for Erfs""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + return _expr.const(1.0) - _op.erf(inputs[0]) + + +class HardSigmoid(OneFlowOpConverter): + """Operator converter for HardSigmoid""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + alpha = attrs.get("alpha", 0.2) + beta = attrs.get("beta", 0.5) + transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta) + attr = {"a_min": 0, "a_max": 1} + return AttrCvt("clip")([transformX], attr) + + +class OneHot(OneFlowOpConverter): + """Operator converter for OneHot""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + # Extract relay one_hot inputs. + indices, depth, values = inputs + ndim = len(infer_shape(indices)) + # Split onnx on off values into two separate expressions. + off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1)) + # Extract the datatype of the output from on_value. + dtype = infer_type(on_value).checked_type.dtype + ind_dtype = infer_type(indices).checked_type.dtype + # Normalize the indices to a positive range + indices = _op.where( + indices < _op.const(0, ind_dtype), indices + _op.cast(depth, ind_dtype), indices + ) + # set default value when axis is not set in the model + axis = attrs.get("axis", -1) + if axis < 0: + axis += ndim + 1 + + return _op.one_hot(indices, on_value, off_value, depth, axis, dtype=dtype) + + +class Where(OneFlowOpConverter): + """Operator converter for Where""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + condition_rank = len(infer_shape(inputs[0])) + x_rank = len(infer_shape(inputs[1])) + y_rank = len(infer_shape(inputs[2])) + ranks = [condition_rank, x_rank, y_rank] + + # If one rank is longer than others, then we can broadcast + # to that shape. + max_rank = max(ranks) + max_rank_idxs = [i for i, x in enumerate(ranks) if x == max_rank] + broadcast_shape = shape_of(inputs[max_rank_idxs[0]]) + # If two or more inputs have the same rank, compute the broadcast + # shape by taking the maximum value of each dimensions. + if len(max_rank_idxs) > 1: + for idx in max_rank_idxs: + broadcast_shape = _op.maximum(broadcast_shape, shape_of(inputs[idx])) + + broadcast_shape = fold_constant(broadcast_shape) + + condition = _op.broadcast_to(inputs[0], broadcast_shape) + x = _op.broadcast_to(inputs[1], broadcast_shape) + y = _op.broadcast_to(inputs[2], broadcast_shape) + return _op.where(condition, x, y) + + +class Constant(OneFlowOpConverter): + """Operator converter for Constant""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + is_float = attrs.get("is_floating_value", True) + shape = attrs.get("shape", (1,)) + if is_float: + dtype = "float32" + value = attrs.pop("floating_value") + else: + dtype = "int8" + value = attrs.pop("integer_value") + np_array = np.zeros(shape) + np_array.fill(value) + value = _expr.const(np_array, dtype) + return value + + +class Range(OneFlowOpConverter): + """Operator converter for Range""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + if len(inputs) != 0: + raise ValueError("Expect no inputs but get {}".format(len(inputs))) + start = attrs.get("start", 0.0) + limit = attrs.get("limit", 1.0) + delta = attrs.get("delta", 1.0) + return _op.arange( + _expr.const(start, dtype="float32"), + _expr.const(limit, dtype="float32"), + _expr.const(delta, dtype="float32"), + ) + + +class Cast(OneFlowOpConverter): + """Operator converter for Cast""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + attrs["dtype"] = infer_type(inputs[0]).checked_type.dtype + return AttrCvt(op_name="cast")(inputs, attrs) + + +def get_convert_map(): + # supported oneflow2relay op + return { + # defs/math + "bias_add": Add.get_converter(), + "scalar_add": ScalarAdd.get_converter(), + "scalar_mul": ScalarMul.get_converter(), + "scalar_pow": ScalarPow.get_converter(), + "reduce_sum": ReduceSum.get_converter(), + "reduce_max": ReduceMax.get_converter(), + "reduce_min": ReduceMin.get_converter(), + "reduce_mean": ReduceMean.get_converter(), + "broadcast_add": BroadcastAdd.get_converter(), + "broadcast_mul": BroadcastMul.get_converter(), + "broadcast_sub": BroadcastSub.get_converter(), + "broadcast_div": BroadcastDiv.get_converter(), + "broadcast_greater": Greater.get_converter(), + "log": Renamer("log"), + "log1p": Log1p.get_converter(), + "acos": Renamer("acos"), + "acosh": Renamer("acosh"), + "asin": Renamer("asin"), + "asinh": Renamer("asinh"), + "atan": Renamer("atan"), + "atanh": Renamer("atanh"), + "cos": Renamer("cos"), + "cosh": Renamer("cosh"), + "sin": Renamer("sin"), + "sinh": Renamer("sinh"), + "tan": Renamer("tan"), + "tanh": Renamer("tanh"), + "pow": Renamer("power"), + "exp": Renamer("exp"), + "expm1": Expm1.get_converter(), + "floor": Renamer("floor"), + "ceil": Renamer("ceil"), + "round": Renamer("round"), + "add_n": AddN.get_converter(), + "sqrt": Renamer("sqrt"), + "rsqrt": Renamer("rsqrt"), + "square": Square.get_converter(), + "sign": Sign.get_converter(), + "erf": Erf.get_converter(), + "erfc": Erfc.get_converter(), + "reciprocal_no_nan": Reciprocal.get_converter(), + # defs/activation + "softmax": Softmax.get_converter(), + "softsign": Softsign.get_converter(), + "hardtanh": HardTanh.get_converter(), + "relu": Renamer("relu"), + "leaky_relu": Renamer("leaky_relu"), + "prelu": PReLU.get_converter(), + "selu": Selu.get_converter(), + "silu": Silu.get_converter(), + "gelu": Gelu.get_converter(), + # defs/nn + "conv2d": Conv2d.get_converter(), + "deconv2d": ConvTranspose2d.get_converter(), + "maxpool_2d": MaxPool2d.get_converter(), + "avgpool_2d": AveragePool2d.get_converter(), + "adaptive_avg_pool2d": AdaptiveAvgPool2d.get_converter(), + "adaptive_max_pool2d": AdaptiveMaxPool2d.get_converter(), + "dropout": Dropout.get_converter(), + "normalization": BatchNorm.get_converter(), + "upsample_nearest_2d": UpsampleNearest.get_converter(), + "upsample_bilinear_2d": UpsampleBiLinear.get_converter(), + # defs/tensor + "matmul": MatMul.get_converter(), + "concat": Concat.get_converter(), + "clip_by_scalar": Clip.get_converter(), + "slice": Slice.get_converter(), + "expand": Expand.get_converter(), + "transpose": AttrCvt("transpose", {"perm": "axes"}), + "expand_dims": ExpandDim.get_converter(), + "range": Range.get_converter(), + "cast": Cast.get_converter(), + # defs/others + "reshape": Reshape.get_converter(), + "constant": Constant.get_converter(), + # "where": Where.get_converter(), + "flatten": Flatten.get_converter(), + "sigmoid": Renamer("sigmoid"), + "sigmoid_v2": Renamer("sigmoid"), + "hardsigmoid": HardSigmoid.get_converter(), + "squeeze": AttrCvt("squeeze", {"axes": "axis"}), + "unsqueeze": Unsqueeze.get_converter(), + } + + +class oneflow_input(object): + """ + Dual purpose list or dictionary access object + """ + + def __init__(self): + self.input_keys = [] + self.input_dict = {} + self.n = 0 + + def __getitem__(self, item): + if isinstance(item, int): + if item > (len(self.input_keys) - 1): + return None + return self.input_dict[self.input_keys[item]] + if isinstance(item, str): + if item not in self.input_keys: + return None + return self.input_dict[item] + if isinstance(item, slice): + keys = self.input_keys[item] + return [self.input_dict[key] for key in keys] + + raise ValueError("Only integer, string, and slice accesses allowed.") + + def __setitem__(self, item, value): + if isinstance(item, int): + self.input_dict[self.input_keys[item]] = value + elif isinstance(item, str): + self.input_keys.append(item) + self.input_dict[item] = value + else: + raise ValueError("Only integer and string indexed writes allowed.") + + def keys(self): + return self.input_keys + + def __len__(self): + return len(self.input_keys) + + def __iter__(self): + self.n = 0 + return self + + def __next__(self): + if self.n < len(self.input_keys): + output = self.input_dict[self.input_keys[self.n]] + self.n += 1 + return output + + raise StopIteration + + +def deal_with_input_convert( + node_input, node_input_shape, node_input_dtype, node_path, _nodes, _input_path_2_name +): + """deal with input convert in oneflow.""" + if node_input not in _nodes: + if ( + node_path not in _input_path_2_name + or "-input_" in node_input + or "FreeEagerTensor" in node_input + ): + _nodes[node_input] = new_var( + node_input, + shape=node_input_shape, + dtype=node_input_dtype, + ) + else: + names = _input_path_2_name[node_path] + node_replace = None + for k in names: + if k in _nodes: + node_replace = k + if node_replace is not None: + op_replace = copy.deepcopy(_nodes[node_replace]) + _nodes[node_input] = op_replace + else: + print("{} will not be in _nodes".format(node_input)) + + +def deal_parameter_convert( + node_input_paths, model_dir_path, _input_path_2_name, _model_array, _params, _nodes +): + """deal with parameter(weight) convert in oneflow.""" + for node_input_path in node_input_paths: + node_path = os.path.join(model_dir_path, node_input_path.replace("m.", "")) + node_input_name = node_input_path.split("/")[0] + _input_path_2_name[node_path] = node_input_name + for param_name in _model_array: + node_p = _model_array[param_name] + if node_path == node_p["path"]: + node_array = node_p["params"] + _params[node_input_name] = node_array + _nodes[node_input_name] = new_var( + node_input_name, shape=node_array.shape, dtype=str(node_array.dtype) + ) + break + + +class OneflowGraph(object): + """ + A helper class for handling Relay expression + + Parameters + ---------- + shape : dict of str to tuple, optional + The input shape to the graph + dtype : dict of str to str + The input types to the graph + + node name: + 1. param: m.layer4.1.bn1.weight / ... + 2. buffer: m.layer4.1.bn1.running_mean / ... + 3. node inputs: m.layer4.1.bn1-input_0 + 4. node outputs: m.layer4.1.bn1-output_0 + """ + + def __init__(self, shape, dtype, nodes, model_dir_path): + self._nodes = {} + self._params = {} + self._inputs = {} + self._num_input = 0 + self._num_param = 0 + self._input_names = [] + self._model_array = {} + self._input_path_2_name = {} + self._output_path_2_name = {} + self._init_variable_node = [] + self._shape = shape + self._dtype = dtype + self._identity_list = [] + self._sort_inputs = {} + + import oneflow + + model = oneflow.load(model_dir_path) + # model_array: keys: layer_name,values: dict('path', 'params') + for layer_name in model: + layer = model[layer_name] + layer_node = {} + layer_node["path"] = os.path.join(model_dir_path, layer_name, "out") # get path + if "System-Train" in layer_name: + continue + node_name = "m." + layer_name + shape = self._shape[node_name] + dtype = self._dtype[node_name] + array = layer.detach().cpu().numpy() + layer_node["params"] = array.reshape(shape) + self._model_array[layer_name] = layer_node + + for node_name in nodes: + node = nodes[node_name] + if is_user_op(node): + for input_name in node.user_conf.input: + node_input_paths = getattr(node.user_conf.input[input_name], "s") + deal_parameter_convert( + node_input_paths, + model_dir_path, + self._input_path_2_name, + self._model_array, + self._params, + self._nodes, + ) + for output_name in node.user_conf.output: + node_output_paths = getattr(node.user_conf.output[output_name], "s") + for node_output_path in node_output_paths: + node_path = os.path.join(model_dir_path, node_output_path.replace("m.", "")) + node_output_name = node_output_path.split("/")[0] + self._output_path_2_name[node_path] = node_output_name + elif is_output_op(node): + node_output_path = getattr(node.output_conf, "in") + output_path = os.path.join( + model_dir_path, getattr(node.output_conf, "in").replace("m.", "") + ) + self._output_path_2_name[output_path] = node_name + elif is_param_op(node): + if "FreeEagerTensor" in node.name: + shape = tuple(node.variable_conf.shape.dim) + dtype = FLOW_2_STR_DTYPE[node.variable_conf.data_type] + self._shape[node.name] = shape + self._dtype[node.name] = dtype + self._init_variable_node.append(node.name) + if self._init_variable_node != []: + print("{} should be defined by user".format(self._init_variable_node)) + + def _parse_input(self, node, model_dir_path): + for input_name in node.user_conf.input: + node_input_paths = getattr(node.user_conf.input[input_name], "s") + for i in node_input_paths: + node_input = i.split("/")[0] + node_input_shape = self._shape[node_input] + node_input_dtype = self._dtype[node_input] + node_path = os.path.join(model_dir_path, i.replace("m.", "")) + deal_with_input_convert( + node_input, + node_input_shape, + node_input_dtype, + node_path, + self._nodes, + self._input_path_2_name, + ) + + def _parse_output(self, op_name, outputs, cnt_init=0): + """ + o: m.classifier.1-output_xxx + new_o: m.classifier.1-conv2d_0 + "_"+new_o is in self._shape + """ + for o in outputs: + if "-output_" not in o: + new_o = o.replace("-" + op_name, "-output") + new_o = new_o.replace("_" + new_o.split("_")[-1], "_0") + self._shape[o] = self._shape["_" + new_o] + self._dtype[o] = self._dtype["_" + new_o] + elif len(outputs) > 1: + outputs.remove(o) + if op_name.lower() == "dropout": + if len(outputs) == 1: + return outputs + outputs = outputs[:-1] + elif op_name.lower() == "constant": + outputs = [self._init_variable_node[cnt_init]] + + if len(outputs) > 1: + outputs = list(set(outputs)) + + return outputs + + def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=None): + """ + Parameters + ---------- + nodes : dict, keys: node.name, value: node + contain the graph + model_dir_path: str + The path of parameter + freeze_params: bool + If freeze_params is True, + the computational graph input is the input of the first layer of the network, + which cannot be specified by the user, e.g. + Default input is: %v_ResNetGraph_0-input_0: Tensor[(1, 3, 224, 224), float32] + User-defined input is: %_0-input_0: Tensor[(1, 3, 640, 480), float32] + If freeze_params is on, then conv1-in will be the graph input, not Input_0 + user_input: dict + User-defined input information for the graph + { + node1_name: + { + 'name': node1_name, # str, like "%v_ResNetGraph_0-input_0" + 'shape': node1_shape, # tuple + 'dtype': node1_dtype # str, like "float32" + } + ... + } + We recommend that users specify the input by specifying the job function, + rather than by this function + + Returns + ------- + mod : tvm.IRModule + The returned relay module + params : dict + A dict of name: tvm.nd.array pairs, used as pretrained weights + """ + # step 1: get the graph input + if not freeze_params: + for node_init_name in user_input: + if "-input_" not in node_init_name: + raise KeyError( + "user_input['name'] should contain '-input_' " + + "to let program know that this is input node" + ) + self._nodes[node_init_name] = new_var( + node_init_name, + shape=user_input[node_init_name]["shape"], + dtype=user_input[node_init_name]["dtype"], + ) + self._inputs[node_init_name] = self._nodes[node_init_name] + + # step 2: find out if unsupported ops are used + convert_map = get_convert_map() + unsupported_ops = set() + for node_name in nodes: + node = nodes[node_name] + if is_user_op(node): + # op names, not the layer names + op_name = node.user_conf.op_type_name + if ( + op_name not in convert_map + and "constant" not in op_name + and op_name not in self._identity_list + ): + unsupported_ops.add(op_name) + # find out the unsupported op + if unsupported_ops: + msg = "The following operators are not supported for frontend OneFlow: " + msg += ", ".join(unsupported_ops) + raise tvm.error.OpNotImplemented(msg) + + # step 3: convert op + for node_name in nodes: + node = nodes[node_name] + if is_user_op(node): + # If there is a user-defined node, skip the following steps + if node_name in self._inputs: + continue + + op_name = node.user_conf.op_type_name + op_attr = parse_attr(node.user_conf.attr) + + self._parse_input(node, model_dir_path=model_dir_path) + + node_inputs = oneflow_input() + for input_name in node.user_conf.input: + node_input_paths = getattr(node.user_conf.input[input_name], "s") + for i in node_input_paths: + node_input = i.split("/")[0] + node_inputs[node_input] = self._nodes[node_input] + + node_outputs = [] + for output_name in node.user_conf.output: + node_output_paths = getattr(node.user_conf.output[output_name], "s") + for i in node_output_paths: + node_output_path = os.path.join(model_dir_path, i.replace("m.", "")) + if node_output_path in self._input_path_2_name: + node_outputs.append(self._input_path_2_name[node_output_path]) + elif node_output_path in self._output_path_2_name: + node_outputs.append(self._output_path_2_name[node_output_path]) + node_outputs = self._parse_output(op_name, node_outputs) + + # convert + op = self._convert_operator(op_name, node_inputs, op_attr) + + if not isinstance(op, _expr.TupleWrapper): + outputs_num = 1 + else: + outputs_num = len(op) + + assert ( + len(node_outputs) == outputs_num + ), "Number of output mismatch {} vs {} in {}.".format( + len(node_outputs), outputs_num, op_name + ) + + if outputs_num == 1: + op = fold_constant(op) + else: + op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op)) + + op_temp = [] + op_temp.append(op) + for i, _ in enumerate(node_outputs): + if isinstance(node_outputs[i], list): + for k in node_outputs[i]: + self._nodes[k] = op_temp[i] + else: + self._nodes[node_outputs[i]] = op_temp[i] + + # step 4: get the outputs + outputs = [] + for node_name in nodes: + node = nodes[node_name] + if is_output_op(node): + node_name_v2 = getattr(node.output_conf, "in").split("/")[0] + if node_name in self._nodes: + outputs.append(self._nodes[node_name]) + elif node_name_v2 in self._nodes: + outputs.append(self._nodes[node_name_v2]) + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + + # step 5: get the relay IR + free_vars = analysis.free_vars(outputs) + + nodes = {v: k for k, v in self._nodes.items()} + free_vars = [nodes[var] for var in free_vars] + + # step 6: make sure the '-input_0' is the first in self._inputs + for free_var in free_vars: + if free_var not in self._inputs: + self._inputs[free_var] = self._nodes[free_var] + + input_names = list(self._inputs.keys()) + for i, _ in enumerate(input_names): + if i != 0 and "-input_0" in input_names[i]: + str_buffer = copy.deepcopy(input_names[i]) + del input_names[i] + input_names.insert(0, str_buffer) + break + + for input_name in input_names: + if input_name in self._inputs: + self._sort_inputs[input_name] = self._inputs[input_name] + else: + raise IndexError("{} is not in self._inputs".format(input_name)) + + # step 7: create a function from our output expression and all input variables. + func = _function.Function([v for _, v in self._sort_inputs.items()], outputs) + + return IRModule.from_expr(func), self._params + + def _convert_operator(self, op_name, node_inputs, op_attr): + """ + Parameters + ---------- + op_name : str + Operator name, such as conv2d、relu + node_inputs : list of tvm.relay.function.Function + List of inputs. + op_attr : dict + Dict of operator attributes + + Returns + ------- + sym : tvm.relay.function.Function + Converted relay function + """ + convert_map = get_convert_map() + if op_name in self._identity_list: + sym = get_relay_op(op_name)(*node_inputs, **op_attr) + elif op_name in convert_map: + sym = convert_map[op_name](node_inputs, op_attr, self._params) + else: + raise NotImplementedError("Operator {} not implemented.".format(op_name)) + + return sym + + +def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): + """ + see OneflowGraph.from_oneflow + """ + try: + import oneflow as flow + except ImportError: + raise ImportError("please check that OneFlow is installed") + + if not freeze_params and user_input is None: + raise ValueError("if you want to specify graph input, please give the 'user_input'") + if freeze_params and user_input is not None: + warnings.warn("'user_input' will not work, please check the 'freeze_params'") + + # get info of nodes + shape = {} + dtype = {} + graph_str = repr(graph) + size_where = 2 + if "cuda" in graph_str: + size_where = 3 + + p_size = re.compile(r"size=\(.*?\)", re.S) + p_type = re.compile(r"dtype=.*?\)", re.S) + types = ["INPUT", "PARAMETER", "BUFFER", "OUTPUT"] + for t in types: + data = re.finditer(t + ":.*", graph_str) + for i in data: + attrs = i.group().split(":") + size_str = re.findall(p_size, attrs[size_where]) + type_str = re.findall(p_type, attrs[size_where]) + assert size_str != [], "size should not be None, please check your repr(graph)" + + size_attr = size_str[0].replace("size=", "") + if size_attr[-2] == ",": + size_attr = size_attr.replace(",", "") + data_size = tuple(map(int, size_attr[1:-1].split(", "))) + node_name = attrs[1] + shape[node_name] = data_size + dtype[node_name] = "float32" + + if type_str != []: + type_attr = type_str[0].replace("dtype=", "").replace(")", "") + if type_attr[-1] == ",": + type_attr = type_attr.replace(",", "") + dtype[node_name] = type_attr.replace("oneflow.", "") + + # get graph proto, if you don't _compile the graph, the _graph_proto will be None + graph_input = re.search(r"INPUT:.*", graph_str).group().split(":") + shape_input = tuple( + map( + int, + re.findall(p_size, graph_input[size_where])[0].replace("size=", "")[1:-1].split(", "), + ) + ) + if not graph._is_compiled: + graph._compile(flow.rand(shape_input)) + graph_proto = graph._graph_proto + + # get all nodes + nodes = {} + for op in graph_proto.net.op: + nodes[op.name] = op + + g = OneflowGraph(shape, dtype, nodes, model_dir_path) + + # Use the graph proto as a scope so that ops can access other nodes if needed. + mod, params = g.from_oneflow( + nodes=nodes, + model_dir_path=model_dir_path, + freeze_params=freeze_params, + user_input=user_input, + ) + + return mod, params diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py new file mode 100644 index 0000000000000..d144cdad2bc52 --- /dev/null +++ b/tests/python/frontend/oneflow/test_forward.py @@ -0,0 +1,723 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name +# pylint: disable=arguments-differ, unused-argument, unused-import +"""Unit tests for various models and operators""" +import os +import sys + +import numpy as np +import pytest +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relay +from tvm.contrib import graph_executor + +import oneflow as flow + +MODEL_HOME = "test_model" + + +def mkdir(path): + # init + path = path.strip() + path = path.rstrip("\\") + + if not os.path.exists(path): + os.makedirs(path) + else: + print("{} is already here".format(path)) + + +def rmdir(path): + for root, dirs, files in os.walk(path, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + os.removedirs(path) + + +def assert_shape(out1, out2): + if out1.shape != out2.shape: + msg = "Output shapes {} and {} don't match" + raise AssertionError(msg.format(out1.shape, out2.shape)) + + +class OneFlowGraph(flow.nn.Graph): + def __init__(self, module): + super().__init__() + self.m = module + + def build(self, x): + out = self.m(x) + return out + + +class OneFlowGraph_v2(flow.nn.Graph): + def __init__(self, module): + super().__init__() + self.m = module + + def build(self, x1, x2, x3): + out = self.m(x1, x2, x3) + return out + + +def get_oneflow_output(model, inputs): + flow_output = model(inputs) + return flow_output.numpy() + + +def get_oneflow_concat_output(model, input1, input2, input3): + flow_output = model(input1, input2, input3).numpy() + return flow_output + + +def get_tvm_output(graph, model_path, inputs: flow.tensor, target="llvm", dtype="float32"): + inputs_numpy = inputs.numpy() + if target == "llvm": + device = tvm.cpu(0) + elif target == "cuda": + device = tvm.cuda(0) + + mod, params = relay.frontend.from_oneflow(graph, model_path) + with tvm.transform.PassContext(opt_level=10): + intrp = relay.build_module.create_executor("graph", mod, device, target) + tvm_output = intrp.evaluate()(tvm.nd.array(inputs_numpy.astype(dtype)), **params).numpy() + return tvm_output + + +def get_tvm_concat_output( + graph, + model_path, + input1: flow.tensor, + input2: flow.tensor, + input3: flow.tensor, + target="llvm", + dtype="float32", +): + input1_numpy = input1.numpy() + input2_numpy = input2.numpy() + input3_numpy = input3.numpy() + if target == "llvm": + device = tvm.cpu(0) + elif target == "cuda": + device = tvm.cuda(0) + + mod, params = relay.frontend.from_oneflow(graph, model_path) + with tvm.transform.PassContext(opt_level=10): + intrp = relay.build_module.create_executor("graph", mod, device, target) + tvm_output = intrp.evaluate()( + tvm.nd.array(input1_numpy.astype(dtype)), + tvm.nd.array(input2_numpy.astype(dtype)), + tvm.nd.array(input3_numpy.astype(dtype)), + **params, + ).numpy() + return tvm_output + + +def verify_conv( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(1, 3, 224, 224), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_pool( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(1, 3, 224, 224), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_normalization( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(1, 3, 224, 224), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + # write params + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_upsample( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(1, 3, 50, 50), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_convtran( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(1, 3, 50, 50), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_activation( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(10, 10), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_math( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(100, 1), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +def verify_concat( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs1=flow.tensor(np.random.randn(2, 5, 5, 4), dtype=flow.float32), + inputs2=flow.tensor(np.random.randn(2, 5, 5, 2), dtype=flow.float32), + inputs3=flow.tensor(np.random.randn(2, 5, 5, 3), dtype=flow.float32), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs1 = inputs1.to(device) + inputs2 = inputs2.to(device) + inputs3 = inputs3.to(device) + + graph = OneFlowGraph_v2(model) + graph._compile(inputs1, inputs2, inputs3) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_concat_output(graph, inputs1, inputs2, inputs3) + out_tvm = get_tvm_concat_output(graph, MODEL_HOME, inputs1, inputs2, inputs3, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +# defs/nn +@tvm.testing.uses_gpu +def test_conv2d(): + class Conv2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.conv = flow.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = self.conv(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model = Conv2dModel() + model.eval() + + for device in ["llvm"]: + verify_conv(model, device=device) + + +@tvm.testing.uses_gpu +def test_pool2d(): + class MaxPool2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.pool = flow.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.pool(x) + return x + + class AvgPool2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.pool = flow.nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.pool(x) + return x + + class AdaptiveAvgPool2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.pool = flow.nn.AdaptiveAvgPool2d((None, 7)) + + def forward(self, x): + x = self.pool(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model1 = MaxPool2dModel().eval() + model2 = AvgPool2dModel().eval() + model3 = AdaptiveAvgPool2dModel().eval() + + for device in ["llvm"]: + verify_pool(model1, device=device) + verify_pool(model2, device=device) + verify_pool(model3, device=device) + + +@tvm.testing.uses_gpu +def test_normalization(): + class BatchNorm2dModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.normalization = flow.nn.BatchNorm2d(3) + + def forward(self, x): + x = self.normalization(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model = BatchNorm2dModel().eval() + + for device in ["llvm"]: + verify_normalization(model, device=device) + + +@tvm.testing.uses_gpu +def test_upsample(): + class UpsampleModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.upsample = flow.nn.Upsample(scale_factor=2.0, mode="nearest") + + def forward(self, x): + x = self.upsample(x) + return x + + class UpsampleBiliModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.upsample = flow.nn.UpsamplingBilinear2d(scale_factor=2.0) + + def forward(self, x): + x = self.upsample(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model1 = UpsampleModel().eval() + model2 = UpsampleBiliModel().eval() + + for device in ["llvm"]: + verify_upsample(model1, device=device) + verify_upsample(model2, device=device) + + +@tvm.testing.uses_gpu +def test_convtran(): + class ConvTranModel(flow.nn.Module): + def __init__(self): + super().__init__() + self.convtran = flow.nn.ConvTranspose2d(3, 4, (3, 5), stride=(2, 1), padding=(4, 2)) + + def forward(self, x): + x = self.convtran(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model = ConvTranModel().eval() + + for device in ["llvm"]: + verify_convtran(model, device=device) + + +@tvm.testing.uses_gpu +def test_activation(): + class Softmax(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Softmax() + + def forward(self, x): + x = self.active(x) + return x + + class Softplus(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Softplus() + + def forward(self, x): + x = self.active(x) + return x + + class Softsign(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Softsign() + + def forward(self, x): + x = self.active(x) + return x + + class Tanh(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Tanh() + + def forward(self, x): + x = self.active(x) + return x + + class ReLU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.ReLU() + + def forward(self, x): + x = self.active(x) + return x + + class ReLU6(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.ReLU6() + + def forward(self, x): + x = self.active(x) + return x + + class PReLU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.PReLU() + + def forward(self, x): + x = self.active(x) + return x + + class SELU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.SELU() + + def forward(self, x): + x = self.active(x) + return x + + class SiLU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.SiLU() + + def forward(self, x): + x = self.active(x) + return x + + class LeakyReLU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.LeakyReLU(0.1) + + def forward(self, x): + x = self.active(x) + return x + + class GELU(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.GELU() + + def forward(self, x): + x = self.active(x) + return x + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + model1 = Softmax().eval() + model2 = Softplus().eval() + model3 = Softsign().eval() + model4 = Tanh().eval() + model5 = ReLU().eval() + model6 = ReLU6().eval() + model7 = PReLU().eval() + model8 = SELU().eval() + model9 = SiLU().eval() + model10 = LeakyReLU().eval() + model11 = GELU().eval() + + for device in ["llvm"]: + verify_activation(model1, device=device) + # verify_activation(model2, device=device) # NO PASS + verify_activation(model3, device=device) + verify_activation(model4, device=device) + verify_activation(model5, device=device) + verify_activation(model6, device=device) + verify_activation(model7, device=device) + verify_activation(model8, device=device) + verify_activation(model9, device=device) + verify_activation(model10, device=device) + verify_activation(model11, device=device) + + +@tvm.testing.uses_gpu +def test_math(): + class Sigmoid(flow.nn.Module): + def forward(self, x): + return flow.sigmoid(x) + + class Sign(flow.nn.Module): + def forward(self, x): + return flow.sign(x) + + class Reciprocal(flow.nn.Module): + def forward(self, x): + return flow.reciprocal(x) + + class Pow(flow.nn.Module): + def forward(self, x): + return flow.pow(x, 2.0) + + class Log(flow.nn.Module): + def forward(self, x): + return flow.log(x) + + class Log2(flow.nn.Module): + def forward(self, x): + return flow.log1p(x) + + class Exp(flow.nn.Module): + def forward(self, x): + return flow.exp(x) + + class Exp2(flow.nn.Module): + def forward(self, x): + return flow.expm1(x) + + model1 = Sigmoid().eval() + model2 = Sign().eval() + model3 = Log().eval() + model4 = Log2().eval() + model5 = Exp().eval() + model6 = Exp2().eval() + + for device in ["llvm"]: + verify_math(model1, device=device) + verify_math(model2, device=device) + verify_math(model3, device=device) + verify_math(model4, device=device) + verify_math(model5, device=device) + verify_math(model6, device=device) + + +@tvm.testing.uses_gpu +def test_slice(): + class Slice(flow.nn.Module): + def forward(self, x): + tup_list = [[None, None, None], [0, 5, 2], [0, 6, 3]] + out = flow.slice(x, slice_tup_list=tup_list) + return out + + model = Slice().eval() + + for device in ["llvm"]: + verify_math( + model, device=device, inputs=flow.tensor(np.random.randn(3, 6, 9).astype(np.float32)) + ) + + +@tvm.testing.uses_gpu +def test_concat(): + class Concat(flow.nn.Module): + def forward(self, x1, x2, x3): + out = flow.cat([x1, x2, x3], dim=-1) + return out + + model = Concat().eval() + + for device in ["llvm"]: + verify_concat(model, device=device) + + +if __name__ == "__main__": + test_conv2d() + test_pool2d() + test_normalization() + test_upsample() + test_convtran() + test_activation() + test_math() + test_slice() + test_concat() + rmdir("log") diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index bbcba37c6d012..2c7e34fac592f 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -58,3 +58,6 @@ run_pytest cython python-frontend-paddlepaddle tests/python/frontend/paddlepaddl echo "Running relay CoreML frontend test..." run_pytest cython python-frontend-coreml tests/python/frontend/coreml + +echo "Running relay OneFlow frontend test..." +run_pytest cython python-frontend-oneflow tests/python/frontend/oneflow