From 18188f4ba3f53cc1dab765b8a0d932d21db0ae8a Mon Sep 17 00:00:00 2001 From: brett koonce Date: Fri, 27 Sep 2019 14:39:42 -0500 Subject: [PATCH 01/17] docs: minor spelling tweaks (#4027) --- docs/contribute/code_review.rst | 2 +- docs/contribute/error_handling.rst | 2 +- docs/contribute/pull_request.rst | 2 +- docs/deploy/android.md | 4 ++-- docs/dev/hybrid_script.rst | 2 +- docs/dev/inferbound.rst | 8 ++++---- docs/dev/relay_pass_infra.rst | 2 +- docs/faq.md | 2 +- docs/install/from_source.rst | 2 +- docs/langref/hybrid_script.rst | 4 ++-- docs/langref/relay_expr.rst | 4 ++-- docs/langref/relay_type.rst | 4 ++-- docs/vta/dev/config.rst | 4 ++-- docs/vta/install.md | 2 +- 14 files changed, 22 insertions(+), 22 deletions(-) diff --git a/docs/contribute/code_review.rst b/docs/contribute/code_review.rst index 920a29c6445e..f6f2136be9af 100644 --- a/docs/contribute/code_review.rst +++ b/docs/contribute/code_review.rst @@ -22,7 +22,7 @@ Perform Code Reviews This is a general guideline for code reviewers. First of all, while it is great to add new features to a project, we must also be aware that each line of code we introduce also brings **technical debt** that we may have to eventually pay. -Open source code is maintained by a community with diverse backend, and it is even more important to bring clear, documented and maintainable code. Code reviews are shepherding process to spot potential problems, improve quality of the code. We should, however, not rely on code review process to get the code into a ready state. Contributors are encouraged to polish the code to a ready state before requesting reviews. This is especially expected for code owner and comitter candidates. +Open source code is maintained by a community with diverse backend, and it is even more important to bring clear, documented and maintainable code. Code reviews are shepherding process to spot potential problems, improve quality of the code. We should, however, not rely on code review process to get the code into a ready state. Contributors are encouraged to polish the code to a ready state before requesting reviews. This is especially expected for code owner and committer candidates. Here are some checklists for code reviews, it is also helpful reference for contributors diff --git a/docs/contribute/error_handling.rst b/docs/contribute/error_handling.rst index 152f613b8002..4d5e5c54f03c 100644 --- a/docs/contribute/error_handling.rst +++ b/docs/contribute/error_handling.rst @@ -107,7 +107,7 @@ error messages when necessary. def preferred(): # Very clear about what is being raised and what is the error message. - raise OpNotImplemented("Operator relu is not implemented in the MXNet fronend") + raise OpNotImplemented("Operator relu is not implemented in the MXNet frontend") def _op_not_implemented(op_name): return OpNotImplemented("Operator {} is not implemented.").format(op_name) diff --git a/docs/contribute/pull_request.rst b/docs/contribute/pull_request.rst index 4a2a335a95c9..7ad53758c54b 100644 --- a/docs/contribute/pull_request.rst +++ b/docs/contribute/pull_request.rst @@ -51,7 +51,7 @@ We use docker container to create stable CI environments that can be deployed to multiple machines. You can find the prebuilt images in ``_ . Because we want a relatively stable CI environment and make use of pre-cached image, -all of the CI images are built and maintained by comitters. +all of the CI images are built and maintained by committers. Upgrade of CI images can cause problems and need fixes to accommodate the new env. Here is the protocol to update CI image: diff --git a/docs/deploy/android.md b/docs/deploy/android.md index 3e8c1293fb0f..78d67a76b756 100644 --- a/docs/deploy/android.md +++ b/docs/deploy/android.md @@ -22,7 +22,7 @@ NNVM compilation of model for android target could follow same approach like android_rpc. -An reference exampe can be found at [chainer-nnvm-example](https://github.com/tkat0/chainer-nnvm-example) +An reference example can be found at [chainer-nnvm-example](https://github.com/tkat0/chainer-nnvm-example) Above example will directly run the compiled model on RPC target. Below modification at [rum_mobile.py](https://github.com/tkat0/chainer-nnvm-example/blob/5b97fd4d41aa4dde4b0aceb0be311054fb5de451/run_mobile.py#L64) will save the compilation output which is required on android target. @@ -39,4 +39,4 @@ deploy_lib.so, deploy_graph.json, deploy_param.params will go to android target. ## TVM Runtime for Android Target Refer [here](https://github.com/dmlc/tvm/blob/master/apps/android_deploy/README.md#build-and-installation) to build CPU/OpenCL version flavor TVM runtime for android target. -From android java TVM API to load model & execute can be refered at this [java](https://github.com/dmlc/tvm/blob/master/apps/android_deploy/app/src/main/java/ml/dmlc/tvm/android/demo/MainActivity.java) sample source. +From android java TVM API to load model & execute can be referred at this [java](https://github.com/dmlc/tvm/blob/master/apps/android_deploy/app/src/main/java/ml/dmlc/tvm/android/demo/MainActivity.java) sample source. diff --git a/docs/dev/hybrid_script.rst b/docs/dev/hybrid_script.rst index 010af6c87f3b..7bb5e234cba0 100644 --- a/docs/dev/hybrid_script.rst +++ b/docs/dev/hybrid_script.rst @@ -83,7 +83,7 @@ In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel`` Variables ~~~~~~~~~ -Because there is no variables in ``HalideIR``, all the mutatable variables will be lowered to an array with size 1. +Because there is no variables in ``HalideIR``, all the mutable variables will be lowered to an array with size 1. It takes the first store of a variable as its declaration. Math Intrinsics diff --git a/docs/dev/inferbound.rst b/docs/dev/inferbound.rst index e16871e3b7cf..1b74cabf01cc 100644 --- a/docs/dev/inferbound.rst +++ b/docs/dev/inferbound.rst @@ -71,7 +71,7 @@ A TVM schedule is composed of Stages. Each stage has exactly one Operation, e.g. Array outputs; Array stages; Map stage_map; - // remainder ommitted + // remainder omitted }; class StageNode : public Node { @@ -81,14 +81,14 @@ A TVM schedule is composed of Stages. Each stage has exactly one Operation, e.g. Array all_iter_vars; Array leaf_iter_vars; Array relations; - // remainder ommitted + // remainder omitted }; class OperationNode : public Node { public: virtual Array root_iter_vars(); virtual Array InputTensors(); - // remainder ommitted + // remainder omitted }; class ComputeOpNode : public OperationNode { @@ -97,7 +97,7 @@ A TVM schedule is composed of Stages. Each stage has exactly one Operation, e.g. Array reduce_axis; Array body; Array root_iter_vars(); - // remainder ommitted + // remainder omitted }; } diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/relay_pass_infra.rst index b358e957b258..98de347734ba 100644 --- a/docs/dev/relay_pass_infra.rst +++ b/docs/dev/relay_pass_infra.rst @@ -83,7 +83,7 @@ more details). For example, during registration of a pass (will be covered in later), the pass developers can specify the name of the pass, the optimization level it will be performed at, and/or the passes that are required. ``opt_level`` could be used to help the pass infra identify if a certain pass -needes to be executed when running under a user-provided optimization level. The +needs to be executed when running under a user-provided optimization level. The ``required`` field can be used by the pass infra to resolve pass dependencies. .. code:: c++ diff --git a/docs/faq.md b/docs/faq.md index 74fc82e3bda0..e587c5591d38 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -31,7 +31,7 @@ This representation is high level, and can be helpful to perform generic optimiz such as memory reuse, layout transformation and automatic differentiation. TVM adopts a low level representation, that explicitly express the choice of memory -layout, parallelization pattern, locality and hardware primtives etc. +layout, parallelization pattern, locality and hardware primitives etc. This level of IR is closer to directly target hardwares. The low level IR adopt ideas from existing image processing languages like Halide, darkroom and loop transformation tools like loopy and polyhedra based analysis. diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index 852a23f7d790..f70a18b60528 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -86,7 +86,7 @@ The configuration of TVM can be modified by `config.cmake`. - TVM optionally depends on LLVM. LLVM is required for CPU codegen that needs LLVM. - - LLVM 4.0 or higher is needed for build with LLVM. Note that verison of LLVM from default apt may lower than 4.0. + - LLVM 4.0 or higher is needed for build with LLVM. Note that version of LLVM from default apt may lower than 4.0. - Since LLVM takes long time to build from source, you can download pre-built version of LLVM from `LLVM Download Page `_. diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index 4a3707fdaee8..99f44cede1f0 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -130,7 +130,7 @@ Users can access containers by either constants or constants loops annotated. Variables ~~~~~~~~~ -All the mutatable variables will be lowered to an array with size 1. +All the mutable variables will be lowered to an array with size 1. It regards the first store of a variable as its declaration. .. note:: @@ -158,7 +158,7 @@ Attributes ~~~~~~~~~~ So far, ONLY tensors' ``shape`` and ``dtype`` attribute are supported! -The ``shape`` atrribute is essentailly a tuple, so you MUST access it as an array. +The ``shape`` attribute is essentially a tuple, so you MUST access it as an array. Currently, only constant-indexed access is supported. .. code-block:: python diff --git a/docs/langref/relay_expr.rst b/docs/langref/relay_expr.rst index a9f1cf8f18ba..4a999a95f851 100644 --- a/docs/langref/relay_expr.rst +++ b/docs/langref/relay_expr.rst @@ -26,7 +26,7 @@ Dataflow and Control Fragments ============================== For the purposes of comparing Relay to traditional computational graph-based IRs, it -can be useful to consider Relay exrpessions in terms of dataflow and control fragments. +can be useful to consider Relay expressions in terms of dataflow and control fragments. Each portion of a Relay program containing expressions that only affect the dataflow can be viewed as a traditional computation graph when writing and expressing transformations. @@ -88,7 +88,7 @@ expression where it is bound, respectively. In the below code segment, notice that :code:`%a` is defined twice. This is permitted, as in most functional languages; in the scope of the second :code:`let` expression, the name :code:`%a` is "shadowed," meaning all -references to :code:`%a` in the inner scope refer to the later defintion, while +references to :code:`%a` in the inner scope refer to the later definition, while references to :code:`%a` in the outer scope continue to refer to the first one. diff --git a/docs/langref/relay_type.rst b/docs/langref/relay_type.rst index d54ea086e16c..ce00dff755c9 100644 --- a/docs/langref/relay_type.rst +++ b/docs/langref/relay_type.rst @@ -290,7 +290,7 @@ parameters must be treated as different types) and be recursive (a constructor for an ADT can take an instance of that ADT, thus an ADT like a tree or list can be inductively built up). The representation of ADTs in the type system must -be able to accomodate these facts, as the below sections will detail. +be able to accommodate these facts, as the below sections will detail. Global Type Variable ~~~~~~~~~~~~~~~~~~~~ @@ -316,7 +316,7 @@ Definitions (Type Data) ~~~~~~~~~~~~~~~~~~~~~~~ Besides a name, an ADT needs to store the constructors that are used -to define it and any type paramters used within them. These are +to define it and any type parameters used within them. These are stored in the module, :ref:`analogous to global function definitions`. While type-checking uses of ADTs, the type system sometimes must diff --git a/docs/vta/dev/config.rst b/docs/vta/dev/config.rst index f4b5bcec8af1..d8808e912b9f 100644 --- a/docs/vta/dev/config.rst +++ b/docs/vta/dev/config.rst @@ -68,7 +68,7 @@ below. We provide additional detail below regarding each parameter: - ``TARGET``: Can be set to ``"pynq"``, ``"ultra96"``, ``"sim"`` (fast simulator), or ``"tsim"`` (cycle accurate sim with verilator). - - ``HW_VER``: Hardware version which increments everytime the VTA hardware design changes. This parameter is used to uniquely idenfity hardware bitstreams. + - ``HW_VER``: Hardware version which increments every time the VTA hardware design changes. This parameter is used to uniquely identity hardware bitstreams. - ``LOG_BATCH``: Equivalent to A in multiplication of shape (A, B) x (B, C), or typically, the batch dimension of inner tensor computation. - - ``LOG_BLOCK``: Equivalent to B and C in multiplication of shape (A, B) x (B, C), or typically, the input/output channel dimensions of the innter tensor computation. + - ``LOG_BLOCK``: Equivalent to B and C in multiplication of shape (A, B) x (B, C), or typically, the input/output channel dimensions of the inner tensor computation. diff --git a/docs/vta/install.md b/docs/vta/install.md index 2f99d1f3fe48..c43a167292b4 100644 --- a/docs/vta/install.md +++ b/docs/vta/install.md @@ -202,7 +202,7 @@ Before powering up the device, we need to flash the microSD card image with late #### Flash SD Card and Boot Angstrom Linux To flash SD card and boot Linux on DE10-Nano, it is recommended to navigate to the [Resource](https://www.terasic.com.tw/cgi-bin/page/archive.pl?Language=English&CategoryNo=167&No=1046&PartNo=4) tab of the DE10-Nano product page from Terasic Inc. -After registeration and login on the webpage, the prebuild Angstrom Linux image would be available for downloading and flashing. +After registration and login on the webpage, the prebuilt Angstrom Linux image would be available for downloading and flashing. Specifically, to flash the downloaded Linux SD card image into your physical SD card: First, extract the gzipped archive file. From 9151d435f941d888236589b988fa47b51ab6495e Mon Sep 17 00:00:00 2001 From: Alex Gladkov <53275205+alexgl-github@users.noreply.github.com> Date: Fri, 27 Sep 2019 17:27:48 -0700 Subject: [PATCH 02/17] Additional MXNet Convolution and Deconvolution tests (#4026) Add different batch sizes and channel numbers to MXNet Convolution and Deconvolution tests. --- tests/python/frontend/mxnet/test_forward.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 34e3fdd4e760..453058556880 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -833,7 +833,7 @@ def verify(data_shape, out_shape, begin, end): def test_forward_convolution(): def verify(data_shape, kernel_size, stride, pad, num_filter): - weight_shape=(num_filter,1,) + kernel_size + weight_shape=(num_filter, data_shape[1],) + kernel_size x = np.random.uniform(size=data_shape).astype("float32") weight = np.random.uniform(size=weight_shape).astype("float32") bias = np.random.uniform(size=num_filter).astype("float32") @@ -852,11 +852,17 @@ def verify(data_shape, kernel_size, stride, pad, num_filter): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3) verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) + verify(data_shape=(20,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) + verify(data_shape=(1,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) + verify(data_shape=(20,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) verify(data_shape=(1, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) + verify(data_shape=(20, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) + verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) + verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) def test_forward_deconvolution(): def verify(data_shape, kernel_size, stride, pad, num_filter): - weight_shape=(1, num_filter) + kernel_size + weight_shape=(data_shape[1], num_filter) + kernel_size x = np.random.uniform(size=data_shape).astype("float32") weight = np.random.uniform(size=weight_shape).astype("float32") bias = np.random.uniform(size=num_filter).astype("float32") @@ -875,7 +881,13 @@ def verify(data_shape, kernel_size, stride, pad, num_filter): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3) verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) + verify(data_shape=(20,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) + verify(data_shape=(1,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) + verify(data_shape=(20,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) verify(data_shape=(1, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) + verify(data_shape=(20, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) + verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) + verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) if __name__ == '__main__': From 4f712c797e9522c64c20e475a75063421daa7ca5 Mon Sep 17 00:00:00 2001 From: Ina Dobreva <55383260+inadob@users.noreply.github.com> Date: Sat, 28 Sep 2019 01:30:11 +0100 Subject: [PATCH 03/17] Add parser support for ReLU tflite operator (#4022) --- python/tvm/relay/frontend/tflite.py | 18 ++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 16 ++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index eba19a47298b..01f6c670de08 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -84,6 +84,7 @@ def __init__(self, model, subgraph, exp_tab): 'PACK': self.convert_pack, 'LOGISTIC': self.convert_logistic, 'TANH':self.convert_tanh, + 'RELU':self.convert_relu, 'SPLIT': self.convert_split, 'TRANSPOSE': self.convert_transpose, 'TILE': self.convert_tile, @@ -345,6 +346,23 @@ def convert_tanh(self, op): return out + def convert_relu(self, op): + """Convert TFLite ReLU""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + out = _op.nn.relu(in_expr) + + return out + def convert_concatenation(self, op): """Convert TFLite concatenation""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 5d97ce8f8ffb..06afa59e0a82 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -836,6 +836,21 @@ def test_forward_tanh(): """ TANH """ _test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6))) +####################################################################### +# ReLu +# -------- + +def _test_relu(data): + """ One iteration of ReLU """ + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = nn_ops.relu(in_data) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + +def test_forward_relu(): + """ ReLU """ + _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + ####################################################################### # Fully Connected # ------- @@ -999,6 +1014,7 @@ def test_forward_ssd_mobilenet_v1(): test_forward_pooling() test_forward_softmax() test_forward_tanh() + test_forward_relu() test_forward_fully_connected() # Elemwise From bbf82e0e19e09073b64e7dd86d966bb248ddda00 Mon Sep 17 00:00:00 2001 From: bindog Date: Sun, 29 Sep 2019 01:22:01 +0800 Subject: [PATCH 04/17] [Fix] Add more pad_mode support for onnx converter (#4029) * [Fix] Add more pad_mode support for onnx converter * robustness fix --- python/tvm/relay/frontend/onnx.py | 22 ++++++++---- tests/python/frontend/onnx/test_forward.py | 39 ++++++++++++++-------- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index e5309367c08b..387e9cfe3ce4 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -326,15 +326,20 @@ def _impl_v1(cls, inputs, attr, params): for i in range(dims): pad_width.append((pads[i], pads[i+dims])) attr['pad_width'] = pad_width + pad_mode = attr.get('mode', 'constant').decode('utf-8') + if pad_mode in ['constant', 'edge', 'reflect']: + attr['pad_mode'] = pad_mode + attr.pop('mode', None) + else: + raise tvm.error.OpAttributeInvalid( + 'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.') return AttrCvt( _op.nn.pad, transforms={ 'value': 'pad_value', }, - ignores=['mode'], - custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant', - 'split mode != constant'))(inputs, attr, params) + )(inputs, attr, params) @classmethod def _impl_v2(cls, inputs, attr, params): @@ -344,15 +349,20 @@ def _impl_v2(cls, inputs, attr, params): for i in range(dims): pad_width.append((pads[i], pads[i+dims])) attr['pad_width'] = pad_width + pad_mode = attr.get('mode', 'constant').decode('utf-8') + if pad_mode in ['constant', 'edge', 'reflect']: + attr['pad_mode'] = pad_mode + attr.pop('mode', None) + else: + raise tvm.error.OpAttributeInvalid( + 'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.') return AttrCvt( 'pad', transforms={ 'value': 'pad_value', }, - ignores=['mode'], - custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant', - 'split mode != constant'))(inputs, attr, params) + )(inputs, attr, params) class ParametricSoftPlus(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 275ead5c8041..dc9493f32237 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -781,21 +781,31 @@ def test_constantfill(): verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6)) -def verify_pad(indata, pads, value=0.0): +def verify_pad(indata, pads, mode='constant', value=0.0): indata = np.array(indata).astype(np.float32) # numpy expect result len_dim = len(pads) // 2 np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)] - outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value) # onnx graph - node = helper.make_node( - 'Pad', - inputs=['input'], - outputs=['output'], - mode='constant', - pads=pads, - value=value - ) + if mode in ['edge', 'reflect']: + outdata = np.pad(indata, pad_width=np_pads, mode=mode) + node = helper.make_node( + 'Pad', + inputs=['input'], + outputs=['output'], + mode=mode, + pads=pads, + ) + else: + outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value) + node = helper.make_node( + 'Pad', + inputs=['input'], + outputs=['output'], + mode='constant', + pads=pads, + value=value + ) graph = helper.make_graph([node], 'pad_test', inputs = [helper.make_tensor_value_info("input", @@ -809,9 +819,11 @@ def verify_pad(indata, pads, value=0.0): tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) def test_pad(): - verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 0.0) - verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 0.0) - verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 5.0) + verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 'constant', 0.0) + verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 'constant', 0.0) + verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 'constant', 5.0) + verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge') + verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect') def verify_reduce_x(name, indata, axis, keepdims): indata = np.array(indata).astype(np.float32) @@ -1266,7 +1278,6 @@ def test_erf(): test_forward_arg_min_max() test_softmax() test_constantfill() - test_pad() test_reduce_max() test_reduce_min() test_reduce_sum() From f98035b093112ce5dfdde518c86b1511830f7172 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 28 Sep 2019 14:43:44 -0700 Subject: [PATCH 05/17] [ARITH] cleanup the indexmod/div on python side (#4028) --- python/tvm/autotvm/task/task.py | 4 ++- python/tvm/expr.py | 20 +++++------ src/pass/rewrite_unsafe_select.cc | 6 ++-- tests/python/relay/test_op_level3.py | 14 ++++---- tests/python/relay/test_op_level5.py | 3 +- .../unittest/test_autotvm_flop_calculator.py | 6 ++-- tests/python/unittest/test_codegen_cuda.py | 5 +-- tests/python/unittest/test_ir_builder.py | 5 +-- tests/python/unittest/test_lang_buffer.py | 28 +++++++-------- .../test_pass_rewrite_unsafe_select.py | 2 +- .../unittest/test_schedule_tensorize.py | 9 ++--- topi/python/topi/arm_cpu/bitserial_conv2d.py | 14 ++++++-- topi/python/topi/arm_cpu/conv2d.py | 35 ++++++++++++------- topi/python/topi/cuda/nms.py | 15 +++++--- topi/python/topi/cuda/rcnn/proposal.py | 13 ++++--- topi/python/topi/cuda/sort.py | 21 ++++++----- topi/python/topi/cuda/ssd/multibox.py | 18 ++++++---- topi/python/topi/nn/bitserial_conv2d.py | 18 +++++----- topi/python/topi/nn/sparse.py | 5 ++- topi/python/topi/util.py | 9 +++-- topi/python/topi/vision/ssd/multibox.py | 4 +-- topi/python/topi/x86/conv2d_avx_1x1.py | 15 ++++++-- tutorials/optimize/opt_gemm.py | 2 +- vta/python/vta/ir_pass.py | 33 +++++++++-------- 24 files changed, 188 insertions(+), 116 deletions(-) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 901183f46948..e0db27574898 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -350,7 +350,9 @@ def _count_flop(exp): return _count_flop(exp.value) if isinstance(exp, expr.Var): return 0 - if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod, + if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, + expr.Div, expr.Mod, + expr.FloorDiv, expr.FloorMod, expr.Max, expr.Min, expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE, expr.And, expr.Or, expr.Not)): diff --git a/python/tvm/expr.py b/python/tvm/expr.py index a8bd651d6469..5b7c60d819bd 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -72,23 +72,23 @@ def __rmul__(self, other): return _generic.multiply(other, self) def __div__(self, other): - # if _dtype_is_int(self) and _dtype_is_int(other): - # raise div_ambiguity_error() + if _dtype_is_int(self) and _dtype_is_int(other): + raise div_ambiguity_error() return _generic.divide(self, other) def __rdiv__(self, other): - # if _dtype_is_int(self) and _dtype_is_int(other): - # raise div_ambiguity_error() + if _dtype_is_int(self) and _dtype_is_int(other): + raise div_ambiguity_error() return _generic.divide(other, self) def __truediv__(self, other): - # if _dtype_is_int(self) and _dtype_is_int(other): - # raise div_ambiguity_error() + if _dtype_is_int(self) and _dtype_is_int(other): + raise div_ambiguity_error() return _generic.divide(self, other) def __rtruediv__(self, other): - # if _dtype_is_int(self) and _dtype_is_int(other): - # raise div_ambiguity_error() + if _dtype_is_int(self) and _dtype_is_int(other): + raise div_ambiguity_error() return _generic.divide(other, self) def __floordiv__(self, other): @@ -100,8 +100,8 @@ def __rfloordiv__(self, other): return _generic.divide(other, self) def __mod__(self, other): - # raise div_ambiguity_error() - return _make._OpMod(self, other) + raise div_ambiguity_error() + # return _make._OpMod(self, other) def __neg__(self): neg_one = _api_internal._const(-1, self.dtype) diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc index 871efcae615d..62db0b414be1 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/pass/rewrite_unsafe_select.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -64,6 +64,8 @@ class UnsafeExprDetector : public ExprFunctor { bool VisitExpr_(const Mul* op) final { return BinaryOp(op); } bool VisitExpr_(const Div* op) final { return BinaryOp(op); } bool VisitExpr_(const Mod* op) final { return BinaryOp(op); } + bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); } + bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); } bool VisitExpr_(const Min* op) final { return BinaryOp(op); } bool VisitExpr_(const Max* op) final { return BinaryOp(op); } bool VisitExpr_(const EQ* op) final { return BinaryOp(op); } diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 424462fbe0c4..2d92489328af 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -373,6 +373,8 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None): yy = run_infer_type(y.astuple()) assert yy.checked_type == ret_type + idxd = tvm.indexdiv + d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") axis = tvm.var("axis") verify_split((5, 5, 2, 2), 5, @@ -393,15 +395,15 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None): axis=0) verify_split((d1, d2, d3, d4), 4, relay.ty.TupleType(tvm.convert([ - relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), - relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), - relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), - relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])), + relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"), + relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"), + relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"), + relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32")])), axis=2) verify_split((d1, d2, d3, d4), 2, relay.ty.TupleType(tvm.convert([ - relay.ty.TensorType((d1/2, d2, d3, d4), "float32"), - relay.ty.TensorType((d1/2, d2, d3, d4), "float32")])), + relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"), + relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32")])), axis=0) verify_split((d1, d2, d3, d4), (2, 4, 7), relay.ty.TupleType(tvm.convert([ diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index f4ac673cf378..8c107351c81a 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -487,8 +487,9 @@ def verify_yolo_reorg(shape, stride, out_shape): assert zz.checked_type == relay.ty.TensorType(out_shape, "float32") n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + idxd = tvm.indexdiv verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2)) - verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2)) + verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, idxd(h, 2), idxd(w, 2))) def test_yolo_reorg(): def verify_yolo_reorg(shape, stride): diff --git a/tests/python/unittest/test_autotvm_flop_calculator.py b/tests/python/unittest/test_autotvm_flop_calculator.py index 54ade9a05267..5cafd02c45bf 100644 --- a/tests/python/unittest/test_autotvm_flop_calculator.py +++ b/tests/python/unittest/test_autotvm_flop_calculator.py @@ -60,14 +60,14 @@ def test_pack_gemm(): k = tvm.reduce_axis((0, L)) bn = 4 - fld = tvm.floordiv - flm = tvm.floormod + idxd = tvm.indexdiv + idxm = tvm.indexmod A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j]) B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j]) C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj: tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k])) - C = tvm.compute((N, M), lambda i, j: C_pack[fld(i, bn)][fld(j, bn)][flm(i, bn)][flm(j, bn)]) + C = tvm.compute((N, M), lambda i, j: C_pack[idxd(i, bn)][idxd(j, bn)][idxm(i, bn)][idxm(j, bn)]) s = tvm.create_schedule([C.op]) assert compute_flop(s) == 2 * N * L * M diff --git a/tests/python/unittest/test_codegen_cuda.py b/tests/python/unittest/test_codegen_cuda.py index 63aaf2146ca8..aa3a5374ce48 100644 --- a/tests/python/unittest/test_codegen_cuda.py +++ b/tests/python/unittest/test_codegen_cuda.py @@ -37,7 +37,7 @@ def check_cuda(dtype, n, lanes): print("skip because gpu does not support int8") return A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes)) - B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B') + B = tvm.compute((n,), lambda i: A[i] + tvm.const(1, A.dtype), name='B') s = tvm.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], factor=num_thread) s[B].bind(xo, bx) @@ -165,9 +165,10 @@ def test_cuda_shuffle(): print("skip because cuda is not enabled..") return + idxm = tvm.indexmod a = tvm.placeholder((64, ), 'int32') b = tvm.placeholder((64, ), 'int32') - c = tvm.compute((64, ), lambda x: a[x] + b[x - (x % 4) + (3 - x % 4)]) + c = tvm.compute((64, ), lambda x: a[x] + b[x - idxm(x, 4) + (3 - idxm(x, 4))]) sch = tvm.create_schedule(c.op) x = c.op.axis[0] xo, xi = sch[c].split(x, 4) diff --git a/tests/python/unittest/test_ir_builder.py b/tests/python/unittest/test_ir_builder.py index ef58174d4474..c910c62424f0 100644 --- a/tests/python/unittest/test_ir_builder.py +++ b/tests/python/unittest/test_ir_builder.py @@ -109,14 +109,15 @@ def test_gpu(): dtype = "float32" A = tvm.placeholder((n,), name='A') B = tvm.placeholder((n,), name='B') - fld = tvm.floordiv + idxd = tvm.indexdiv + def test_device_ir(A, B, C): n = A.shape[0] max_threads = 32 ib = tvm.ir_builder.create() bx = tvm.thread_axis("blockIdx.x") tx = tvm.thread_axis("threadIdx.x") - ib.scope_attr(bx, "thread_extent", fld(n+max_threads-1, max_threads)) + ib.scope_attr(bx, "thread_extent", idxd(n+max_threads-1, max_threads)) ib.scope_attr(tx, "thread_extent", max_threads) idx = bx.var * max_threads + tx.var Aptr = ib.buffer_ptr(A) diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py index 9ad8b62821cf..32c17452269e 100644 --- a/tests/python/unittest/test_lang_buffer.py +++ b/tests/python/unittest/test_lang_buffer.py @@ -94,31 +94,31 @@ def test_buffer_index_merge_mult_mod(): def assert_simplified_equal(index_simplified, index_direct): assert tvm.ir_pass.Equal(index_simplified, index_direct),\ "index_simplified=%s, index_direct=%s" %(index_simplified, index_direct) - idxdiv = tvm.indexdiv - idxmod = tvm.indexmod + idxd = tvm.indexdiv + idxm = tvm.indexmod # Test Case1 index_simplified = A_stride.vload( - (idxdiv(idxmod(k0, k1), s), idxmod(idxmod(k0, k1), s) + idxdiv(k0, k1) * k1)) + (idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1)) index_direct = A_stride.vload((0, k0)) assert_simplified_equal(index_simplified, index_direct) # Test Case2 - index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n), - idxmod(idxmod(k0, idxdiv(k1, s)), n) + idxmod(k0, k1))) - index_direct = A.vload((0, idxmod(k0, k1) + idxmod(k0, idxdiv(k1, s)))) + index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n), + idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1))) + index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s)))) assert_simplified_equal(index_simplified, index_direct) # Test Case3 - index_simplified = A.vload((idxdiv((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) + - idxdiv(idxmod(k0, idxdiv(k1, s)), n), - idxmod((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) + - idxmod(idxmod(k0, idxdiv(k1, s)), n))) + index_simplified = A.vload((idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) + + idxd(idxm(k0, idxd(k1, s)), n), + idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) + + idxm(idxm(k0, idxd(k1, s)), n))) index_direct = A.vload((0, k0)) assert_simplified_equal(index_simplified, index_direct) # Test Case4 (not able to simplify) - index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n), - idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1))) - index_direct = A.vload((0, idxdiv(idxmod(k0, idxdiv(k1, s)), n) * n + - (idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1)))) + index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n), + idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))) + index_direct = A.vload((0, idxd(idxm(k0, idxd(k1, s)), n) * n + + (idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)))) assert_simplified_equal(index_simplified, index_direct) diff --git a/tests/python/unittest/test_pass_rewrite_unsafe_select.py b/tests/python/unittest/test_pass_rewrite_unsafe_select.py index b2d73ec00ce8..4c42899be62a 100644 --- a/tests/python/unittest/test_pass_rewrite_unsafe_select.py +++ b/tests/python/unittest/test_pass_rewrite_unsafe_select.py @@ -28,7 +28,7 @@ def test_rewrite_Select(): tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1) zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value - a = tvm.expr.Select(i>10, y, z) + a = tvm.expr.Select(tvm.floordiv(i, 4) > 10, y, z) aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value assert yy.name == "tvm_if_then_else" assert zz.name == "tvm_if_then_else" diff --git a/tests/python/unittest/test_schedule_tensorize.py b/tests/python/unittest/test_schedule_tensorize.py index 4bad959c2453..59adf0cc7e99 100644 --- a/tests/python/unittest/test_schedule_tensorize.py +++ b/tests/python/unittest/test_schedule_tensorize.py @@ -221,14 +221,15 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): # This tests whether algorithm and intrinsics expressions are simplified # as much as possible first and then checked for equality. See Issue #696 def test_tensorize_op(): - tdiv = tvm.truncdiv - tmod = tvm.truncmod + idxd = tvm.indexdiv + idxm = tvm.indexmod + def op_intrin(): bh = 9 bw = 9 x = tvm.placeholder((5, 5), name='A') y = tvm.compute((bh, bw), - lambda i, j: x[tdiv(j,3) + tmod(i,3), tmod(j,3)+ tdiv(i,3)]) + lambda i, j: x[idxd(j,3) + idxm(i,3), idxm(j,3)+ idxd(i,3)]) def intrin_func(ins, outs): xx, = ins @@ -239,7 +240,7 @@ def intrin_func(ins, outs): return tvm.decl_tensor_intrin(y.op, intrin_func) A = tvm.placeholder((5, 5), name='A') - B = tvm.compute((9,9), lambda i, j: A[tdiv(j,3) + tmod(i,3), tmod(j,3) + tdiv(i,3)]) + B = tvm.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)]) bt = op_intrin() s = tvm.create_schedule(B.op) diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index 072c187ee294..9b8360dd1427 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -70,6 +70,9 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh OW = (PAD_W - KW) // WSTR + 1 oshape = (1, OH, OW, CO) + idxd = tvm.indexdiv + idxm = tvm.indexmod + # Pad input channels of weights and data when it is not a multiple of 8 if CI_packed % 8 != 0: CI_PAD = CI_packed % 8 @@ -106,7 +109,8 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8') kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC, len(kernel.shape) == 4) - if kernel_vec.shape[-1] % 8 != 0 and CI_PAD != 0: + idxm = tvm.indexmod + if idxm(kernel_vec.shape[-1], 8) != 0 and CI_PAD != 0: kernel_vec = pad(kernel_vec, [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, CI_PAD]) N, H, W, IB, CI = data_q.shape @@ -147,8 +151,12 @@ def _unipolar_conv(n, h, w, co, vh, vw, vc): else: conv_vec = tvm.compute(ovshape, _bipolar_conv, name='conv_vec', tag='bipolar') - conv = tvm.compute(oshape, lambda n, h, w, co: - conv_vec[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype), + + conv = tvm.compute(oshape, + lambda n, h, w, co: + conv_vec[n, + idxd(h, VH), idxd(w, VW), idxd(co, VC), + idxm(h, VH), idxm(w, VW), idxm(co, VC)].astype(out_dtype), name='conv', tag='spatial_bitserial_conv_nhwc') return conv diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 73a97d2bb33c..f5cbbf0f7bad 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -171,6 +171,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1 data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") + idxd = tvm.indexdiv + idxm = tvm.indexmod + r = KW m = tile_size alpha = m + r - 1 @@ -190,10 +193,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt VK = cfg['tile_k'].size[-1] # pack input tile - input_tile = tvm.compute((C, P // VP, alpha, alpha, VP), + input_tile = tvm.compute((C, idxd(P, VP), alpha, alpha, VP), lambda c, b, eps, nu, bb: - data_pad[(b*VP+bb) // (nH*nW)][c][(b*VP+bb) // nW % nH * m + eps] - [(b*VP+bb) % nW * m + nu], + data_pad[idxd(b*VP + bb, nH*nW), c, + idxm(idxd(b*VP + bb, nW), nH) * m + eps, + idxm(b*VP + bb, nW) * m + nu], name='d') # transform kernel @@ -202,22 +206,22 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt else: r_kh = tvm.reduce_axis((0, KH), 'r_kh') r_kw = tvm.reduce_axis((0, KW), 'r_kw') - U = tvm.compute((alpha, alpha, K // VK, C, VK), lambda eps, nu, k, c, kk: + U = tvm.compute((alpha, alpha, idxd(K, VK), C, VK), lambda eps, nu, k, c, kk: tvm.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U') # transform image r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - V = tvm.compute((alpha, alpha, P // VP, C, VP), lambda eps, nu, b, c, bb: + V = tvm.compute((alpha, alpha, idxd(P, VP), C, VP), lambda eps, nu, b, c, bb: tvm.sum(input_tile[c][b][r_eps][r_nu][bb].astype(out_dtype) * B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), name='V') # batch gemm c = tvm.reduce_axis((0, C), name='c') M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b: - tvm.sum(U[eps][nu][k // VK][c][k % VK] * - V[eps][nu][b // VP][c][b % VP], axis=c), name='M') + tvm.sum(U[eps][nu][idxd(k, VK)][c][idxm(k, VK)] * + V[eps][nu][idxd(b, VP)][c][idxm(b, VP)], axis=c), name='M') # inverse transform r_eps = tvm.reduce_axis((0, alpha), 'r_eps') @@ -228,7 +232,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt # unpack output output = tvm.compute((N, K, H, W), lambda n, k, h, w: - Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m], + Y[k][n * nH * nW + idxd(h, m) * nW + idxd(w, m), + idxm(h, m), idxm(w, m)], name='output', tag='winograd_conv2d_output') # we have to manually assign effective GFLOP for winograd @@ -517,6 +522,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): N, CI, H, W = get_const_tuple(data.shape) CO, _, KH, KW = get_const_tuple(kernel.shape) + idxd = tvm.indexdiv + if groups == 1: # query config of this workload workload = autotvm.task.args_to_workload( @@ -535,7 +542,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): # Store the same config for the altered operator (workload) new_data = data - new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype) + new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d) dispatch_ctx.update(target, new_workload, cfg) @@ -553,7 +560,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size) weight = F.reshape(weight, - newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI)) + newshape=(KH + tile_size - 1, + KW + tile_size - 1, + idxd(CO, VC), VC, CI)) weight = F.transpose(weight, axes=[0, 1, 2, 4, 3]) copy_inputs[1] = weight @@ -561,7 +570,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): # Store the same config for the altered operator (workload) new_data = data - new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC), + new_weight = tvm.placeholder((KH + tile_size - 1, + KH + tile_size -1, + idxd(CO, VC), CI, VC), kernel.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_weight, strides, padding, dilation, @@ -612,7 +623,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): # Store the same config for the altered operator (workload) new_data = data CO, M, KH, KW = get_const_tuple(kernel.shape) - new_kernel = tvm.placeholder((CO // VC, M, KH, KW, VC), dtype=kernel.dtype) + new_kernel = tvm.placeholder((idxd(CO, VC), M, KH, KW, VC), dtype=kernel.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 6ff8a79d3630..33fc7249802b 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -243,14 +243,16 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx): ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx new_range = num_anchors // elem_per_thread + 1 + idxd = tvm.indexdiv + idxm = tvm.indexmod # Scan: Downsweep: with ib. if_scope(tid < batch_size * num_anchors): - i = tid // num_anchors # number of batches - j = tid % num_anchors # number of anchors + i = idxd(tid, num_anchors) # number of batches + j = idxm(tid, num_anchors) # number of anchors with ib.if_scope(j < elem_per_thread): idx[tid] = idx_in[tid] with ib.else_scope(): - idx[tid] = idx_in[tid] + partial[i * new_range + j // elem_per_thread - 1] + idx[tid] = idx_in[tid] + partial[i * new_range + idxd(j, elem_per_thread) - 1] return ib.get() @@ -303,9 +305,12 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx + idxd = tvm.indexdiv + idxm = tvm.indexmod + with ib.if_scope(tid < batch_size * num_anchors): - i = tid // num_anchors - j = tid % num_anchors + i = idxd(tid, num_anchors) + j = idxm(tid, num_anchors) base_idx = i * num_anchors * elem_length with ib.if_scope(flag[tid] > 0): with ib.for_range(0, elem_length) as k: diff --git a/topi/python/topi/cuda/rcnn/proposal.py b/topi/python/topi/cuda/rcnn/proposal.py index 06226d1b40b9..54f73a10c17e 100644 --- a/topi/python/topi/cuda/rcnn/proposal.py +++ b/topi/python/topi/cuda/rcnn/proposal.py @@ -79,10 +79,13 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r p_im_info = ib.buffer_ptr(im_info_buf) p_out = ib.buffer_ptr(out_buf) + idxm = tvm.indexmod + idxd = tvm.indexdiv + with ib.if_scope(tid < batch * height * width): - w = tid % width - h = (tid // width) % height - b = tid // width // height + w = idxm(tid, width) + h = idxm(idxd(tid, width), height) + b = idxd(idxd(tid, width), height) for k in range(num_anchors): out_index = tid * num_anchors + k @@ -163,6 +166,8 @@ def argsort_ir(data_buf, out_index_buf): temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") + idxm = tvm.indexmod + with ib.for_range(0, batch, for_type="unroll") as b: start = b * num_bbox for i in range(2): @@ -170,7 +175,7 @@ def argsort_ir(data_buf, out_index_buf): with ib.if_scope(bbox_id < num_bbox): index_out[start + bbox_id] = bbox_id with ib.for_range(0, num_bbox) as k: - offset = start + 2 * tid + (k % 2) + offset = start + 2 * tid + idxm(k, 2) with ib.if_scope( tvm.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])): temp_data[0] = p_data[offset] diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index c45465e31624..b02c14b47e60 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -115,6 +115,8 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): ib.emit(tvm.make.Call(None, 'tvm_storage_sync', tvm.convert(['shared']), tvm.expr.Call.Intrinsic, None, 0)) + idxd = tvm.indexdiv + idxm = tvm.indexmod with ib.for_range(0, axis_mul_before) as i: with ib.for_range(0, axis_mul_after) as j: @@ -122,13 +124,13 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): base_idx = i * shape[axis] * axis_mul_after + j # OddEvenTransposeSort with ib.for_range(0, current_sort_num) as k: - with ib.if_scope(tid < (current_sort_num + 1) // 2): - offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after + with ib.if_scope(tid < idxd(current_sort_num + 1, 2)): + offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after if is_ascend: - cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num, + cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num, values_out[offset] > values_out[offset + axis_mul_after]) else: - cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num, + cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num, values_out[offset] < values_out[offset + axis_mul_after]) with ib.if_scope(cond): temp_data[0] = values_out[offset] @@ -199,6 +201,9 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend) + idxd = tvm.indexdiv + idxm = tvm.indexmod + with ib.for_range(0, axis_mul_before) as i: with ib.for_range(0, axis_mul_after) as j: current_sort_num = valid_count[i * axis_mul_after + j] @@ -207,10 +212,10 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): output[base_idx + tid * axis_mul_after] = tid # OddEvenTransposeSort with ib.for_range(0, current_sort_num) as k: - with ib.if_scope(tid < (current_sort_num + 1) // 2): - offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after + with ib.if_scope(tid < idxd(current_sort_num + 1, 2)): + offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after with ib.if_scope(tvm.all(is_ascend == 1, \ - 2 * tid + (k % 2) + 1 < current_sort_num, \ + 2 * tid + idxm(k, 2) + 1 < current_sort_num, \ data[offset] > data[offset + axis_mul_after])): temp_data[0] = data[offset] data[offset] = data[offset + axis_mul_after] @@ -219,7 +224,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): output[offset] = output[offset + axis_mul_after] output[offset + axis_mul_after] = temp_index[0] with ib.if_scope(tvm.all(is_ascend == 0, \ - 2 * tid + (k % 2) + 1 < current_sort_num, \ + 2 * tid + idxm(k, 2) + 1 < current_sort_num, \ data[offset] < data[offset + axis_mul_after])): temp_data[0] = data[offset] data[offset] = data[offset + axis_mul_after] diff --git a/topi/python/topi/cuda/ssd/multibox.py b/topi/python/topi/cuda/ssd/multibox.py index 03fa9995fa7d..e1af4365520e 100644 --- a/topi/python/topi/cuda/ssd/multibox.py +++ b/topi/python/topi/cuda/ssd/multibox.py @@ -95,8 +95,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): for k in range(num_sizes + num_ratios - 1): w = if_then_else(k < num_sizes, - size_ratio_concat[k] * in_height / in_width / 2.0, - size_ratio_concat[0] * in_height / in_width * + float(size_ratio_concat[k]) * in_height / in_width / 2.0, + float(size_ratio_concat[0]) * in_height / in_width * math.sqrt(size_ratio_concat[k + 1]) / 2.0) h = if_then_else( k < num_sizes, size_ratio_concat[k] / 2.0, @@ -204,10 +204,12 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx + idxd = tvm.indexdiv + idxm = tvm.indexmod with ib.if_scope(tid < batch_size * num_anchors): - i = tid / num_anchors - j = tid % num_anchors + i = idxd(tid, num_anchors) + j = idxm(tid, num_anchors) valid_count[i] = 0 score[tid] = -1.0 cls_id[tid] = 0 @@ -314,9 +316,13 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx + idxd = tvm.indexdiv + idxm = tvm.indexmod + with ib.if_scope(tid < batch_size * num_anchors): - i = tid // num_anchors - j = tid % num_anchors + i = idxd(tid, num_anchors) + j = idxm(tid, num_anchors) + with ib.if_scope(cls_id[tid] > 0): with ib.if_scope(tid == 0): out_base_idx = i * num_anchors * 6 diff --git a/topi/python/topi/nn/bitserial_conv2d.py b/topi/python/topi/nn/bitserial_conv2d.py index 2faabf2bbf89..932c141450ac 100644 --- a/topi/python/topi/nn/bitserial_conv2d.py +++ b/topi/python/topi/nn/bitserial_conv2d.py @@ -313,13 +313,14 @@ def _conv(n, co, h, w, vh, vw, vc): axis=[ci, dh, dw, b1, b2]) conv = tvm.compute(ovshape, _conv, name='conv_out') - idxdiv = tvm.indexdiv - idxmod = tvm.indexmod + idxd = tvm.indexdiv + idxm = tvm.indexmod return tvm.compute( oshape, lambda n, co, h, w: - conv[n][idxdiv(co, VC)][idxdiv(h, VH)][idxdiv( - w, VW)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)], + conv[n, + idxd(co, VC), idxd(h, VH), idxd(w, VW), + idxm(h, VH), idxm(w, VW), idxm(co, VC)], name='conv_vec', tag='spatial_bitserial_conv_nchw') @autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct') @@ -419,12 +420,13 @@ def _conv(n, h, w, co, vh, vw, vc): conv = tvm.compute(ovshape, _conv, name='conv') - idxdiv = tvm.indexdiv - idxmod = tvm.indexmod + idxd = tvm.indexdiv + idxm = tvm.indexmod return tvm.compute( oshape, lambda n, h, w, co: - conv[n][idxdiv(h, VH)][idxdiv(w, VW)][idxdiv( - co, VC)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)], + conv[n, + idxd(h, VH), idxd(w, VW), idxd(co, VC), + idxm(h, VH), idxm(w, VW), idxm(co, VC)], name='output_unpack', tag='spatial_bitserial_conv_nhwc') @tvm.target.generic_func diff --git a/topi/python/topi/nn/sparse.py b/topi/python/topi/nn/sparse.py index 11116b2e6d2c..584126ea2015 100644 --- a/topi/python/topi/nn/sparse.py +++ b/topi/python/topi/nn/sparse.py @@ -94,12 +94,15 @@ def _compute_block(i, nb_j, j): x_val = data[i, bs_c * block_j + c] return tvm.sum(block_ij_val * x_val, axis=[elem_idx, c]) + idxd = tvm.indexdiv + idxm = tvm.indexmod + bsrmm_block = tvm.compute( (m, num_blocks, bs_r), _compute_block, tag="sparse_dense_bsrmm_block") return tvm.compute( (m, num_blocks * bs_r), - lambda m, n: bsrmm_block[m, n // bs_r, n % bs_r], + lambda m, n: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r)], tag="sparse_dense_bsrmm") @tvm.target.generic_func diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 6de916c8a106..1bf3a102a88f 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -232,10 +232,12 @@ def unravel_index(idx, shape): indices : tuple of int or tvm.expr.IntImm Corresponding coordinate of the 1D index """ + idxd = tvm.indexdiv + idxm = tvm.indexmod indices = [] for i in range(len(shape) - 1, -1, -1): - indices.append(idx % shape[i]) - idx = idx // shape[i] + indices.append(idxm(idx, shape[i])) + idx = idxd(idx, shape[i]) indices = indices[::-1] return indices @@ -257,12 +259,13 @@ def const_matrix(matrix, name="const_matrix"): """ row, col = matrix.shape dtype = str(matrix.dtype) + idxm = tvm.indexmod def select_array(i, j): now = tvm.const(0.0, dtype) for ii in range(row): for jj in range(col): - now = tvm.expr.Select(tvm.all(i % row == ii, j % col == jj), + now = tvm.expr.Select(tvm.all(idxm(i, row) == ii, idxm(j, col) == jj), tvm.const(matrix[ii][jj], dtype), now) return now diff --git a/topi/python/topi/vision/ssd/multibox.py b/topi/python/topi/vision/ssd/multibox.py index ca1b4a9eb268..135315b3f086 100644 --- a/topi/python/topi/vision/ssd/multibox.py +++ b/topi/python/topi/vision/ssd/multibox.py @@ -73,10 +73,10 @@ def hybrid_multibox_prior(data, sizes, ratios, steps, offsets): center_w = (j + offset_w) * steps_w for k in const_range(num_sizes + num_ratios - 1): if k < num_sizes: - w = sizes[k] * in_height / in_width / 2.0 + w = float32(sizes[k] * in_height) / in_width / 2.0 h = sizes[k] / 2.0 else: - w = sizes[0] * in_height / in_width \ + w = float32(sizes[0] * in_height) / in_width \ * sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0 h = sizes[0] / sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0 count = i * in_width * (num_sizes + num_ratios - 1) \ diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index 3d0978cc94a5..6e36e93b9806 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -309,8 +309,15 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o # packing the Filter to let memory access be consecutive for AVX512 intrinsic # Done in pre-compute stage - packw_shape = (kernel_h, kernel_w, num_filter/16, 16*(channel/4), 4) - PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e], + idxd = tvm.indexdiv + idxm = tvm.indexmod + + packw_shape = (kernel_h, kernel_w, idxd(num_filter, 16), 16 * idxd(channel, 4), 4) + PackW = tvm.compute(packw_shape, + lambda a, b, c, d, e: + Filter[a, b, + c*16 + idxm(d, 16), + idxd(d, 16) * 4 + e], name="packed_filter") rc = tvm.reduce_axis((0, in_channel), name='rc') @@ -321,7 +328,9 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o lambda nn, yy, xx, ff: tvm.sum( PaddedInput[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * - PackW[ry, rx, ff/16, (rc/4)*16+ff%16, rc%4].astype(out_dtype), axis=[ry, rx, rc]), + PackW[ry, rx, idxd(ff, 16), + idxd(rc, 4) * 16 + idxm(ff, 16), + idxm(rc, 4)].astype(out_dtype), axis=[ry, rx, rc]), name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8") return Output diff --git a/tutorials/optimize/opt_gemm.py b/tutorials/optimize/opt_gemm.py index 0fb73ec40dbb..a23589a4ab19 100644 --- a/tutorials/optimize/opt_gemm.py +++ b/tutorials/optimize/opt_gemm.py @@ -247,7 +247,7 @@ # We have to re-write the algorithm slightly. packedB = tvm.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name='packedB') C = tvm.compute((M, N), - lambda x, y: tvm.sum(A[x, k] * packedB[y / bn, k, y % bn], axis=k), + lambda x, y: tvm.sum(A[x, k] * packedB[y // bn, k, tvm.indexmod(y, bn)], axis=k), name = 'C') s = tvm.create_schedule(C.op) diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index 06a1975b0008..12ef7daac731 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -335,6 +335,9 @@ def inject_dma_intrin(stmt_in): Transformed statement """ env = get_env() + idxd = tvm.indexdiv + idxm = tvm.indexmod + def _check_compact(buf): ndim = len(buf.shape) size = tvm.const(1, buf.shape[0].dtype) @@ -369,7 +372,7 @@ def _fold_buffer_dim(buf, scope, elem_block): x_size = 1 x_stride = buf.strides[ndim - base] next_base = base - if not util.equal_const_int(x_stride % elem_block, 0): + if not util.equal_const_int(idxm(x_stride, elem_block), 0): raise RuntimeError( "scope %s need to have block=%d, shape=%s, strides=%s" % ( scope, elem_block, buf.shape, buf.strides)) @@ -394,7 +397,7 @@ def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold): raise RuntimeError("Expect buffer type to be %s instead of %s" % (dtype, buf.dtype)) shape, strides = buf.shape, buf.strides - if not util.equal_const_int(buf.elem_offset % elem_block, 0): + if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0): raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block)) if allow_fold: shape, strides = _fold_buffer_dim(buf, scope, elem_block) @@ -421,7 +424,7 @@ def raise_error(): x_size = 1 x_stride = 1 y_size = 1 - return x_size, y_size, x_stride, buf.elem_offset / elem_block + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) if not util.equal_const_int(strides[-2] - elem_block, 0): raise_error() @@ -429,15 +432,15 @@ def raise_error(): x_size = shape[-2] x_stride = shape[-2] y_size = 1 - return x_size, y_size, x_stride, buf.elem_offset / elem_block - if not util.equal_const_int(strides[-3] % elem_block, 0): + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + if not util.equal_const_int(idxm(strides[-3], elem_block), 0): raise_error() if ndim == 3: x_size = shape[-2] - x_stride = strides[-3] / elem_block + x_stride = idxd(strides[-3], elem_block) y_size = shape[-3] - return x_size, y_size, x_stride, buf.elem_offset / elem_block + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) else: if not util.equal_const_int(strides[-1], 1): @@ -451,7 +454,7 @@ def raise_error(): x_size = 1 x_stride = 1 y_size = 1 - return x_size, y_size, x_stride, buf.elem_offset / elem_block + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) if not util.equal_const_int(strides[-3], elem_block): raise_error() @@ -459,15 +462,15 @@ def raise_error(): x_size = shape[-3] x_stride = shape[-3] y_size = 1 - return x_size, y_size, x_stride, buf.elem_offset / elem_block - if not util.equal_const_int(strides[-4] % elem_block, 0): + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + if not util.equal_const_int(idxm(strides[-4], elem_block), 0): raise_error() if ndim == 4: x_size = shape[-3] - x_stride = strides[-4] / elem_block + x_stride = idxd(strides[-4], elem_block) y_size = shape[-4] - return x_size, y_size, x_stride, buf.elem_offset / elem_block + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) raise_error() @@ -765,6 +768,8 @@ def inject_alu_intrin(stmt_in): Transformed statement """ env = get_env() + idxm = tvm.indexmod + def _do_fold(stmt): def _equal(x, y): return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0) @@ -910,10 +915,10 @@ def _flatten_loop(src_coeff, dst_coeff, extents): assert len(extents) != 0 assert tvm.ir_pass.Equal( tvm.ir_pass.Simplify( - src_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0) + idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) assert tvm.ir_pass.Equal( tvm.ir_pass.Simplify( - dst_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0) + idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) assert tvm.ir_pass.Equal(src_coeff[-2], 1) assert tvm.ir_pass.Equal(dst_coeff[-2], 1) if env.BATCH > 1: From 8f18cc443c12593ca66a22269cb2c697f2624955 Mon Sep 17 00:00:00 2001 From: Neo Chien Date: Sun, 29 Sep 2019 11:20:34 +0800 Subject: [PATCH 06/17] [AUTOTVM][DOCS] Add a link to the defining network description of auto-tuning tutorial (#4023) * [AUTOTVM][DOCS] Add a link to autoTVM tutorial to direct the details of building NN with relay * [AUTOTVM][DOCS] Add a link to autoTVM tutorial to direct the details of building NN with relay --- tutorials/autotvm/tune_relay_x86.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tutorials/autotvm/tune_relay_x86.py b/tutorials/autotvm/tune_relay_x86.py index 22a31b79bd0d..93a073170388 100644 --- a/tutorials/autotvm/tune_relay_x86.py +++ b/tutorials/autotvm/tune_relay_x86.py @@ -37,11 +37,13 @@ # Define network # -------------- # First we need to define the network in relay frontend API. -# We can load some pre-defined network from :code:`relay.testing`. +# We can either load some pre-defined network from :code:`relay.testing` +# or building :any:`relay.testing.resnet` with relay. # We can also load models from MXNet, ONNX and TensorFlow. # # In this tutorial, we choose resnet-18 as tuning example. + def get_network(name, batch_size): """Get the symbol definition and random weight of a network""" input_shape = (batch_size, 3, 224, 224) @@ -73,6 +75,7 @@ def get_network(name, batch_size): return mod, params, input_shape, output_shape + # Replace "llvm" with the correct target of your CPU. # For example, for AWS EC2 c5 instance with Intel Xeon # Platinum 8000 series, the target should be "llvm -mcpu=skylake-avx512". @@ -121,6 +124,7 @@ def get_network(name, batch_size): ), } + # You can skip the implementation of this function for this tutorial. def tune_kernels(tasks, measure_option, @@ -165,6 +169,7 @@ def tune_kernels(tasks, autotvm.callback.progress_bar(n_trial, prefix=prefix), autotvm.callback.log_to_file(log_filename)]) + # Use graph tuner to achieve graph level optimal schedules # Set use_DP=False if it takes too long to finish. def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True): From 9b46ace16a4ebd30eaa81ede0e230caa78b5d51b Mon Sep 17 00:00:00 2001 From: egolearner <45122959+egolearner@users.noreply.github.com> Date: Mon, 30 Sep 2019 00:21:18 +0800 Subject: [PATCH 07/17] make tvm compilable by gcc 4.9.2 (#4032) please see https://stackoverflow.com/a/26949099 --- src/op/tensorize.cc | 2 +- src/relay/pass/well_formed.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index 476a0a5d67fc..230472f2ddee 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -522,7 +522,7 @@ TVM_REGISTER_API("test.op.MatchTensorizeBody") CHECK(stage->op.as()); *ret = MatchTensorizeBody(stage->op.as(), stage, - {}, + {{}}, as_unordered_map(out_dom), as_unordered_map(in_region), intrin, diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc index 27b31deb4f96..99f34bdd1d2a 100644 --- a/src/relay/pass/well_formed.cc +++ b/src/relay/pass/well_formed.cc @@ -43,7 +43,7 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { struct Scope { WellFormedChecker* wfc; explicit Scope(WellFormedChecker* wfc) : wfc(wfc) { - wfc->scope.push_back({}); + wfc->scope.push_back({{}}); } ~Scope() { CHECK_GE(wfc->scope.size(), 0); From 2dac17d8b694758845ddc1aa00a3289b04af4b15 Mon Sep 17 00:00:00 2001 From: Logan Weber <36520469+weberlo@users.noreply.github.com> Date: Sun, 29 Sep 2019 16:48:10 -0700 Subject: [PATCH 08/17] [Relay] Move prelude to text format (#3939) * Fix parser * Doc fix * Add module utility functions necessary for prelude * Implement prelude in text format * Remove programmatically constructed prelude defs * Fix 0-arity type conses in pretty printer and test * Make prelude loading backwards-compatible * Fix patterns * Improve some prelude defs * Fix `ImportFromStd` It needs to also follow the "add unchecked, add checked" pattern * Lint roller * Woops * Address feedback * Fix `test_list_constructor` VM test * Fix `test_adt.py` failures --- include/tvm/expr.h | 2 +- include/tvm/relay/expr_functor.h | 3 +- include/tvm/relay/module.h | 44 +- include/tvm/relay/pattern_functor.h | 3 +- python/tvm/relay/_parser.py | 121 +- python/tvm/relay/grammar/Relay.g4 | 55 +- python/tvm/relay/grammar/py3/RelayParser.py | 1038 ++++++++++-------- python/tvm/relay/grammar/py3/RelayVisitor.py | 24 +- python/tvm/relay/module.py | 22 +- python/tvm/relay/prelude.py | 533 +-------- python/tvm/relay/std/prelude.rly | 299 ++++- src/relay/ir/base.cc | 4 +- src/relay/ir/module.cc | 98 +- src/relay/ir/pretty_printer.cc | 19 +- tests/python/relay/test_ir_parser.py | 3 +- 15 files changed, 1166 insertions(+), 1102 deletions(-) diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 07cfbc7791da..201a2b485aa6 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -92,7 +92,7 @@ class Var; /*! * \brief A variable node in the IR. * - * A vraible is uniquely identified by its address. + * A variable is uniquely identified by its address. * * Each variable is only binded once in the following nodes: * - Allocate diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 7e6fb7f2a5fe..e0d940c5d1a5 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -117,7 +117,8 @@ class ExprFunctor { virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { - throw Error(std::string("Do not have a default for ") + op->type_key()); + LOG(FATAL) << "Do not have a default for " << op->type_key(); + throw; } private: diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index ee9b4873d28a..8b17020a1132 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -87,21 +87,34 @@ class ModuleNode : public RelayNode { */ TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false); + /*! + * \brief Add a function to the global environment. + * \param var The name of the global function. + * \param func The function. + * + * It does not do type inference as Add does. + */ + TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func); + /*! * \brief Add a type-level definition to the global environment. * \param var The var of the global type definition. - * \param type The type definition. + * \param type The ADT. + * \param update Controls whether you can replace a definition in the + * environment. */ - TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type); + TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type, bool update = false); /*! - * \brief Add a function to the global environment. + * \brief Add a type definition to the global environment. * \param var The name of the global function. - * \param func The function. + * \param type The ADT. + * \param update Controls whether you can replace a definition in the + * environment. * - * It does not do type inference as Add does. + * It does not do type inference as AddDef does. */ - TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func); + TVM_DLL void AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update = false); /*! * \brief Update a function in the global environment. @@ -110,6 +123,13 @@ class ModuleNode : public RelayNode { */ TVM_DLL void Update(const GlobalVar& var, const Function& func); + /*! + * \brief Update a type definition in the global environment. + * \param var The name of the global type definition to update. + * \param type The new ADT. + */ + TVM_DLL void UpdateDef(const GlobalTypeVar& var, const TypeData& type); + /*! * \brief Remove a function from the global environment. * \param var The name of the global function to update. @@ -130,6 +150,12 @@ class ModuleNode : public RelayNode { */ TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const; + /*! + * \brief Collect all global vars defined in this module. + * \returns An array of global vars + */ + tvm::Array GetGlobalVars() const; + /*! * \brief Look up a global function by its name. * \param str The unique string specifying the global variable. @@ -137,6 +163,12 @@ class ModuleNode : public RelayNode { */ TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const; + /*! + * \brief Collect all global type vars defined in this module. + * \returns An array of global type vars + */ + tvm::Array GetGlobalTypeVars() const; + /*! * \brief Look up a global function by its variable. * \param var The global var to lookup. diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h index 611b7431d414..7f1c47e03592 100644 --- a/include/tvm/relay/pattern_functor.h +++ b/include/tvm/relay/pattern_functor.h @@ -103,7 +103,8 @@ class PatternFunctor { virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; virtual R VisitPatternDefault_(const Node* op, Args...) { - throw Error(std::string("Do not have a default for ") + op->type_key()); + LOG(FATAL) << "Do not have a default for " << op->type_key(); + throw; } private: diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 8a969e95081e..325108893d06 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -231,7 +231,7 @@ def exit_type_param_scope(self) -> Scope[ty.TypeVar]: def mk_typ(self, name: str, kind: ty.Kind) -> ty.TypeVar: """Create a new TypeVar and add it to the TypeVar scope.""" typ = ty.TypeVar(name, kind) - self.type_var_scopes[0].appendleft((name, typ)) + self.type_var_scopes[0].append((name, typ)) return typ def mk_global_typ_var(self, name, kind): @@ -242,7 +242,7 @@ def mk_global_typ_var(self, name, kind): self.global_type_vars[name] = typ return typ - # TODO: rethink whether we should have type constructors mixed with type vars. + # TODO(weberlo): rethink whether we should have type constructors mixed with type vars. def mk_global_typ_cons(self, name, cons): self._check_existing_typ_expr(name, cons) self.global_type_vars[name] = cons @@ -291,11 +291,15 @@ def visitGeneralIdent(self, ctx): if name.startswith(type_prefix): return ty.scalar_type(name) # Next, look it up in the local then global type params. - type_param = lookup(self.type_var_scopes, name) - if type_param is None: - type_param = self.global_type_vars.get(name, None) - if type_param is not None: - return type_param + type_expr = lookup(self.type_var_scopes, name) + if type_expr is None: + type_expr = self.global_type_vars.get(name, None) + if type_expr is not None: + # Zero-arity constructor calls fall into the general ident case, so in that case, + # we construct a constructor call with no args. + if isinstance(type_expr, adt.Constructor) and not type_expr.inputs: + type_expr = expr.Call(type_expr, []) + return type_expr # Check if it's an operator. op_name = ".".join([name.getText() for name in ctx.CNAME()]) if op_name in FUNC_OPS: @@ -321,14 +325,12 @@ def visitGraphVar(self, ctx): def visit_list(self, ctx_list) -> List[Any]: """"Visit a list of contexts.""" - # type: RelayParser.ContextParserRuleContext assert isinstance(ctx_list, list) return [self.visit(ctx) for ctx in ctx_list] - def getTypeExpr(self, ctx) -> Optional[ty.Type]: + def getTypeExpr(self, ctx: Optional[RelayParser.TypeExprContext]) -> Optional[ty.Type]: """Return a (possibly None) Relay type.""" - # type: : Optional[RelayParser.Type_Context] if ctx is None: return None @@ -360,6 +362,10 @@ def visitOpIdent(self, ctx) -> op.Op: def visitParen(self, ctx: RelayParser.ParenContext) -> expr.Expr: return self.visit(ctx.expr()) + # pass through + def visitTypeParen(self, ctx: RelayParser.TypeParenContext) -> expr.Expr: + return self.visit(ctx.typeExpr()) + # pass through def visitBody(self, ctx: RelayParser.BodyContext) -> expr.Expr: return self.visit(ctx.expr()) @@ -466,7 +472,7 @@ def mk_func( type_params = ctx.typeParamList() if type_params is not None: - type_params = type_params.generalIdent() + type_params = type_params.typeExpr() assert type_params for ty_param in type_params: name = ty_param.getText() @@ -498,7 +504,8 @@ def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function: def visitFuncDefn(self, ctx: RelayParser.DefnContext) -> None: ident_name = ctx.globalVar().getText()[1:] ident = self.mk_global_var(ident_name) - self.module[ident] = self.mk_func(ctx) + func = self.mk_func(ctx) + self.module[ident] = func def handle_adt_header( self, @@ -512,7 +519,7 @@ def handle_adt_header( type_params = [] else: type_params = [self.mk_typ(type_ident.getText(), ty.Kind.Type) - for type_ident in type_params.generalIdent()] + for type_ident in type_params.typeExpr()] return adt_var, type_params def visitExternAdtDefn(self, ctx: RelayParser.ExternAdtDefnContext): @@ -552,8 +559,6 @@ def visitMatch(self, ctx: RelayParser.MatchContext): else: raise RuntimeError(f"unknown match type {match_type}") - # TODO: Will need some kind of type checking to know which ADT is being - # matched on. match_data = self.visit(ctx.expr()) match_clauses = ctx.matchClauseList() if match_clauses is None: @@ -562,39 +567,36 @@ def visitMatch(self, ctx: RelayParser.MatchContext): match_clauses = match_clauses.matchClause() parsed_clauses = [] for clause in match_clauses: - constructor_name = clause.constructorName().getText() - constructor = self.global_type_vars[constructor_name] self.enter_var_scope() - patternList = clause.patternList() - if patternList is None: - patterns = [] - else: - patterns = [self.visit(pattern) for pattern in patternList.pattern()] + pattern = self.visit(clause.pattern()) clause_body = self.visit(clause.expr()) self.exit_var_scope() - # TODO: Do we need to pass `None` if it's a 0-arity cons, or is an empty list fine? - parsed_clauses.append(adt.Clause( - adt.PatternConstructor( - constructor, - patterns - ), - clause_body - )) + parsed_clauses.append(adt.Clause(pattern, clause_body)) return adt.Match(match_data, parsed_clauses, complete=complete_match) - def visitPattern(self, ctx: RelayParser.PatternContext): - text = ctx.getText() - if text == "_": - return adt.PatternWildcard() - elif text.startswith("%"): - text = ctx.localVar().getText() - typ = ctx.typeExpr() - if typ is not None: - typ = self.visit(typ) - var = self.mk_var(text[1:], typ=typ) - return adt.PatternVar(var) + def visitWildcardPattern(self, ctx: RelayParser.WildcardPatternContext): + return adt.PatternWildcard() + + def visitVarPattern(self, ctx: RelayParser.VarPatternContext): + text = ctx.localVar().getText() + typ = ctx.typeExpr() + if typ is not None: + typ = self.visit(typ) + var = self.mk_var(text[1:], typ=typ) + return adt.PatternVar(var) + + def visitConstructorPattern(self, ctx: RelayParser.ConstructorPatternContext): + constructor_name = ctx.constructorName().getText() + constructor = self.global_type_vars[constructor_name] + pattern_list = ctx.patternList() + if pattern_list is None: + patterns = [] else: - raise ParseError(f"invalid pattern syntax \"{text}\"") + patterns = [self.visit(pattern) for pattern in pattern_list.pattern()] + return adt.PatternConstructor(constructor, patterns) + + def visitTuplePattern(self, ctx: RelayParser.TuplePatternContext): + return adt.PatternTuple([self.visit(pattern) for pattern in ctx.patternList().pattern()]) def visitCallNoAttr(self, ctx: RelayParser.CallNoAttrContext): return (self.visit_list(ctx.exprList().expr()), None) @@ -610,16 +612,14 @@ def call(self, func, args, attrs, type_args): return expr.Call(func, args, attrs, type_args) @spanify - def visitCall(self, ctx: RelayParser.CallContext): - # type: (RelayParser.CallContext) -> expr.Call + def visitCall(self, ctx: RelayParser.CallContext) -> expr.Call: func = self.visit(ctx.expr()) args, attrs = self.visit(ctx.callList()) res = self.call(func, args, attrs, []) return res @spanify - def visitIfElse(self, ctx: RelayParser.IfElseContext): - # type: (RelayParser.IfElseContext) -> expr.If + def visitIfElse(self, ctx: RelayParser.IfElseContext) -> expr.If: """Construct a Relay If node. Creates a new scope for each branch.""" cond = self.visit(ctx.expr()) @@ -634,8 +634,7 @@ def visitIfElse(self, ctx: RelayParser.IfElseContext): return expr.If(cond, true_branch, false_branch) @spanify - def visitGraph(self, ctx: RelayParser.GraphContext): - # type: (RelayParser.GraphContext) -> expr.Expr + def visitGraph(self, ctx: RelayParser.GraphContext) -> expr.Expr: """Visit a graph variable assignment.""" graph_nid = int(ctx.graphVar().getText()[1:]) @@ -655,28 +654,24 @@ def visitGraph(self, ctx: RelayParser.GraphContext): # Types # pylint: disable=unused-argument - def visitIncompleteType(self, ctx: RelayParser.IncompleteTypeContext): - # type (RelayParser.IncompleteTypeContext) -> None: + def visitIncompleteType(self, ctx: RelayParser.IncompleteTypeContext) -> None: return None def visitTypeCallType(self, ctx: RelayParser.TypeCallTypeContext): func = self.visit(ctx.generalIdent()) - args = [self.visit(arg) for arg in ctx.typeParamList().generalIdent()] + args = [self.visit(arg) for arg in ctx.typeParamList().typeExpr()] return ty.TypeCall(func, args) - def visitParensShape(self, ctx: RelayParser.ParensShapeContext): - # type: (RelayParser.ParensShapeContext) -> int + def visitParensShape(self, ctx: RelayParser.ParensShapeContext) -> int: return self.visit(ctx.shape()) - def visitShapeList(self, ctx: RelayParser.ShapeListContext): - # type: (RelayParser.ShapeListContext) -> List[int] + def visitShapeList(self, ctx: RelayParser.ShapeListContext) -> List[int]: return self.visit_list(ctx.shape()) def visitTensor(self, ctx: RelayParser.TensorContext): return tuple(self.visit_list(ctx.expr())) - def visitTensorType(self, ctx: RelayParser.TensorTypeContext): - # type: (RelayParser.TensorTypeContext) -> ty.TensorType + def visitTensorType(self, ctx: RelayParser.TensorTypeContext) -> ty.TensorType: """Create a simple tensor type. No generics.""" shape = self.visit(ctx.shapeList()) @@ -689,12 +684,10 @@ def visitTensorType(self, ctx: RelayParser.TensorTypeContext): return ty.TensorType(shape, dtype) - def visitTupleType(self, ctx: RelayParser.TupleTypeContext): - # type: (RelayParser.TupleTypeContext) -> ty.TupleType + def visitTupleType(self, ctx: RelayParser.TupleTypeContext) -> ty.TupleType: return ty.TupleType(self.visit_list(ctx.typeExpr())) - def visitFuncType(self, ctx: RelayParser.FuncTypeContext): - # type: (RelayParser.FuncTypeContext) -> ty.FuncType + def visitFuncType(self, ctx: RelayParser.FuncTypeContext) -> ty.FuncType: types = self.visit_list(ctx.typeExpr()) arg_types = types[:-1] @@ -702,8 +695,7 @@ def visitFuncType(self, ctx: RelayParser.FuncTypeContext): return ty.FuncType(arg_types, ret_type, [], None) -def make_parser(data): - # type: (str) -> RelayParser +def make_parser(data: str) -> RelayParser: """Construct a RelayParser a given data stream.""" input_stream = InputStream(data) lexer = RelayLexer(input_stream) @@ -738,8 +730,7 @@ def reportAttemptingFullContext(self, def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs): raise Exception("Context Sensitivity in:\n" + self.text) -def fromtext(data, source_name=None): - # type: (str, str) -> Union[expr.Expr, module.Module] +def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, module.Module]: """Parse a Relay program.""" if data == "": raise ParseError("cannot parse the empty string.") diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index b3269fb113fd..bfcd18ffc98f 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -87,33 +87,33 @@ callList expr // operators - : '(' expr ')' # paren + : '(' expr ')' # paren // function application - | expr '(' callList ')' # call - | '-' expr # neg - | expr op=('*'|'/') expr # binOp - | expr op=('+'|'-') expr # binOp - | expr op=('<'|'>'|'<='|'>=') expr # binOp - | expr op=('=='|'!=') expr # binOp + | expr '(' callList ')' # call + | '-' expr # neg + | expr op=('*'|'/') expr # binOp + | expr op=('+'|'-') expr # binOp + | expr op=('<'|'>'|'<='|'>=') expr # binOp + | expr op=('=='|'!=') expr # binOp // function definition - | func # funcExpr + | func # funcExpr // tuples and tensors - | '(' ')' # tuple - | '(' expr ',' ')' # tuple - | '(' expr (',' expr)+ ')' # tuple - | '[' (expr (',' expr)*)? ']' # tensor - | 'if' '(' expr ')' body 'else' body # ifElse - | matchType '(' expr ')' '{' matchClauseList? '}' # match - | expr '.' NAT # projection + | '(' ')' # tuple + | '(' expr ',' ')' # tuple + | '(' expr (',' expr)+ ')' # tuple + | '[' (expr (',' expr)*)? ']' # tensor + | 'if' '(' expr ')' body 'else' body # ifElse + | matchType expr '{' matchClauseList? '}' # match + | expr '.' NAT # projection // sequencing - | 'let' var '=' expr ';' expr # let + | 'let' var '=' expr ';' expr # let // sugar for let %_ = expr; expr - | expr ';;' expr # let - | graphVar '=' expr ';' expr # graph - | ident # identExpr - | scalar # scalarExpr - | meta # metaExpr - | QUOTED_STRING # stringExpr + | expr ';;' expr # let + | graphVar '=' expr ';' expr # graph + | ident # identExpr + | scalar # scalarExpr + | meta # metaExpr + | QUOTED_STRING # stringExpr ; func: 'fn' typeParamList? '(' argList ')' ('->' typeExpr)? body ; @@ -128,14 +128,16 @@ constructorName: CNAME ; adtConsDefnList: adtConsDefn (',' adtConsDefn)* ','? ; adtConsDefn: constructorName ('(' typeExpr (',' typeExpr)* ')')? ; matchClauseList: matchClause (',' matchClause)* ','? ; -matchClause: constructorName patternList? '=>' ('{' expr '}' | expr) ; +matchClause: pattern '=>' ('{' expr '}' | expr) ; // complete or incomplete match, respectively matchType : 'match' | 'match?' ; patternList: '(' pattern (',' pattern)* ')'; pattern - : '_' - | localVar (':' typeExpr)? + : '_' # wildcardPattern + | localVar (':' typeExpr)? # varPattern + | constructorName patternList? # constructorPattern + | patternList # tuplePattern ; adtCons: constructorName adtConsParamList? ; @@ -155,6 +157,7 @@ attr: CNAME '=' expr ; typeExpr : '(' ')' # tupleType + | '(' typeExpr ')' # typeParen | '(' typeExpr ',' ')' # tupleType | '(' typeExpr (',' typeExpr)+ ')' # tupleType | generalIdent typeParamList # typeCallType @@ -164,7 +167,7 @@ typeExpr | '_' # incompleteType ; -typeParamList: '[' generalIdent (',' generalIdent)* ']' ; +typeParamList: '[' typeExpr (',' typeExpr)* ']' ; shapeList : '(' ')' diff --git a/python/tvm/relay/grammar/py3/RelayParser.py b/python/tvm/relay/grammar/py3/RelayParser.py index 91fde4c04114..f24eed4be92f 100644 --- a/python/tvm/relay/grammar/py3/RelayParser.py +++ b/python/tvm/relay/grammar/py3/RelayParser.py @@ -9,7 +9,7 @@ def serializedATN(): with StringIO() as buf: buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\3\62") - buf.write("\u01fc\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7\t\7") + buf.write("\u0200\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7\t\7") buf.write("\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r\4\16") buf.write("\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23\t\23") buf.write("\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30\4\31") @@ -23,238 +23,241 @@ def serializedATN(): buf.write("\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\6\t\u0090\n\t\r\t\16\t") buf.write("\u0091\3\t\3\t\3\t\3\t\3\t\3\t\7\t\u009a\n\t\f\t\16\t") buf.write("\u009d\13\t\5\t\u009f\n\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t") - buf.write("\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\5\t\u00b0\n\t\3\t\3\t") + buf.write("\3\t\3\t\3\t\3\t\3\t\3\t\5\t\u00ae\n\t\3\t\3\t\3\t\3\t") buf.write("\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3") - buf.write("\t\3\t\3\t\3\t\5\t\u00c5\n\t\3\t\3\t\3\t\3\t\3\t\3\t\3") + buf.write("\t\3\t\5\t\u00c3\n\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3") buf.write("\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t\3\t") - buf.write("\3\t\3\t\3\t\7\t\u00de\n\t\f\t\16\t\u00e1\13\t\3\n\3\n") - buf.write("\5\n\u00e5\n\n\3\n\3\n\3\n\3\n\3\n\5\n\u00ec\n\n\3\n\3") - buf.write("\n\3\13\3\13\3\13\5\13\u00f3\n\13\3\13\3\13\3\13\3\13") - buf.write("\3\13\5\13\u00fa\n\13\3\13\3\13\3\13\3\13\3\13\3\13\5") - buf.write("\13\u0102\n\13\3\13\3\13\3\13\5\13\u0107\n\13\3\13\3\13") - buf.write("\5\13\u010b\n\13\3\13\3\13\5\13\u010f\n\13\3\f\3\f\3\r") - buf.write("\3\r\3\r\7\r\u0116\n\r\f\r\16\r\u0119\13\r\3\r\5\r\u011c") - buf.write("\n\r\3\16\3\16\3\16\3\16\3\16\7\16\u0123\n\16\f\16\16") - buf.write("\16\u0126\13\16\3\16\3\16\5\16\u012a\n\16\3\17\3\17\3") - buf.write("\17\7\17\u012f\n\17\f\17\16\17\u0132\13\17\3\17\5\17\u0135") - buf.write("\n\17\3\20\3\20\5\20\u0139\n\20\3\20\3\20\3\20\3\20\3") - buf.write("\20\3\20\5\20\u0141\n\20\3\21\3\21\3\22\3\22\3\22\3\22") - buf.write("\7\22\u0149\n\22\f\22\16\22\u014c\13\22\3\22\3\22\3\23") - buf.write("\3\23\3\23\3\23\5\23\u0154\n\23\5\23\u0156\n\23\3\24\3") - buf.write("\24\5\24\u015a\n\24\3\25\3\25\3\25\3\25\7\25\u0160\n\25") - buf.write("\f\25\16\25\u0163\13\25\3\25\3\25\3\26\3\26\5\26\u0169") - buf.write("\n\26\3\27\3\27\3\27\3\27\7\27\u016f\n\27\f\27\16\27\u0172") + buf.write("\3\t\7\t\u00dc\n\t\f\t\16\t\u00df\13\t\3\n\3\n\5\n\u00e3") + buf.write("\n\n\3\n\3\n\3\n\3\n\3\n\5\n\u00ea\n\n\3\n\3\n\3\13\3") + buf.write("\13\3\13\5\13\u00f1\n\13\3\13\3\13\3\13\3\13\3\13\5\13") + buf.write("\u00f8\n\13\3\13\3\13\3\13\3\13\3\13\3\13\5\13\u0100\n") + buf.write("\13\3\13\3\13\3\13\5\13\u0105\n\13\3\13\3\13\5\13\u0109") + buf.write("\n\13\3\13\3\13\5\13\u010d\n\13\3\f\3\f\3\r\3\r\3\r\7") + buf.write("\r\u0114\n\r\f\r\16\r\u0117\13\r\3\r\5\r\u011a\n\r\3\16") + buf.write("\3\16\3\16\3\16\3\16\7\16\u0121\n\16\f\16\16\16\u0124") + buf.write("\13\16\3\16\3\16\5\16\u0128\n\16\3\17\3\17\3\17\7\17\u012d") + buf.write("\n\17\f\17\16\17\u0130\13\17\3\17\5\17\u0133\n\17\3\20") + buf.write("\3\20\3\20\3\20\3\20\3\20\3\20\5\20\u013c\n\20\3\21\3") + buf.write("\21\3\22\3\22\3\22\3\22\7\22\u0144\n\22\f\22\16\22\u0147") + buf.write("\13\22\3\22\3\22\3\23\3\23\3\23\3\23\5\23\u014f\n\23\3") + buf.write("\23\3\23\5\23\u0153\n\23\3\23\5\23\u0156\n\23\3\24\3\24") + buf.write("\5\24\u015a\n\24\3\25\3\25\3\25\3\25\7\25\u0160\n\25\f") + buf.write("\25\16\25\u0163\13\25\3\25\3\25\3\26\3\26\5\26\u0169\n") + buf.write("\26\3\27\3\27\3\27\3\27\7\27\u016f\n\27\f\27\16\27\u0172") buf.write("\13\27\3\27\5\27\u0175\n\27\3\30\3\30\3\30\7\30\u017a") buf.write("\n\30\f\30\16\30\u017d\13\30\5\30\u017f\n\30\3\31\3\31") buf.write("\3\31\5\31\u0184\n\31\3\32\3\32\3\32\7\32\u0189\n\32\f") buf.write("\32\16\32\u018c\13\32\3\33\3\33\3\33\3\33\3\34\3\34\3") - buf.write("\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\6\34\u019d") - buf.write("\n\34\r\34\16\34\u019e\3\34\3\34\3\34\3\34\3\34\3\34\3") - buf.write("\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\5\34\u01b0") - buf.write("\n\34\3\34\3\34\3\34\3\34\7\34\u01b6\n\34\f\34\16\34\u01b9") - buf.write("\13\34\5\34\u01bb\n\34\3\34\3\34\3\34\3\34\5\34\u01c1") - buf.write("\n\34\3\35\3\35\3\35\3\35\7\35\u01c7\n\35\f\35\16\35\u01ca") - buf.write("\13\35\3\35\3\35\3\36\3\36\3\36\3\36\3\36\3\36\6\36\u01d4") - buf.write("\n\36\r\36\16\36\u01d5\3\36\3\36\3\36\5\36\u01db\n\36") - buf.write("\3\37\3\37\3\37\3\37\3\37\3\37\3\37\3\37\3 \3 \3 \3 \3") - buf.write(" \3 \5 \u01eb\n \3!\3!\3!\3!\3\"\3\"\3\"\5\"\u01f4\n\"") - buf.write("\3#\3#\3#\3#\5#\u01fa\n#\3#\2\3\20$\2\4\6\b\n\f\16\20") - buf.write("\22\24\26\30\32\34\36 \"$&(*,.\60\62\64\668:<>@BD\2\b") - buf.write("\4\2\6\6//\3\2$%\3\2&\'\3\2(+\3\2,-\3\2\32\33\2\u022d") - buf.write("\2F\3\2\2\2\4U\3\2\2\2\6]\3\2\2\2\b`\3\2\2\2\nc\3\2\2") - buf.write("\2\fn\3\2\2\2\16z\3\2\2\2\20\u00c4\3\2\2\2\22\u00e2\3") - buf.write("\2\2\2\24\u010e\3\2\2\2\26\u0110\3\2\2\2\30\u0112\3\2") - buf.write("\2\2\32\u011d\3\2\2\2\34\u012b\3\2\2\2\36\u0136\3\2\2") - buf.write("\2 \u0142\3\2\2\2\"\u0144\3\2\2\2$\u0155\3\2\2\2&\u0157") - buf.write("\3\2\2\2(\u015b\3\2\2\2*\u0168\3\2\2\2,\u0174\3\2\2\2") - buf.write(".\u017e\3\2\2\2\60\u0180\3\2\2\2\62\u0185\3\2\2\2\64\u018d") - buf.write("\3\2\2\2\66\u01c0\3\2\2\28\u01c2\3\2\2\2:\u01da\3\2\2") - buf.write("\2<\u01dc\3\2\2\2>\u01ea\3\2\2\2@\u01ec\3\2\2\2B\u01f3") - buf.write("\3\2\2\2D\u01f9\3\2\2\2FN\7\37\2\2GI\5\24\13\2HG\3\2\2") - buf.write("\2IL\3\2\2\2JH\3\2\2\2JK\3\2\2\2KO\3\2\2\2LJ\3\2\2\2M") - buf.write("O\5\20\t\2NJ\3\2\2\2NM\3\2\2\2OQ\3\2\2\2PR\7\62\2\2QP") - buf.write("\3\2\2\2QR\3\2\2\2RS\3\2\2\2ST\7\2\2\3T\3\3\2\2\2UZ\7") - buf.write("/\2\2VW\7\3\2\2WY\7/\2\2XV\3\2\2\2Y\\\3\2\2\2ZX\3\2\2") - buf.write("\2Z[\3\2\2\2[\5\3\2\2\2\\Z\3\2\2\2]^\7\4\2\2^_\7/\2\2") - buf.write("_\7\3\2\2\2`a\7\5\2\2ab\t\2\2\2b\t\3\2\2\2cd\7\5\2\2d") - buf.write("e\7\61\2\2e\13\3\2\2\2fk\5\20\t\2gh\7\7\2\2hj\5\20\t\2") - buf.write("ig\3\2\2\2jm\3\2\2\2ki\3\2\2\2kl\3\2\2\2lo\3\2\2\2mk\3") - buf.write("\2\2\2nf\3\2\2\2no\3\2\2\2o\r\3\2\2\2p{\5\f\7\2qr\5\20") - buf.write("\t\2rs\7\7\2\2su\3\2\2\2tq\3\2\2\2ux\3\2\2\2vt\3\2\2\2") - buf.write("vw\3\2\2\2wy\3\2\2\2xv\3\2\2\2y{\5\62\32\2zp\3\2\2\2z") - buf.write("v\3\2\2\2{\17\3\2\2\2|}\b\t\1\2}~\7\b\2\2~\177\5\20\t") - buf.write("\2\177\u0080\7\t\2\2\u0080\u00c5\3\2\2\2\u0081\u0082\7") - buf.write("\'\2\2\u0082\u00c5\5\20\t\26\u0083\u00c5\5\22\n\2\u0084") - buf.write("\u0085\7\b\2\2\u0085\u00c5\7\t\2\2\u0086\u0087\7\b\2\2") - buf.write("\u0087\u0088\5\20\t\2\u0088\u0089\7\7\2\2\u0089\u008a") - buf.write("\7\t\2\2\u008a\u00c5\3\2\2\2\u008b\u008c\7\b\2\2\u008c") - buf.write("\u008f\5\20\t\2\u008d\u008e\7\7\2\2\u008e\u0090\5\20\t") - buf.write("\2\u008f\u008d\3\2\2\2\u0090\u0091\3\2\2\2\u0091\u008f") - buf.write("\3\2\2\2\u0091\u0092\3\2\2\2\u0092\u0093\3\2\2\2\u0093") - buf.write("\u0094\7\t\2\2\u0094\u00c5\3\2\2\2\u0095\u009e\7\n\2\2") - buf.write("\u0096\u009b\5\20\t\2\u0097\u0098\7\7\2\2\u0098\u009a") - buf.write("\5\20\t\2\u0099\u0097\3\2\2\2\u009a\u009d\3\2\2\2\u009b") - buf.write("\u0099\3\2\2\2\u009b\u009c\3\2\2\2\u009c\u009f\3\2\2\2") - buf.write("\u009d\u009b\3\2\2\2\u009e\u0096\3\2\2\2\u009e\u009f\3") - buf.write("\2\2\2\u009f\u00a0\3\2\2\2\u00a0\u00c5\7\13\2\2\u00a1") - buf.write("\u00a2\7\f\2\2\u00a2\u00a3\7\b\2\2\u00a3\u00a4\5\20\t") - buf.write("\2\u00a4\u00a5\7\t\2\2\u00a5\u00a6\5@!\2\u00a6\u00a7\7") - buf.write("\r\2\2\u00a7\u00a8\5@!\2\u00a8\u00c5\3\2\2\2\u00a9\u00aa") - buf.write("\5 \21\2\u00aa\u00ab\7\b\2\2\u00ab\u00ac\5\20\t\2\u00ac") - buf.write("\u00ad\7\t\2\2\u00ad\u00af\7\16\2\2\u00ae\u00b0\5\34\17") - buf.write("\2\u00af\u00ae\3\2\2\2\u00af\u00b0\3\2\2\2\u00b0\u00b1") - buf.write("\3\2\2\2\u00b1\u00b2\7\17\2\2\u00b2\u00c5\3\2\2\2\u00b3") - buf.write("\u00b4\7\20\2\2\u00b4\u00b5\5\60\31\2\u00b5\u00b6\7\21") - buf.write("\2\2\u00b6\u00b7\5\20\t\2\u00b7\u00b8\7\22\2\2\u00b8\u00b9") - buf.write("\5\20\t\t\u00b9\u00c5\3\2\2\2\u00ba\u00bb\5\n\6\2\u00bb") - buf.write("\u00bc\7\21\2\2\u00bc\u00bd\5\20\t\2\u00bd\u00be\7\22") - buf.write("\2\2\u00be\u00bf\5\20\t\7\u00bf\u00c5\3\2\2\2\u00c0\u00c5") - buf.write("\5D#\2\u00c1\u00c5\5B\"\2\u00c2\u00c5\5<\37\2\u00c3\u00c5") - buf.write("\7#\2\2\u00c4|\3\2\2\2\u00c4\u0081\3\2\2\2\u00c4\u0083") - buf.write("\3\2\2\2\u00c4\u0084\3\2\2\2\u00c4\u0086\3\2\2\2\u00c4") - buf.write("\u008b\3\2\2\2\u00c4\u0095\3\2\2\2\u00c4\u00a1\3\2\2\2") - buf.write("\u00c4\u00a9\3\2\2\2\u00c4\u00b3\3\2\2\2\u00c4\u00ba\3") - buf.write("\2\2\2\u00c4\u00c0\3\2\2\2\u00c4\u00c1\3\2\2\2\u00c4\u00c2") - buf.write("\3\2\2\2\u00c4\u00c3\3\2\2\2\u00c5\u00df\3\2\2\2\u00c6") - buf.write("\u00c7\f\25\2\2\u00c7\u00c8\t\3\2\2\u00c8\u00de\5\20\t") - buf.write("\26\u00c9\u00ca\f\24\2\2\u00ca\u00cb\t\4\2\2\u00cb\u00de") - buf.write("\5\20\t\25\u00cc\u00cd\f\23\2\2\u00cd\u00ce\t\5\2\2\u00ce") - buf.write("\u00de\5\20\t\24\u00cf\u00d0\f\22\2\2\u00d0\u00d1\t\6") - buf.write("\2\2\u00d1\u00de\5\20\t\23\u00d2\u00d3\f\b\2\2\u00d3\u00d4") - buf.write("\7\23\2\2\u00d4\u00de\5\20\t\t\u00d5\u00d6\f\27\2\2\u00d6") - buf.write("\u00d7\7\b\2\2\u00d7\u00d8\5\16\b\2\u00d8\u00d9\7\t\2") - buf.write("\2\u00d9\u00de\3\2\2\2\u00da\u00db\f\n\2\2\u00db\u00dc") - buf.write("\7\3\2\2\u00dc\u00de\7\61\2\2\u00dd\u00c6\3\2\2\2\u00dd") - buf.write("\u00c9\3\2\2\2\u00dd\u00cc\3\2\2\2\u00dd\u00cf\3\2\2\2") - buf.write("\u00dd\u00d2\3\2\2\2\u00dd\u00d5\3\2\2\2\u00dd\u00da\3") - buf.write("\2\2\2\u00de\u00e1\3\2\2\2\u00df\u00dd\3\2\2\2\u00df\u00e0") - buf.write("\3\2\2\2\u00e0\21\3\2\2\2\u00e1\u00df\3\2\2\2\u00e2\u00e4") - buf.write("\7\24\2\2\u00e3\u00e5\58\35\2\u00e4\u00e3\3\2\2\2\u00e4") - buf.write("\u00e5\3\2\2\2\u00e5\u00e6\3\2\2\2\u00e6\u00e7\7\b\2\2") - buf.write("\u00e7\u00e8\5,\27\2\u00e8\u00eb\7\t\2\2\u00e9\u00ea\7") - buf.write("\25\2\2\u00ea\u00ec\5\66\34\2\u00eb\u00e9\3\2\2\2\u00eb") - buf.write("\u00ec\3\2\2\2\u00ec\u00ed\3\2\2\2\u00ed\u00ee\5@!\2\u00ee") - buf.write("\23\3\2\2\2\u00ef\u00f0\7\26\2\2\u00f0\u00f2\5\6\4\2\u00f1") - buf.write("\u00f3\58\35\2\u00f2\u00f1\3\2\2\2\u00f2\u00f3\3\2\2\2") - buf.write("\u00f3\u00f4\3\2\2\2\u00f4\u00f5\7\b\2\2\u00f5\u00f6\5") - buf.write(",\27\2\u00f6\u00f9\7\t\2\2\u00f7\u00f8\7\25\2\2\u00f8") - buf.write("\u00fa\5\66\34\2\u00f9\u00f7\3\2\2\2\u00f9\u00fa\3\2\2") - buf.write("\2\u00fa\u00fb\3\2\2\2\u00fb\u00fc\5@!\2\u00fc\u010f\3") - buf.write("\2\2\2\u00fd\u00fe\7\27\2\2\u00fe\u00ff\7\30\2\2\u00ff") - buf.write("\u0101\5\4\3\2\u0100\u0102\58\35\2\u0101\u0100\3\2\2\2") - buf.write("\u0101\u0102\3\2\2\2\u0102\u010f\3\2\2\2\u0103\u0104\7") - buf.write("\30\2\2\u0104\u0106\5\4\3\2\u0105\u0107\58\35\2\u0106") - buf.write("\u0105\3\2\2\2\u0106\u0107\3\2\2\2\u0107\u0108\3\2\2\2") - buf.write("\u0108\u010a\7\16\2\2\u0109\u010b\5\30\r\2\u010a\u0109") - buf.write("\3\2\2\2\u010a\u010b\3\2\2\2\u010b\u010c\3\2\2\2\u010c") - buf.write("\u010d\7\17\2\2\u010d\u010f\3\2\2\2\u010e\u00ef\3\2\2") - buf.write("\2\u010e\u00fd\3\2\2\2\u010e\u0103\3\2\2\2\u010f\25\3") - buf.write("\2\2\2\u0110\u0111\7/\2\2\u0111\27\3\2\2\2\u0112\u0117") - buf.write("\5\32\16\2\u0113\u0114\7\7\2\2\u0114\u0116\5\32\16\2\u0115") - buf.write("\u0113\3\2\2\2\u0116\u0119\3\2\2\2\u0117\u0115\3\2\2\2") - buf.write("\u0117\u0118\3\2\2\2\u0118\u011b\3\2\2\2\u0119\u0117\3") - buf.write("\2\2\2\u011a\u011c\7\7\2\2\u011b\u011a\3\2\2\2\u011b\u011c") - buf.write("\3\2\2\2\u011c\31\3\2\2\2\u011d\u0129\5\26\f\2\u011e\u011f") - buf.write("\7\b\2\2\u011f\u0124\5\66\34\2\u0120\u0121\7\7\2\2\u0121") - buf.write("\u0123\5\66\34\2\u0122\u0120\3\2\2\2\u0123\u0126\3\2\2") - buf.write("\2\u0124\u0122\3\2\2\2\u0124\u0125\3\2\2\2\u0125\u0127") - buf.write("\3\2\2\2\u0126\u0124\3\2\2\2\u0127\u0128\7\t\2\2\u0128") - buf.write("\u012a\3\2\2\2\u0129\u011e\3\2\2\2\u0129\u012a\3\2\2\2") - buf.write("\u012a\33\3\2\2\2\u012b\u0130\5\36\20\2\u012c\u012d\7") - buf.write("\7\2\2\u012d\u012f\5\36\20\2\u012e\u012c\3\2\2\2\u012f") - buf.write("\u0132\3\2\2\2\u0130\u012e\3\2\2\2\u0130\u0131\3\2\2\2") - buf.write("\u0131\u0134\3\2\2\2\u0132\u0130\3\2\2\2\u0133\u0135\7") - buf.write("\7\2\2\u0134\u0133\3\2\2\2\u0134\u0135\3\2\2\2\u0135\35") - buf.write("\3\2\2\2\u0136\u0138\5\26\f\2\u0137\u0139\5\"\22\2\u0138") - buf.write("\u0137\3\2\2\2\u0138\u0139\3\2\2\2\u0139\u013a\3\2\2\2") - buf.write("\u013a\u0140\7\31\2\2\u013b\u013c\7\16\2\2\u013c\u013d") - buf.write("\5\20\t\2\u013d\u013e\7\17\2\2\u013e\u0141\3\2\2\2\u013f") - buf.write("\u0141\5\20\t\2\u0140\u013b\3\2\2\2\u0140\u013f\3\2\2") - buf.write("\2\u0141\37\3\2\2\2\u0142\u0143\t\7\2\2\u0143!\3\2\2\2") - buf.write("\u0144\u0145\7\b\2\2\u0145\u014a\5$\23\2\u0146\u0147\7") - buf.write("\7\2\2\u0147\u0149\5$\23\2\u0148\u0146\3\2\2\2\u0149\u014c") - buf.write("\3\2\2\2\u014a\u0148\3\2\2\2\u014a\u014b\3\2\2\2\u014b") - buf.write("\u014d\3\2\2\2\u014c\u014a\3\2\2\2\u014d\u014e\7\t\2\2") - buf.write("\u014e#\3\2\2\2\u014f\u0156\7\6\2\2\u0150\u0153\5\b\5") - buf.write("\2\u0151\u0152\7\34\2\2\u0152\u0154\5\66\34\2\u0153\u0151") - buf.write("\3\2\2\2\u0153\u0154\3\2\2\2\u0154\u0156\3\2\2\2\u0155") - buf.write("\u014f\3\2\2\2\u0155\u0150\3\2\2\2\u0156%\3\2\2\2\u0157") - buf.write("\u0159\5\26\f\2\u0158\u015a\5(\25\2\u0159\u0158\3\2\2") - buf.write("\2\u0159\u015a\3\2\2\2\u015a\'\3\2\2\2\u015b\u015c\7\b") - buf.write("\2\2\u015c\u0161\5*\26\2\u015d\u015e\7\7\2\2\u015e\u0160") - buf.write("\5*\26\2\u015f\u015d\3\2\2\2\u0160\u0163\3\2\2\2\u0161") - buf.write("\u015f\3\2\2\2\u0161\u0162\3\2\2\2\u0162\u0164\3\2\2\2") - buf.write("\u0163\u0161\3\2\2\2\u0164\u0165\7\t\2\2\u0165)\3\2\2") - buf.write("\2\u0166\u0169\5\b\5\2\u0167\u0169\5\26\f\2\u0168\u0166") - buf.write("\3\2\2\2\u0168\u0167\3\2\2\2\u0169+\3\2\2\2\u016a\u0175") - buf.write("\5.\30\2\u016b\u016c\5\60\31\2\u016c\u016d\7\7\2\2\u016d") - buf.write("\u016f\3\2\2\2\u016e\u016b\3\2\2\2\u016f\u0172\3\2\2\2") - buf.write("\u0170\u016e\3\2\2\2\u0170\u0171\3\2\2\2\u0171\u0173\3") - buf.write("\2\2\2\u0172\u0170\3\2\2\2\u0173\u0175\5\62\32\2\u0174") - buf.write("\u016a\3\2\2\2\u0174\u0170\3\2\2\2\u0175-\3\2\2\2\u0176") - buf.write("\u017b\5\60\31\2\u0177\u0178\7\7\2\2\u0178\u017a\5\60") - buf.write("\31\2\u0179\u0177\3\2\2\2\u017a\u017d\3\2\2\2\u017b\u0179") - buf.write("\3\2\2\2\u017b\u017c\3\2\2\2\u017c\u017f\3\2\2\2\u017d") - buf.write("\u017b\3\2\2\2\u017e\u0176\3\2\2\2\u017e\u017f\3\2\2\2") - buf.write("\u017f/\3\2\2\2\u0180\u0183\5\b\5\2\u0181\u0182\7\34\2") - buf.write("\2\u0182\u0184\5\66\34\2\u0183\u0181\3\2\2\2\u0183\u0184") - buf.write("\3\2\2\2\u0184\61\3\2\2\2\u0185\u018a\5\64\33\2\u0186") - buf.write("\u0187\7\7\2\2\u0187\u0189\5\64\33\2\u0188\u0186\3\2\2") - buf.write("\2\u0189\u018c\3\2\2\2\u018a\u0188\3\2\2\2\u018a\u018b") - buf.write("\3\2\2\2\u018b\63\3\2\2\2\u018c\u018a\3\2\2\2\u018d\u018e") - buf.write("\7/\2\2\u018e\u018f\7\21\2\2\u018f\u0190\5\20\t\2\u0190") - buf.write("\65\3\2\2\2\u0191\u0192\7\b\2\2\u0192\u01c1\7\t\2\2\u0193") - buf.write("\u0194\7\b\2\2\u0194\u0195\5\66\34\2\u0195\u0196\7\7\2") - buf.write("\2\u0196\u0197\7\t\2\2\u0197\u01c1\3\2\2\2\u0198\u0199") - buf.write("\7\b\2\2\u0199\u019c\5\66\34\2\u019a\u019b\7\7\2\2\u019b") - buf.write("\u019d\5\66\34\2\u019c\u019a\3\2\2\2\u019d\u019e\3\2\2") - buf.write("\2\u019e\u019c\3\2\2\2\u019e\u019f\3\2\2\2\u019f\u01a0") - buf.write("\3\2\2\2\u01a0\u01a1\7\t\2\2\u01a1\u01c1\3\2\2\2\u01a2") - buf.write("\u01a3\5\4\3\2\u01a3\u01a4\58\35\2\u01a4\u01c1\3\2\2\2") - buf.write("\u01a5\u01c1\5\4\3\2\u01a6\u01a7\7\35\2\2\u01a7\u01a8") - buf.write("\7\n\2\2\u01a8\u01a9\5:\36\2\u01a9\u01aa\7\7\2\2\u01aa") - buf.write("\u01ab\5\66\34\2\u01ab\u01ac\7\13\2\2\u01ac\u01c1\3\2") - buf.write("\2\2\u01ad\u01af\7\24\2\2\u01ae\u01b0\58\35\2\u01af\u01ae") - buf.write("\3\2\2\2\u01af\u01b0\3\2\2\2\u01b0\u01b1\3\2\2\2\u01b1") - buf.write("\u01ba\7\b\2\2\u01b2\u01b7\5\66\34\2\u01b3\u01b4\7\7\2") - buf.write("\2\u01b4\u01b6\5\66\34\2\u01b5\u01b3\3\2\2\2\u01b6\u01b9") - buf.write("\3\2\2\2\u01b7\u01b5\3\2\2\2\u01b7\u01b8\3\2\2\2\u01b8") - buf.write("\u01bb\3\2\2\2\u01b9\u01b7\3\2\2\2\u01ba\u01b2\3\2\2\2") - buf.write("\u01ba\u01bb\3\2\2\2\u01bb\u01bc\3\2\2\2\u01bc\u01bd\7") - buf.write("\t\2\2\u01bd\u01be\7\25\2\2\u01be\u01c1\5\66\34\2\u01bf") - buf.write("\u01c1\7\6\2\2\u01c0\u0191\3\2\2\2\u01c0\u0193\3\2\2\2") - buf.write("\u01c0\u0198\3\2\2\2\u01c0\u01a2\3\2\2\2\u01c0\u01a5\3") - buf.write("\2\2\2\u01c0\u01a6\3\2\2\2\u01c0\u01ad\3\2\2\2\u01c0\u01bf") - buf.write("\3\2\2\2\u01c1\67\3\2\2\2\u01c2\u01c3\7\n\2\2\u01c3\u01c8") - buf.write("\5\4\3\2\u01c4\u01c5\7\7\2\2\u01c5\u01c7\5\4\3\2\u01c6") - buf.write("\u01c4\3\2\2\2\u01c7\u01ca\3\2\2\2\u01c8\u01c6\3\2\2\2") - buf.write("\u01c8\u01c9\3\2\2\2\u01c9\u01cb\3\2\2\2\u01ca\u01c8\3") - buf.write("\2\2\2\u01cb\u01cc\7\13\2\2\u01cc9\3\2\2\2\u01cd\u01ce") - buf.write("\7\b\2\2\u01ce\u01db\7\t\2\2\u01cf\u01d0\7\b\2\2\u01d0") - buf.write("\u01d3\5> \2\u01d1\u01d2\7\7\2\2\u01d2\u01d4\5> \2\u01d3") - buf.write("\u01d1\3\2\2\2\u01d4\u01d5\3\2\2\2\u01d5\u01d3\3\2\2\2") - buf.write("\u01d5\u01d6\3\2\2\2\u01d6\u01d7\3\2\2\2\u01d7\u01d8\7") - buf.write("\t\2\2\u01d8\u01db\3\2\2\2\u01d9\u01db\5> \2\u01da\u01cd") - buf.write("\3\2\2\2\u01da\u01cf\3\2\2\2\u01da\u01d9\3\2\2\2\u01db") - buf.write(";\3\2\2\2\u01dc\u01dd\7\36\2\2\u01dd\u01de\7\n\2\2\u01de") - buf.write("\u01df\7/\2\2\u01df\u01e0\7\13\2\2\u01e0\u01e1\7\n\2\2") - buf.write("\u01e1\u01e2\7\61\2\2\u01e2\u01e3\7\13\2\2\u01e3=\3\2") - buf.write("\2\2\u01e4\u01eb\5<\37\2\u01e5\u01e6\7\b\2\2\u01e6\u01e7") - buf.write("\5> \2\u01e7\u01e8\7\t\2\2\u01e8\u01eb\3\2\2\2\u01e9\u01eb") - buf.write("\7\61\2\2\u01ea\u01e4\3\2\2\2\u01ea\u01e5\3\2\2\2\u01ea") - buf.write("\u01e9\3\2\2\2\u01eb?\3\2\2\2\u01ec\u01ed\7\16\2\2\u01ed") - buf.write("\u01ee\5\20\t\2\u01ee\u01ef\7\17\2\2\u01efA\3\2\2\2\u01f0") - buf.write("\u01f4\7\60\2\2\u01f1\u01f4\7\61\2\2\u01f2\u01f4\7.\2") - buf.write("\2\u01f3\u01f0\3\2\2\2\u01f3\u01f1\3\2\2\2\u01f3\u01f2") - buf.write("\3\2\2\2\u01f4C\3\2\2\2\u01f5\u01fa\5\4\3\2\u01f6\u01fa") - buf.write("\5\6\4\2\u01f7\u01fa\5\b\5\2\u01f8\u01fa\5\n\6\2\u01f9") - buf.write("\u01f5\3\2\2\2\u01f9\u01f6\3\2\2\2\u01f9\u01f7\3\2\2\2") - buf.write("\u01f9\u01f8\3\2\2\2\u01faE\3\2\2\28JNQZknvz\u0091\u009b") - buf.write("\u009e\u00af\u00c4\u00dd\u00df\u00e4\u00eb\u00f2\u00f9") - buf.write("\u0101\u0106\u010a\u010e\u0117\u011b\u0124\u0129\u0130") - buf.write("\u0134\u0138\u0140\u014a\u0153\u0155\u0159\u0161\u0168") - buf.write("\u0170\u0174\u017b\u017e\u0183\u018a\u019e\u01af\u01b7") - buf.write("\u01ba\u01c0\u01c8\u01d5\u01da\u01ea\u01f3\u01f9") + buf.write("\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34") + buf.write("\3\34\3\34\6\34\u01a1\n\34\r\34\16\34\u01a2\3\34\3\34") + buf.write("\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34\3\34") + buf.write("\3\34\3\34\5\34\u01b4\n\34\3\34\3\34\3\34\3\34\7\34\u01ba") + buf.write("\n\34\f\34\16\34\u01bd\13\34\5\34\u01bf\n\34\3\34\3\34") + buf.write("\3\34\3\34\5\34\u01c5\n\34\3\35\3\35\3\35\3\35\7\35\u01cb") + buf.write("\n\35\f\35\16\35\u01ce\13\35\3\35\3\35\3\36\3\36\3\36") + buf.write("\3\36\3\36\3\36\6\36\u01d8\n\36\r\36\16\36\u01d9\3\36") + buf.write("\3\36\3\36\5\36\u01df\n\36\3\37\3\37\3\37\3\37\3\37\3") + buf.write("\37\3\37\3\37\3 \3 \3 \3 \3 \3 \5 \u01ef\n \3!\3!\3!\3") + buf.write("!\3\"\3\"\3\"\5\"\u01f8\n\"\3#\3#\3#\3#\5#\u01fe\n#\3") + buf.write("#\2\3\20$\2\4\6\b\n\f\16\20\22\24\26\30\32\34\36 \"$&") + buf.write("(*,.\60\62\64\668:<>@BD\2\b\4\2\6\6//\3\2$%\3\2&\'\3\2") + buf.write("(+\3\2,-\3\2\32\33\2\u0234\2F\3\2\2\2\4U\3\2\2\2\6]\3") + buf.write("\2\2\2\b`\3\2\2\2\nc\3\2\2\2\fn\3\2\2\2\16z\3\2\2\2\20") + buf.write("\u00c2\3\2\2\2\22\u00e0\3\2\2\2\24\u010c\3\2\2\2\26\u010e") + buf.write("\3\2\2\2\30\u0110\3\2\2\2\32\u011b\3\2\2\2\34\u0129\3") + buf.write("\2\2\2\36\u0134\3\2\2\2 \u013d\3\2\2\2\"\u013f\3\2\2\2") + buf.write("$\u0155\3\2\2\2&\u0157\3\2\2\2(\u015b\3\2\2\2*\u0168\3") + buf.write("\2\2\2,\u0174\3\2\2\2.\u017e\3\2\2\2\60\u0180\3\2\2\2") + buf.write("\62\u0185\3\2\2\2\64\u018d\3\2\2\2\66\u01c4\3\2\2\28\u01c6") + buf.write("\3\2\2\2:\u01de\3\2\2\2<\u01e0\3\2\2\2>\u01ee\3\2\2\2") + buf.write("@\u01f0\3\2\2\2B\u01f7\3\2\2\2D\u01fd\3\2\2\2FN\7\37\2") + buf.write("\2GI\5\24\13\2HG\3\2\2\2IL\3\2\2\2JH\3\2\2\2JK\3\2\2\2") + buf.write("KO\3\2\2\2LJ\3\2\2\2MO\5\20\t\2NJ\3\2\2\2NM\3\2\2\2OQ") + buf.write("\3\2\2\2PR\7\62\2\2QP\3\2\2\2QR\3\2\2\2RS\3\2\2\2ST\7") + buf.write("\2\2\3T\3\3\2\2\2UZ\7/\2\2VW\7\3\2\2WY\7/\2\2XV\3\2\2") + buf.write("\2Y\\\3\2\2\2ZX\3\2\2\2Z[\3\2\2\2[\5\3\2\2\2\\Z\3\2\2") + buf.write("\2]^\7\4\2\2^_\7/\2\2_\7\3\2\2\2`a\7\5\2\2ab\t\2\2\2b") + buf.write("\t\3\2\2\2cd\7\5\2\2de\7\61\2\2e\13\3\2\2\2fk\5\20\t\2") + buf.write("gh\7\7\2\2hj\5\20\t\2ig\3\2\2\2jm\3\2\2\2ki\3\2\2\2kl") + buf.write("\3\2\2\2lo\3\2\2\2mk\3\2\2\2nf\3\2\2\2no\3\2\2\2o\r\3") + buf.write("\2\2\2p{\5\f\7\2qr\5\20\t\2rs\7\7\2\2su\3\2\2\2tq\3\2") + buf.write("\2\2ux\3\2\2\2vt\3\2\2\2vw\3\2\2\2wy\3\2\2\2xv\3\2\2\2") + buf.write("y{\5\62\32\2zp\3\2\2\2zv\3\2\2\2{\17\3\2\2\2|}\b\t\1\2") + buf.write("}~\7\b\2\2~\177\5\20\t\2\177\u0080\7\t\2\2\u0080\u00c3") + buf.write("\3\2\2\2\u0081\u0082\7\'\2\2\u0082\u00c3\5\20\t\26\u0083") + buf.write("\u00c3\5\22\n\2\u0084\u0085\7\b\2\2\u0085\u00c3\7\t\2") + buf.write("\2\u0086\u0087\7\b\2\2\u0087\u0088\5\20\t\2\u0088\u0089") + buf.write("\7\7\2\2\u0089\u008a\7\t\2\2\u008a\u00c3\3\2\2\2\u008b") + buf.write("\u008c\7\b\2\2\u008c\u008f\5\20\t\2\u008d\u008e\7\7\2") + buf.write("\2\u008e\u0090\5\20\t\2\u008f\u008d\3\2\2\2\u0090\u0091") + buf.write("\3\2\2\2\u0091\u008f\3\2\2\2\u0091\u0092\3\2\2\2\u0092") + buf.write("\u0093\3\2\2\2\u0093\u0094\7\t\2\2\u0094\u00c3\3\2\2\2") + buf.write("\u0095\u009e\7\n\2\2\u0096\u009b\5\20\t\2\u0097\u0098") + buf.write("\7\7\2\2\u0098\u009a\5\20\t\2\u0099\u0097\3\2\2\2\u009a") + buf.write("\u009d\3\2\2\2\u009b\u0099\3\2\2\2\u009b\u009c\3\2\2\2") + buf.write("\u009c\u009f\3\2\2\2\u009d\u009b\3\2\2\2\u009e\u0096\3") + buf.write("\2\2\2\u009e\u009f\3\2\2\2\u009f\u00a0\3\2\2\2\u00a0\u00c3") + buf.write("\7\13\2\2\u00a1\u00a2\7\f\2\2\u00a2\u00a3\7\b\2\2\u00a3") + buf.write("\u00a4\5\20\t\2\u00a4\u00a5\7\t\2\2\u00a5\u00a6\5@!\2") + buf.write("\u00a6\u00a7\7\r\2\2\u00a7\u00a8\5@!\2\u00a8\u00c3\3\2") + buf.write("\2\2\u00a9\u00aa\5 \21\2\u00aa\u00ab\5\20\t\2\u00ab\u00ad") + buf.write("\7\16\2\2\u00ac\u00ae\5\34\17\2\u00ad\u00ac\3\2\2\2\u00ad") + buf.write("\u00ae\3\2\2\2\u00ae\u00af\3\2\2\2\u00af\u00b0\7\17\2") + buf.write("\2\u00b0\u00c3\3\2\2\2\u00b1\u00b2\7\20\2\2\u00b2\u00b3") + buf.write("\5\60\31\2\u00b3\u00b4\7\21\2\2\u00b4\u00b5\5\20\t\2\u00b5") + buf.write("\u00b6\7\22\2\2\u00b6\u00b7\5\20\t\t\u00b7\u00c3\3\2\2") + buf.write("\2\u00b8\u00b9\5\n\6\2\u00b9\u00ba\7\21\2\2\u00ba\u00bb") + buf.write("\5\20\t\2\u00bb\u00bc\7\22\2\2\u00bc\u00bd\5\20\t\7\u00bd") + buf.write("\u00c3\3\2\2\2\u00be\u00c3\5D#\2\u00bf\u00c3\5B\"\2\u00c0") + buf.write("\u00c3\5<\37\2\u00c1\u00c3\7#\2\2\u00c2|\3\2\2\2\u00c2") + buf.write("\u0081\3\2\2\2\u00c2\u0083\3\2\2\2\u00c2\u0084\3\2\2\2") + buf.write("\u00c2\u0086\3\2\2\2\u00c2\u008b\3\2\2\2\u00c2\u0095\3") + buf.write("\2\2\2\u00c2\u00a1\3\2\2\2\u00c2\u00a9\3\2\2\2\u00c2\u00b1") + buf.write("\3\2\2\2\u00c2\u00b8\3\2\2\2\u00c2\u00be\3\2\2\2\u00c2") + buf.write("\u00bf\3\2\2\2\u00c2\u00c0\3\2\2\2\u00c2\u00c1\3\2\2\2") + buf.write("\u00c3\u00dd\3\2\2\2\u00c4\u00c5\f\25\2\2\u00c5\u00c6") + buf.write("\t\3\2\2\u00c6\u00dc\5\20\t\26\u00c7\u00c8\f\24\2\2\u00c8") + buf.write("\u00c9\t\4\2\2\u00c9\u00dc\5\20\t\25\u00ca\u00cb\f\23") + buf.write("\2\2\u00cb\u00cc\t\5\2\2\u00cc\u00dc\5\20\t\24\u00cd\u00ce") + buf.write("\f\22\2\2\u00ce\u00cf\t\6\2\2\u00cf\u00dc\5\20\t\23\u00d0") + buf.write("\u00d1\f\b\2\2\u00d1\u00d2\7\23\2\2\u00d2\u00dc\5\20\t") + buf.write("\t\u00d3\u00d4\f\27\2\2\u00d4\u00d5\7\b\2\2\u00d5\u00d6") + buf.write("\5\16\b\2\u00d6\u00d7\7\t\2\2\u00d7\u00dc\3\2\2\2\u00d8") + buf.write("\u00d9\f\n\2\2\u00d9\u00da\7\3\2\2\u00da\u00dc\7\61\2") + buf.write("\2\u00db\u00c4\3\2\2\2\u00db\u00c7\3\2\2\2\u00db\u00ca") + buf.write("\3\2\2\2\u00db\u00cd\3\2\2\2\u00db\u00d0\3\2\2\2\u00db") + buf.write("\u00d3\3\2\2\2\u00db\u00d8\3\2\2\2\u00dc\u00df\3\2\2\2") + buf.write("\u00dd\u00db\3\2\2\2\u00dd\u00de\3\2\2\2\u00de\21\3\2") + buf.write("\2\2\u00df\u00dd\3\2\2\2\u00e0\u00e2\7\24\2\2\u00e1\u00e3") + buf.write("\58\35\2\u00e2\u00e1\3\2\2\2\u00e2\u00e3\3\2\2\2\u00e3") + buf.write("\u00e4\3\2\2\2\u00e4\u00e5\7\b\2\2\u00e5\u00e6\5,\27\2") + buf.write("\u00e6\u00e9\7\t\2\2\u00e7\u00e8\7\25\2\2\u00e8\u00ea") + buf.write("\5\66\34\2\u00e9\u00e7\3\2\2\2\u00e9\u00ea\3\2\2\2\u00ea") + buf.write("\u00eb\3\2\2\2\u00eb\u00ec\5@!\2\u00ec\23\3\2\2\2\u00ed") + buf.write("\u00ee\7\26\2\2\u00ee\u00f0\5\6\4\2\u00ef\u00f1\58\35") + buf.write("\2\u00f0\u00ef\3\2\2\2\u00f0\u00f1\3\2\2\2\u00f1\u00f2") + buf.write("\3\2\2\2\u00f2\u00f3\7\b\2\2\u00f3\u00f4\5,\27\2\u00f4") + buf.write("\u00f7\7\t\2\2\u00f5\u00f6\7\25\2\2\u00f6\u00f8\5\66\34") + buf.write("\2\u00f7\u00f5\3\2\2\2\u00f7\u00f8\3\2\2\2\u00f8\u00f9") + buf.write("\3\2\2\2\u00f9\u00fa\5@!\2\u00fa\u010d\3\2\2\2\u00fb\u00fc") + buf.write("\7\27\2\2\u00fc\u00fd\7\30\2\2\u00fd\u00ff\5\4\3\2\u00fe") + buf.write("\u0100\58\35\2\u00ff\u00fe\3\2\2\2\u00ff\u0100\3\2\2\2") + buf.write("\u0100\u010d\3\2\2\2\u0101\u0102\7\30\2\2\u0102\u0104") + buf.write("\5\4\3\2\u0103\u0105\58\35\2\u0104\u0103\3\2\2\2\u0104") + buf.write("\u0105\3\2\2\2\u0105\u0106\3\2\2\2\u0106\u0108\7\16\2") + buf.write("\2\u0107\u0109\5\30\r\2\u0108\u0107\3\2\2\2\u0108\u0109") + buf.write("\3\2\2\2\u0109\u010a\3\2\2\2\u010a\u010b\7\17\2\2\u010b") + buf.write("\u010d\3\2\2\2\u010c\u00ed\3\2\2\2\u010c\u00fb\3\2\2\2") + buf.write("\u010c\u0101\3\2\2\2\u010d\25\3\2\2\2\u010e\u010f\7/\2") + buf.write("\2\u010f\27\3\2\2\2\u0110\u0115\5\32\16\2\u0111\u0112") + buf.write("\7\7\2\2\u0112\u0114\5\32\16\2\u0113\u0111\3\2\2\2\u0114") + buf.write("\u0117\3\2\2\2\u0115\u0113\3\2\2\2\u0115\u0116\3\2\2\2") + buf.write("\u0116\u0119\3\2\2\2\u0117\u0115\3\2\2\2\u0118\u011a\7") + buf.write("\7\2\2\u0119\u0118\3\2\2\2\u0119\u011a\3\2\2\2\u011a\31") + buf.write("\3\2\2\2\u011b\u0127\5\26\f\2\u011c\u011d\7\b\2\2\u011d") + buf.write("\u0122\5\66\34\2\u011e\u011f\7\7\2\2\u011f\u0121\5\66") + buf.write("\34\2\u0120\u011e\3\2\2\2\u0121\u0124\3\2\2\2\u0122\u0120") + buf.write("\3\2\2\2\u0122\u0123\3\2\2\2\u0123\u0125\3\2\2\2\u0124") + buf.write("\u0122\3\2\2\2\u0125\u0126\7\t\2\2\u0126\u0128\3\2\2\2") + buf.write("\u0127\u011c\3\2\2\2\u0127\u0128\3\2\2\2\u0128\33\3\2") + buf.write("\2\2\u0129\u012e\5\36\20\2\u012a\u012b\7\7\2\2\u012b\u012d") + buf.write("\5\36\20\2\u012c\u012a\3\2\2\2\u012d\u0130\3\2\2\2\u012e") + buf.write("\u012c\3\2\2\2\u012e\u012f\3\2\2\2\u012f\u0132\3\2\2\2") + buf.write("\u0130\u012e\3\2\2\2\u0131\u0133\7\7\2\2\u0132\u0131\3") + buf.write("\2\2\2\u0132\u0133\3\2\2\2\u0133\35\3\2\2\2\u0134\u0135") + buf.write("\5$\23\2\u0135\u013b\7\31\2\2\u0136\u0137\7\16\2\2\u0137") + buf.write("\u0138\5\20\t\2\u0138\u0139\7\17\2\2\u0139\u013c\3\2\2") + buf.write("\2\u013a\u013c\5\20\t\2\u013b\u0136\3\2\2\2\u013b\u013a") + buf.write("\3\2\2\2\u013c\37\3\2\2\2\u013d\u013e\t\7\2\2\u013e!\3") + buf.write("\2\2\2\u013f\u0140\7\b\2\2\u0140\u0145\5$\23\2\u0141\u0142") + buf.write("\7\7\2\2\u0142\u0144\5$\23\2\u0143\u0141\3\2\2\2\u0144") + buf.write("\u0147\3\2\2\2\u0145\u0143\3\2\2\2\u0145\u0146\3\2\2\2") + buf.write("\u0146\u0148\3\2\2\2\u0147\u0145\3\2\2\2\u0148\u0149\7") + buf.write("\t\2\2\u0149#\3\2\2\2\u014a\u0156\7\6\2\2\u014b\u014e") + buf.write("\5\b\5\2\u014c\u014d\7\34\2\2\u014d\u014f\5\66\34\2\u014e") + buf.write("\u014c\3\2\2\2\u014e\u014f\3\2\2\2\u014f\u0156\3\2\2\2") + buf.write("\u0150\u0152\5\26\f\2\u0151\u0153\5\"\22\2\u0152\u0151") + buf.write("\3\2\2\2\u0152\u0153\3\2\2\2\u0153\u0156\3\2\2\2\u0154") + buf.write("\u0156\5\"\22\2\u0155\u014a\3\2\2\2\u0155\u014b\3\2\2") + buf.write("\2\u0155\u0150\3\2\2\2\u0155\u0154\3\2\2\2\u0156%\3\2") + buf.write("\2\2\u0157\u0159\5\26\f\2\u0158\u015a\5(\25\2\u0159\u0158") + buf.write("\3\2\2\2\u0159\u015a\3\2\2\2\u015a\'\3\2\2\2\u015b\u015c") + buf.write("\7\b\2\2\u015c\u0161\5*\26\2\u015d\u015e\7\7\2\2\u015e") + buf.write("\u0160\5*\26\2\u015f\u015d\3\2\2\2\u0160\u0163\3\2\2\2") + buf.write("\u0161\u015f\3\2\2\2\u0161\u0162\3\2\2\2\u0162\u0164\3") + buf.write("\2\2\2\u0163\u0161\3\2\2\2\u0164\u0165\7\t\2\2\u0165)") + buf.write("\3\2\2\2\u0166\u0169\5\b\5\2\u0167\u0169\5\26\f\2\u0168") + buf.write("\u0166\3\2\2\2\u0168\u0167\3\2\2\2\u0169+\3\2\2\2\u016a") + buf.write("\u0175\5.\30\2\u016b\u016c\5\60\31\2\u016c\u016d\7\7\2") + buf.write("\2\u016d\u016f\3\2\2\2\u016e\u016b\3\2\2\2\u016f\u0172") + buf.write("\3\2\2\2\u0170\u016e\3\2\2\2\u0170\u0171\3\2\2\2\u0171") + buf.write("\u0173\3\2\2\2\u0172\u0170\3\2\2\2\u0173\u0175\5\62\32") + buf.write("\2\u0174\u016a\3\2\2\2\u0174\u0170\3\2\2\2\u0175-\3\2") + buf.write("\2\2\u0176\u017b\5\60\31\2\u0177\u0178\7\7\2\2\u0178\u017a") + buf.write("\5\60\31\2\u0179\u0177\3\2\2\2\u017a\u017d\3\2\2\2\u017b") + buf.write("\u0179\3\2\2\2\u017b\u017c\3\2\2\2\u017c\u017f\3\2\2\2") + buf.write("\u017d\u017b\3\2\2\2\u017e\u0176\3\2\2\2\u017e\u017f\3") + buf.write("\2\2\2\u017f/\3\2\2\2\u0180\u0183\5\b\5\2\u0181\u0182") + buf.write("\7\34\2\2\u0182\u0184\5\66\34\2\u0183\u0181\3\2\2\2\u0183") + buf.write("\u0184\3\2\2\2\u0184\61\3\2\2\2\u0185\u018a\5\64\33\2") + buf.write("\u0186\u0187\7\7\2\2\u0187\u0189\5\64\33\2\u0188\u0186") + buf.write("\3\2\2\2\u0189\u018c\3\2\2\2\u018a\u0188\3\2\2\2\u018a") + buf.write("\u018b\3\2\2\2\u018b\63\3\2\2\2\u018c\u018a\3\2\2\2\u018d") + buf.write("\u018e\7/\2\2\u018e\u018f\7\21\2\2\u018f\u0190\5\20\t") + buf.write("\2\u0190\65\3\2\2\2\u0191\u0192\7\b\2\2\u0192\u01c5\7") + buf.write("\t\2\2\u0193\u0194\7\b\2\2\u0194\u0195\5\66\34\2\u0195") + buf.write("\u0196\7\t\2\2\u0196\u01c5\3\2\2\2\u0197\u0198\7\b\2\2") + buf.write("\u0198\u0199\5\66\34\2\u0199\u019a\7\7\2\2\u019a\u019b") + buf.write("\7\t\2\2\u019b\u01c5\3\2\2\2\u019c\u019d\7\b\2\2\u019d") + buf.write("\u01a0\5\66\34\2\u019e\u019f\7\7\2\2\u019f\u01a1\5\66") + buf.write("\34\2\u01a0\u019e\3\2\2\2\u01a1\u01a2\3\2\2\2\u01a2\u01a0") + buf.write("\3\2\2\2\u01a2\u01a3\3\2\2\2\u01a3\u01a4\3\2\2\2\u01a4") + buf.write("\u01a5\7\t\2\2\u01a5\u01c5\3\2\2\2\u01a6\u01a7\5\4\3\2") + buf.write("\u01a7\u01a8\58\35\2\u01a8\u01c5\3\2\2\2\u01a9\u01c5\5") + buf.write("\4\3\2\u01aa\u01ab\7\35\2\2\u01ab\u01ac\7\n\2\2\u01ac") + buf.write("\u01ad\5:\36\2\u01ad\u01ae\7\7\2\2\u01ae\u01af\5\66\34") + buf.write("\2\u01af\u01b0\7\13\2\2\u01b0\u01c5\3\2\2\2\u01b1\u01b3") + buf.write("\7\24\2\2\u01b2\u01b4\58\35\2\u01b3\u01b2\3\2\2\2\u01b3") + buf.write("\u01b4\3\2\2\2\u01b4\u01b5\3\2\2\2\u01b5\u01be\7\b\2\2") + buf.write("\u01b6\u01bb\5\66\34\2\u01b7\u01b8\7\7\2\2\u01b8\u01ba") + buf.write("\5\66\34\2\u01b9\u01b7\3\2\2\2\u01ba\u01bd\3\2\2\2\u01bb") + buf.write("\u01b9\3\2\2\2\u01bb\u01bc\3\2\2\2\u01bc\u01bf\3\2\2\2") + buf.write("\u01bd\u01bb\3\2\2\2\u01be\u01b6\3\2\2\2\u01be\u01bf\3") + buf.write("\2\2\2\u01bf\u01c0\3\2\2\2\u01c0\u01c1\7\t\2\2\u01c1\u01c2") + buf.write("\7\25\2\2\u01c2\u01c5\5\66\34\2\u01c3\u01c5\7\6\2\2\u01c4") + buf.write("\u0191\3\2\2\2\u01c4\u0193\3\2\2\2\u01c4\u0197\3\2\2\2") + buf.write("\u01c4\u019c\3\2\2\2\u01c4\u01a6\3\2\2\2\u01c4\u01a9\3") + buf.write("\2\2\2\u01c4\u01aa\3\2\2\2\u01c4\u01b1\3\2\2\2\u01c4\u01c3") + buf.write("\3\2\2\2\u01c5\67\3\2\2\2\u01c6\u01c7\7\n\2\2\u01c7\u01cc") + buf.write("\5\66\34\2\u01c8\u01c9\7\7\2\2\u01c9\u01cb\5\66\34\2\u01ca") + buf.write("\u01c8\3\2\2\2\u01cb\u01ce\3\2\2\2\u01cc\u01ca\3\2\2\2") + buf.write("\u01cc\u01cd\3\2\2\2\u01cd\u01cf\3\2\2\2\u01ce\u01cc\3") + buf.write("\2\2\2\u01cf\u01d0\7\13\2\2\u01d09\3\2\2\2\u01d1\u01d2") + buf.write("\7\b\2\2\u01d2\u01df\7\t\2\2\u01d3\u01d4\7\b\2\2\u01d4") + buf.write("\u01d7\5> \2\u01d5\u01d6\7\7\2\2\u01d6\u01d8\5> \2\u01d7") + buf.write("\u01d5\3\2\2\2\u01d8\u01d9\3\2\2\2\u01d9\u01d7\3\2\2\2") + buf.write("\u01d9\u01da\3\2\2\2\u01da\u01db\3\2\2\2\u01db\u01dc\7") + buf.write("\t\2\2\u01dc\u01df\3\2\2\2\u01dd\u01df\5> \2\u01de\u01d1") + buf.write("\3\2\2\2\u01de\u01d3\3\2\2\2\u01de\u01dd\3\2\2\2\u01df") + buf.write(";\3\2\2\2\u01e0\u01e1\7\36\2\2\u01e1\u01e2\7\n\2\2\u01e2") + buf.write("\u01e3\7/\2\2\u01e3\u01e4\7\13\2\2\u01e4\u01e5\7\n\2\2") + buf.write("\u01e5\u01e6\7\61\2\2\u01e6\u01e7\7\13\2\2\u01e7=\3\2") + buf.write("\2\2\u01e8\u01ef\5<\37\2\u01e9\u01ea\7\b\2\2\u01ea\u01eb") + buf.write("\5> \2\u01eb\u01ec\7\t\2\2\u01ec\u01ef\3\2\2\2\u01ed\u01ef") + buf.write("\7\61\2\2\u01ee\u01e8\3\2\2\2\u01ee\u01e9\3\2\2\2\u01ee") + buf.write("\u01ed\3\2\2\2\u01ef?\3\2\2\2\u01f0\u01f1\7\16\2\2\u01f1") + buf.write("\u01f2\5\20\t\2\u01f2\u01f3\7\17\2\2\u01f3A\3\2\2\2\u01f4") + buf.write("\u01f8\7\60\2\2\u01f5\u01f8\7\61\2\2\u01f6\u01f8\7.\2") + buf.write("\2\u01f7\u01f4\3\2\2\2\u01f7\u01f5\3\2\2\2\u01f7\u01f6") + buf.write("\3\2\2\2\u01f8C\3\2\2\2\u01f9\u01fe\5\4\3\2\u01fa\u01fe") + buf.write("\5\6\4\2\u01fb\u01fe\5\b\5\2\u01fc\u01fe\5\n\6\2\u01fd") + buf.write("\u01f9\3\2\2\2\u01fd\u01fa\3\2\2\2\u01fd\u01fb\3\2\2\2") + buf.write("\u01fd\u01fc\3\2\2\2\u01feE\3\2\2\28JNQZknvz\u0091\u009b") + buf.write("\u009e\u00ad\u00c2\u00db\u00dd\u00e2\u00e9\u00f0\u00f7") + buf.write("\u00ff\u0104\u0108\u010c\u0115\u0119\u0122\u0127\u012e") + buf.write("\u0132\u013b\u0145\u014e\u0152\u0155\u0159\u0161\u0168") + buf.write("\u0170\u0174\u017b\u017e\u0183\u018a\u01a2\u01b3\u01bb") + buf.write("\u01be\u01c4\u01cc\u01d9\u01de\u01ee\u01f7\u01fd") return buf.getvalue() @@ -1180,7 +1183,7 @@ def expr(self, _p:int=0): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 194 + self.state = 192 self._errHandler.sync(self) la_ = self._interp.adaptivePredict(self._input,12,self._ctx) if la_ == 1: @@ -1321,22 +1324,18 @@ def expr(self, _p:int=0): self.state = 167 self.matchType() self.state = 168 - self.match(RelayParser.T__5) - self.state = 169 self.expr(0) - self.state = 170 - self.match(RelayParser.T__6) - self.state = 171 + self.state = 169 self.match(RelayParser.T__11) - self.state = 173 + self.state = 171 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==RelayParser.CNAME: - self.state = 172 + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__2) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.CNAME))) != 0): + self.state = 170 self.matchClauseList() - self.state = 175 + self.state = 173 self.match(RelayParser.T__12) pass @@ -1344,17 +1343,17 @@ def expr(self, _p:int=0): localctx = RelayParser.LetContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 177 + self.state = 175 self.match(RelayParser.T__13) - self.state = 178 + self.state = 176 self.var() - self.state = 179 + self.state = 177 self.match(RelayParser.T__14) - self.state = 180 + self.state = 178 self.expr(0) - self.state = 181 + self.state = 179 self.match(RelayParser.T__15) - self.state = 182 + self.state = 180 self.expr(7) pass @@ -1362,15 +1361,15 @@ def expr(self, _p:int=0): localctx = RelayParser.GraphContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 184 + self.state = 182 self.graphVar() - self.state = 185 + self.state = 183 self.match(RelayParser.T__14) - self.state = 186 + self.state = 184 self.expr(0) - self.state = 187 + self.state = 185 self.match(RelayParser.T__15) - self.state = 188 + self.state = 186 self.expr(5) pass @@ -1378,7 +1377,7 @@ def expr(self, _p:int=0): localctx = RelayParser.IdentExprContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 190 + self.state = 188 self.ident() pass @@ -1386,7 +1385,7 @@ def expr(self, _p:int=0): localctx = RelayParser.ScalarExprContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 191 + self.state = 189 self.scalar() pass @@ -1394,7 +1393,7 @@ def expr(self, _p:int=0): localctx = RelayParser.MetaExprContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 192 + self.state = 190 self.meta() pass @@ -1402,13 +1401,13 @@ def expr(self, _p:int=0): localctx = RelayParser.StringExprContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 193 + self.state = 191 self.match(RelayParser.QUOTED_STRING) pass self._ctx.stop = self._input.LT(-1) - self.state = 221 + self.state = 219 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input,14,self._ctx) while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: @@ -1416,17 +1415,17 @@ def expr(self, _p:int=0): if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - self.state = 219 + self.state = 217 self._errHandler.sync(self) la_ = self._interp.adaptivePredict(self._input,13,self._ctx) if la_ == 1: localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 196 + self.state = 194 if not self.precpred(self._ctx, 19): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException(self, "self.precpred(self._ctx, 19)") - self.state = 197 + self.state = 195 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not(_la==RelayParser.MUL or _la==RelayParser.DIV): @@ -1434,18 +1433,18 @@ def expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 198 + self.state = 196 self.expr(20) pass elif la_ == 2: localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 199 + self.state = 197 if not self.precpred(self._ctx, 18): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException(self, "self.precpred(self._ctx, 18)") - self.state = 200 + self.state = 198 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not(_la==RelayParser.ADD or _la==RelayParser.SUB): @@ -1453,18 +1452,18 @@ def expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 201 + self.state = 199 self.expr(19) pass elif la_ == 3: localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 202 + self.state = 200 if not self.precpred(self._ctx, 17): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException(self, "self.precpred(self._ctx, 17)") - self.state = 203 + self.state = 201 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not((((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.LT) | (1 << RelayParser.GT) | (1 << RelayParser.LE) | (1 << RelayParser.GE))) != 0)): @@ -1472,18 +1471,18 @@ def expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 204 + self.state = 202 self.expr(18) pass elif la_ == 4: localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 205 + self.state = 203 if not self.precpred(self._ctx, 16): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException(self, "self.precpred(self._ctx, 16)") - self.state = 206 + self.state = 204 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not(_la==RelayParser.EQ or _la==RelayParser.NE): @@ -1491,53 +1490,53 @@ def expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 207 + self.state = 205 self.expr(17) pass elif la_ == 5: localctx = RelayParser.LetContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 208 + self.state = 206 if not self.precpred(self._ctx, 6): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException(self, "self.precpred(self._ctx, 6)") - self.state = 209 + self.state = 207 self.match(RelayParser.T__16) - self.state = 210 + self.state = 208 self.expr(7) pass elif la_ == 6: localctx = RelayParser.CallContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 211 + self.state = 209 if not self.precpred(self._ctx, 21): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException(self, "self.precpred(self._ctx, 21)") - self.state = 212 + self.state = 210 self.match(RelayParser.T__5) - self.state = 213 + self.state = 211 self.callList() - self.state = 214 + self.state = 212 self.match(RelayParser.T__6) pass elif la_ == 7: localctx = RelayParser.ProjectionContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 216 + self.state = 214 if not self.precpred(self._ctx, 8): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException(self, "self.precpred(self._ctx, 8)") - self.state = 217 + self.state = 215 self.match(RelayParser.T__0) - self.state = 218 + self.state = 216 self.match(RelayParser.NAT) pass - self.state = 223 + self.state = 221 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input,14,self._ctx) @@ -1591,33 +1590,33 @@ def func(self): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 224 + self.state = 222 self.match(RelayParser.T__17) - self.state = 226 + self.state = 224 self._errHandler.sync(self) _la = self._input.LA(1) if _la==RelayParser.T__7: - self.state = 225 + self.state = 223 self.typeParamList() - self.state = 228 + self.state = 226 self.match(RelayParser.T__5) - self.state = 229 + self.state = 227 self.argList() - self.state = 230 + self.state = 228 self.match(RelayParser.T__6) - self.state = 233 + self.state = 231 self._errHandler.sync(self) _la = self._input.LA(1) if _la==RelayParser.T__18: - self.state = 231 + self.state = 229 self.match(RelayParser.T__18) - self.state = 232 + self.state = 230 self.typeExpr() - self.state = 235 + self.state = 233 self.body() except RecognitionException as re: localctx.exception = re @@ -1723,57 +1722,57 @@ def defn(self): self.enterRule(localctx, 18, self.RULE_defn) self._la = 0 # Token type try: - self.state = 268 + self.state = 266 self._errHandler.sync(self) token = self._input.LA(1) if token in [RelayParser.T__19]: localctx = RelayParser.FuncDefnContext(self, localctx) self.enterOuterAlt(localctx, 1) - self.state = 237 + self.state = 235 self.match(RelayParser.T__19) - self.state = 238 + self.state = 236 self.globalVar() - self.state = 240 + self.state = 238 self._errHandler.sync(self) _la = self._input.LA(1) if _la==RelayParser.T__7: - self.state = 239 + self.state = 237 self.typeParamList() - self.state = 242 + self.state = 240 self.match(RelayParser.T__5) - self.state = 243 + self.state = 241 self.argList() - self.state = 244 + self.state = 242 self.match(RelayParser.T__6) - self.state = 247 + self.state = 245 self._errHandler.sync(self) _la = self._input.LA(1) if _la==RelayParser.T__18: - self.state = 245 + self.state = 243 self.match(RelayParser.T__18) - self.state = 246 + self.state = 244 self.typeExpr() - self.state = 249 + self.state = 247 self.body() pass elif token in [RelayParser.T__20]: localctx = RelayParser.ExternAdtDefnContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 251 + self.state = 249 self.match(RelayParser.T__20) - self.state = 252 + self.state = 250 self.match(RelayParser.T__21) - self.state = 253 + self.state = 251 self.generalIdent() - self.state = 255 + self.state = 253 self._errHandler.sync(self) _la = self._input.LA(1) if _la==RelayParser.T__7: - self.state = 254 + self.state = 252 self.typeParamList() @@ -1781,29 +1780,29 @@ def defn(self): elif token in [RelayParser.T__21]: localctx = RelayParser.AdtDefnContext(self, localctx) self.enterOuterAlt(localctx, 3) - self.state = 257 + self.state = 255 self.match(RelayParser.T__21) - self.state = 258 + self.state = 256 self.generalIdent() - self.state = 260 + self.state = 258 self._errHandler.sync(self) _la = self._input.LA(1) if _la==RelayParser.T__7: - self.state = 259 + self.state = 257 self.typeParamList() - self.state = 262 + self.state = 260 self.match(RelayParser.T__11) - self.state = 264 + self.state = 262 self._errHandler.sync(self) _la = self._input.LA(1) if _la==RelayParser.CNAME: - self.state = 263 + self.state = 261 self.adtConsDefnList() - self.state = 266 + self.state = 264 self.match(RelayParser.T__12) pass else: @@ -1845,7 +1844,7 @@ def constructorName(self): self.enterRule(localctx, 20, self.RULE_constructorName) try: self.enterOuterAlt(localctx, 1) - self.state = 270 + self.state = 268 self.match(RelayParser.CNAME) except RecognitionException as re: localctx.exception = re @@ -1888,26 +1887,26 @@ def adtConsDefnList(self): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 272 + self.state = 270 self.adtConsDefn() - self.state = 277 + self.state = 275 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input,23,self._ctx) while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: if _alt==1: - self.state = 273 + self.state = 271 self.match(RelayParser.T__4) - self.state = 274 + self.state = 272 self.adtConsDefn() - self.state = 279 + self.state = 277 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input,23,self._ctx) - self.state = 281 + self.state = 279 self._errHandler.sync(self) _la = self._input.LA(1) if _la==RelayParser.T__4: - self.state = 280 + self.state = 278 self.match(RelayParser.T__4) @@ -1956,29 +1955,29 @@ def adtConsDefn(self): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 283 + self.state = 281 self.constructorName() - self.state = 295 + self.state = 293 self._errHandler.sync(self) _la = self._input.LA(1) if _la==RelayParser.T__5: - self.state = 284 + self.state = 282 self.match(RelayParser.T__5) - self.state = 285 + self.state = 283 self.typeExpr() - self.state = 290 + self.state = 288 self._errHandler.sync(self) _la = self._input.LA(1) while _la==RelayParser.T__4: - self.state = 286 + self.state = 284 self.match(RelayParser.T__4) - self.state = 287 + self.state = 285 self.typeExpr() - self.state = 292 + self.state = 290 self._errHandler.sync(self) _la = self._input.LA(1) - self.state = 293 + self.state = 291 self.match(RelayParser.T__6) @@ -2023,26 +2022,26 @@ def matchClauseList(self): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 297 + self.state = 295 self.matchClause() - self.state = 302 + self.state = 300 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input,27,self._ctx) while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: if _alt==1: - self.state = 298 + self.state = 296 self.match(RelayParser.T__4) - self.state = 299 + self.state = 297 self.matchClause() - self.state = 304 + self.state = 302 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input,27,self._ctx) - self.state = 306 + self.state = 304 self._errHandler.sync(self) _la = self._input.LA(1) if _la==RelayParser.T__4: - self.state = 305 + self.state = 303 self.match(RelayParser.T__4) @@ -2061,18 +2060,14 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser - def constructorName(self): - return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0) + def pattern(self): + return self.getTypedRuleContext(RelayParser.PatternContext,0) def expr(self): return self.getTypedRuleContext(RelayParser.ExprContext,0) - def patternList(self): - return self.getTypedRuleContext(RelayParser.PatternListContext,0) - - def getRuleIndex(self): return RelayParser.RULE_matchClause @@ -2089,34 +2084,25 @@ def matchClause(self): localctx = RelayParser.MatchClauseContext(self, self._ctx, self.state) self.enterRule(localctx, 28, self.RULE_matchClause) - self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 308 - self.constructorName() - self.state = 310 - self._errHandler.sync(self) - _la = self._input.LA(1) - if _la==RelayParser.T__5: - self.state = 309 - self.patternList() - - - self.state = 312 + self.state = 306 + self.pattern() + self.state = 307 self.match(RelayParser.T__22) - self.state = 318 + self.state = 313 self._errHandler.sync(self) token = self._input.LA(1) if token in [RelayParser.T__11]: - self.state = 313 + self.state = 308 self.match(RelayParser.T__11) - self.state = 314 + self.state = 309 self.expr(0) - self.state = 315 + self.state = 310 self.match(RelayParser.T__12) pass elif token in [RelayParser.T__1, RelayParser.T__2, RelayParser.T__5, RelayParser.T__7, RelayParser.T__9, RelayParser.T__13, RelayParser.T__17, RelayParser.T__23, RelayParser.T__24, RelayParser.T__27, RelayParser.QUOTED_STRING, RelayParser.SUB, RelayParser.BOOL_LIT, RelayParser.CNAME, RelayParser.FLOAT, RelayParser.NAT]: - self.state = 317 + self.state = 312 self.expr(0) pass else: @@ -2157,7 +2143,7 @@ def matchType(self): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 320 + self.state = 315 _la = self._input.LA(1) if not(_la==RelayParser.T__23 or _la==RelayParser.T__24): self._errHandler.recoverInline(self) @@ -2205,23 +2191,23 @@ def patternList(self): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 322 + self.state = 317 self.match(RelayParser.T__5) - self.state = 323 + self.state = 318 self.pattern() - self.state = 328 + self.state = 323 self._errHandler.sync(self) _la = self._input.LA(1) while _la==RelayParser.T__4: - self.state = 324 + self.state = 319 self.match(RelayParser.T__4) - self.state = 325 + self.state = 320 self.pattern() - self.state = 330 + self.state = 325 self._errHandler.sync(self) _la = self._input.LA(1) - self.state = 331 + self.state = 326 self.match(RelayParser.T__6) except RecognitionException as re: localctx.exception = re @@ -2238,26 +2224,88 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser + + def getRuleIndex(self): + return RelayParser.RULE_pattern + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + + class WildcardPatternContext(PatternContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext + super().__init__(parser) + self.copyFrom(ctx) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitWildcardPattern" ): + return visitor.visitWildcardPattern(self) + else: + return visitor.visitChildren(self) + + + class ConstructorPatternContext(PatternContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext + super().__init__(parser) + self.copyFrom(ctx) + + def constructorName(self): + return self.getTypedRuleContext(RelayParser.ConstructorNameContext,0) + + def patternList(self): + return self.getTypedRuleContext(RelayParser.PatternListContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitConstructorPattern" ): + return visitor.visitConstructorPattern(self) + else: + return visitor.visitChildren(self) + + + class TuplePatternContext(PatternContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext + super().__init__(parser) + self.copyFrom(ctx) + + def patternList(self): + return self.getTypedRuleContext(RelayParser.PatternListContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTuplePattern" ): + return visitor.visitTuplePattern(self) + else: + return visitor.visitChildren(self) + + + class VarPatternContext(PatternContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.PatternContext + super().__init__(parser) + self.copyFrom(ctx) + def localVar(self): return self.getTypedRuleContext(RelayParser.LocalVarContext,0) - def typeExpr(self): return self.getTypedRuleContext(RelayParser.TypeExprContext,0) - def getRuleIndex(self): - return RelayParser.RULE_pattern - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitPattern" ): - return visitor.visitPattern(self) + if hasattr( visitor, "visitVarPattern" ): + return visitor.visitVarPattern(self) else: return visitor.visitChildren(self) - def pattern(self): localctx = RelayParser.PatternContext(self, self._ctx, self.state) @@ -2268,24 +2316,46 @@ def pattern(self): self._errHandler.sync(self) token = self._input.LA(1) if token in [RelayParser.T__3]: + localctx = RelayParser.WildcardPatternContext(self, localctx) self.enterOuterAlt(localctx, 1) - self.state = 333 + self.state = 328 self.match(RelayParser.T__3) pass elif token in [RelayParser.T__2]: + localctx = RelayParser.VarPatternContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 334 + self.state = 329 self.localVar() - self.state = 337 + self.state = 332 self._errHandler.sync(self) _la = self._input.LA(1) if _la==RelayParser.T__25: - self.state = 335 + self.state = 330 self.match(RelayParser.T__25) - self.state = 336 + self.state = 331 self.typeExpr() + pass + elif token in [RelayParser.CNAME]: + localctx = RelayParser.ConstructorPatternContext(self, localctx) + self.enterOuterAlt(localctx, 3) + self.state = 334 + self.constructorName() + self.state = 336 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__5: + self.state = 335 + self.patternList() + + + pass + elif token in [RelayParser.T__5]: + localctx = RelayParser.TuplePatternContext(self, localctx) + self.enterOuterAlt(localctx, 4) + self.state = 338 + self.patternList() pass else: raise NoViableAltException(self) @@ -2800,6 +2870,23 @@ def copyFrom(self, ctx:ParserRuleContext): + class TypeParenContext(TypeExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def typeExpr(self): + return self.getTypedRuleContext(RelayParser.TypeExprContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTypeParen" ): + return visitor.visitTypeParen(self) + else: + return visitor.visitChildren(self) + + class TupleTypeContext(TypeExprContext): def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.TypeExprContext @@ -2921,7 +3008,7 @@ def typeExpr(self): self.enterRule(localctx, 52, self.RULE_typeExpr) self._la = 0 # Token type try: - self.state = 446 + self.state = 450 self._errHandler.sync(self) la_ = self._interp.adaptivePredict(self._input,47,self._ctx) if la_ == 1: @@ -2934,123 +3021,134 @@ def typeExpr(self): pass elif la_ == 2: - localctx = RelayParser.TupleTypeContext(self, localctx) + localctx = RelayParser.TypeParenContext(self, localctx) self.enterOuterAlt(localctx, 2) self.state = 401 self.match(RelayParser.T__5) self.state = 402 self.typeExpr() self.state = 403 - self.match(RelayParser.T__4) - self.state = 404 self.match(RelayParser.T__6) pass elif la_ == 3: localctx = RelayParser.TupleTypeContext(self, localctx) self.enterOuterAlt(localctx, 3) - self.state = 406 + self.state = 405 self.match(RelayParser.T__5) + self.state = 406 + self.typeExpr() self.state = 407 + self.match(RelayParser.T__4) + self.state = 408 + self.match(RelayParser.T__6) + pass + + elif la_ == 4: + localctx = RelayParser.TupleTypeContext(self, localctx) + self.enterOuterAlt(localctx, 4) + self.state = 410 + self.match(RelayParser.T__5) + self.state = 411 self.typeExpr() - self.state = 410 + self.state = 414 self._errHandler.sync(self) _la = self._input.LA(1) while True: - self.state = 408 + self.state = 412 self.match(RelayParser.T__4) - self.state = 409 + self.state = 413 self.typeExpr() - self.state = 412 + self.state = 416 self._errHandler.sync(self) _la = self._input.LA(1) if not (_la==RelayParser.T__4): break - self.state = 414 + self.state = 418 self.match(RelayParser.T__6) pass - elif la_ == 4: + elif la_ == 5: localctx = RelayParser.TypeCallTypeContext(self, localctx) - self.enterOuterAlt(localctx, 4) - self.state = 416 + self.enterOuterAlt(localctx, 5) + self.state = 420 self.generalIdent() - self.state = 417 + self.state = 421 self.typeParamList() pass - elif la_ == 5: + elif la_ == 6: localctx = RelayParser.TypeIdentTypeContext(self, localctx) - self.enterOuterAlt(localctx, 5) - self.state = 419 + self.enterOuterAlt(localctx, 6) + self.state = 423 self.generalIdent() pass - elif la_ == 6: + elif la_ == 7: localctx = RelayParser.TensorTypeContext(self, localctx) - self.enterOuterAlt(localctx, 6) - self.state = 420 + self.enterOuterAlt(localctx, 7) + self.state = 424 self.match(RelayParser.T__26) - self.state = 421 + self.state = 425 self.match(RelayParser.T__7) - self.state = 422 + self.state = 426 self.shapeList() - self.state = 423 + self.state = 427 self.match(RelayParser.T__4) - self.state = 424 + self.state = 428 self.typeExpr() - self.state = 425 + self.state = 429 self.match(RelayParser.T__8) pass - elif la_ == 7: + elif la_ == 8: localctx = RelayParser.FuncTypeContext(self, localctx) - self.enterOuterAlt(localctx, 7) - self.state = 427 + self.enterOuterAlt(localctx, 8) + self.state = 431 self.match(RelayParser.T__17) - self.state = 429 + self.state = 433 self._errHandler.sync(self) _la = self._input.LA(1) if _la==RelayParser.T__7: - self.state = 428 + self.state = 432 self.typeParamList() - self.state = 431 + self.state = 435 self.match(RelayParser.T__5) - self.state = 440 + self.state = 444 self._errHandler.sync(self) _la = self._input.LA(1) if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__17) | (1 << RelayParser.T__26) | (1 << RelayParser.CNAME))) != 0): - self.state = 432 + self.state = 436 self.typeExpr() - self.state = 437 + self.state = 441 self._errHandler.sync(self) _la = self._input.LA(1) while _la==RelayParser.T__4: - self.state = 433 + self.state = 437 self.match(RelayParser.T__4) - self.state = 434 + self.state = 438 self.typeExpr() - self.state = 439 + self.state = 443 self._errHandler.sync(self) _la = self._input.LA(1) - self.state = 442 + self.state = 446 self.match(RelayParser.T__6) - self.state = 443 + self.state = 447 self.match(RelayParser.T__18) - self.state = 444 + self.state = 448 self.typeExpr() pass - elif la_ == 8: + elif la_ == 9: localctx = RelayParser.IncompleteTypeContext(self, localctx) - self.enterOuterAlt(localctx, 8) - self.state = 445 + self.enterOuterAlt(localctx, 9) + self.state = 449 self.match(RelayParser.T__3) pass @@ -3070,11 +3168,11 @@ def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser - def generalIdent(self, i:int=None): + def typeExpr(self, i:int=None): if i is None: - return self.getTypedRuleContexts(RelayParser.GeneralIdentContext) + return self.getTypedRuleContexts(RelayParser.TypeExprContext) else: - return self.getTypedRuleContext(RelayParser.GeneralIdentContext,i) + return self.getTypedRuleContext(RelayParser.TypeExprContext,i) def getRuleIndex(self): @@ -3096,23 +3194,23 @@ def typeParamList(self): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 448 + self.state = 452 self.match(RelayParser.T__7) - self.state = 449 - self.generalIdent() - self.state = 454 + self.state = 453 + self.typeExpr() + self.state = 458 self._errHandler.sync(self) _la = self._input.LA(1) while _la==RelayParser.T__4: - self.state = 450 + self.state = 454 self.match(RelayParser.T__4) - self.state = 451 - self.generalIdent() - self.state = 456 + self.state = 455 + self.typeExpr() + self.state = 460 self._errHandler.sync(self) _la = self._input.LA(1) - self.state = 457 + self.state = 461 self.match(RelayParser.T__8) except RecognitionException as re: localctx.exception = re @@ -3154,44 +3252,44 @@ def shapeList(self): self.enterRule(localctx, 56, self.RULE_shapeList) self._la = 0 # Token type try: - self.state = 472 + self.state = 476 self._errHandler.sync(self) la_ = self._interp.adaptivePredict(self._input,50,self._ctx) if la_ == 1: self.enterOuterAlt(localctx, 1) - self.state = 459 + self.state = 463 self.match(RelayParser.T__5) - self.state = 460 + self.state = 464 self.match(RelayParser.T__6) pass elif la_ == 2: self.enterOuterAlt(localctx, 2) - self.state = 461 + self.state = 465 self.match(RelayParser.T__5) - self.state = 462 + self.state = 466 self.shape() - self.state = 465 + self.state = 469 self._errHandler.sync(self) _la = self._input.LA(1) while True: - self.state = 463 + self.state = 467 self.match(RelayParser.T__4) - self.state = 464 + self.state = 468 self.shape() - self.state = 467 + self.state = 471 self._errHandler.sync(self) _la = self._input.LA(1) if not (_la==RelayParser.T__4): break - self.state = 469 + self.state = 473 self.match(RelayParser.T__6) pass elif la_ == 3: self.enterOuterAlt(localctx, 3) - self.state = 471 + self.state = 475 self.shape() pass @@ -3235,19 +3333,19 @@ def meta(self): self.enterRule(localctx, 58, self.RULE_meta) try: self.enterOuterAlt(localctx, 1) - self.state = 474 + self.state = 478 self.match(RelayParser.T__27) - self.state = 475 + self.state = 479 self.match(RelayParser.T__7) - self.state = 476 + self.state = 480 self.match(RelayParser.CNAME) - self.state = 477 + self.state = 481 self.match(RelayParser.T__8) - self.state = 478 + self.state = 482 self.match(RelayParser.T__7) - self.state = 479 + self.state = 483 self.match(RelayParser.NAT) - self.state = 480 + self.state = 484 self.match(RelayParser.T__8) except RecognitionException as re: localctx.exception = re @@ -3330,29 +3428,29 @@ def shape(self): localctx = RelayParser.ShapeContext(self, self._ctx, self.state) self.enterRule(localctx, 60, self.RULE_shape) try: - self.state = 488 + self.state = 492 self._errHandler.sync(self) token = self._input.LA(1) if token in [RelayParser.T__27]: localctx = RelayParser.MetaShapeContext(self, localctx) self.enterOuterAlt(localctx, 1) - self.state = 482 + self.state = 486 self.meta() pass elif token in [RelayParser.T__5]: localctx = RelayParser.ParensShapeContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 483 + self.state = 487 self.match(RelayParser.T__5) - self.state = 484 + self.state = 488 self.shape() - self.state = 485 + self.state = 489 self.match(RelayParser.T__6) pass elif token in [RelayParser.NAT]: localctx = RelayParser.IntShapeContext(self, localctx) self.enterOuterAlt(localctx, 3) - self.state = 487 + self.state = 491 self.match(RelayParser.NAT) pass else: @@ -3395,11 +3493,11 @@ def body(self): self.enterRule(localctx, 62, self.RULE_body) try: self.enterOuterAlt(localctx, 1) - self.state = 490 + self.state = 494 self.match(RelayParser.T__11) - self.state = 491 + self.state = 495 self.expr(0) - self.state = 492 + self.state = 496 self.match(RelayParser.T__12) except RecognitionException as re: localctx.exception = re @@ -3480,25 +3578,25 @@ def scalar(self): localctx = RelayParser.ScalarContext(self, self._ctx, self.state) self.enterRule(localctx, 64, self.RULE_scalar) try: - self.state = 497 + self.state = 501 self._errHandler.sync(self) token = self._input.LA(1) if token in [RelayParser.FLOAT]: localctx = RelayParser.ScalarFloatContext(self, localctx) self.enterOuterAlt(localctx, 1) - self.state = 494 + self.state = 498 self.match(RelayParser.FLOAT) pass elif token in [RelayParser.NAT]: localctx = RelayParser.ScalarIntContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 495 + self.state = 499 self.match(RelayParser.NAT) pass elif token in [RelayParser.BOOL_LIT]: localctx = RelayParser.ScalarBoolContext(self, localctx) self.enterOuterAlt(localctx, 3) - self.state = 496 + self.state = 500 self.match(RelayParser.BOOL_LIT) pass else: @@ -3552,30 +3650,30 @@ def ident(self): localctx = RelayParser.IdentContext(self, self._ctx, self.state) self.enterRule(localctx, 66, self.RULE_ident) try: - self.state = 503 + self.state = 507 self._errHandler.sync(self) la_ = self._interp.adaptivePredict(self._input,53,self._ctx) if la_ == 1: self.enterOuterAlt(localctx, 1) - self.state = 499 + self.state = 503 self.generalIdent() pass elif la_ == 2: self.enterOuterAlt(localctx, 2) - self.state = 500 + self.state = 504 self.globalVar() pass elif la_ == 3: self.enterOuterAlt(localctx, 3) - self.state = 501 + self.state = 505 self.localVar() pass elif la_ == 4: self.enterOuterAlt(localctx, 4) - self.state = 502 + self.state = 506 self.graphVar() pass diff --git a/python/tvm/relay/grammar/py3/RelayVisitor.py b/python/tvm/relay/grammar/py3/RelayVisitor.py index 98dd09fef669..c6a7b7a0558c 100644 --- a/python/tvm/relay/grammar/py3/RelayVisitor.py +++ b/python/tvm/relay/grammar/py3/RelayVisitor.py @@ -184,8 +184,23 @@ def visitPatternList(self, ctx:RelayParser.PatternListContext): return self.visitChildren(ctx) - # Visit a parse tree produced by RelayParser#pattern. - def visitPattern(self, ctx:RelayParser.PatternContext): + # Visit a parse tree produced by RelayParser#wildcardPattern. + def visitWildcardPattern(self, ctx:RelayParser.WildcardPatternContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#varPattern. + def visitVarPattern(self, ctx:RelayParser.VarPatternContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#constructorPattern. + def visitConstructorPattern(self, ctx:RelayParser.ConstructorPatternContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tuplePattern. + def visitTuplePattern(self, ctx:RelayParser.TuplePatternContext): return self.visitChildren(ctx) @@ -239,6 +254,11 @@ def visitTupleType(self, ctx:RelayParser.TupleTypeContext): return self.visitChildren(ctx) + # Visit a parse tree produced by RelayParser#typeParen. + def visitTypeParen(self, ctx:RelayParser.TypeParenContext): + return self.visitChildren(ctx) + + # Visit a parse tree produced by RelayParser#typeCallType. def visitTypeCallType(self, ctx:RelayParser.TypeCallTypeContext): return self.visitChildren(ctx) diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 57980dd09cf2..5513bd711c4f 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -96,7 +96,7 @@ def _add(self, var, val, update=False): assert isinstance(val, _ty.Type) if isinstance(var, _base.string_types): var = _ty.GlobalTypeVar(var) - _module.Module_AddDef(self, var, val) + _module.Module_AddDef(self, var, val, update) def __getitem__(self, var): """Lookup a global definition by name or by variable. @@ -149,6 +149,26 @@ def get_global_var(self, name): """ return _module.Module_GetGlobalVar(self, name) + def get_global_vars(self): + """Collect all global vars defined in this module. + + Returns + ------- + global_vars: tvm.Array[GlobalVar] + An array of global vars. + """ + return _module.Module_GetGlobalVars(self) + + def get_global_type_vars(self): + """Collect all global type vars defined in this module. + + Returns + ------- + global_type_vars: tvm.Array[GlobalTypeVar] + An array of global type vars. + """ + return _module.Module_GetGlobalTypeVars(self) + def get_global_type_var(self, name): """Get a global type variable in the function by name. diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index d05b669ee7f1..803d8ef50db5 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -16,502 +16,61 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" -from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type -from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const -from .op.tensor import add, subtract, equal -from .adt import Constructor, TypeData, Clause, Match -from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple from .module import Module class Prelude: """Contains standard definitions.""" - def define_list_adt(self): - """Defines a LISP-style list ADT. An empty list is - represented by nil(). A member x can be appended to the - front of a list l via the constructor cons(x, l).""" - self.l = GlobalTypeVar("list") - a = TypeVar("a") - self.nil = Constructor("nil", [], self.l) - self.cons = Constructor("cons", [a, self.l(a)], self.l) - self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons]) - - def define_list_hd(self): - """Defines a function to get the head of a list. Assume the list has at least one - element. - - hd(l) : list[a] -> a - """ - self.hd = GlobalVar("hd") - a = TypeVar("a") - x = Var("x", self.l(a)) - y = Var("y") - z = Var("z") - cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y) - self.mod[self.hd] = Function([x], Match(x, [cons_case], False), a, [a]) - - def define_list_tl(self): - """Defines a function to get the tail of a list. - - tl(l) : list[a] -> list[a] - """ - self.tl = GlobalVar("tl") - a = TypeVar("a") - x = Var("x", self.l(a)) - y = Var("y") - z = Var("z") - cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z) - self.mod[self.tl] = Function([x], Match(x, [cons_case], False), self.l(a), [a]) - - - def define_list_nth(self): - """Defines a function to get the nth element of a list. - - nth(l) : list[a] -> Tensor[(), int32] -> a - """ - self.nth = GlobalVar("nth") - a = TypeVar("a") - x = Var("x", self.l(a)) - n = Var("n", scalar_type('int32')) - - body = If(equal(n, const(0)), - self.hd(x), - self.nth(self.tl(x), subtract(n, const(1)))) - - self.mod[self.nth] = Function([x, n], body, a, [a]) - - - def define_list_update(self): - """Defines a function to update the nth element of a list and return the updated list. - - update(l, i, v) : list[a] -> Tensor[(), int32] -> a -> list[a] - """ - self.update = GlobalVar("update") - a = TypeVar("a") - l = Var("l", self.l(a)) - n = Var("n", scalar_type('int32')) - v = Var("v", a) - - body = If(equal(n, const(0)), - self.cons(v, self.tl(l)), - self.cons(self.hd(l), - self.update(self.tl(l), - subtract(n, const(1)), - v))) - - self.mod[self.update] = Function([l, n, v], body, self.l(a), [a]) - - - def define_list_map(self): - """Defines a function for mapping a function over a list's - elements. That is, map(f, l) returns a new list where - the ith member is f applied to the ith member of l. - - map(f, l) : fn(fn(a) -> b, list[a]) -> list[b] - """ - self.map = GlobalVar("map") - a = TypeVar("a") - b = TypeVar("b") - f = Var("f", FuncType([a], b)) - x = Var("x", self.l(a)) - y = Var("y") - z = Var("z") - nil_case = Clause(PatternConstructor(self.nil), self.nil()) - cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), - self.cons(f(y), self.map(f, z))) - self.mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), self.l(b), [a, b]) - - - def define_list_foldl(self): - """Defines a left-way fold over a list. - - foldl(f, z, l) : fn(fn(a, b) -> a, a, list[b]) -> a - - foldl(f, z, cons(a1, cons(a2, cons(a3, cons(..., nil))))) - evaluates to f(...f(f(f(z, a1), a2), a3)...) - """ - self.foldl = GlobalVar("foldl") - a = TypeVar("a") - b = TypeVar("b") - f = Var("f", FuncType([a, b], a)) - av = Var("av", a) - bv = Var("bv", self.l(b)) - y = Var("y") - z = Var("z") - nil_case = Clause(PatternConstructor(self.nil), av) - cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), - self.foldl(f, f(av, y), z)) - self.mod[self.foldl] = Function([f, av, bv], - Match(bv, [nil_case, cons_case]), a, [a, b]) - - - def define_list_foldr(self): - """Defines a right-way fold over a list. - - foldr(f, l, z) : fn(fn(a, b) -> b, list[a], b) -> b - - foldr(f, cons(a1, cons(a2, cons(..., cons(an, nil)))), z) - evalutes to f(a1, f(a2, f(..., f(an, z)))...) - """ - self.foldr = GlobalVar("foldr") - a = TypeVar("a") - b = TypeVar("b") - f = Var("f", FuncType([a, b], b)) - av = Var("av", self.l(a)) - bv = Var("bv", b) - y = Var("y") - z = Var("z") - nil_case = Clause(PatternConstructor(self.nil), bv) - cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), - f(y, self.foldr(f, bv, z))) - self.mod[self.foldr] = Function([f, bv, av], - Match(av, [nil_case, cons_case]), b, [a, b]) - - - def define_list_foldr1(self): - """Defines a right-way fold over a nonempty list. - - foldr1(f, l) : fn(fn(a, a) -> a, list[a]) -> a - - foldr1(f, cons(a1, cons(a2, cons(..., cons(an, nil))))) - evalutes to f(a1, f(a2, f(..., f(an-1, an)))...) - """ - self.foldr1 = GlobalVar("foldr1") - a = TypeVar("a") - f = Var("f", FuncType([a, a], a)) - av = Var("av", self.l(a)) - x = Var("x") - y = Var("y") - z = Var("z") - one_case = Clause(PatternConstructor(self.cons, - [PatternVar(x), PatternConstructor(self.nil)]), x) - cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), - f(y, self.foldr1(f, z))) - self.mod[self.foldr1] = Function([f, av], - Match(av, [one_case, cons_case], False), a, [a]) - - - def define_list_concat(self): - """Defines a function that concatenates two lists. - - concat(l1, l2) : fn(list[a], list[a]) -> list[a]""" - self.concat = GlobalVar("concat") - a = TypeVar("a") - l1 = Var("l1", self.l(a)) - l2 = Var("l2", self.l(a)) - h = Var("h") - t = Var("t") - updater = Function([h, t], self.cons(h, t)) - self.mod[self.concat] = Function([l1, l2], - self.foldr(updater, l2, l1), - self.l(a), [a]) - - - def define_list_filter(self): - """Defines a function that filters a list. - - filter(f, l) : fn(fn(a) -> Tensor[(), bool], list[a]) -> list[a] - - It returns the sublist of l consisting of the elements for which f returns true. - """ - self.filter = GlobalVar("filter") - a = TypeVar("a") - f = Var("f", FuncType([a], scalar_type("bool"))) - l = Var("l", self.l(a)) - h = Var("h") - t = Var("t") - nil_case = Clause(PatternConstructor(self.nil), self.nil()) - cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h), PatternVar(t)]), - If(f(h), self.cons(h, self.filter(f, t)), self.filter(f, t))) - self.mod[self.filter] = Function([f, l], Match(l, [nil_case, cons_case]), self.l(a), [a]) - - - def define_list_zip(self): - """Defines a function that combines two lists into a list of tuples of their elements. - - zip(l, m) : fn(list[a], list[b]) -> list[(a, b)] - - The zipped list will be the length of the shorter list. - """ - self.zip = GlobalVar("zip") - a = TypeVar("a") - b = TypeVar("b") - l1 = Var("l1") - l2 = Var("l2") - h1 = Var("h1") - h2 = Var("h2") - t1 = Var("t1") - t2 = Var("t2") - cons_case = Clause(PatternTuple([PatternConstructor(self.cons, - [PatternVar(h1), PatternVar(t1)]), - PatternConstructor(self.cons, - [PatternVar(h2), PatternVar(t2)])]), - self.cons(Tuple([h1, h2]), self.zip(t1, t2))) - nil_case = Clause(PatternWildcard(), self.nil()) - self.mod[self.zip] = Function([l1, l2], Match(Tuple([l1, l2]), [cons_case, nil_case]), - self.l(TupleType([a, b])), [a, b]) - - - def define_list_rev(self): - """Defines a function that reverses a list. - - rev(l) : fn(list[a]) -> list[a] - """ - self.rev = GlobalVar("rev") - a = TypeVar("a") - l = Var("l", self.l(a)) - x = Var("x") - y = Var("y") - updater = Function([y, x], self.cons(x, y)) - self.mod[self.rev] = Function([l], - self.foldl(updater, self.nil(), l), - self.l(a), [a]) - - - def define_list_map_accumr(self): - """Defines an accumulative map, which is a fold that simulataneously updates - an accumulator value and a list of results. - - map_accumr(f, s, l) : fn(fn(a, b) -> (a, c), a, list[b]) -> (a, list[c]) - - This map proceeds through l from right to left. - """ - self.map_accumr = GlobalVar("map_accumr") - a = TypeVar("a") - b = TypeVar("b") - c = TypeVar("c") - f = Var("f", FuncType([a, b], TupleType([a, c]))) - acc = Var("acc", a) - l = Var("l", self.l(b)) - v = Var("v", b) - p = Var("p", TupleType([a, self.l(c)])) - f_out = Var("f_out", TupleType([a, c])) - updater = Function([v, p], - Let(f_out, f(TupleGetItem(p, 0), v), - Tuple([TupleGetItem(f_out, 0), - self.cons(TupleGetItem(f_out, 1), - TupleGetItem(p, 1))])), - TupleType([a, self.l(c)])) - self.mod[self.map_accumr] = Function([f, acc, l], - self.foldr(updater, Tuple([acc, self.nil()]), l), - TupleType([a, self.l(c)]), - [a, b, c]) - - - def define_list_map_accuml(self): - """Defines an accumulative map, which is a fold that simulataneously updates - an accumulator value and a list of results. - - map_accuml(f, s, l) : fn(fn(a, b) -> (a, c), a, list[b]) -> (a, list[c]) - - This map proceeds through l from left to right. - """ - self.map_accuml = GlobalVar("map_accuml") - a = TypeVar("a") - b = TypeVar("b") - c = TypeVar("c") - f = Var("f", FuncType([a, b], TupleType([a, c]))) - acc = Var("acc", a) - l = Var("l", self.l(b)) - v = Var("v", b) - p = Var("p", TupleType([a, self.l(c)])) - f_out = Var("f_out", TupleType([a, c])) - updater = Function([p, v], - Let(f_out, f(TupleGetItem(p, 0), v), - Tuple([TupleGetItem(f_out, 0), - self.cons(TupleGetItem(f_out, 1), - TupleGetItem(p, 1))])), - TupleType([a, self.l(c)])) - self.mod[self.map_accuml] = Function([f, acc, l], - self.foldl(updater, Tuple([acc, self.nil()]), l), - TupleType([a, self.l(c)]), - [a, b, c]) - - - def define_optional_adt(self): - """Defines an optional ADT, which can either contain some other - type or nothing at all.""" - self.optional = GlobalTypeVar("optional") - a = TypeVar("a") - self.some = Constructor("some", [a], self.optional) - self.none = Constructor("none", [], self.optional) - self.mod[self.optional] = TypeData(self.optional, [a], [self.some, self.none]) - - - def define_list_unfoldr(self): - """Defines a function that builds up a list starting from a seed value. - - unfoldr(f, s) : fn(fn(a) -> Optional[(a, b)], a) -> list[b] - - f returns an option containing a new seed and an output value. f will - continue to be called on the new seeds until it returns None. All the - output values will be combined into a list, right to left. - """ - self.unfoldr = GlobalVar("unfoldr") - a = TypeVar("a") - b = TypeVar("b") - f = Var("f", FuncType([a], self.optional(TupleType([a, b])))) - s = Var("s", a) - p = Var("p", TupleType([a, b])) - none_case = Clause(PatternConstructor(self.none), self.nil()) - some_case = Clause(PatternConstructor(self.some, [PatternVar(p)]), - self.cons(TupleGetItem(p, 1), - self.unfoldr(f, TupleGetItem(p, 0)))) - self.mod[self.unfoldr] = Function([f, s], Match(f(s), [none_case, some_case]), - self.l(b), [a, b]) - - - def define_list_unfoldl(self): - """Defines a function that builds up a list starting from a seed value. - - unfoldl(f, s) : fn(fn(a) -> Optional[(a, b)], a) -> list[b] - - f returns an option containing a new seed and an output value. f will - continue to be called on the new seeds until it returns None. All the - output values will be combined into a list, left to right. - """ - self.unfoldl = GlobalVar("unfoldl") - a = TypeVar("a") - b = TypeVar("b") - f = Var("f", FuncType([a], self.optional(TupleType([a, b])))) - s = Var("s", a) - # easiest way to implement is to do a right unfold and reverse - self.mod[self.unfoldl] = Function([f, s], - self.rev(self.unfoldr(f, s)), - self.l(b), [a, b]) - - - def define_list_sum(self): - """Defines a function that computes the sum of a list of integer scalars.""" - self.sum = GlobalVar("sum") - a = Var("a", self.l(scalar_type('int32'))) - x = Var('x') - y = Var('y') - addf = Function([x, y], add(x, y)) - self.mod[self.sum] = Function([a], self.foldl(addf, const(0), a)) - - - def define_list_length(self): - """Defines a function that returns the length of a list""" - self.length = GlobalVar("length") - a = TypeVar("a") - x = Var("x", self.l(a)) - y = Var("y") - nil_case = Clause(PatternConstructor(self.nil), const(0)) - cons_case = Clause(PatternConstructor(self.cons, [PatternWildcard(), PatternVar(y)]), - add(const(1), self.length(y))) - self.mod[self.length] = Function([x], - Match(x, [nil_case, cons_case]), scalar_type('int32'), [a]) - - - def define_tree_adt(self): - """Defines a tree ADT. A tree can contain any type. - It has only one constructor, rose(x, l), where x is the content - of that point of the tree and l is a list of more trees of the - same type. A leaf is thus rose(x, nil()). - """ - self.tree = GlobalTypeVar("tree") - a = TypeVar("a") - self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree) - self.mod[self.tree] = TypeData(self.tree, [a], [self.rose]) - - - def define_tree_map(self): - """Defines a function that maps over a tree. The function - is applied to each subtree's contents. - - Signature: fn(f : fn(a) -> b, t : tree[a]) -> tree[b] - """ - self.tmap = GlobalVar("tmap") - a = TypeVar("a") - b = TypeVar("b") - t = Var("t", self.tree(a)) - f = Var("f", FuncType([a], b)) - x = Var("x", self.tree(a)) - y = Var("y") - z = Var("z") - rose_case = Clause(PatternConstructor(self.rose, [PatternVar(y), PatternVar(z)]), - self.rose(f(y), self.map(Function([x], self.tmap(f, x)), z))) - self.mod[self.tmap] = Function([f, t], - Match(t, [rose_case]), self.tree(b), [a, b]) - - - def define_tree_size(self): - """Defines a function that computes the size of a tree. - - Signature: fn(t : tree[a]) -> Tensor[(), int32] - """ - self.size = GlobalVar("size") - a = TypeVar("a") - t = Var("t", self.tree(a)) - z = Var("z") - rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]), - add(const(1), self.sum(self.map(self.size, z)))) - self.mod[self.size] = Function([t], - Match(t, [rose_case]), scalar_type('int32'), [a]) - - - def define_iterate(self): - """Defines a function that take a number n and a function f; - returns a closure that takes an argument and applies f - n times to its argument. - - Signature: fn(f : fn(a) -> a, n : Tensor[(), int32]) -> fn(a) -> a - """ - self.iterate = GlobalVar("iterate") - a = TypeVar("a") - f = Var("f", FuncType([a], a)) - x = Var("x", scalar_type('int32')) - body = If(equal(x, const(0)), - self.id, - self.compose(f, - self.iterate(f, subtract(x, const(1))))) - self.mod[self.iterate] = Function([f, x], - body, - FuncType([a], a), - [a]) - - def load_prelude(self): - """ - Parses the portions of the Prelude written in Relay's text format and adds - them to the module. - """ - # TODO(@jroesch): we should remove this helper when we port over prelude - self.mod.import_from_std("prelude.rly") - self.id = self.mod.get_global_var("id") - self.compose = self.mod.get_global_var("compose") - - def __init__(self, mod=None): if mod is None: mod = Module() self.mod = mod self.load_prelude() - self.define_list_adt() - self.define_list_hd() - self.define_list_tl() - self.define_list_map() - self.define_list_foldl() - self.define_list_foldr() - self.define_list_foldr1() - self.define_list_concat() - self.define_list_filter() - self.define_list_zip() - self.define_list_rev() - self.define_list_map_accumr() - self.define_list_map_accuml() - self.define_optional_adt() - self.define_list_unfoldr() - self.define_list_unfoldl() - - self.define_list_length() - self.define_list_nth() - self.define_list_update() - self.define_list_sum() - - self.define_tree_adt() - self.define_tree_map() - self.define_tree_size() + def load_prelude(self): + """Parses the Prelude from Relay's text format into a module.""" + # TODO(@jroesch): we should remove this helper when we port over prelude + self.mod.import_from_std("prelude.rly") - self.define_iterate() + self.l = self.mod.get_global_type_var("List") + list_adt = self.mod[self.l] + self.cons = list_adt.constructors[0] + self.nil = list_adt.constructors[1] + + self.optional = self.mod.get_global_type_var("Option") + optional_adt = self.mod[self.optional] + self.some = optional_adt.constructors[0] + self.none = optional_adt.constructors[1] + + self.tree = self.mod.get_global_type_var("Tree") + tree_adt = self.mod[self.tree] + self.rose = tree_adt.constructors[0] + + GLOBAL_DEFS = [ + "id", + "compose", + "flip", + "hd", + "tl", + "nth", + "update", + "map", + "foldl", + "foldr", + "foldr1", + "concat", + "filter", + "zip", + "rev", + "map_accuml", + "map_accumr", + "unfoldl", + "unfoldr", + "sum", + "length", + "tmap", + "size", + "iterate", + ] + for global_def in GLOBAL_DEFS: + setattr(self, global_def, self.mod.get_global_var(global_def)) diff --git a/python/tvm/relay/std/prelude.rly b/python/tvm/relay/std/prelude.rly index 6b6047c1b0a7..a5c2c9f8a9cb 100644 --- a/python/tvm/relay/std/prelude.rly +++ b/python/tvm/relay/std/prelude.rly @@ -18,12 +18,299 @@ */ v0.0.4 -def @id[a](%x: a) -> a { - %x +// TODO(weberlo): should we add sugar for scalar types (e.g., `int32` => `Tensor[(), int32]`)? + +def @id[A](%x: A) -> A { + %x +} + +def @compose[A, B, C](%f: fn(B) -> C, %g: fn(A) -> B) { + fn (%x: A) -> C { + %f(%g(%x)) + } +} + +def @flip[A, B, C](%f: fn(A, B) -> C) -> fn(B, A) -> C { + fn(%b: B, %a: A) -> C { + %f(%a, %b) + } +} + +/* + * A LISP-style list ADT. An empty list is represented by `Nil`, and a member + * `x` can be appended to the front of a list `l` via the constructor `Cons(x, l)`. + */ +type List[A] { + Cons(A, List[A]), + Nil, +} + +/* + * Get the head of a list. Assume the list has at least one element. + */ +def @hd[A](%xs: List[A]) -> A { + match? (%xs) { + Cons(%x, _) => %x, + } +} + +/* + * Get the tail of a list. + */ +def @tl[A](%xs: List[A]) -> List[A] { + match? (%xs) { + Cons(_, %rest) => %rest, + } +} + +/* + * Get the `n`th element of a list. + */ +def @nth[A](%xs: List[A], %n: Tensor[(), int32]) -> A { + if (%n == 0) { + @hd(%xs) + } else { + @nth(@tl(%xs), %n - 1) + } +} + +/* + * Return the length of a list. + */ +def @length[A](%xs: List[A]) -> Tensor[(), int32] { + match (%xs) { + Cons(_, %rest) => 1 + @length(%rest), + Nil => 0, + } +} + +/* + * Update the `n`th element of a list and return the updated list. + */ +def @update[A](%xs: List[A], %n: Tensor[(), int32], %v: A) -> List[A] { + if (%n == 0) { + Cons(%v, @tl(%xs)) + } else { + Cons(@hd(%xs), @update(@tl(%xs), %n - 1, %v)) + } +} + +/* + * Map a function over a list's elements. That is, `map(f, xs)` returns a new + * list where the `i`th member is `f` applied to the `i`th member of `xs`. + */ +def @map[A, B](%f: fn(A) -> B, %xs: List[A]) -> List[B] { + match (%xs) { + Cons(%x, %rest) => Cons(%f(%x), @map(%f, %rest)), + Nil => Nil, + } +} + +/* + * A left-way fold over a list. + * + * `foldl(f, z, cons(a1, cons(a2, cons(a3, cons(..., nil)))))` + * evaluates to `f(...f(f(f(z, a1), a2), a3)...)`. + */ +def @foldl[A, B](%f: fn(A, B) -> A, %acc: A, %xs: List[B]) -> A { + match (%xs) { + Cons(%x, %rest) => @foldl(%f, %f(%acc, %x), %rest), + Nil => %acc, + } } -def @compose[a, b, c](%f: fn(b) -> c, %g: fn(a) -> b) { - fn (%x: a) -> c { - %f(%g(%x)) - } +/* + * A right-way fold over a list. + * + * `foldr(f, z, cons(a1, cons(a2, cons(..., cons(an, nil)))))` + * evaluates to `f(a1, f(a2, f(..., f(an, z)))...)`. + */ +def @foldr[A, B](%f: fn(A, B) -> B, %acc: B, %xs: List[A]) -> B { + match (%xs) { + Cons(%x, %rest) => %f(%x, @foldr(%f, %acc, %rest)), + Nil => %acc, + } +} + +/* + * A right-way fold over a nonempty list. + * + * `foldr1(f, cons(a1, cons(a2, cons(..., cons(an, nil)))))` + * evaluates to `f(a1, f(a2, f(..., f(an-1, an)))...)` + */ +def @foldr1[A](%f: fn(A, A) -> A, %xs: List[A]) -> A { + match? (%xs) { + Cons(%x, Nil) => %x, + Cons(%x, %rest) => %f(%x, @foldr1(%f, %rest)), + } +} + +/* + * Computes the sum of a list of integer scalars. + */ +def @sum(%xs: List[Tensor[(), int32]]) { + let %add_f = fn(%x: Tensor[(), int32], %y: Tensor[(), int32]) -> Tensor[(), int32] { + %x + %y + }; + @foldl(%add_f, 0, %xs) +} + +/* + * Concatenates two lists. + */ +def @concat[A](%xs: List[A], %ys: List[A]) -> List[A] { + let %updater = fn(%x: A, %xss: List[A]) -> List[A] { + Cons(%x, %xss) + }; + @foldr(%updater, %ys, %xs) + // TODO(weberlo): write it like below, once VM constructor compilation is fixed + // @foldr(Cons, %ys, %xs) +} + +/* + * Filters a list, returning a sublist of only the values which satisfy the given predicate. + */ +def @filter[A](%f: fn(A) -> Tensor[(), bool], %xs: List[A]) -> List[A] { + match (%xs) { + Cons(%x, %rest) => { + if (%f(%x)) { + Cons(%x, @filter(%f, %rest)) + } else { + @filter(%f, %rest) + } + }, + Nil => Nil, + } +} + +/* + * Combines two lists into a list of tuples of their elements. + * + * The zipped list will be the length of the shorter list. + */ +def @zip[A, B](%xs: List[A], %ys: List[B]) -> List[(A, B)] { + match (%xs, %ys) { + (Cons(%x, %x_rest), Cons(%y, %y_rest)) => Cons((%x, %y), @zip(%x_rest, %y_rest)), + _ => Nil, + } +} + +/* + * Reverses a list. + */ +def @rev[A](%xs: List[A]) -> List[A] { + let %updater = fn(%xss: List[A], %x: A) -> List[A] { + Cons(%x, %xss) + }; + @foldl(%updater, Nil, %xs) + // TODO(weberlo): write it like below, once VM constructor compilation is fixed + // @foldl(@flip(Cons), Nil, %xs) +} + +/* + * An accumulative map, which is a fold that simulataneously updates an + * accumulator value and a list of results. + * + * This map proceeds through the list from right to left. + */ +def @map_accumr[A, B, C](%f: fn(A, B) -> (A, C), %init: A, %xs: List[B]) -> (A, List[C]) { + let %updater = fn(%x: B, %acc: (A, List[C])) -> (A, List[C]) { + let %f_out = %f(%acc.0, %x); + (%f_out.0, Cons(%f_out.1, %acc.1)) + }; + @foldr(%updater, (%init, Nil), %xs) +} + +/* + * an accumulative map, which is a fold that simulataneously updates an + * accumulator value and a list of results. + * + * This map proceeds through the list from left to right. + */ +def @map_accuml[A, B, C](%f: fn(A, B) -> (A, C), %init: A, %xs: List[B]) -> (A, List[C]) { + let %updater = fn(%acc: (A, List[C]), %x: B) -> (A, List[C]) { + let %f_out = %f(%acc.0, %x); + (%f_out.0, Cons(%f_out.1, %acc.1)) + }; + @foldl(%updater, (%init, Nil), %xs) +} + +/* + * An optional ADT, which can either contain some other type or nothing at all. + */ +type Option[A] { + Some(A), + None, +} + +/* + * Builds up a list starting from a seed value. + * + * `f` returns an option containing a new seed and an output value. `f` will + * continue to be called on the new seeds until it returns `None`. All the output + * values will be combined into a list, right to left. + */ +def @unfoldr[A, B](%f: fn(A) -> Option[(A, B)], %seed: A) -> List[B] { + match (%f(%seed)) { + Some(%val) => Cons(%val.1, @unfoldr(%f, %val.0)), + None => Nil, + } +} + +/* + * Builds up a list starting from a seed value. + * + * `f` returns an option containing a new seed and an output value. `f` will + * continue to be called on the new seeds until it returns `None`. All the + * output values will be combined into a list, left to right. + */ +def @unfoldl[A, B](%f: fn(A) -> Option[(A, B)], %seed: A) -> List[B] { + @rev(@unfoldr(%f, %seed)) +} + +/* + * A tree ADT. A tree can contain any type. It has only one + * constructor, rose(x, l), where x is the content of that point of the tree + * and l is a list of more trees of the same type. A leaf is thus rose(x, + * nil()). + */ +type Tree[A] { + Rose(A, List[Tree[A]]), +} + +/* + * Maps over a tree. The function is applied to each subtree's contents. + */ +def @tmap[A, B](%f: fn(A) -> B, %t: Tree[A]) -> Tree[B] { + match(%t) { + Rose(%v, %sub_trees) => { + let %list_f = fn(%tt: Tree[A]) -> Tree[B] { + @tmap(%f, %tt) + }; + Rose(%f(%v), @map(%list_f, %sub_trees)) + }, + } +} + +/* + * Computes the size of a tree. + */ +def @size[A](%t: Tree[A]) -> Tensor[(), int32] { + match(%t) { + Rose(_, %sub_trees) => { + 1 + @sum(@map(@size, %sub_trees)) + }, + } +} + +/* + * Takes a number n and a function f; returns a closure that takes an argument + * and applies f n times to its argument. + */ +def @iterate[A](%f: fn(A) -> A, %n: Tensor[(), int32]) -> (fn(A) -> A) { + if (%n == 0) { + @id + } else { + @compose(%f, @iterate(%f, %n - 1)) + } } diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index f60f6594559c..2032112f2a85 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 2601f355d03e..0e8e6f5591dd 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -46,14 +46,14 @@ Module ModuleNode::make(tvm::Map global_funcs, for (const auto& kv : n->functions) { // set global var map - CHECK(!n->global_var_map_.count(kv.first->name_hint)) + CHECK(n->global_var_map_.count(kv.first->name_hint) == 0) << "Duplicate global function name " << kv.first->name_hint; n->global_var_map_.Set(kv.first->name_hint, kv.first); } for (const auto& kv : n->type_definitions) { // set global typevar map - CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint)) + CHECK(n->global_type_var_map_.count(kv.first->var->name_hint) == 0) << "Duplicate global type definition name " << kv.first->var->name_hint; n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first); n->RegisterConstructors(kv.first, kv.second); @@ -73,20 +73,12 @@ GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const { return (*it).second; } -void ModuleNode::AddUnchecked(const GlobalVar& var, - const Function& func) { - auto mod = GetRef(this); - this->functions.Set(var, func); - - auto it = global_var_map_.find(var->name_hint); - if (it != global_var_map_.end()) { - CHECK_EQ((*it).second, var); - } else { - CHECK(!global_var_map_.count(var->name_hint)) - << "Duplicate global function name " << var->name_hint; +tvm::Array ModuleNode::GetGlobalVars() const { + std::vector global_vars; + for (const auto& pair : global_var_map_) { + global_vars.push_back(pair.second); } - - global_var_map_.Set(var->name_hint, var); + return tvm::Array(global_vars); } GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { @@ -97,6 +89,14 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { return (*it).second; } +tvm::Array ModuleNode::GetGlobalTypeVars() const { + std::vector global_type_vars; + for (const auto& pair : global_type_var_map_) { + global_type_vars.push_back(pair.second); + } + return tvm::Array(global_type_vars); +} + template tvm::Array concat(const tvm::Array& l, const tvm::Array& r) { tvm::Array ret(l); @@ -151,6 +151,22 @@ void ModuleNode::Add(const GlobalVar& var, AddUnchecked(var, checked_func); } +void ModuleNode::AddUnchecked(const GlobalVar& var, + const Function& func) { + auto mod = GetRef(this); + this->functions.Set(var, func); + + auto it = global_var_map_.find(var->name_hint); + if (it != global_var_map_.end()) { + CHECK_EQ((*it).second, var); + } else { + CHECK(global_var_map_.count(var->name_hint) == 0) + << "Duplicate global function name " << var->name_hint; + } + + global_var_map_.Set(var->name_hint, var); +} + void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) { // We hash the global type var name to use as a globally unique prefix for tags. // The hash will be used as the most significant byte of the tag, with the index of @@ -163,25 +179,33 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& } } -void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { - this->type_definitions.Set(var, type); - // set global type var map - CHECK(!global_type_var_map_.count(var->var->name_hint)) - << "Duplicate global type definition name " << var->var->name_hint; - - global_type_var_map_.Set(var->var->name_hint, var); - RegisterConstructors(var, type); - +void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type, bool update) { + AddDefUnchecked(var, type, update); // need to kind check at the end because the check can look up // a definition potentially CHECK(KindCheck(type, GetRef(this)) == Kind::kTypeData) << "Invalid or malformed typedata given to module: " << type; } +void ModuleNode::AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) { + this->type_definitions.Set(var, type); + if (!update) { + // set global type var map + CHECK(global_type_var_map_.count(var->var->name_hint) == 0) + << "Duplicate global type definition name " << var->var->name_hint; + } + global_type_var_map_.Set(var->var->name_hint, var); + RegisterConstructors(var, type); +} + void ModuleNode::Update(const GlobalVar& var, const Function& func) { this->Add(var, func, true); } +void ModuleNode::UpdateDef(const GlobalTypeVar& var, const TypeData& type) { + this->AddDef(var, type, true); +} + void ModuleNode::Remove(const GlobalVar& var) { auto functions_node = this->functions.CopyOnWrite(); functions_node->data.erase(var.node_); @@ -226,9 +250,20 @@ Constructor ModuleNode::LookupTag(const int32_t tag) { } void ModuleNode::Update(const Module& mod) { + // add functions and type defs. we add them unchecked first, so all definitions + // can reference each other, independent of the order in which they were defined. + for (auto pair : mod->functions) { + this->AddUnchecked(pair.first, pair.second); + } + for (auto pair : mod->type_definitions) { + this->AddDefUnchecked(pair.first, pair.second); + } for (auto pair : mod->functions) { this->Update(pair.first, pair.second); } + for (auto pair : mod->type_definitions) { + this->UpdateDef(pair.first, pair.second); + } } Module ModuleNode::FromExpr( @@ -257,14 +292,7 @@ void ModuleNode::Import(const std::string& path) { std::istreambuf_iterator(src_file), std::istreambuf_iterator() }; auto mod_to_import = FromText(file_contents, path); - - for (auto func : mod_to_import->functions) { - this->Add(func.first, func.second, false); - } - - for (auto type : mod_to_import->type_definitions) { - this->AddDef(type.first, type.second); - } + Update(mod_to_import); } } @@ -315,6 +343,12 @@ TVM_REGISTER_API("relay._module.Module_AddDef") TVM_REGISTER_API("relay._module.Module_GetGlobalVar") .set_body_method(&ModuleNode::GetGlobalVar); +TVM_REGISTER_API("relay._module.Module_GetGlobalVars") +.set_body_method(&ModuleNode::GetGlobalVars); + +TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVars") +.set_body_method(&ModuleNode::GetGlobalTypeVars); + TVM_REGISTER_API("relay._module.Module_ContainGlobalVar") .set_body_method(&ModuleNode::ContainGlobalVar); diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 22bdbcbb2d6a..afc8ad9dcf6a 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -570,7 +570,13 @@ class PrettyPrinter : } else { doc << Print(op->op); } - return doc << "(" << PrintSep(args) << ")"; + + if (cons_node && cons_node->inputs.size() == 0) { + // don't print as a call if it's a 0-arity cons + return doc; + } else { + return doc << "(" << PrintSep(args) << ")"; + } } Doc VisitExpr_(const RefCreateNode* op) final { @@ -641,6 +647,17 @@ class PrettyPrinter : return doc; } + Doc VisitPattern_(const PatternTupleNode* pt) final { + Doc doc; + doc << "("; + std::vector pats; + for (const auto& pat : pt->patterns) { + pats.push_back(Print(pat)); + } + doc << PrintSep(pats) << ")"; + return doc; + } + Doc VisitPattern_(const PatternWildcardNode* pw) final { return Doc("_"); } diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index e6104da5d7d1..e9a24bfa31d0 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -800,12 +800,13 @@ def test_adt_cons_expr(): %s def @make_singleton(%%x: int32) -> List[int32] { - Cons(%%x, Nil()) + Cons(%%x, Nil) } """ % LIST_DEFN, mod ) + @raises_parse_error def test_duplicate_adt_defn(): parse_text( From f5f2feeaa2640b570e69bf4465332e6f6773a6d5 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 29 Sep 2019 22:06:58 -0700 Subject: [PATCH 09/17] [ARITH] migrate indexdiv/mod to floordiv/mod (#4008) --- python/tvm/expr.py | 9 +++------ src/lang/attr_functor.h | 13 +++++++++++-- src/lang/attrs.cc | 8 ++++++-- src/lang/buffer.cc | 4 ++-- src/lang/expr_operator.cc | 4 ++-- src/pass/lower_intrin.cc | 14 ++++++++++---- tests/python/unittest/test_codegen_device.py | 2 ++ tests/python/unittest/test_codegen_vm_basic.py | 1 + topi/python/topi/cuda/nms.py | 2 +- 9 files changed, 38 insertions(+), 19 deletions(-) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 5b7c60d819bd..733f57a68c56 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -92,16 +92,13 @@ def __rtruediv__(self, other): return _generic.divide(other, self) def __floordiv__(self, other): - # return _generic.floordiv(self, other) - return _generic.divide(self, other) + return _generic.floordiv(self, other) def __rfloordiv__(self, other): - # return _generic.floordiv(other, self) - return _generic.divide(other, self) + return _generic.floordiv(other, self) def __mod__(self, other): - raise div_ambiguity_error() - # return _make._OpMod(self, other) + return _make._OpFloorMod(self, other) def __neg__(self): neg_one = _api_internal._const(-1, self.dtype) diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 249ce523a3cc..995dfb392e87 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -87,6 +87,8 @@ class AttrFunctor { virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::FloorDiv* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::FloorMod* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -119,6 +121,9 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(Sub); ATTR_FUNCTOR_DISPATCH(Mul); ATTR_FUNCTOR_DISPATCH(Div); + ATTR_FUNCTOR_DISPATCH(Mod); + ATTR_FUNCTOR_DISPATCH(FloorDiv); + ATTR_FUNCTOR_DISPATCH(FloorMod); ATTR_FUNCTOR_DISPATCH(Min); ATTR_FUNCTOR_DISPATCH(Max); ATTR_FUNCTOR_DISPATCH(GE); @@ -160,6 +165,8 @@ class AttrsEqualHandler : bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::FloorDiv* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::FloorMod* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final; @@ -201,6 +208,8 @@ class AttrsHashHandler : size_t VisitAttr_(const ir::Mul* op) final; size_t VisitAttr_(const ir::Div* op) final; size_t VisitAttr_(const ir::Mod* op) final; + size_t VisitAttr_(const ir::FloorDiv* op) final; + size_t VisitAttr_(const ir::FloorMod* op) final; size_t VisitAttr_(const ir::Min* op) final; size_t VisitAttr_(const ir::Max* op) final; size_t VisitAttr_(const ir::GE* op) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index ec2fd742ba14..c5b14ac577ec 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -154,6 +154,8 @@ TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub); TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul); TVM_DEFINE_ATTRS_BINOP_EQUAL(Div); TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod); +TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDiv); +TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorMod); TVM_DEFINE_ATTRS_BINOP_EQUAL(Max); TVM_DEFINE_ATTRS_BINOP_EQUAL(Min); TVM_DEFINE_ATTRS_BINOP_EQUAL(GE); @@ -266,6 +268,8 @@ TVM_DEFINE_ATTRS_BINOP_HASH(Sub); TVM_DEFINE_ATTRS_BINOP_HASH(Mul); TVM_DEFINE_ATTRS_BINOP_HASH(Div); TVM_DEFINE_ATTRS_BINOP_HASH(Mod); +TVM_DEFINE_ATTRS_BINOP_HASH(FloorDiv); +TVM_DEFINE_ATTRS_BINOP_HASH(FloorMod); TVM_DEFINE_ATTRS_BINOP_HASH(Max); TVM_DEFINE_ATTRS_BINOP_HASH(Min); TVM_DEFINE_ATTRS_BINOP_HASH(GE); diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 206056bf889b..689b291ae2ed 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -32,8 +32,8 @@ namespace tvm { // TODO(tqchen): change to floormod/div -using IndexMod = ir::Mod; -using IndexDiv = ir::Div; +using IndexMod = ir::FloorMod; +using IndexDiv = ir::FloorDiv; Array SimplifyArray(Array array) { for (size_t i = 0; i < array.size(); ++i) { diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 46a0737eab7e..9c9100b1902e 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -208,11 +208,11 @@ Expr operator%(Expr a, Expr b) { // TODO(tqchen): switch to floordiv Expr indexdiv(Expr a, Expr b) { - return truncdiv(a, b); + return floordiv(a, b); } Expr indexmod(Expr a, Expr b) { - return truncmod(a, b); + return floormod(a, b); } Expr floordiv(Expr a, Expr b) { diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index bbc3c572ca7e..3935d23cce0c 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -46,6 +46,9 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { patterns_.push_back("tvm.intrin.rule." + starget + "."); patterns_.push_back("tvm.intrin.rule.default."); fma_ = runtime::Registry::Get(patterns_[0] + "fma"); + if (target == "stackvm") { + support_bitwise_op_ = false; + } } Expr Mutate_(const Call* op, const Expr& e) final { @@ -76,7 +79,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { const DataType& dtype = op->type; CHECK(dtype.is_int() || !dtype.is_uint()); - if (is_const_power_of_two_integer(op->b, &shift)) { + if (support_bitwise_op_ && + is_const_power_of_two_integer(op->b, &shift)) { // lower to right shift if possible. return op->a >> make_const(dtype, shift); } @@ -93,7 +97,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { // condition on b >= 0. // truncmod(a, b) < 0 will implies ceildiv, // So we need to correct these cases. - if (dtype == Int(32) || dtype == Int(64)) { + if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) { // equivalent to rdiv + (rmod >= 0 ? 0: -1); return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); } else { @@ -122,7 +126,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { const DataType& dtype = op->type; CHECK(dtype.is_int() || !dtype.is_uint()); - if (is_const_power_of_two_integer(op->b, &shift)) { + if (support_bitwise_op_ && + is_const_power_of_two_integer(op->b, &shift)) { // lower to masking if possible. int64_t mask = ( static_cast(1) << static_cast(shift)) - 1; @@ -140,7 +145,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { // mod(a, b) < 0 will imply we are doing ceildiv, // So we need to correct these cases. Expr rmod = truncmod(op->a, op->b); - if (dtype == Int(32) || dtype == Int(64)) { + if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) { // (rmod >> shift) & b // -> (rmod >= 0 ? 0: -1) & b // -> rmod >= 0 ? 0 : b @@ -268,6 +273,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { // patterns std::vector patterns_; const PackedFunc* fma_{nullptr}; + bool support_bitwise_op_{true}; }; Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) { diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 6cb424c8a5eb..45ecf9539337 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -48,6 +48,8 @@ def test_add_pipeline(): stmt = tvm.ir_pass.Simplify(stmt) fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True) fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)] + # lower the floordiv(use stackvm rules so it works for all targets) + fsplits = [tvm.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits] fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0]) def check_target(device, host="stackvm"): diff --git a/tests/python/unittest/test_codegen_vm_basic.py b/tests/python/unittest/test_codegen_vm_basic.py index a9b382f1fd61..7ff217728034 100644 --- a/tests/python/unittest/test_codegen_vm_basic.py +++ b/tests/python/unittest/test_codegen_vm_basic.py @@ -37,6 +37,7 @@ def tvm_call_back_get_shape(shape0): stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0])) fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) + fapi = tvm.ir_pass.LowerIntrin(fapi, "stackvm") run_jit(fapi, lambda f: f(a)) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 33fc7249802b..d032527ec273 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -185,7 +185,7 @@ def get_valid_counts_scan(data, partial_in, partial): ib.scope_attr(bx, "thread_extent", nthread_bx) var = tvm.make.node("FloatImm", dtype="float32", value=2) new_range = num_anchors // elem_per_thread + 1 - iteration = log(cast(new_range, "float32")) // math.log(2) + iteration = cast(log(cast(new_range, "float32")) / math.log(2), "int32") # Scan: Kogge-Stone adder with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))): with ib.for_range(0, iteration) as k: From 1bff2c89dec4b5b044c7b2284e65552bc2b2107b Mon Sep 17 00:00:00 2001 From: ndl Date: Mon, 30 Sep 2019 18:25:24 +0200 Subject: [PATCH 10/17] Add dmlc-core to the list of installed header directories. (#4035) There are dependencies on dmlc-core in TVM public API headers (e.g. some headers include dmlc/logging.h) so it needs to be installed as part of TVM for TVM headers to be actually usable. --- CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index abf198de1c53..d60e2ad356d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -361,6 +361,11 @@ if (INSTALL_DEV) FILES_MATCHING PATTERN "*.h" ) + install( + DIRECTORY "3rdparty/dmlc-core/include/." DESTINATION "include" + FILES_MATCHING + PATTERN "*.h" + ) install( DIRECTORY "nnvm/include/." DESTINATION "include" FILES_MATCHING From d0fe532ed8edf2f1474beec868b0742a3d503c5b Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 30 Sep 2019 10:06:35 -0700 Subject: [PATCH 11/17] [Relay][Compile_engine] Int64 shape handling for outputs. (#4031) --- src/relay/backend/compile_engine.cc | 25 ++++++++++++++++--- .../relay/test_backend_compile_engine.py | 15 +++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index c88703ea4d05..a75cdb299bf4 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -219,6 +219,25 @@ class ScheduleGetter : CHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; } + + // Prepare the call_node->checked_type(). For the call node inputs, we ensure that the shape is + // Int32. Following code ensures the same for the output as well. + // TODO(@icemelon): Support recursive tuple + Type call_node_type = call_node->checked_type(); + if (const auto* tt = call_node->checked_type().as()) { + call_node_type = TensorTypeNode::make(GetShape(tt->shape), tt->dtype); + } else if (const auto* tuple_t = call_node->checked_type().as()) { + std::vector new_fields; + for (auto field : tuple_t->fields) { + if (const auto* tt = field.as()) { + new_fields.push_back(TensorTypeNode::make(GetShape(tt->shape), tt->dtype)); + } else { + new_fields.push_back(field); + } + } + call_node_type = TupleTypeNode::make(new_fields); + } + CHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); @@ -232,7 +251,7 @@ class ScheduleGetter : Operation(), 0)); } else { outputs = fcompute[op](call_node->attrs, inputs, - call_node->checked_type(), target_); + call_node_type, target_); } int op_pattern = fpattern[op]; diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_backend_compile_engine.py index ea16a8d6122e..b1f41a43148c 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_backend_compile_engine.py @@ -79,8 +79,23 @@ def test_compile_tuple_dup(): relay.build(relay.Module.from_expr(f), 'llvm') +def test_compile_full(): + # Shape calculations can happen in int64. The test checks that full operator + # can handle when shapes are not int32 + shape = (tvm.expr.IntImm('int32', 1), + tvm.expr.IntImm('int64', 16), + tvm.expr.IntImm('int64', 16), + tvm.expr.IntImm('int32', 64)) + output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32') + f = relay.Function([], output) + mod = relay.Module.from_expr(f) + mod = relay.qnn.transform.CanonicalizeOps()(mod) + relay.build(mod, 'llvm') + + if __name__ == "__main__": test_compile_engine() test_compile_placeholder_bypass() test_compile_injective_with_tuple() test_compile_tuple_dup() + test_compile_full() From 0cd80478ba3e34a4cae4eb950638529ba9acf066 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 30 Sep 2019 10:24:36 -0700 Subject: [PATCH 12/17] [QNN] Renaming dense operator. (#4033) --- python/tvm/relay/qnn/op/qnn.py | 12 ++++++------ tests/python/relay/test_qnn_dense.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 878a3a72b01e..ed443abb5293 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -313,12 +313,12 @@ def add(lhs, output_scale, output_zero_point) -def quantized_dense(data, - weight, - input_zero_point, - kernel_zero_point, - units=None, - out_dtype="int32"): +def dense(data, + weight, + input_zero_point, + kernel_zero_point, + units=None, + out_dtype="int32"): """Qnn Dense operator. Applies a quantized linear transformation diff --git a/tests/python/relay/test_qnn_dense.py b/tests/python/relay/test_qnn_dense.py index 2d14593b331f..f1e0767aff2e 100644 --- a/tests/python/relay/test_qnn_dense.py +++ b/tests/python/relay/test_qnn_dense.py @@ -153,7 +153,7 @@ def qnn_dense_driver(test_configuration): quantized_kernel = relay.var(quantized_kernel_name, shape=test_configuration['kernel_shape'], dtype=in_dtype) - mod = relay.qnn.op.quantized_dense( + mod = relay.qnn.op.dense( quantized_data, quantized_kernel, test_configuration['input_zero_point'], From 85a1d3ff8c18682d6f1c1713933ab26380be36ab Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 30 Sep 2019 12:14:59 -0700 Subject: [PATCH 13/17] [COMMUNITY] anijain2305 -> reviewer (#4036) --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 590ff7d0963a..f1e5a019c949 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -74,6 +74,7 @@ We do encourage everyone to work anything they are interested in. - [Hao Lu](https://github.com/hlu1): @hlu1 - [Nick Hynes](https://github.com/nhynes): @nhynes - [Yuwei Hu](https://github.com/Huyuwei): @Huyuwei +- [Animesh Jain](https://github.com/anijain2305): @anijain2305 - [Yizhi Liu](https://github.com/yzhliu) : @yzhliu - [Zhixun Tan](https://github.com/phisiart): @phisiart - [Zhi Chen](https://github.com/zhiics): @zhiics From 5cc17649f491299ddf15a8eb144fbb6732382c9a Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 1 Oct 2019 23:40:16 +0800 Subject: [PATCH 14/17] [topi] add ARM v8.2 udot (uint8) support (#3978) * [topi] add ARM v8.2 udot (uint8) support * fix test case * fix common conv2d schedule * add back fp32_time in test * fix lint * fix doc, add support for int32_lanes=4, signed int * fix lint * add ic_bn % 4 checker in schedule --- topi/python/topi/arm_cpu/__init__.py | 1 + topi/python/topi/arm_cpu/conv2d_int8.py | 112 ++++++++++ topi/python/topi/arm_cpu/tensor_intrin.py | 110 ++++++++++ topi/python/topi/generic/conv2d.py | 239 ++++++++++++++++++++++ topi/python/topi/nn/conv2d.py | 8 - topi/python/topi/x86/conv2d_avx_1x1.py | 99 +-------- topi/python/topi/x86/conv2d_avx_common.py | 70 +------ topi/python/topi/x86/conv2d_int8.py | 7 +- topi/recipe/conv/test_conv_int8_arm.py | 158 ++++++++++++++ 9 files changed, 633 insertions(+), 171 deletions(-) create mode 100644 topi/python/topi/arm_cpu/conv2d_int8.py create mode 100644 topi/python/topi/arm_cpu/tensor_intrin.py create mode 100644 topi/python/topi/generic/conv2d.py create mode 100644 topi/recipe/conv/test_conv_int8_arm.py diff --git a/topi/python/topi/arm_cpu/__init__.py b/topi/python/topi/arm_cpu/__init__.py index 6cf4d9139343..32751bf58458 100644 --- a/topi/python/topi/arm_cpu/__init__.py +++ b/topi/python/topi/arm_cpu/__init__.py @@ -3,6 +3,7 @@ from . import conv2d from . import depthwise_conv2d from . import conv2d_transpose +from . import conv2d_int8 from . import bitserial_conv2d from . import bitserial_dense from . import injective diff --git a/topi/python/topi/arm_cpu/conv2d_int8.py b/topi/python/topi/arm_cpu/conv2d_int8.py new file mode 100644 index 000000000000..8f43f5c210d4 --- /dev/null +++ b/topi/python/topi/arm_cpu/conv2d_int8.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Conv2D int8 schedule on ARM""" + +import tvm +from tvm import autotvm +from .. import generic, tag +from ..util import get_const_tuple +from ..nn.conv2d import conv2d_NCHWc_int8 +from ..generic import conv2d as conv2d_generic +from .. import nn +from ..nn.conv2d import _get_workload as _get_conv2d_workload +from .tensor_intrin import dot_int8_int8_int32 + + +def _get_default_config(cfg, data, kernel, strides, padding, out_dtype): + """ + Get default int8 schedule config for the workload + """ + wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype) + is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1 + if is_kernel_1x1: + conv2d_generic.fallback_schedule_cpu_1x1_int8( + cfg, wkl, int32_lanes=2, num_int8_elements=4) + else: + conv2d_generic.fallback_schedule_cpu_common_int8( + cfg, wkl, int32_lanes=2, num_int8_elements=4) + + +@autotvm.register_topi_compute(conv2d_NCHWc_int8, ['arm_cpu'], 'direct') +def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides, + padding, dilation, layout, out_layout, out_dtype): + # layout and out_layout are not used here, + # we keep them for debug convenience when dumping autotvm workload + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + in_channel = ic_chunk * ic_bn + + oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, n_elems = get_const_tuple(kernel.shape) + num_filter = oc_chunk * oc_bn + + # If no config was set, we can fallback to NCHW config. + if cfg.is_fallback: + _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), + tvm.placeholder((num_filter, in_channel, kh, kw), dtype=kernel.dtype), + strides, padding, out_dtype) + return nn.conv2d_NCHWc_int8_compute(data, + kernel, + strides, + padding, + dilation, + layout, + out_layout, + out_dtype) + + +@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc_int8, ['arm_cpu'], ['direct']) +def _schedule_conv2d_NCHWc_int8(cfg, outs): + """Create schedule for tensors""" + s = tvm.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: + traverse(tensor.op) + + if 'conv2d_NCHWc_int8' in op.tag: + conv_out = op.output(0) + kernel = conv_out.op.input_tensors[1] + data_vec = conv_out.op.input_tensors[0] + data = data_vec.op.input_tensors[0] \ + if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \ + else data_vec + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + args = [s, cfg, data_vec, conv_out, outs[0]] + # int8 conv kernel is 7-dim + _, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape) + dtype = "uint" if data.dtype == "uint8" else "int" + if kh == 1 and kw == 1: + conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8( + *args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype)) + else: + conv2d_generic.schedule_conv_NCHWc_cpu_common_int8( + *args, int32_lanes=4, intrin=dot_int8_int8_int32(int32_lanes=4, dtype=dtype)) + + scheduled_ops.append(op) + + traverse(outs[0].op) + return s diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py new file mode 100644 index 000000000000..2f300a18e117 --- /dev/null +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Conv2D int8 schedule on ARM""" + +import tvm + +def dot_int8_int8_int32(int32_lanes, dtype='uint'): + """ + Int8 dot product by every 4 elements using ARM v8.2 udot. + This function takes two arrays of int8 datatype -- data[4] and + kernel[int32_lanes][4] -- and computes a dot product of data[4] with every + 4 elements of kernels, resulting in output[int32_lanes] of uint32 datatype. + The pseudo code is as follows. + + .. code-block:: c + + void dot_int8_int8_int32(int8 data[4], int8 kernel[16][4], int32 output[16]){ + for (int i = 0; i < int32_lanes; i++){ + out[i] = 0; + for (int k = 0; k < 4; k++){ + out[i] += data[k] * kernel[i][k] + } + } + } + + Physically, the kernel array sits in a vector register and + the data[4] is broadcasted to another vector register. This + function returns a TensorIntrin that can be used to tensorize + a schedule. + + Parameters + ---------- + int32_lanes: int + How many int32/uint32 to produce + dtype: str, optional, {"uint", "int"} + Whether it works on unsigned int or signed int + + Returns + ------- + intrin : TensorIntrin + The ARM uint8 TensorIntrin that can be used in tensorizing schedule + """ + num_int8_elements = 4 # 4 int8 elements in int32 + + data = tvm.placeholder((num_int8_elements,), dtype='%s8' % dtype, name='data') + kernel = tvm.placeholder((int32_lanes, num_int8_elements), dtype='%s8' % dtype, name='kernel') + + k = tvm.reduce_axis((0, num_int8_elements), name='k') + C = tvm.compute((int32_lanes,), + lambda i: tvm.sum(data[k].astype('%s32' % dtype) * + kernel[i, k].astype('%s32' % dtype), + axis=k), name="C") + + a_buffer = tvm.decl_buffer(data.shape, dtype='%s8' % dtype, name="a_buffer", + offset_factor=1, + strides=[1]) + b_buffer = tvm.decl_buffer(kernel.shape, dtype='%s8' % dtype, name="b_buffer", + offset_factor=1, + strides=[tvm.var('s'), 1]) + + def _intrin_func(ins, outs): + def _instr(index): + ib = tvm.ir_builder.create() + if index == 1: + ib.emit(outs[0].vstore(0, tvm.const(0, '%s32x%d' % (dtype, int32_lanes)))) + return ib.get() + + dtype_a = '%s8x%d' % (dtype, num_int8_elements) + dtype_b = '%s8x%d' % (dtype, int32_lanes * num_int8_elements) + dtype_c = '%s32x%d' % (dtype, int32_lanes) + + a_int8 = ins[0].vload([0], dtype_a) + re_int32 = tvm.call_pure_intrin('%s32' % dtype, 'reinterpret', a_int8) + # broadcast a + vec_ai32 = re_int32.astype(dtype_c) + + vec_a = tvm.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32) + vec_b = ins[1].vload([0, 0], dtype_b) + vec_c = outs[0].vload([0], dtype_c) + + inst = 'udot' if dtype == 'uint' else 'sdot' + inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % ( + inst, int32_lanes, int32_lanes * num_int8_elements) + vdot = tvm.call_llvm_intrin(dtype_c, + inst, + tvm.const(2, 'uint32'), + vec_c, vec_a, vec_b) + ib.emit(outs[0].vstore(0, vdot)) + return ib.get() + + # body, reset, update + return _instr(0), _instr(1), _instr(2) + + with tvm.build_config(offset_factor=1, partition_const_loop=True): + return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) diff --git a/topi/python/topi/generic/conv2d.py b/topi/python/topi/generic/conv2d.py new file mode 100644 index 000000000000..332c2fdad459 --- /dev/null +++ b/topi/python/topi/generic/conv2d.py @@ -0,0 +1,239 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable, too-many-locals +# pylint: disable=unused-argument, redefined-builtin +"""Generic convolution schedules""" +from __future__ import absolute_import as _abs +import tvm +from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity +from ..util import get_const_tuple + +def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements): + """Fallback schedule for conv2d int8 on cpu. + Normally the inner most pattern takes two int8/uint8 tensors + data[num_int8_elements] and kernel[int32_lanes, num_int8_elements], + produces a dot product int32/uint32 output[int32_lanes]. + + Parameters + ---------- + int32_lanes : int + How many numbers of int32/uint32 will be produced using intrinsic. + This is related to output channel. + num_int8_elements : int + How many numbers of input int32/uint32 will be multiplied and reduced. + This is related to input channel. + """ + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 + + assert wkl.out_filter % int32_lanes == 0, \ + "wkl.out_filter=%d, int32_lanes=%d" % (wkl.out_filter, int32_lanes) + assert wkl.in_filter % num_int8_elements == 0, \ + "wkl.in_filter=%d, num_int8_elements=%d" % (wkl.in_filter, num_int8_elements) + + oc_bn = int32_lanes + ic_bn = 1 + for bn in range(oc_bn, 0, -4): + if wkl.in_filter % bn == 0: + ic_bn = bn + break + + reg_n = 1 + for n in range(31, 0, -1): + if out_width % n == 0: + reg_n = n + break + + cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) + cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) + cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) + cfg["unroll_kw"] = OtherOptionEntity(False) + + +def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements): + """Fallback schedule for 1x1 conv2d int8 on cpu. + Normally the inner most pattern takes two int8/uint8 tensors + data[num_int8_elements] and kernel[int32_lanes, num_int8_elements], + produces a dot product int32/uint32 output[int32_lanes]. + + Parameters + ---------- + int32_lanes : int + How many numbers of int32/uint32 will be produced using intrinsic. + This is related to output channel. + num_int8_elements : int + How many numbers of input int32/uint32 will be multiplied and reduced. + This is related to input channel. + """ + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 + out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 + + assert wkl.out_filter % int32_lanes == 0, \ + "wkl.out_filter=%d, int32_lanes=%d" % (wkl.out_filter, int32_lanes) + assert wkl.in_filter % num_int8_elements == 0, \ + "wkl.in_filter=%d, num_int8_elements=%d" % (wkl.in_filter, num_int8_elements) + + oc_bn = int32_lanes + ic_bn = 1 + for bn in range(oc_bn, 0, -4): + if wkl.in_filter % bn == 0: + ic_bn = bn + break + + for ow_factor in range(out_width, 0, -1): + if out_width % ow_factor == 0: + for oh_factor in range(out_height, 0, -1): + if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: + cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) + cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) + cfg["tile_oh"] = OtherOptionEntity(oh_factor) + cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor]) + return + raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) + + +def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last, int32_lanes=16, intrin=None): + """ + Defines the schedule for INT8 for Intel and ARM machines + Uses the Intel/ARM intrinsics to use INT8 operations + More details - https://software.intel.com/en-us/articles/ + lower-numerical-precision-deep-learning-inference-and-training + """ + reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val + _, _, _, _, ic_bn = get_const_tuple(data.shape) + _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) + + A = data + if isinstance(s[A].op, tvm.tensor.ComputeOp): + batch, ic_chunk, ih, iw, _ = s[A].op.axis + parallel_axis = s[A].fuse(batch, ic_chunk, ih) + s[A].parallel(parallel_axis) + + # schedule 5-D NCHW[x]c conv + C, O = conv_out, last + CC = s.cache_write(C, 'global') + + batch, oc_chunk, oh, ow, oc_block = s[C].op.axis + ow_chunk, ow_block = s[C].split(ow, factor=reg_n) + s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[C].fuse(batch, oc_chunk, oh) + s[C].vectorize(oc_block) + if C == O: + s[C].parallel(parallel_axis) + + s[CC].compute_at(s[C], ow_chunk) + _, oc_chunk, oh, ow, oc_block = s[CC].op.axis + kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis + + ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) + + assert oc_bn % int32_lanes == 0 + assert ic_bn % 4 == 0 # 4 (u)int8 elements in (u)int32 + + oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) + + if unroll_kw: + s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw, + ow_block, oc_f_inner, oc_s_inner, ic_s_inner) + s[CC].unroll(kw) + else: + s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, kw, ic_f_inner, + ow_block, oc_f_inner, oc_s_inner, ic_s_inner) + + if intrin is not None: + s[CC].tensorize(oc_s_inner, intrin) + s[CC].unroll(ow_block) + s[CC].unroll(oc_f_inner) + + if C != O: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[O].fuse(batch, oc_chunk, oh) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + + return s + +def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last, int32_lanes=16, intrin=None): + """ + Defines the 1x1 conv schedule for INT8 for Intel and ARM machines + Uses the Intel/ARM intrinsics to use INT8 operations + More details - https://software.intel.com/en-us/articles/ + lower-numerical-precision-deep-learning-inference-and-training + """ + oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1] + _, _, _, _, ic_bn = get_const_tuple(data.shape) + _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) + + # schedule data + A = data + if isinstance(s[A].op, tvm.tensor.ComputeOp): + batch, ic_chunk, ih, iw, ic_block = s[A].op.axis + parallel_axis = s[A].fuse(batch, ic_chunk, ih) + s[A].parallel(parallel_axis) + + C, O = conv_out, last + CC = s.cache_write(C, 'global') + + batch, oc_chunk, oh, ow, oc_block = s[C].op.axis + oh_outer, oh_inner = s[C].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[C].split(ow, factor=ow_factor) + s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + s[C].vectorize(oc_block) + + parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer) + s[CC].compute_at(s[C], parallel_axis) + if C == O: + s[C].parallel(parallel_axis) + + _, oc_chunk, oh, ow, oc_block = s[CC].op.axis + kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis + + assert oc_bn % int32_lanes == 0 + assert ic_bn % 4 == 0 # 4 (u)int8 elements in (u)int32 + + oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) + + oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor) + + s[CC].reorder(oc_chunk, oh_outer, ow_outer, kh, kw, ic_outer, ic_f_inner, oh_inner, + ow_inner, oc_f_inner, oc_s_inner, ic_s_inner) + s[CC].fuse(oc_chunk, oh_outer) + + if intrin is not None: + s[CC].tensorize(oc_s_inner, intrin) + s[CC].unroll(ow_inner) + s[CC].unroll(oh_inner) + + if C != O: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) + s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + + parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + + return s diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 904dd54cce66..ffae4b2094e4 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -595,19 +595,11 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn - target = tvm.target.current_target(allow_none=False) oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \ get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn groups = ic_chunk // ic_chunk_group - # Since the weight is 7-D and the last element size is 4, we have to - # check ic_bn should be a multiple of 4. - # Similary, oc_bn has to be a multiple of 4. - - assert ic_bn % 4 == 0 - assert oc_bn % 16 == 0 - dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index 6e36e93b9806..96b6e47789f7 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -22,6 +22,7 @@ from ..nn.pad import pad from ..nn.util import infer_pad, get_pad_tuple +from ..generic import conv2d as conv2d_generic from ..util import get_const_tuple, simplify from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .util import get_fp32_len @@ -57,36 +58,6 @@ def _fallback_schedule(cfg, wkl): raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) -def _fallback_schedule_int8(cfg, wkl): - simd_width = get_fp32_len() - HPAD, WPAD = wkl.hpad, wkl.wpad - HSTR, WSTR = wkl.hstride, wkl.wstride - out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 - out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 - - oc_bn = 16 - assert wkl.out_filter % oc_bn == 0 - - ic_bn = 1 - for bn in range(oc_bn, 0, -4): - if wkl.in_filter % bn == 0: - ic_bn = bn - break - assert wkl.in_filter % 4 == 0 - - for ow_factor in range(out_width, 0, -1): - if out_width % ow_factor == 0: - for oh_factor in range(out_height, 0, -1): - if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: - cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) - cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) - cfg["tile_oh"] = OtherOptionEntity(oh_factor) - cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor]) - return - raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) - - - def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last): # fetch schedule ic_bn, oc_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], @@ -210,71 +181,9 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): - """ - Defines the schedule for INT8 for intel machines - Uses the Intel intrinsics to use INT8 operations - More details - https://software.intel.com/en-us/articles/ - lower-numerical-precision-deep-learning-inference-and-training - """ - int32_lanes = 16 - - oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1] - _, _, _, _, ic_bn = get_const_tuple(data.shape) - _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) - - # schedule data - A = data - if isinstance(s[A].op, tvm.tensor.ComputeOp): - batch, ic_chunk, ih, iw, ic_block = s[A].op.axis - parallel_axis = s[A].fuse(batch, ic_chunk, ih) - s[A].parallel(parallel_axis) - - C, O = conv_out, last - CC = s.cache_write(C, 'global') - - batch, oc_chunk, oh, ow, oc_block = s[C].op.axis - oh_outer, oh_inner = s[C].split(oh, factor=oh_factor) - ow_outer, ow_inner = s[C].split(ow, factor=ow_factor) - s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) - s[C].vectorize(oc_block) - - parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer) - s[CC].compute_at(s[C], parallel_axis) - if C == O: - s[C].parallel(parallel_axis) - - _, oc_chunk, oh, ow, oc_block = s[CC].op.axis - kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis - - # Skylake and future processors have 16 vector lanes - assert oc_bn % int32_lanes == 0 - - oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) - - oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor) - ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor) - - s[CC].reorder(oc_chunk, oh_outer, ow_outer, kh, kw, ic_outer, ic_f_inner, oh_inner, - ow_inner, oc_f_inner, oc_s_inner, ic_s_inner) - s[CC].fuse(oc_chunk, oh_outer) - - pc = dot_16x1x16_int8_int8_int32() - s[CC].tensorize(oc_s_inner, pc) - s[CC].unroll(ow_inner) - s[CC].unroll(oh_inner) - - if C != O: - batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) - ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) - s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) - - parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) - s[C].compute_at(s[O], parallel_axis) - s[O].vectorize(oc_block) - s[O].parallel(parallel_axis) - - return s + return conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last, + int32_lanes=16, + intrin=dot_16x1x16_int8_int8_int32()) def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype): diff --git a/topi/python/topi/x86/conv2d_avx_common.py b/topi/python/topi/x86/conv2d_avx_common.py index a7f38acced4a..53b79bdbeec9 100644 --- a/topi/python/topi/x86/conv2d_avx_common.py +++ b/topi/python/topi/x86/conv2d_avx_common.py @@ -21,6 +21,7 @@ from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from ..nn.util import infer_pad +from ..generic import conv2d as conv2d_generic from ..util import get_const_tuple from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .util import get_fp32_len @@ -56,7 +57,6 @@ def _fallback_schedule(cfg, wkl): def _fallback_schedule_int8(cfg, wkl): - simd_width = get_fp32_len() HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 @@ -207,68 +207,6 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): - """ - Defines the schedule for INT8 for intel machines - Uses the Intel intrinsics to use INT8 operations - More details - https://software.intel.com/en-us/articles/ - lower-numerical-precision-deep-learning-inference-and-training - """ - int32_lanes = 16 - - reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val - _, _, _, _, ic_bn = get_const_tuple(data.shape) - _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) - - A = data - if isinstance(s[A].op, tvm.tensor.ComputeOp): - batch, ic_chunk, ih, iw, _ = s[A].op.axis - parallel_axis = s[A].fuse(batch, ic_chunk, ih) - s[A].parallel(parallel_axis) - - # schedule 5-D NCHW[x]c conv - C, O = conv_out, last - CC = s.cache_write(C, 'global') - - batch, oc_chunk, oh, ow, oc_block = s[C].op.axis - ow_chunk, ow_block = s[C].split(ow, factor=reg_n) - s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) - parallel_axis = s[C].fuse(batch, oc_chunk, oh) - s[C].vectorize(oc_block) - if C == O: - s[C].parallel(parallel_axis) - - s[CC].compute_at(s[C], ow_chunk) - _, oc_chunk, oh, ow, oc_block = s[CC].op.axis - kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis - - ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) - - # Skylake and future processors have 16 vector lanes - assert oc_bn % int32_lanes == 0 - - oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) - - if unroll_kw: - s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw, - ow_block, oc_f_inner, oc_s_inner, ic_s_inner) - s[CC].unroll(kw) - else: - s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, kw, ic_f_inner, - ow_block, oc_f_inner, oc_s_inner, ic_s_inner) - - - pc = dot_16x1x16_int8_int8_int32() - s[CC].tensorize(oc_s_inner, pc) - s[CC].unroll(ow_block) - s[CC].unroll(oc_f_inner) - - if C != O: - batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - ow_chunk, ow_block = s[O].split(ow, factor=reg_n) - s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) - parallel_axis = s[O].fuse(batch, oc_chunk, oh) - s[C].compute_at(s[O], parallel_axis) - s[O].vectorize(oc_block) - s[O].parallel(parallel_axis) - - return s + return conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last, + int32_lanes=16, + intrin=dot_16x1x16_int8_int8_int32()) diff --git a/topi/python/topi/x86/conv2d_int8.py b/topi/python/topi/x86/conv2d_int8.py index 3f65db4518f2..f701108071e5 100644 --- a/topi/python/topi/x86/conv2d_int8.py +++ b/topi/python/topi/x86/conv2d_int8.py @@ -24,6 +24,7 @@ from tvm.autotvm.task.topi_integration import deserialize_args from ..nn.conv2d import _get_workload as _get_conv2d_workload from .. import generic, tag +from ..generic import conv2d as conv2d_generic from ..util import get_const_tuple from ..nn.conv2d import conv2d_NCHWc_int8 from .. import nn @@ -38,9 +39,11 @@ def _get_default_config_int8(cfg, data, kernel, strides, padding, out_dtype, is_ wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1 if is_kernel_1x1: - conv2d_avx_1x1._fallback_schedule_int8(cfg, wkl) + conv2d_generic.fallback_schedule_cpu_1x1_int8( + cfg, wkl, int32_lanes=16, num_int8_elements=4) else: - conv2d_avx_common._fallback_schedule_int8(cfg, wkl) + conv2d_generic.fallback_schedule_cpu_common_int8( + cfg, wkl, int32_lanes=16, num_int8_elements=4) def _is_int8_hw_support(data_dtype, kernel_dtype): diff --git a/topi/recipe/conv/test_conv_int8_arm.py b/topi/recipe/conv/test_conv_int8_arm.py new file mode 100644 index 000000000000..ff0d37d9a66d --- /dev/null +++ b/topi/recipe/conv/test_conv_int8_arm.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable-msg=too-many-arguments, too-many-locals, assignment-from-no-return +""" Conv Int8 functional and performance testing""" +import sys +import logging +import numpy as np +import tvm +import topi + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +LOGGER = logging.getLogger('test_conv_int8_intel') +LOGGER.disabled = False + +# All the WORKLOADS from Resnet except first layer +# Workload is ['height', 'width', 'in_filter', 'out_filter', +# 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) +WORKLOADS = [(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + (56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + (56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + (56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + (28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + (28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + (28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + (14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + (14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + (14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + (7, 7, 512, 512, 3, 3, 1, 1, 1, 1), + (56, 56, 64, 256, 1, 1, 0, 0, 1, 1), + (56, 56, 256, 64, 1, 1, 0, 0, 1, 1), + (56, 56, 256, 128, 1, 1, 0, 0, 2, 2), + (28, 28, 128, 512, 1, 1, 0, 0, 1, 1), + (56, 56, 256, 512, 1, 1, 0, 0, 2, 2), + (28, 28, 512, 128, 1, 1, 0, 0, 1, 1), + (28, 28, 512, 256, 1, 1, 0, 0, 2, 2), + (14, 14, 256, 1024, 1, 1, 0, 0, 1, 1), + (28, 28, 512, 1024, 1, 1, 0, 0, 2, 2), + (14, 14, 1024, 256, 1, 1, 0, 0, 1, 1), + (14, 14, 1024, 512, 1, 1, 0, 0, 2, 2), + (7, 7, 512, 2048, 1, 1, 0, 0, 1, 1), + (14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2), + (7, 7, 2048, 512, 1, 1, 0, 0, 1, 1) + ] + + +TARGET_NAME = 'llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod' +NUM_VEC_LANES = 16 +CTX = tvm.context(TARGET_NAME, 0) + +def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad, + hstride, wstride, out_dtype): + """ + Finds out the shape of all data structures + """ + data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES) + + if out_dtype == 'int32' or out_dtype == 'uint32': + kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, + NUM_VEC_LANES//4, NUM_VEC_LANES, 4) + elif out_dtype == 'float32': + kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, + NUM_VEC_LANES, NUM_VEC_LANES) + out_height = (im_height + 2 * hpad - k_h) // hstride + 1 + out_width = (im_width + 2 * wpad - k_w) // wstride + 1 + o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES) + return (data_shape, kernel_shape, o_shape) + + + +def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_filter, + out_filter, k_h, k_w, hpad, wpad, hstride, wstride): + """ + Runs the inference and checks the functional correctness between + compute and schedule outputs + """ + (data_shape, kernel_shape, o_shape) = get_shape(im_height, im_width, in_filter, + out_filter, k_h, k_w, hpad, wpad, + hstride, wstride, out_dtype) + + # Create TVM placeholders + data = tvm.placeholder(data_shape, name='data', dtype=data_dtype) + kernel = tvm.placeholder(kernel_shape, name='kernel', dtype=kernel_dtype) + + # Create the numpy arrays to be used for executing conv models + if data_dtype == 'float32': + data_array = tvm.nd.array(np.random.rand(*data_shape).astype(dtype=data_dtype), CTX) + kernel_array = tvm.nd.array(np.random.rand(*kernel_shape).astype(dtype=kernel_dtype), CTX) + else: + data_array = tvm.nd.array(np.random.randint(100, size=data_shape).astype(data_dtype)) + kernel_array = tvm.nd.array(np.random.randint(100, size=kernel_shape).astype(kernel_dtype)) + + # c_orig will be used for declaration ouptut + # c_sch will be used for scheduled computation output + c_orig = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX) + c_sch = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX) + + + with tvm.target.create(TARGET_NAME): + if out_dtype == "float32": + conv = topi.nn.conv2d_NCHWc(data, kernel, stride=hstride, + padding=hpad, dilation=(1, 1), + layout='NCHWc', out_layout='NCHWc', out_dtype=out_dtype) + else: + conv = topi.nn.conv2d_NCHWc_int8(data, kernel, strides=hstride, + padding=hpad, dilation=(1, 1), + layout='NCHWc', out_layout='NCHWc', out_dtype=out_dtype) + out = topi.nn.relu(conv) + sch = tvm.create_schedule(out.op) + func = tvm.build(sch, [data, kernel, out], target=TARGET_NAME, name='out') + func(data_array, kernel_array, c_orig) + LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True)) + + # Generate and run the optimized schedule + if out_dtype == "float32": + sconv = topi.generic.nn.schedule_conv2d_NCHWc(outs=[out]) + else: + sconv = topi.generic.nn.schedule_conv2d_NCHWc_int8(outs=[out]) + func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv') + func(data_array, kernel_array, c_sch) + + # Functional check + if data_dtype == 'uint8': + np.testing.assert_equal(c_orig.asnumpy(), c_sch.asnumpy()) + else: + assert np.allclose(c_orig.asnumpy(), c_sch.asnumpy()) + + evaluator = func.time_evaluator(func.entry_name, CTX, number=1000) + LOGGER.debug(tvm.lower(sconv, [data, kernel], simple_mode=True)) + return evaluator(data_array, kernel_array, c_sch).mean + +if __name__ == "__main__": + LOGGER.info("Workload, Kernel_size, FP32_time, INT8_time, Speedup") + SPEEDUP_ARRAY = [] + for i, wkl in enumerate(WORKLOADS): + for dtype in ["uint", "int"]: + fp32_time = run_inference('float32', 'float32', 'float32', *wkl) + int8_time = run_inference('%s8' % dtype, '%s8' % dtype, '%s32' % dtype, *wkl) + kernel_h = wkl[4] + kernel_w = wkl[5] + LOGGER.info("[%s] Workload#" % dtype + str(i) + ", " + str(kernel_h) + "x" + str(kernel_w) + ", " + + str(fp32_time) + ", " + str(int8_time) + ", " + str(fp32_time/int8_time)) + + SPEEDUP_ARRAY.append(fp32_time/int8_time) + LOGGER.info("Average speedup --> %s" % str(sum(SPEEDUP_ARRAY)/float(len(SPEEDUP_ARRAY)))) From fa4d3ec6e31461697acf910a8aa1f0680308dcaf Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Tue, 1 Oct 2019 13:09:21 -0700 Subject: [PATCH 15/17] [TOPI]Add op argwhere (#3994) * Add op argwhere * Move shape func to _algorithm.py * Add lint rule * Raise exception if rank is not supportted * move argwhere to transform * Add argwhere example * Fix lint * Add 1-d support * cleanup * Add more dtype support * CR comment * Improve error message * Docs * raise exception --- include/tvm/relay/attrs/transform.h | 6 + python/tvm/relay/op/_transform.py | 101 +++++++++++++- python/tvm/relay/op/transform.py | 23 +++- src/relay/op/tensor/transform.cc | 34 +++++ tests/python/relay/test_any.py | 30 +++++ topi/python/topi/__init__.py | 1 + topi/python/topi/argwhere.py | 191 +++++++++++++++++++++++++++ topi/python/topi/generic/__init__.py | 1 + topi/python/topi/generic/search.py | 37 ++++++ 9 files changed, 422 insertions(+), 2 deletions(-) create mode 100644 topi/python/topi/argwhere.py create mode 100644 topi/python/topi/generic/search.py diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 52656872ad10..ccdc871e8a78 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -314,6 +314,12 @@ struct OneHotAttrs : public tvm::AttrsNode { } }; // struct OneHotAttrs +/*! \brief Attributes for ArgWhere operator */ +struct ArgWhereAttrs : public tvm::AttrsNode { + TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") { + } +}; // struct ArgWhereAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index d1c0f09aaedb..687d5b4c5b2c 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -15,8 +15,10 @@ # specific language governing permissions and limitations # under the License. """Backend compiler related feature registration""" -# pylint: disable=invalid-name,unused-argument, len-as-condition +# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks from __future__ import absolute_import +import tvm +import topi from topi.util import get_const_int, get_const_tuple from . import op as _reg from ._reduce import _schedule_reduce @@ -204,3 +206,100 @@ def take_shape_func(attrs, inputs, out_ndims): axis += data_ndim assert 0 <= axis < data_ndim return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])] + +@script +def _argwhere_shape_func_1d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(1) + for i1 in range(condition.shape[0]): + if condition[i1] != 0: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_2d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(2) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + if condition[i1, i2] != 0: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_3d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(3) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + if condition[i1, i2, i3] != 0: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_4d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(4) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + for i4 in range(condition.shape[3]): + if condition[i1, i2, i3, i4] != 0: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_5d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(5) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + for i4 in range(condition.shape[3]): + for i5 in range(condition.shape[4]): + if condition[i1, i2, i3, i4, i5] != 0: + out[0] += int64(1) + return out + +@_reg.register_shape_func("argwhere", True) +def argwhere_shape_func(attrs, inputs, out_ndims): + """ + Shape function for argwhere. + """ + if len(inputs[0].shape) == 1: + return [_argwhere_shape_func_1d(inputs[0])] + elif len(inputs[0].shape) == 2: + return [_argwhere_shape_func_2d(inputs[0])] + elif len(inputs[0].shape) == 3: + return [_argwhere_shape_func_3d(inputs[0])] + elif len(inputs[0].shape) == 4: + return [_argwhere_shape_func_4d(inputs[0])] + elif len(inputs[0].shape) == 5: + return [_argwhere_shape_func_5d(inputs[0])] + return ValueError("Does not support rank higher than 5 in argwhere") + +@_reg.register_schedule("argwhere") +def schedule_argwhere(_, outs, target): + """Schedule definition of argwhere""" + with target: + return topi.generic.schedule_argwhere(outs) + + +@_reg.register_compute("argwhere") +def compute_argwhere(attrs, inputs, output_type, _): + """Compute definition of argwhere""" + output_shape = [] + for s in output_type.shape: + if hasattr(s, "value"): + output_shape.append(s) + else: + # see Any, replace it with a var + output_shape.append(tvm.var("any_dim", "int32")) + new_output_type = tvm.relay.ty.TensorType(output_shape, "int32") + return [topi.argwhere(new_output_type, inputs[0])] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 7f921d03a62f..88d7a448005c 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -144,7 +144,6 @@ def squeeze(data, axis=None): """ return _make.squeeze(data, axis) - def reshape(data, newshape): """Reshapes the input array. @@ -214,6 +213,28 @@ def reshape(data, newshape): newshape = [newshape] return _make.reshape(data, list(newshape)) +def argwhere(condition): + """Find the indices of elements of a tensor that are + non-zero. + + Parameters + ---------- + condition : relay.Expr + The input condition tensor. + + Returns + ------- + out : relay.Expr + Tensor with the indices of elements that are non-zero. + + Examples + -------- + .. code-block:: python + + condition = [[True, False], [False, True]] + relay.argwhere(condition) = [[0, 0], [1, 1]] + """ + return _make.argwhere(condition) def reshape_like(data, shape_like): """Reshapes the input array by the size of another array. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index b92005754020..0002390be809 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -817,6 +817,40 @@ the input array into an output array with the same shape as the second input arr .set_attr("FTVMCompute", ReshapeCompute) .set_attr("TOpPattern", kInjective); +// ArgWhere +bool ArgWhereRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 1); + auto tt = types[0].as(); + CHECK(tt != nullptr); + const auto& input_shape = tt->shape; + const auto& input_rank = input_shape.size(); + std::vector result_shape; + result_shape.push_back(Any::make()); + result_shape.push_back(IntImm::make(Int(32), input_rank)); + reporter->Assign(types[1], TensorTypeNode::make(result_shape, Int(32))); + return true; +} + +TVM_REGISTER_API("relay.op._make.argwhere") +.set_body_typed([](Expr data) { + static const Op& op = Op::Get("argwhere"); + auto attrs = make_node(); + return CallNode::make(op, {data}, Attrs(attrs), {}); +}); + +RELAY_REGISTER_OP("argwhere") +.describe(R"doc(Find the indices of elements of a tensor that are +non-zero)doc" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.ArgWhereAttrs") +.add_argument("condition", "Tensor", "The input condition tensor.") +.add_type_rel("ArgWhere", ArgWhereRel) +.set_attr("TOpIsStateful", false) +.set_attr("TOpPattern", kOpaque) +.set_support_level(10); // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 214b88fa1850..d02dcd0b73dd 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -92,6 +92,36 @@ def test_any_reshape(): verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4)) verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12)) +def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): + x = relay.var('x', shape=x_shape, dtype=dtype) + y = relay.argwhere(x) + mod = relay.module.Module() + mod["main"] = relay.Function([x], y) + data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data).asnumpy() + expected = np.argwhere(data) + assert result.shape == expected.shape + tvm.testing.assert_allclose(result.flatten(), expected.flatten()) + +def test_any_argwhere(): + verify_any_argwhere(any_dims(1), (5,)) + verify_any_argwhere(any_dims(2), (5, 5)) + verify_any_argwhere(any_dims(3), (5, 5, 5)) + verify_any_argwhere(any_dims(4), (5, 5, 5, 5)) + verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5)) + verify_any_argwhere(any_dims(1), (5,), "int32") + verify_any_argwhere(any_dims(2), (5, 5), "int32") + verify_any_argwhere(any_dims(3), (5, 5, 5), "int32") + verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int32") + verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int32") + verify_any_argwhere(any_dims(1), (5,), "int8") + verify_any_argwhere(any_dims(2), (5, 5), "int8") + verify_any_argwhere(any_dims(3), (5, 5, 5), "int8") + verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int8") + verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8") + def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape): mod = relay.Module() data = relay.var('data', shape=data_shape, dtype='float32') diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py index ac855d144aad..fd293a09b9e7 100644 --- a/topi/python/topi/__init__.py +++ b/topi/python/topi/__init__.py @@ -22,6 +22,7 @@ from .transform import * from .broadcast import * from .sort import * +from .argwhere import * from . import nn from . import x86 from . import cuda diff --git a/topi/python/topi/argwhere.py b/topi/python/topi/argwhere.py new file mode 100644 index 000000000000..32f4e8718c46 --- /dev/null +++ b/topi/python/topi/argwhere.py @@ -0,0 +1,191 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks +"""Argwhere operator""" +import tvm +from tvm import hybrid + +@hybrid.script +def hybrid_argwhere_1d(output_shape, condition): + """Find the indices of elements of a 1-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 1-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + valid_index = 0 + for i1 in range(a1): + if condition[i1] != 0: + a[valid_index, 0] = i1 + valid_index += 1 + return a + +@hybrid.script +def hybrid_argwhere_2d(output_shape, condition): + """Find the indices of elements of a 2-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 2-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + a2 = condition.shape[1] + valid_index = 0 + for i1 in range(a1): + for i2 in range(a2): + if condition[i1, i2] != 0: + a[valid_index, 0] = i1 + a[valid_index, 1] = i2 + valid_index += 1 + return a + +@hybrid.script +def hybrid_argwhere_3d(output_shape, condition): + """Find the indices of elements of a 3-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 3-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + a2 = condition.shape[1] + a3 = condition.shape[2] + valid_index = 0 + for i1 in range(a1): + for i2 in range(a2): + for i3 in range(a3): + if condition[i1, i2, i3] != 0: + a[valid_index, 0] = i1 + a[valid_index, 1] = i2 + a[valid_index, 2] = i3 + valid_index += 1 + return a + +@hybrid.script +def hybrid_argwhere_4d(output_shape, condition): + """Find the indices of elements of a 4-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 4-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + a2 = condition.shape[1] + a3 = condition.shape[2] + a4 = condition.shape[3] + valid_index = 0 + for i1 in range(a1): + for i2 in range(a2): + for i3 in range(a3): + for i4 in range(a4): + if condition[i1, i2, i3, i4] != 0: + a[valid_index, 0] = i1 + a[valid_index, 1] = i2 + a[valid_index, 2] = i3 + a[valid_index, 3] = i4 + valid_index += 1 + return a + +@hybrid.script +def hybrid_argwhere_5d(output_shape, condition): + """Find the indices of elements of a 5-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 5-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + a2 = condition.shape[1] + a3 = condition.shape[2] + a4 = condition.shape[3] + a5 = condition.shape[4] + valid_index = 0 + for i1 in range(a1): + for i2 in range(a2): + for i3 in range(a3): + for i4 in range(a4): + for i5 in range(a5): + if condition[i1, i2, i3, i4, i5] != 0: + a[valid_index, 0] = i1 + a[valid_index, 1] = i2 + a[valid_index, 2] = i3 + a[valid_index, 3] = i4 + a[valid_index, 4] = i5 + valid_index += 1 + return a + +@tvm.target.generic_func +def argwhere(output_shape, condition): + """Find the indices of elements of a tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + Tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + if len(condition.shape) == 1: + return hybrid_argwhere_1d(output_shape.shape, condition) + if len(condition.shape) == 2: + return hybrid_argwhere_2d(output_shape.shape, condition) + if len(condition.shape) == 3: + return hybrid_argwhere_3d(output_shape.shape, condition) + if len(condition.shape) == 4: + return hybrid_argwhere_4d(output_shape.shape, condition) + if len(condition.shape) == 5: + return hybrid_argwhere_5d(output_shape.shape, condition) + raise ValueError("Does not support rank higher than 5 in argwhere") diff --git a/topi/python/topi/generic/__init__.py b/topi/python/topi/generic/__init__.py index 6bf5f3a053c9..18af0e328471 100644 --- a/topi/python/topi/generic/__init__.py +++ b/topi/python/topi/generic/__init__.py @@ -20,3 +20,4 @@ from .extern import * from .vision import * from .sort import * +from .search import * diff --git a/topi/python/topi/generic/search.py b/topi/python/topi/generic/search.py new file mode 100644 index 000000000000..41045e492e53 --- /dev/null +++ b/topi/python/topi/generic/search.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-member +"""Generic search operators""" +from __future__ import absolute_import as _abs +import tvm +from .vision import _default_schedule + +@tvm.target.generic_func +def schedule_argwhere(outs): + """Schedule for argwhere operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of argwhere. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) From 2f1edb9961751f0f89c2dcc739fdda1a41b2705d Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 1 Oct 2019 14:07:49 -0700 Subject: [PATCH 16/17] [COMMUNITY] ajtulloch -> committer (#4043) --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index f1e5a019c949..55d818b1d4ea 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -55,6 +55,7 @@ We do encourage everyone to work anything they are interested in. - [Siva](https://github.com/srkreddy1238): @srkreddy1238 - frontends, golang - [Haichen Shen](https://github.com/icemelon9) (PMC): @icemelon9 - relay, topi - [Zhixun Tan](https://github.com/phisiart): @phisiart - opengl, web +- [Andrew Tulloch](https://github.com/ajtulloch): @ajtulloch - topi, compiler, runtime - [Leyuan Wang](https://github.com/Laurawly): @Laurawly: - topi - [Yao Wang](https://github.com/kevinthesun): @kevinthesun: - topi, vision - [Jian Weng](https://github.com/were): @were: - hybrid script From 2d537621982b56d2cead7650f0c0ab683c5483cb Mon Sep 17 00:00:00 2001 From: Cody Hao Yu Date: Tue, 1 Oct 2019 16:20:29 -0700 Subject: [PATCH 17/17] Fix split's last factor issue (#4044) --- python/tvm/autotvm/task/space.py | 3 ++- tests/python/unittest/test_autotvm_space.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py index defb0612144c..f1422bf28213 100644 --- a/python/tvm/autotvm/task/space.py +++ b/python/tvm/autotvm/task/space.py @@ -226,7 +226,8 @@ def __init__(self, axes, policy, **kwargs): def _generate_space(self, now, tmp_stack, enforce_no_tail=False): """Generate space by DFS""" if now == self.num_output - 1: - if not enforce_no_tail or self.product % np.prod(tmp_stack, dtype=np.int64) == 0: + prod = np.prod(tmp_stack, dtype=np.int64) + if self.product % prod == 0 or (not enforce_no_tail and prod < self.product): self.entities.append(SplitEntity([-1] + tmp_stack[::-1])) else: for factor in self.factors: diff --git a/tests/python/unittest/test_autotvm_space.py b/tests/python/unittest/test_autotvm_space.py index 1da3fb0182ba..85d572412f9e 100644 --- a/tests/python/unittest/test_autotvm_space.py +++ b/tests/python/unittest/test_autotvm_space.py @@ -42,6 +42,26 @@ def test_split(): assert len(cfg) == 64 assert len(cfg.space_map['tile_y']) == 8 + # test policy + cfg = ConfigSpace() + cfg.define_split('tile_x', cfg.axis(256), policy='factors', num_outputs=3) + assert len(cfg.space_map['tile_x']) == 45 + + cfg.define_split('tile_y', cfg.axis(256), policy='power2', num_outputs=3) + assert len(cfg.space_map['tile_y']) == 45 + + cfg.define_split('tile_z', cfg.axis(256), policy='verbose', num_outputs=3) + assert len(cfg.space_map['tile_z']) == 45 + + cfg.define_split('tile_a', cfg.axis(224), policy='factors', num_outputs=3) + assert len(cfg.space_map['tile_a']) == 63 + + cfg.define_split('tile_b', cfg.axis(224), policy='power2', num_outputs=3) + assert len(cfg.space_map['tile_b']) == 36 + + cfg.define_split('tile_c', cfg.axis(224), policy='verbose', num_outputs=3) + assert len(cfg.space_map['tile_c']) == 84 + # test fallback cfg = FallbackConfigEntity() cfg.define_split('tile_n', cfg.axis(128), num_outputs=3)