From d5c40802221d83f576425ec8112b290f5ca9d3c1 Mon Sep 17 00:00:00 2001 From: joddiy Date: Wed, 20 May 2020 04:34:54 +0800 Subject: [PATCH] refactor soonx backend --- python/singa/sonnx.py | 1630 +++++++++++++++++++---------------------- 1 file changed, 740 insertions(+), 890 deletions(-) diff --git a/python/singa/sonnx.py b/python/singa/sonnx.py index 69ec9867a..a4f411533 100755 --- a/python/singa/sonnx.py +++ b/python/singa/sonnx.py @@ -28,8 +28,9 @@ import warnings from . import singa_wrap as singa -from . import autograd +from . import autograd, layer from . import tensor +from . import module from singa import utils import collections @@ -829,7 +830,7 @@ def handle_special_ops(cls, op, X, W): @classmethod def _common_singa_tensor_to_onnx_node(cls, op, op_t): """ - get a onnx node from a singa operator, prepare its type, inputs and outputs + get a onnx node from singa operator, prepare its type, inputs and outputs Args: op: a given operator Args: @@ -963,14 +964,22 @@ def __init__(self, node): self.name = str(node.name) self.op_type = str(node.op_type) self.attrs = OnnxAttributes.from_onnx(node.attribute) - # there may some inputs which we regard as attribute, so we mark them there - self.consumed_inputs = list() + # inputs as attributes in singa + self.attr_inputs = {} + # inputs as weights in singa + self.weight_inputs = {} self.inputs = list(node.input) self.outputs = list(node.output) def getattr(self, key, default=None): return self.attrs[key] if key in self.attrs else default + def set_attr_inputs(self, key, name): + self.attr_inputs[key] = name + + def set_weight_inputs(self, key, name): + self.weight_inputs[key] = name + class OnnxAttributes(dict): """ @@ -989,1084 +998,867 @@ def from_onnx(args): class SingaBackend(Backend): # This number indicates the onnx operator set version - _known_opset_version = 11 + _opset_version = 11 + + _ir_version = 0x0000000000000006 # beceuase singa's operators are different from onnx. # we define a dict for the name projection _rename_operators = { - 'Relu': 'relu', - 'Softmax': 'SoftMax', - 'Sigmoid': 'sigmoid', - 'Add': 'add', - 'MatMul': 'matmul', - 'Conv': '_Conv2d', - 'MaxPool': '_Pooling2d', - 'AveragePool': '_Pooling2d', - 'BatchNormalization': 'batchnorm_2d', - 'Concat': 'Concat', - 'Flatten': 'Flatten', - 'Gemm': 'Gemm', - 'Reshape': 'Reshape', - 'Sum': 'sum', - 'Cos': 'cos', - 'Cosh': 'cosh', - 'Sin': 'sin', - 'Sinh': 'sinh', - 'Tan': 'tan', - 'Tanh': 'tanh', - 'Acos': 'acos', - 'Acosh': 'acosh', - 'Asin': 'asin', - 'Asinh': 'asinh', - 'Atan': 'atan', - 'Atanh': 'atanh', - 'Selu': 'SeLU', - 'Elu': 'Elu', - 'Equal': 'equal', - 'Less': 'less', - 'Sign': 'sign', - 'Div': 'div', - 'Sub': 'sub', - 'Sqrt': 'sqrt', - 'Log': 'log', - 'Greater': 'greater', - 'HardSigmoid': 'HardSigmoid', - 'Identity': 'identity', - 'Softplus': 'softplus', - 'Softsign': 'softsign', - 'Mean': 'mean', - 'Pow': 'pow', - 'Clip': 'Clip', - 'PRelu': 'prelu', - 'Mul': 'mul', - 'Transpose': 'Transpose', - 'Max': 'max', - 'Min': 'min', - 'Shape': 'shape', - 'And': '_and', - 'Or': '_or', - 'Xor': '_xor', - 'Not': '_not', - 'Neg': 'negative', - 'Reciprocal': 'reciprocal', - 'ConstantOfShape': 'ConstantOfShape', - 'Dropout': 'Dropout', - 'ReduceSum': 'ReduceSum', - 'ReduceMean': 'ReduceMean', - 'LeakyRelu': 'LeakyRelu', - 'GlobalAveragePool': 'GlobalAveragePool', - 'Squeeze': 'Squeeze', + # common op + 'Relu': 'ReLU', + 'Sigmoid': 'Sigmoid', + 'Add': 'Add', + 'MatMul': 'MatMul', + 'Sum': 'Sum', + 'Cos': 'Cos', + 'Cosh': 'Cosh', + 'Sin': 'Sin', + 'Sinh': 'Sinh', + 'Tan': 'Tan', + 'Tanh': 'Tanh', + 'Acos': 'Acos', + 'Acosh': 'Acosh', + 'Asin': 'Asin', + 'Asinh': 'Asinh', + 'Atan': 'Atan', + 'Atanh': 'Atanh', + 'Equal': 'Equal', + 'Less': 'Less', + 'Sign': 'Sign', + 'Div': 'Div', + 'Sub': 'Sub', + 'Sqrt': 'Sqrt', + 'Log': 'Log', + 'Greater': 'Greater', + 'Identity': 'Identity', + 'Softplus': 'Softplus', + 'Softsign': 'Softsign', + 'Mean': 'Mean', + 'Pow': 'Pow', + 'PRelu': 'Prelu', + 'Mul': 'Mul', + 'Max': 'Max', + 'Min': 'Min', + 'Shape': 'Shape', + 'And': 'And', + 'Or': 'Or', + 'Xor': 'Xor', + 'Not': 'Not', + 'Neg': 'Negative', + 'Reciprocal': 'Reciprocal', 'Unsqueeze': 'Unsqueeze', - 'Slice': 'Slice', + 'NonZero': 'Nonzero', 'Ceil': 'Ceil', - 'Split': 'Split', - 'Gather': 'Gather', - 'Tile': 'Tile', - 'NonZero': 'nonzero', + # special op 'Cast': 'Cast', - 'OneHot': 'OneHot', + 'Split': 'Split', + 'Squeeze': 'Squeeze', + 'GlobalAveragePool': 'GlobalAveragePool', + 'LeakyRelu': 'LeakyRelu', + 'ReduceSum': 'ReduceSum', + 'ReduceMean': 'ReduceMean', + 'Dropout': 'Dropout', + 'ConstantOfShape': 'ConstantOfShape', + 'Transpose': 'Transpose', + 'HardSigmoid': 'HardSigmoid', + 'Elu': 'Elu', + 'Selu': 'SeLU', + 'Concat': 'Concat', + 'Softmax': 'SoftMax', + 'Gemm': 'Gemm', + 'Flatten': 'Flatten', + # 'OneHot': 'OneHot', + # 'Tile': 'Tile', + # 'Gather': 'Gather', + # 'Reshape': 'Reshape', + # 'Slice': 'Slice', + # 'Clip': 'Clip', + 'BatchNormalization': 'layer.BatchNorm2d', # layer + 'Conv': 'layer.Conv2d', # layer + 'MaxPool': 'layer.Pooling2d', # layer + 'AveragePool': 'layer.Pooling2d', # layer } # this dict indicates the operators that need extra handle # each indicates a function name _special_operators = { - 'Conv': '_create_conv', - 'MaxPool': '_create_max_avg_pool', - 'AveragePool': '_create_max_avg_pool', - 'BatchNormalization': '_create_batchnorm', + 'Cast': '_create_cast', + 'Split': '_create_split', + 'Squeeze': '_create_squeeze_unsqueeze', + 'Unsqueeze': '_create_squeeze_unsqueeze', + 'GlobalAveragePool': '_create_global_average_pool', + 'LeakyRelu': '_create_leakyrelu', + 'ReduceSum': '_create_reduce_ops', + 'ReduceMean': '_create_reduce_ops', + 'Dropout': '_create_dropout', + 'ConstantOfShape': '_create_constant_of_shape', + 'Transpose': '_create_transpose', + 'HardSigmoid': '_create_hardsigmoid', + 'Elu': '_create_elu', + 'Selu': '_create_selu', 'Concat': '_create_concat', - 'Flatten': '_create_flatten', + 'Softmax': '_create_softmax', 'Gemm': '_create_gemm', + 'Flatten': '_create_flatten', + 'OneHot': '_create_onehot', + 'Tile': '_create_tile', + 'Gather': '_create_gather', 'Reshape': '_create_reshape', - 'Softmax': '_create_softmax', - 'Selu': '_create_selu', - 'Elu': '_create_elu', - 'HardSigmoid': '_create_hardsigmoid', - 'Clip': '_create_clip', - 'Transpose': '_create_transpose', - 'ConstantOfShape': '_create_constantOfShape', - 'Dropout': '_create_dropout', - 'ReduceSum': '_create_reduceOp', - 'ReduceMean': '_create_reduceOp', - 'LeakyRelu': '_create_leakyrelu', - 'GlobalAveragePool': '_create_globalaveragepool', - 'Squeeze': '_create_squeeze', - 'Unsqueeze': '_create_squeeze', 'Slice': '_create_slice', - 'Split': '_create_split', - 'Gather': '_create_gather', - 'Tile': '_create_tile', - 'Cast': '_create_cast', - 'OneHot': '_create_onehot', - 'Constant': "_create_constant" + 'Clip': '_create_clip', + 'BatchNormalization': '_create_batch_norm', + 'Conv': '_create_conv', + 'MaxPool': '_create_max_avg_pool', + 'AveragePool': '_create_max_avg_pool', + 'BatchNormalization': '_create_batch_norm', + 'Conv': '_create_conv', + 'MaxPool': '_create_max_avg_pool', + 'AveragePool': '_create_max_avg_pool', } - @classmethod - def _create_constant(cls, onnx_node, inputs, opset_version): - """ - parse onnx constatn node to weights - Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator - Returns: - forward, the autograd of singa operator - """ - tmp_tensor = onnx_node.getattr('value') - np_dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[tmp_tensor.data_type] - np_tensor = np.frombuffer(tmp_tensor.raw_data, dtype=np_dtype) - if np_tensor.dtype == "int64": - np_tensor = np_tensor.astype(np.int32) - # todo, we cannot support scalar tensor - if np.ndim(np_tensor) == 0: - np_tensor = np.array(np_tensor, ndmin=1) - return None, np_tensor - - @classmethod - def _create_onehot(cls, onnx_node, inputs, opset_version): - """ - get the OneHot operator from onnx node - Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator - Returns: - forward, the autograd of singa operator - """ - axis = onnx_node.getattr("axis", -1) - # we move several inputs to singa's attribuates - # and mark them so we don't use them when we run this operator - depth = tensor.to_numpy(inputs.pop(1)).astype(np.int32) - value = tensor.to_numpy(inputs.pop(1)) - onnx_node.consumed_inputs.extend(onnx_node.inputs[1:]) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(axis, depth, value) + # singa only supports float32 and int32 + _type_map = { + TensorProto.FLOAT: tensor.float32, # FLOAT to float32 + TensorProto.UINT8: None, # UINT8 + TensorProto.INT8: tensor.int32, # INT8 to int32 + TensorProto.UINT16: None, # UINT16 + TensorProto.INT16: tensor.int32, # INT16 to int32 + TensorProto.INT32: tensor.int32, # INT32 to int32 + TensorProto.INT64: tensor.int32, # INT64 to int32 + TensorProto.STRING: None, # stirng + TensorProto.BOOL: None, # bool + } @classmethod - def _create_cast(cls, onnx_node, inputs, opset_version): + def _create_cast(cls, onnx_node, operator, opset_version=_opset_version): """ get the Cast operator from onnx node Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator - Returns: - forward, the autograd of singa operator - """ - to = onnx_node.getattr("to") - # singa only supports float32 and int32 - map_dict = { - TensorProto.FLOAT: tensor.float32, # FLOAT to float32 - TensorProto.UINT8: None, # UINT8 - TensorProto.INT8: tensor.int32, # INT8 to int32 - TensorProto.UINT16: None, # UINT16 - TensorProto.INT16: tensor.int32, # INT16 to int32 - TensorProto.INT32: tensor.int32, # INT32 to int32 - TensorProto.INT64: tensor.int32, # INT64 to int32 - TensorProto.STRING: None, # stirng - TensorProto.BOOL: None, # bool - } - to = map_dict[to] - assert to != None, "not support cast type: {}".format(to) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(to) - - @classmethod - def _create_tile(cls, onnx_node, inputs, opset_version): - """ - get the Tile operator from onnx node - Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator - Returns: - forward, the autograd of singa operator - """ - # we move several inputs to singa's attribuates - # and mark them so we don't use them when we run this operator - repeats = tensor.to_numpy(inputs.pop(1)).astype(np.int32).tolist() - onnx_node.consumed_inputs.append(onnx_node.inputs[1]) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(repeats) - - @classmethod - def _create_gather(cls, onnx_node, inputs, opset_version): - """ - get the Gather operator from onnx node - Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - handle, the handle of singa operator - Returns: - forward, the autograd of singa operator + singa operator instance """ - axis = onnx_node.getattr("axis", 0) - # we move several inputs to singa's attribuates - # and mark them so we don't use them when we run this operator - indices = tensor.to_numpy(inputs.pop(1)).astype(np.int32).tolist() - onnx_node.consumed_inputs.append(onnx_node.inputs[1]) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(axis, indices) + to_type = cls._type_map[onnx_node.getattr("to")] + assert to_type != None, "not support cast type: {}".format(to_type) + return operator(to_type) @classmethod - def _create_split(cls, onnx_node, inputs, opset_version): + def _create_split(cls, onnx_node, operator, opset_version=_opset_version): """ get the Split operator from onnx node Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - forward, the autograd of singa operator + singa operator instance """ axis = onnx_node.getattr("axis", 0) split = onnx_node.getattr("split", None) num_output = len(onnx_node.outputs) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(axis, split, num_output) - - @classmethod - def _create_slice(cls, onnx_node, inputs, opset_version): - """ - get the Slice operator from onnx node - Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator - Returns: - forward, the autograd of singa operator - """ - # we move several inputs to singa's attribuates - # and mark them so we don't use them when we run this operator - starts = tensor.to_numpy(inputs.pop(1)).astype(np.int32).tolist() - ends = tensor.to_numpy(inputs.pop(1)).astype(np.int32).tolist() - # sometime onnx may ignore these two inputs, axes and step - if len(inputs) >= 2 and onnx_node.inputs[3] != '': - axes = tensor.to_numpy(inputs.pop(1)).astype(np.int32).tolist() - else: - axes = None - steps = tensor.to_numpy(inputs.pop(1)).astype( - np.int32).tolist() if len(inputs) >= 2 else None - onnx_node.consumed_inputs.extend(onnx_node.inputs[1:]) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(starts, ends, axes, steps) + return operator(axis, split, num_output) @classmethod - def _create_squeeze(cls, onnx_node, inputs, opset_version): + def _create_squeeze_unsqueeze(cls, + onnx_node, + operator, + opset_version=_opset_version): """ get the Squeeze and Unsqueeze operator from onnx node Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - forward, the autograd of singa operator + singa operator instance """ axes = onnx_node.getattr("axes") - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(axes) + return operator(axes) @classmethod - def _create_globalaveragepool(cls, onnx_node, inputs, opset_version): + def _create_global_average_pool(cls, + onnx_node, + operator, + opset_version=_opset_version): """ get the GlobalAveragePool operator from onnx node Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - forward, the autograd of singa operator + singa operator instance """ data_format = onnx_node.getattr("data_format", 'channels_first') - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(data_format) + return operator(data_format) @classmethod - def _create_leakyrelu(cls, onnx_node, inputs, opset_version): + def _create_leakyrelu(cls, + onnx_node, + operator, + opset_version=_opset_version): """ get the LeakyRelu operator from onnx node Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - handle, the handle of singa operator - Returns: - forward, the autograd of singa operator + singa operator instance """ alpha = onnx_node.getattr("alpha", 0.01) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(alpha) + return operator(alpha) @classmethod - def _create_reduceOp(cls, onnx_node, inputs, opset_version): + def _create_reduce_ops(cls, + onnx_node, + operator, + opset_version=_opset_version): """ get the ReduceSum, ReduceMean, ReduceMax, ReduceMin, etc, operator from onnx node Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - handle, the handle of singa operator - Returns: - forward, the autograd of singa operator + singa operator instance """ axes = onnx_node.getattr("axes", None) keepdims = onnx_node.getattr("keepdims", 1) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(axes, keepdims) + return operator(axes, keepdims) @classmethod - def _create_dropout(cls, onnx_node, inputs, opset_version): + def _create_dropout(cls, onnx_node, operator, opset_version=_opset_version): """ get the Dropout operator from onnx node Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - handle, the handle of singa operator - Returns: - forward, the autograd of singa operator + singa operator instance """ ratio = onnx_node.getattr("ratio", 0) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(ratio) + return operator(ratio) @classmethod - def _create_constantOfShape(cls, onnx_node, inputs, opset_version): + def _create_constant_of_shape(cls, + onnx_node, + operator, + opset_version=_opset_version): """ get the ConstantOfShape operator from onnx node Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - handle, the handle of singa operator - Returns: - forward, the autograd of singa operator + singa operator instance """ value = onnx_node.getattr("value", 0) if isinstance(value, onnx.TensorProto): value = numpy_helper.to_array(value)[0].item() - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(value) + return operator(value) @classmethod - def _create_transpose(cls, onnx_node, inputs, opset_version): + def _create_transpose(cls, + onnx_node, + operator, + opset_version=_opset_version): """ get the Transpose operator from onnx node Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator - Returns: - forward, the autograd of singa operator - """ - shape = inputs[0].shape - perm = onnx_node.getattr("perm", list(range(len(shape) - 1, -1, -1))) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(perm) - - @classmethod - def _create_clip(cls, onnx_node, inputs, opset_version): - """ - get the clip operator from onnx node - Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - handle, the handle of singa operator - Returns: - forward, the autograd of singa operator + singa operator instance """ - # sometime onnx may ignore these two inputs, min or max or both - if len(inputs) >= 2 and onnx_node.inputs[1] != '': - min_v = tensor.to_numpy(inputs.pop(1)).tolist()[0] - else: - min_v = None - if len(inputs) >= 2 and onnx_node.inputs[2] != '': - max_v = tensor.to_numpy(inputs.pop(1)).tolist()[0] - else: - max_v = None - onnx_node.consumed_inputs.extend(onnx_node.inputs[1:]) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(min_v, max_v) + perm = onnx_node.getattr("perm") + return operator(perm) @classmethod - def _create_hardsigmoid(cls, onnx_node, inputs, opset_version): + def _create_hardsigmoid(cls, + onnx_node, + operator, + opset_version=_opset_version): """ - get the HardSigmoid operator from onnx node - Args: - onnx_node: a given onnx node + get the hardsigmoid operator from onnx node Args: - inputs: the input tensor - Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + + opset_version (int): the opset version Returns: - forward, the autograd of singa operator + singa operator instance """ alpha = onnx_node.getattr("alpha", 0.2) beta = onnx_node.getattr("beta", 0.5) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(alpha, beta) + return operator(alpha, beta) @classmethod - def _create_elu(cls, onnx_node, inputs, opset_version): + def _create_elu(cls, onnx_node, operator, opset_version=_opset_version): """ get the elu operator from onnx node Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - forward, the autograd of singa operator + singa operator instance """ alpha = onnx_node.getattr("alpha", 1.) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(alpha) + return operator(alpha) @classmethod - def _create_selu(cls, onnx_node, inputs, opset_version): + def _create_selu(cls, onnx_node, operator, opset_version=_opset_version): """ get the selu operator from onnx node Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - forward, the autograd of singa operator + singa operator instance """ alpha = onnx_node.getattr("alpha", 1.67326) gamma = onnx_node.getattr("gamma", 1.0507) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(alpha, gamma) + return operator(alpha, gamma) @classmethod - def _create_reshape(cls, onnx_node, inputs, opset_version): + def _create_concat(cls, onnx_node, operator, opset_version=_opset_version): """ - get the reshape operator from onnx node - Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor + get the concat operator from onnx node Args: - opset_version: the opset version - Returns: - the handle of singa operator + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - the autograd of singa operator + singa operator instance """ - shape = tensor.to_numpy(inputs.pop(1)).astype(np.int32).tolist() - onnx_node.consumed_inputs.append(onnx_node.inputs[1]) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(shape) + factor = onnx_node.getattr('axis') + return operator(axis=factor) @classmethod - def _create_conv(cls, onnx_node, inputs, opset_version): + def _create_softmax(cls, onnx_node, operator, opset_version=_opset_version): """ - get the conv operator from onnx node + get the softmax operator from onnx node Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - forward, the autograd of singa operator + singa operator instance """ - kernel = tuple(onnx_node.attrs["kernel_shape"]) - padding = tuple( - onnx_node.attrs["pads"]) if "pads" in onnx_node.attrs else (0, 0) - stride = tuple(onnx_node.getattr('strides', (1, 1))) - # default the odd_padding is 0, once there are same pad mode, we modify it - # for odd_padding, please refer the autegrade.py - odd_padding = (0, 0, 0, 0) - if "auto_pad" in onnx_node.attrs: - auto_pad = utils.force_unicode(onnx_node.attrs['auto_pad']) - if auto_pad in ('SAME_UPPER', 'SAME_LOWER'): - padding, odd_padding = utils.get_padding_shape( - auto_pad, inputs[0].shape[2:], kernel, stride) - - # not support dilation - dilation = onnx_node.getattr('dilations', 1) - if dilation != 1 and list(dilation) != [1, 1]: - raise ValueError("Not implemented yet for dilation") - group = onnx_node.getattr('group', 1) - - # only support 1d or 2d - if len(kernel) > 2: - raise ValueError("Only implemented for 1d or 2d") - - bias = len(inputs) == 3 - x = inputs[0] - x_shape = inputs[0].shape - in_channels = x_shape[1] - w_shape = inputs[1].shape - out_channels = w_shape[0] - assert w_shape[1] == in_channels // group - - if inputs[0].device.id() == -1: - if group != 1: - raise NotImplementedError - else: - handle = singa.ConvHandle(x.data, kernel, stride, padding, - in_channels, out_channels, bias, - group) - else: - handle = singa.CudnnConvHandle(x.data, kernel, stride, padding, - in_channels, out_channels, bias, - group) - - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(handle, odd_padding) + factor = onnx_node.getattr('axis', 1) + return operator(axis=factor) @classmethod - def _create_max_avg_pool(cls, onnx_node, inputs, opset_version): + def _create_gemm(cls, onnx_node, operator, opset_version=_opset_version): """ - get the max or avg pool operator from onnx node - Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor + get the gemm operator from onnx node Args: - opset_version: the opset version - Returns: - handle, the handle of singa operator + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - forward, the autograd of singa operator + singa operator instance """ - kernel = tuple(onnx_node.attrs["kernel_shape"]) - padding = tuple( - onnx_node.attrs["pads"]) if "pads" in onnx_node.attrs else (0, 0) - stride = tuple(onnx_node.getattr('strides', (1, 1))) - # default the odd_padding is 0, once there are same pad mode, we modify it - # for odd_padding, please refer the autegrade.py - odd_padding = (0, 0, 0, 0) - if "auto_pad" in onnx_node.attrs: - auto_pad = utils.force_unicode(onnx_node.attrs['auto_pad']) - if auto_pad in ('SAME_UPPER', 'SAME_LOWER'): - padding, odd_padding = utils.get_padding_shape( - auto_pad, inputs[0].shape[2:], kernel, stride) - - # not support count_include_pad and auto_pad - if "count_include_pad" in onnx_node.attrs or "ceil_mode" in onnx_node.attrs: - raise ValueError( - "Not implemented yet for count_include_pad or ceil_mode") - - # only support 2d - if len(kernel) != 2: - raise ValueError("Not implemented yet") - - is_max = onnx_node.op_type == 'MaxPool' - x = inputs[0] - if x.device.id() == -1: - handle = singa.PoolingHandle(x.data, kernel, stride, padding, - is_max) - else: - handle = singa.CudnnPoolingHandle(x.data, kernel, stride, padding, - is_max) - - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return _, forward(handle, odd_padding) + alpha = onnx_node.getattr('alpha', 1.) + beta = onnx_node.getattr('beta', 1.) + transA = onnx_node.getattr('transA', 0) + transB = onnx_node.getattr('transB', 0) + return operator(alpha=alpha, beta=beta, transA=transA, transB=transB) @classmethod - def _create_batchnorm(cls, onnx_node, inputs, opset_version): + def _create_flatten(cls, onnx_node, operator, opset_version=_opset_version): """ - get the batch norm operator from onnx node - Args:onnx_node: a given onnx node - Args:inputs: the input tensor - Args:opset_version: the opset version - Returns: the handle of singa operator - Returns: the autograd of singa operator + get the flatten operator from onnx node + Args: + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version + Returns: + singa operator instance """ - x = inputs[0] - factor = onnx_node.getattr('momentum', 0.9) - if x.device.id() == -1: - handle = singa.BatchNormHandle(factor, x.data) - else: - handle = singa.CudnnBatchNormHandle(factor, x.data) - - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return handle, forward + factor = onnx_node.getattr('axis', 1) + return operator(axis=factor) @classmethod - def _create_concat(cls, onnx_node, inputs, opset_version): + def _create_onehot(cls, onnx_node, operator, opset_version=_opset_version): """ - get the concat operator from onnx node - Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor + get the OneHot operator from onnx node Args: - opset_version: the opset version - Returns: - the handle of singa operator + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - the autograd of singa operator + singa operator instance """ - factor = onnx_node.attrs["axis"] - if factor < 0: - factor = len(inputs[0].shape - ) + factor # in order to support the negative axis - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return None, forward(axis=factor) + axis = onnx_node.getattr("axis", -1) + onnx_node.set_attr_inputs(onnx_node.inputs[1], 'depth') + onnx_node.set_attr_inputs(onnx_node.inputs[2], 'values') + return operator(axis, None, None) @classmethod - def _create_softmax(cls, onnx_node, inputs, opset_version): + def _create_tile(cls, onnx_node, operator, opset_version=_opset_version): """ - get the concat operator from onnx node - Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor + get the Tile operator from onnx node Args: - opset_version: the opset version - Returns: - the handle of singa operator + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - the autograd of singa operator + singa operator instance """ - factor = onnx_node.getattr('axis', 1) - if factor < 0: - # in order to support the negative axis - factor = len(inputs[0].shape) + factor - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return None, forward(axis=factor) + onnx_node.set_attr_inputs(onnx_node.inputs[1], 'repeats') + return operator(None) @classmethod - def _create_gemm(cls, onnx_node, inputs, opset_version): + def _create_gather(cls, onnx_node, operator, opset_version=_opset_version): """ - get the gemm operator from onnx node - Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor + get the Gather operator from onnx node Args: - opset_version: the opset version + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - the handle of singa operator + singa operator instance + """ + axis = onnx_node.getattr("axis", 0) + onnx_node.set_attr_inputs(onnx_node.inputs[1], 'indices') + return operator(axis, None) + + @classmethod + def _create_reshape(cls, onnx_node, operator, opset_version=_opset_version): + """ + get the reshape operator from onnx node + Args: + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - the autograd of singa operator + singa operator instance """ - x = inputs[0] - alpha = onnx_node.getattr('alpha', 1.) - beta = onnx_node.getattr('beta', 1.) - transA = onnx_node.getattr('transA', 0) - transB = onnx_node.getattr('transB', 0) - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return None, forward(alpha=alpha, - beta=beta, - transA=transA, - transB=transB) + onnx_node.set_attr_inputs(onnx_node.inputs[1], 'shape') + return operator(None) @classmethod - def _create_flatten(cls, onnx_node, inputs, opset_version): + def _create_slice(cls, onnx_node, operator, opset_version=_opset_version): """ - get the flatten operator from onnx node + get the Slice operator from onnx node Args: - onnx_node: a given onnx node + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version + Returns: + singa operator instance + """ + onnx_node.set_attr_inputs(onnx_node.inputs[1], 'starts') + onnx_node.set_attr_inputs(onnx_node.inputs[2], 'ends') + if len(onnx_node.inputs) >= 3 and onnx_node.inputs[2] != '': + onnx_node.set_attr_inputs(onnx_node.inputs[2], 'axes') + if len(onnx_node.inputs) == 4: + onnx_node.set_attr_inputs(onnx_node.inputs[3], 'steps') + return operator(None, None, None, None) + + @classmethod + def _create_clip(cls, onnx_node, operator, opset_version=_opset_version): + """ + get the clip operator from onnx node Args: - inputs: the input tensor + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version + Returns: + singa operator instance + """ + onnx_node.set_attr_inputs(onnx_node.inputs[1], 'starts') + if len(onnx_node.inputs) >= 2 and onnx_node.inputs[1] != '': + onnx_node.set_attr_inputs(onnx_node.inputs[1], 'min') + if len(onnx_node.inputs) == 3: + onnx_node.set_attr_inputs(onnx_node.inputs[2], 'max') + return operator(None, None) + + @classmethod + def _create_batch_norm(cls, + onnx_node, + operator, + opset_version=_opset_version): + """ + get the clip operator from onnx node Args: - opset_version: the opset version + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - the handle of singa operator + singa operator instance + """ + factor = onnx_node.getattr('momentum', 0.9) + onnx_node.set_weight_inputs(onnx_node.inputs[1], 'scale') + onnx_node.set_weight_inputs(onnx_node.inputs[2], 'bias') + onnx_node.set_weight_inputs(onnx_node.inputs[3], 'running_mean') + onnx_node.set_weight_inputs(onnx_node.inputs[4], 'running_var') + return operator(factor) + + @classmethod + def _create_conv(cls, onnx_node, operator, opset_version=_opset_version): + """ + get the clip operator from onnx node + Args: + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - the autograd of singa operator + singa operator instance """ - factor = onnx_node.getattr('axis', 1) - if factor < 0: - # in order to support the negative axis - factor = len(inputs[0].shape) + factor + kernel_size = tuple(onnx_node.getattr('kernel_shape')) + padding = tuple(onnx_node.getattr('pads', (0, 0))) + stride = tuple(onnx_node.getattr('strides', (1, 1))) + auto_pad = utils.force_unicode(onnx_node.getattr('auto_pad', 'NOTSET')) + + # not support dilation + dilation = onnx_node.getattr('dilations', 1) + if dilation != 1 and list(dilation) != [1, 1]: + raise ValueError("Not implemented yet for dilation") + group = onnx_node.getattr('group', 1) + + # only support 1d or 2d + if len(kernel_size) > 2: + raise ValueError("Only implemented for 1d or 2d") - _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, - opset_version) - return None, forward(axis=factor) + onnx_node.set_weight_inputs(onnx_node.inputs[1], 'W') + bias = False + if len(onnx_node.inputs) == 3: + onnx_node.set_weight_inputs(onnx_node.inputs[2], 'B') + bias = True + # todo, in_channels, out_channels + + return operator(None, None, kernel_size, stride, padding, dilation, + group, bias, auto_pad) @classmethod - def _common_onnx_node_to_singa_op(cls, onnx_node, inputs, opset_version): + def _create_max_avg_pool(cls, + onnx_node, + operator, + opset_version=_opset_version): """ - get a common singa operator(only autograd) from a onnx node - other special operators also can call this func to get autograd - Args: - onnx_node: a given onnx node - Args: - tensor_map: the input tensor + get the clip operator from onnx node Args: - opset_version: the opset version - Returns: - a dict of tensors + onnx_node (OnnxNode): a given onnx node + operator (Operator Class): a singa operator class + opset_version (int): the opset version Returns: - a list of SingaOps('name', 'op', 'handle', 'forward') + singa operator instance """ - onnx_op_type = onnx_node.op_type - assert onnx_op_type in cls._rename_operators, "not support operator: {}".format( - onnx_op_type) - autograd_op = getattr(autograd, cls._rename_operators[onnx_op_type]) - return None, autograd_op + kernel_size = tuple(onnx_node.getattr('kernel_shape')) + padding = tuple(onnx_node.getattr('pads', (0, 0))) + stride = tuple(onnx_node.getattr('strides', (1, 1))) + auto_pad = utils.force_unicode(onnx_node.getattr('auto_pad', 'NOTSET')) + + # not support count_include_pad and auto_pad + ceil_mode = onnx_node.getattr('pads', 0) + count_include_pad = onnx_node.getattr('count_include_pad', 0) + if ceil_mode != 0 or count_include_pad != 0: + raise ValueError( + "Not implemented yet for count_include_pad or ceil_mode") + + # only support 1d or 2d + if len(kernel_size) > 2: + raise ValueError("Only implemented for 1d or 2d") + + is_max = onnx_node.op_type == 'MaxPool' + return operator(kernel_size, stride, padding, is_max, auto_pad) @classmethod - def _onnx_node_to_singa_op(cls, - onnx_node, - inputs, - opset_version=_known_opset_version): + def _onnx_constant_to_np(cls, onnx_node, opset_version): """ - get a singa operator(handle and autograd) from a onnx node + parse onnx constatn node to numpy array Args: - onnx_node: a given onnx node - Args: - inputs: the input list - Args: - opset_version: the opset version + onnx_node (OnnxNode): a given onnx node + opset_version (int): the opset version Returns: - a dict of tensors + a numpy ndarray + """ + onnx_tensor = onnx_node.getattr('value') + np_dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[onnx_tensor.data_type] + np_tensor = np.frombuffer(onnx_tensor.raw_data, dtype=np_dtype) + return tensor.from_numpy(np_tensor) + + @classmethod + def _onnx_node_to_singa_op(cls, onnx_node, opset_version=_opset_version): + """ + get singa operator from a onnx node + Args: + onnx_node (OnnxNode): a given onnx node + opset_version (int): the opset version Returns: - a list of SingaOps('name', 'op', 'handle', 'forward') + singa operator instance """ + onnx_op_type = onnx_node.op_type + assert onnx_op_type in cls._rename_operators, "not support operator: {}".format( + onnx_op_type) + renamed_op = cls._rename_operators[onnx_op_type] + if renamed_op.startswith('layer.'): + singa_op = getattr(layer, renamed_op[6:]) + else: + singa_op = getattr(autograd, renamed_op) if onnx_node.op_type in cls._special_operators: translator = getattr(cls, cls._special_operators[onnx_node.op_type]) + return translator(onnx_node, singa_op, opset_version) else: - translator = cls._common_onnx_node_to_singa_op - return translator(onnx_node, inputs, opset_version) + return singa_op() @classmethod - def run_node(cls, onnx_node, inputs, opset_version=_known_opset_version): + def run_node(cls, node, inputs, device='CPU', opset_version=_opset_version): """ run a single singa operator from a onnx node Args: - onnx_node: a given onnx node - Args: - inputs: the input tensor - Args: - device: the used device - Args: - opset_version: the opset version - Returns: - list, the output of the + node (NodeProto): a given onnx node + inputs (ndarray[]): a list of numpy ndarray + device (string): CPU or CUDA + opset_version (int): the opset version + Returns: + list, the output """ - valid_inputs = [x for x in onnx_node.inputs if x != ""] + node = OnnxNode(node) + valid_inputs = [x for x in node.inputs if x != ""] assert len(valid_inputs) == len( - inputs), "{}: expected {} but got {}".format( - onnx_node.op_type, len(valid_inputs), len(inputs)) - - tmp_inputs = [inputs[x] for x in onnx_node.inputs if x != ""] - handle, forward = cls._onnx_node_to_singa_op(onnx_node, tmp_inputs, - opset_version) - # only give the inputs it needs - # consumed_inputs are the inputs marked as attributes - # so we remove it here - tmp_inputs = [ - inputs[x] - for x in onnx_node.inputs - if x not in onnx_node.consumed_inputs - ] - return cls._run_node(onnx_node, tmp_inputs, handle, forward, - opset_version) + inputs), "{}: expected {} inputs, but got {}. ".format( + node.op_type, len(valid_inputs), len(inputs)) + + operator = cls._onnx_node_to_singa_op(node, opset_version) + # seperate weights with inputs, and init inputs as Tensor + weights = {} + _inputs = [] + for (key, val) in zip(valid_inputs, inputs): + val = val.astype(cls._type_map[val.dtype]) + if key in node.weight_inputs: + weights[key] = val + else: + x = tensor.from_numpy(val) + if device == 'CPU': + assert singa.USE_CUDA, "Your SINGA doesn't compile GPU module." + dev = device.create_cuda_gpu(set_default=False) + else: + dev = device.get_default_device() + x.to_device(dev) + _inputs.append(x) + inputs = _inputs + # set params + params = {} + for key, name in node.weight_inputs: + params[name] = weights[key] + operator.set_params(params) + outputs = cls._run_node(operator, inputs) + outputs_dict = OrderedDict() + for (key, val) in zip(node.outputs, outputs): + outputs_dict[key] = val + return outputs_dict @classmethod - def _run_node(cls, - onnx_node, - inputs, - handle, - forward, - opset_version=_known_opset_version): + def _run_node(cls, operator, inputs): """ - run a single singa operator from a onnx node - Args:inputs: - the input tensor - Args:handle: - the handle of singa operator - Args:forward: - the forward of singa operator - Args: - opset_version: the opset version - Returns: - list, the output of the + run a single singa operator from singa operator + Args: + operator (Operator): the Operator instance + inputs (Tensor[]): a list of SINGA Tensor + Returns: + list, the output """ - outputs = forward(*inputs) if handle is None else forward( - handle, *inputs) + outputs = operator(*inputs) if not isinstance(outputs, collections.Iterable): outputs = [outputs] - outputs_dict = OrderedDict() - for (key, val) in zip(onnx_node.outputs, outputs): - outputs_dict[key] = val - return outputs_dict + return outputs @classmethod - def _init_graph_parameter(cls, graph, init_inputs, device): + def _parse_graph_params(cls, graph, device): """ - init the singa tensor from onnx infos + parse the parameters from onnx graph Args: - graph: a given onnx graph - Args: - init_inputs: a list of inputs, which used to init the operators - Args: - device: the used device + graph (Graph): a given onnx graph + device (string): CPU or CUDA Returns: - a dict of tensors + a dict of numpy ndarray """ - tensor_map = {} - # due to https://github.com/onnx/onnx/issues/2417 - # sometimes, input contains all initializer's info - # sometimes, may not - all_inputs = OrderedDict() - for t in graph.input: - all_inputs[t.name] = t - # so we refresh the input by the initializer - for t in graph.initializer: - all_inputs[t.name] = t + params = {} initializers = {t.name for t in graph.initializer} - inp_idx = 0 - for name, x in all_inputs.items(): - if name in initializers: - # if it has initializer, we use its value as the input - np_tensor = numpy_helper.to_array(x) - if np_tensor.dtype == "int64": - np_tensor = np_tensor.astype(np.int32) - # todo, we cannot support scalar tensor - if np.ndim(np_tensor) == 0: - np_tensor = np.array(np_tensor, ndmin=1) - else: - # if not, means it's a input rather than a inner weight - # so if the user gives values, we use these values - # if not, we just use the shape of input gived by onnx to init a random value - # HOWEVER, the random value may not be correct for some inputs, such as gather which needs indices - # so if have operators, the user must give inputs - x_shape = tuple( - dim.dim_value for dim in x.type.tensor_type.shape.dim) - if init_inputs is not None: - np_tensor = init_inputs[inp_idx] - inp_idx += 1 - else: - np_tensor = np.random.randn(*x_shape).astype(np.float32) - tmp_tensor = tensor.from_numpy(np_tensor) - tmp_tensor.to_device(device) - # todo, for backward - tmp_tensor.stores_grad = (name in initializers) - tensor_map[x.name] = tmp_tensor - return tensor_map + for tp in graph.initializer: + val = numpy_helper.to_array(tp) + val = val.astype(cls._type_map[val.dtype]) + params[tp.name] = val + return params @classmethod - def _onnx_model_to_singa_net(cls, model, init_inputs, device, - opset_version): + def _parse_graph_inputs_outputs(cls, graph, params, device): """ - get all intermediate tensors and operators from onnx model - Args: - model: a given onnx model - Args: - init_inputs: a list of inputs, which used to init the operators + parse the inits, outputs from onnx graph Args: - device: the used device - Args: - opset_version: the opset version - Returns: - a dict of tensors + graph (Graph): a given onnx graph + device (string): # CPU or CUDA Returns: - a list of SingaOps('name', 'op', 'handle', 'forward') + a dict of ValueInfo + a dict of ValueInfo """ - # init all tensor input and weight as a tensor map - tensor_map = cls._init_graph_parameter(model.graph, init_inputs, device) - # only weights tensor - weights = {x.name: tensor_map[x.name] for x in model.graph.initializer} + inputs = {} + outputs = {} + for t in graph.input: + if t.name not in params: + inputs[t.name] = t + for t in graph.output: + outputs[t.name] = t + return inputs, outputs + + @classmethod + def _onnx_model_to_singa_ops(cls, graph, device, opset_version): + """ + get all intermediate params, operators, and input info from onnx model + Args: + graph (Graph): the loaded ONNX graph + device (string): CPU or CUDA + opset_version (int): the opset version + Returns: + a dict of weights + a dict of ValueInfo + a dict of ValueInfo + a list of SingaOps('node', 'forward') + """ + # init all tensor input and params as a tensor map + params = cls._parse_graph_params(graph, device) + inputs, outputs = cls._parse_graph_inputs(graph, params, device) # the parsed operators queue - singa_ops = [] - singa_op = namedtuple('SingaOps', ['name', 'op', 'handle', 'forward']) - for node in model.graph.node: + operators = [] + operator_tuple = namedtuple('operator_tuple', ['node', 'operator']) + for node in graph.node: node = OnnxNode(node) - # only give the inputs it needs - # consumed_inputs are the inputs marked as attributes - # so we remove it here - inputs = [ - tensor_map[x] - for x in node.inputs - if x not in node.consumed_inputs - ] - handle, forward = cls._onnx_node_to_singa_op( - node, inputs, opset_version) - # if it is Constant, we hanlde it as a weight - # otherwise, we run it and add its output into map for being used by later operators + # convert Constant to param if node.op_type == 'Constant': - tmp_tensor = tensor.from_numpy(forward) - tmp_tensor.to_device(device) - tmp_name = node.outputs.pop(0) - weights[tmp_name] = tmp_tensor - tensor_map[tmp_name] = tmp_tensor + params[node.outputs[0]] = cls._onnx_constant_to_np(node) else: - outputs = cls._run_node(node, inputs, handle, forward) - for key, val in outputs.items(): - tensor_map[key] = val - singa_ops.extend([singa_op(node.name, node, handle, forward)]) - return weights, singa_ops + node_params = [inputs[x] for x in node.inputs if x in params] + op = cls._onnx_node_to_singa_op(node, node_params, + opset_version) + operators.extend([operator_tuple(node, op)]) + return params, inputs, outputs, operators @classmethod - def prepare(cls, model, device, **kwargs): + def prepare(cls, model, device='CPU', **kwargs): """ - get the batch norm operator from onnx node + parse the ONNX and to create layers Args: - model: a given onnx node - Args: - device: the used device - Returns: - a list of output values + model (ModelProto): the loaded ONNX model + device (string): CPU or CUDA + Returns: + a SingaRep instance to stores the layers and weights """ super(SingaBackend, cls).prepare(model, device, **kwargs) - # when parsing graph, we use the shape of input gived by onnx to init a random value - # HOWEVER, the random value may not be correct for some inputs, such as gather which needs indices - # so if have operators, the user must give inputs - init_inputs = kwargs.get("init_inputs", None) - # whether initializers are moved into inputs, due to https://github.com/onnx/onnx/issues/2417 - # sometimes, input contains all initializer's info, sometimes, may not - cls.keep_initializers_as_inputs = kwargs.get( - 'keep_initializers_as_inputs', True) # optimize and infer the shape of the model try: model = onnx.utils.polish_model(model) except IndexError as err: - # due to https://github.com/onnx/onnx/issues/2417 model = onnx.shape_inference.infer_shapes(model) # check the opset version and ir version + # SINGA supports opset version(11), ir version(1.6.0 -> 6) opset_version = None for imp in model.opset_import: if not imp.HasField("domain") or imp.domain == "": opset_version = imp.version - if imp.version > cls._known_opset_version: + if imp.version > cls._opset_version: warnings.warn( - "This version of singa targets ONNX operator set version {}, but the model we are trying to import uses version {}. We will try to import it anyway, but if the model uses operators which had BC-breaking changes in the intervening versions, import will fail." - .format(cls._known_opset_version, imp.version)) + "The imported opertor set verion {} is larger than the supported version {}." + .format(imp.version, cls._opset_version)) else: warnings.warn("Unrecognized operator set {}".format(imp.domain)) - if opset_version is None: - if model.ir_version >= 0x00000003: - raise RuntimeError( - "Model with IR version >= 3 did not specify ONNX operator set version (singa requires it)" - ) - else: - opset_version = 1 - weights, singa_ops = cls._onnx_model_to_singa_net( - model, init_inputs, device, opset_version) - return SingaRep(model, weights, singa_ops, - cls.keep_initializers_as_inputs) + + if model.ir_version > cls._ir_version: + warnings.warn( + "The imported ir verion {} is larger than the supported version {}." + .format(cls._ir_version, imp.version)) + + graph = model.graph + params, inputs, outputs, layers = cls._onnx_model_to_singa_ops( + graph, device, opset_version) + del graph.initializer # remove the inited initializer + return SingaRep(params, inputs, outputs, layers) class SingaRep(BackendRep): - def __init__(self, - model, - weights, - singa_ops, - keep_initializers_as_inputs=True): + def __init__(self, params, inputs, outputs, layers, device): """ + https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md SingaRep provides the intermediate representation of Singa, the user can run the forward of the singa model by run func, or, the user can append more layers after the singa_ops to do the transfer learning Args: - model: a given operator - Args: - weights: the tensor of weights - Args: - singa_ops: the tensor of the operator + params (dict{}): a dict of params, data type is numpy ndarray + inputs (ValueInfo): a dict of inputs + outputs (ValueInfo): a dict of outputs + layers (namedtuple('operator_tuple', ['node', 'operator'])[]): a list of singa operator + device (string): CPU or CUDA """ super(SingaRep, self).__init__() - self.model = model - self.tensor_map = weights - self.keep_initializers_as_inputs = keep_initializers_as_inputs - # this each item of singa_ops is: ('name', 'op', 'handle', 'forward') - # the name is a string, op is OnnxNode, - # handle is Singa handle to store the tensor into singa operator - # the forward is singa autograd operator - self.singa_ops = singa_ops + self.states = params + self.inputs = inputs + self.outputs = outputs + self.layers = layers + self.device = device + self.has_initialized = False + self.x = [] + + def initialize(self, x): + """ + init tensors and operators + Args: + x (np.ndarray[]): a list of numpy ndarray as inputs + Returns: + the onnx node + """ + weights = {} + _inputs = [] + # init inputs as Tensor + for (key, val) in zip(self.inputs, x): + val = val.astype(SingaBackend._type_map[val.dtype]) + x = tensor.from_numpy(val) + if device == 'CPU': + assert singa.USE_CUDA, "Your SINGA doesn't compile GPU module." + dev = device.create_cuda_gpu(set_default=False) + else: + dev = device.get_default_device() + x.to_device(dev) + _inputs.append(x) + x = _inputs + + layers = [] # layers by topo order + for node, operator in self.layers: + if node.weight_inputs: + states = {} + for key, name in node.weight_inputs: + states[name] = self.states[key] + # init the operator + operator.initialize(x) + operator.has_initialized = True + # set states + operator.set_states(states) + self.__dict__[node.name] = operator + # init the tensor count + for inp in node.inputs: + if inp not in self.tensor_count: + self.tensor_count[inp] = 0 + else: + self.tensor_count[inp] += 1 + layers.append(node) + # gc + del self.states + self.layers = layers + # todo, onnx input -> singa attributes, dynamic - def run(self, inputs, **kwargs): + def run(self, x, **kwargs): """ run the forward of singa model Args: @@ -2074,65 +1866,123 @@ def run(self, inputs, **kwargs): Returns: the onnx node """ - graph = self.model.graph + outputs_dict = OrderedDict([(outp.name, None) for outp in self.outputs]) + # last_layers means we run this model until the last #N layers - last_layers = kwargs.get('last_layers', len(self.singa_ops)) - if last_layers != len(self.singa_ops): - final_outputs = self.singa_ops[last_layers-1].op.outputs - else: - final_outputs = [outp.name for outp in graph.output] - # whether return all outputs - all_outputs = kwargs.get('all_outputs', False) - # get a specific op by its name - op_name = kwargs.get('op_name', None) - # record the tensor we added from input - tmp_tensor_map = {name: val for name, val in self.tensor_map.items()} - - # the dict will be returned - ret_outputs = OrderedDict() - if self.keep_initializers_as_inputs: - require_input_len = len(graph.input) - len(graph.initializer) - actual_input_len = len(inputs) - else: - require_input_len = len(graph.input) - actual_input_len = len(inputs) - assert require_input_len == actual_input_len, "The length of graph input is different from the tensor input: %d, %d" % ( - require_input_len, actual_input_len) - # run the handle by the order of the list(the list is Topological Sorting) - for inp in graph.input: - if inp.name not in tmp_tensor_map: - tmp_tensor_map[inp.name] = inputs.pop(0) - - for _, op, handle, forward in self.singa_ops[:last_layers]: - if len(op.consumed_inputs) != 0: - # because if op has consumed_inputs, it means it moved some inputs into attributes - # so when running, we should update these attributes - handle, forward = get_op(op, - [tmp_tensor_map[x] for x in op.inputs]) + last_layers = kwargs.get('last_layers', len(self.layers)) + if last_layers != len(self.layers): + for outp in self.layers[last_layers - 1].outputs: + outputs_dict[outp] = None + + if self.has_initialized == False: + self.initialize(x) + self.has_initialized = True + + tensor_dict = {} + for (key, val) in zip(self.inputs, self.x): + # todo check shape + tensor_dict[key.name] = val + + # run the layer by the topo order + for node in self.layers[:last_layers]: inputs = [ - tmp_tensor_map[x] - for x in op.inputs - if x not in op.consumed_inputs + tensor_dict[inp] + for inp in node.inputs + if inp not in node.weight_inputs ] - outputs = _run_node(op, inputs, handle, forward) - for key, val in outputs.items(): - tmp_tensor_map[key] = val - ret_outputs[key] = val - - if op_name is not None: - if op_name in outputs: - return outputs[op_name] - else: - raise RuntimeError( - "The op_name {} does not exist, please check. The available op_names are: {}" - .format(op_name, [val for key, val in op_name.items()])) - - # return all outputs if all_outputs==True - # else return last outputs - if all_outputs: - return ret_outputs - else: - return [ret_outputs[outp] for outp in final_outputs] + outputs = _run_node(self.__dict__[node.name], inputs) + # release the input tensor + for inp in node.inputs: + if inp in self.tensor_count: + self.tensor_count[inp] -= 1 + if self.tensor_count[inp] == 0: + del tensor_dict[inp] + del self.tensor_count[inp] + # store the output + for (outp, val) in zip(node.outputs, outputs): + tensor_dict[outp] = val + if outp in outputs_dict: + outputs_dict[outp] = val + return outputs_dict + + +class SONNXModel(module.Module): + + def __init__(self, onnx_model): + """ + Init a SIGNA Module + Args: + onnx_model (ModelProto): a loaded onnx model + """ + super(SONNXModel, self).__init__() + singa_rep = SingaBackend.prepare(onnx_model) + self.states = singa_rep.weights + self.layers = singa_rep.layers + self.inputs = singa_rep.inputs + self.outputs = singa_rep.outputs + self.tensor_count = {} + + def forward(self, *x, aux_output=()): + """ + The forward of the SINGA model + Args: + input (Tensors[]): a list of Tensor + aux_output (string()): a set of required output name + Returns: + a OrderedDict of Tensor + """ + tensor_dict = {} + for (key, val) in zip(self.inputs, x): + # todo check shape + tensor_dict[key.name] = val + + outputs_dict = OrderedDict([(outp.name, None) for outp in self.outputs]) + for outp in aux_output: + outputs_dict[outp] = None + # run the layer by the topo order + for node in self.layers: + inputs = [ + tensor_dict[inp] + for inp in node.inputs + if inp not in node.weight_inputs + ] + outputs = _run_node(self.__dict__[node.name], inputs) + # release the input tensor + for inp in node.inputs: + if inp in self.tensor_count: + self.tensor_count[inp] -= 1 + if self.tensor_count[inp] == 0: + del tensor_dict[inp] + del self.tensor_count[inp] + # store the output + for (outp, val) in zip(node.outputs, outputs): + tensor_dict[outp] = val + if outp in aux_output: + outputs_dict[outp] = val + return outputs_dict + + def compile(self, x, is_train, use_graph, sequential): + super.compile(self, x, is_train, use_graph, sequential) + layers = [] # layers by topo order + for node, operator in self.layers: + # onnx weigths -> singa params + if node.weight_inputs: + states = {} + for key, name in node.weight_inputs: + states[name] = self.states[key] + operator.set_states(states) + self.__dict__[node.name] = operator + # init the tensor count + for inp in node.inputs: + if inp not in self.tensor_count: + self.tensor_count[inp] = 0 + else: + self.tensor_count[inp] += 1 + layers.append(node) + # gc + del self.states + self.layers = layers + # todo, onnx input -> singa attributes, dynamic run_node = SingaBackend.run_node