Skip to content

Commit

Permalink
[Hexagon] Slice op relu (#11449)
Browse files Browse the repository at this point in the history
* Add support for relu slice op.

* Format code

* removing out_shape in relu def and lint issues

* removing out_shape in relu def and lint issues

* Changes as per the new format

Co-authored-by: Venkat Rasagna Komatireddy <89959097+rasagna-quic@users.noreply.github.com>
Co-authored-by: Venkat Rasagna Reddy Komatireddy <rasagna@hu-rasagna-hyd.qualcomm.com>
  • Loading branch information
3 people committed Jul 19, 2022
1 parent ae015d9 commit a1f27e5
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/topi/hexagon/slice_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@
)
from .conv2d import *
from .reshape import reshape_compute, reshape_stir_schedule
from .relu import relu_compute, relu_stir_schedule
65 changes: 65 additions & 0 deletions python/tvm/topi/hexagon/slice_ops/relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Hexagon slice relu op"""

from tvm import te, tir, topi
from ..utils import get_layout_transform_fn


def relu_compute(Input):
"""Relu topi compute"""
return topi.nn.relu(Input)


def relu_te_sched(Output, Input, layout):
"""
Schedule assumes the layout function to be bijective
"""
s = te.create_schedule(Output.op)
s[Input].transform_layout(layout)
out_axes = s[Output].transform_layout(layout)
fused = s[Output].fuse(out_axes[6], out_axes[7])
s[Output].vectorize(fused)
return s


def relu_stir_schedule(Input, Output, input_layout, output_layout):
"""
Schedule assumes the layout function to be bijective
"""
if (input_layout != output_layout) or (output_layout != "nhwc-8h2w32c2w-2d"):
raise RuntimeError(
f"Unexpected input_layout, output_layout '{input_layout, output_layout}'"
)
relu_func = te.create_prim_func([Input, Output])
sch = tir.Schedule(relu_func, debug_mask="all")
block = sch.get_block("compute")
sch.transform_layout(block, Input.name, get_layout_transform_fn(input_layout))
sch.transform_layout(block, Output.name, get_layout_transform_fn(output_layout))

n, h, w, c = sch.get_loops(block)
h_o, h_i = sch.split(h, [None, 8])
w_o, w_i = sch.split(w, [None, 4])
c_o, c_i = sch.split(c, [None, 32])
wio, wii = sch.split(w_i, [None, 2])

sch.reorder(n, h_o, w_o, c_o, h_i, wio, c_i, wii)

fused = sch.fuse(c_i, wii)
sch.vectorize(fused)
return sch
123 changes: 123 additions & 0 deletions tests/python/contrib/test_hexagon/topi/test_relu_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import numpy as np
import pytest

import tvm
import tvm.testing
from tvm.topi.hexagon.slice_ops.relu import relu_compute, relu_stir_schedule
from tvm import te
from tvm.contrib.hexagon.build import HexagonLauncher

from ..infrastructure import allocate_hexagon_array, transform_numpy


@tvm.testing.fixture
def input_np(in_shape, dtype):
return np.random.uniform(size=in_shape).astype(dtype)


@tvm.testing.fixture
def ref_output_np(input_np):
output_np = input_np * (input_np > 0)
return output_np


@tvm.testing.fixture
def transformed_input_np(input_np, input_layout):
return transform_numpy(input_np, "nhwc", input_layout)


@tvm.testing.fixture
def transformed_ref_output_np(ref_output_np, output_layout):
return transform_numpy(ref_output_np, "nhwc", output_layout)


class BaseRelu:
in_shape = tvm.testing.parameter(
(1, 8, 4, 32),
(1, 16, 4, 32),
(1, 16, 8, 32),
(1, 16, 8, 64),
(2, 8, 4, 32),
(2, 16, 4, 32),
(2, 16, 8, 32),
(2, 16, 8, 64),
)
dtype = tvm.testing.parameter("float16")
working_scope = tvm.testing.parameter("global.vtcm")
input_layout = tvm.testing.parameter("nhwc-8h2w32c2w-2d")
output_layout = tvm.testing.parameter("nhwc-8h2w32c2w-2d")


class TestReluSlice(BaseRelu):
@tvm.testing.requires_hexagon
def test_relu(
self,
in_shape,
dtype,
input_layout,
output_layout,
transformed_input_np,
transformed_ref_output_np,
target,
working_scope,
hexagon_session,
):
InputTensor = te.placeholder(in_shape, name="InputTensor", dtype=dtype)

OutputTensor = relu_compute(InputTensor)

target_hexagon = tvm.target.hexagon("v69")
target = tvm.target.Target(target_hexagon, host=target_hexagon)

tir_s = relu_stir_schedule(InputTensor, OutputTensor, input_layout, output_layout)

input_data = allocate_hexagon_array(
hexagon_session.device,
data=transformed_input_np,
axis_separators=[4],
mem_scope=working_scope,
)
output_data = allocate_hexagon_array(
hexagon_session.device,
tensor_shape=transformed_ref_output_np.shape,
dtype=transformed_ref_output_np.dtype,
axis_separators=[4],
mem_scope=working_scope,
)

func_name = "relu"
with tvm.transform.PassContext(opt_level=3):
runtime_module = tvm.build(tir_s.mod, target=target, name=func_name)

mod = hexagon_session.load_module(runtime_module)

mod(input_data, output_data)
output_np = output_data.numpy()

tvm.testing.assert_allclose(
output_np,
transformed_ref_output_np,
1e-3,
1e-3,
)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit a1f27e5

Please sign in to comment.