Skip to content

Commit

Permalink
removing out_shape in relu def and lint issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Venkat Rasagna Reddy Komatireddy committed May 27, 2022
1 parent ea29518 commit 12ba658
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 37 deletions.
49 changes: 22 additions & 27 deletions python/tvm/topi/hexagon/slice_ops/relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,38 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Relu slice op for Hexagon:"""

import tvm
from tvm import te, tir
from tvm.ir.module import IRModule
from tvm.script import tir as T


def relu_te_compute(Input, out_shape, dtype):
def relu_te_compute(A, dtype):
x = tvm.tir.const(0, dtype)
Output = te.compute(
out_shape, lambda n, h, w, c: tvm.te.max(Input[n, h, w, c], x), name="reluf16"
)
return Output

B = te.compute(A.shape,
lambda n, h, w, c: tvm.te.max(A[n, h, w, c], x), name="reluf16")
return B

def reluf16_te_sched(Output, Input, transform_crouton_activation):
s = tvm.te.create_schedule(Output.op)
s[Input].transform_layout(transform_crouton_activation)
out_axes = s[Output].transform_layout(transform_crouton_activation)
fused = s[Output].fuse(out_axes[6], out_axes[7])
s[Output].vectorize(fused)
def reluf16_te_sched(B, A, layout):
s = tvm.te.create_schedule(B.op)
s[A].transform_layout(layout)
out_axes = s[B].transform_layout(layout)
fused = s[B].fuse(out_axes[6], out_axes[7])
s[B].vectorize(fused)
return s


def reluf16_stir_sched(func, transform_crouton_activation):
def reluf16_stir_sched(func, layout):
sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("reluf16")
n, i, j, k = sch.get_loops(block)
i1, i2 = sch.split(i, [None, 8])
j1, j2 = sch.split(j, [None, 4])
k1, k2 = sch.split(k, [None, 32])
j3, j4 = sch.split(j2, [None, 2])
sch.reorder(n, i1, j1, k1, i2, j3, k2, j4)
sch.transform_layout(block, 0, "read", transform_crouton_activation)
sch.set_axis_separator(block, 0, "read", [4])
sch.transform_layout(block, 0, "write", transform_crouton_activation)
n, h, w, c = sch.get_loops(block)
ho, hi = sch.split(h, [None, 8])
wo, wi = sch.split(w, [None, 4])
co, ci = sch.split(c, [None, 32])
wio, wii = sch.split(wi, [None, 2])
sch.reorder(n, ho, wo, co, hi, wio, ci, wii)
sch.transform_layout (block, 0, "read", layout)
sch.set_axis_separator(block, 0, "read", [4])
sch.transform_layout (block, 0, "write", layout)
sch.set_axis_separator(block, 0, "write", [4])
fused = sch.fuse(k2, j4)
fused = sch.fuse(ci, wii)
sch.vectorize(fused)
return sch
14 changes: 4 additions & 10 deletions tests/python/contrib/test_hexagon/test_relu_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,20 @@

from .infrastructure import allocate_hexagon_array


def transform_numpy(arr_np):
N, H, W, C = arr_np.shape
return arr_np.reshape([N, H // 8, 8, W // 4, 2, 2, C // 32, 32]).transpose(
0, 1, 3, 6, 2, 4, 7, 5
)


def transform_2d(arr_np):
N, H, W, C, h, w1, c, w2 = arr_np.shape
return arr_np.reshape(N * H * W * C, h * w1 * c * w2)


@tvm.testing.fixture
def input_np(in_shape, dtype):
return np.random.uniform(size=in_shape).astype(dtype)


@tvm.testing.fixture
def input_np_padded(input_np, in_shape, padded_in_shape):
pad_height = padded_in_shape[1] - in_shape[1]
Expand All @@ -57,7 +53,6 @@ def input_np_padded(input_np, in_shape, padded_in_shape):
)
return input_padded


class BaseRelu:
in_shape = tvm.testing.parameter(
(1, 8, 4, 32),
Expand All @@ -72,7 +67,6 @@ class BaseRelu:
dtype = tvm.testing.parameter("float16")
working_scope = tvm.testing.parameter("global.vtcm")


class TestReluSlice(BaseRelu):
@tvm.testing.fixture
def padded_in_shape(self, in_shape):
Expand Down Expand Up @@ -102,9 +96,9 @@ def test_relu(
):
InputTensor = tvm.te.placeholder(padded_in_shape, name="InputTensor", dtype=dtype)

OutputTensor = sl.relu_te_compute(InputTensor, in_shape, dtype)
OutputTensor = sl.relu_te_compute(InputTensor, dtype)

def transform_crouton_activation(n, h, w, c):
def layout_func(n, h, w, c):
return [n, h // 8, w // 4, c // 32, h % 8, (w % 4) // 2, c % 32, w % 2]

target_hexagon = tvm.target.hexagon("v69", codegen_options="emit-llvm, emit-asm=1")
Expand All @@ -113,7 +107,7 @@ def transform_crouton_activation(n, h, w, c):
reluf16_func = te.create_prim_func([InputTensor, OutputTensor])
tir_s = sl.reluf16_stir_sched(
reluf16_func,
transform_crouton_activation,
layout_func,
)

func_name = "reluf16"
Expand Down Expand Up @@ -147,4 +141,4 @@ def transform_crouton_activation(n, h, w, c):
mod(input_arr, output_arr)
output_np = output_arr.numpy()

np.testing.assert_allclose(output_np, output_np_tr_2d, atol=1.0, rtol=0.05)
np.testing.assert_allclose(output_np, output_np_tr_2d)

0 comments on commit 12ba658

Please sign in to comment.