Skip to content

Commit

Permalink
fixing linter formatting part 3
Browse files Browse the repository at this point in the history
  • Loading branch information
Wheest committed Dec 21, 2020
1 parent 05bcbc1 commit 955844b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
8 changes: 3 additions & 5 deletions python/tvm/topi/arm_cpu/group_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from ..nn.conv2d import _get_workload as _get_conv2d_workload



def group_conv2d_nchw(data, kernel, strides, padding, dilation, groups, out_dtype):
"""Compute group_conv2d with NCHW layout"""
return group_conv2d_nchw_spatial_pack(
Expand All @@ -44,7 +43,7 @@ def schedule_group_conv2d_nchw(outs):
return schedule_group_conv2d_nchwc(outs)


def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, layout='NCHW'):
def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, layout="NCHW"):
"""
Get default schedule config for the workload
"""
Expand Down Expand Up @@ -99,7 +98,7 @@ def _fallback_schedule(cfg, wkl):

@autotvm.register_topi_compute("group_conv2d_nchw.arm_cpu")
def group_conv2d_nchw_spatial_pack(
cfg, data, kernel, strides, padding, dilation, groups, out_dtype='float32'
cfg, data, kernel, strides, padding, dilation, groups, out_dtype="float32"
):
"""
Compute group conv2d with NCHW layout, using GSPC algorithm.
Expand Down Expand Up @@ -202,7 +201,7 @@ def group_conv2d_nchw_spatial_pack(
)

# convolution
oshape = (groups, batch_size, kernels_per_group//oc_bn, out_height, out_width, oc_bn)
oshape = (groups, batch_size, kernels_per_group // oc_bn, out_height, out_width, oc_bn)
unpack_shape = (batch_size, out_channel, out_height, out_width)

ic = te.reduce_axis((0, (kernel_depth)), name="ic")
Expand Down Expand Up @@ -299,7 +298,6 @@ def _schedule_gspc_nchw(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out,
cfg["unroll_kw"].val,
)


# no stride and padding info here
padding = infer_pad(data, data_pad)
hpad, wpad = padding
Expand Down
5 changes: 2 additions & 3 deletions python/tvm/topi/x86/group_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _fallback_schedule(cfg, wkl):

@autotvm.register_topi_compute("group_conv2d_nchw.x86")
def group_conv2d_nchw_spatial_pack(
cfg, data, kernel, strides, padding, dilation, groups, out_dtype="float32"
cfg, data, kernel, strides, padding, dilation, groups, out_dtype="float32"
):
"""
Compute group conv2d with NCHW layout, using GSPC algorithm.
Expand Down Expand Up @@ -289,8 +289,7 @@ def traverse(op):
return s


def _schedule_gspc_nchw(s, cfg, data, data_pad, data_vec, kernel_vec,
conv_out, output, last):
def _schedule_gspc_nchw(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
"""Schedule GSPC"""
ic_bn, oc_bn, reg_n, unroll_kw = (
cfg["tile_ic"].size[-1],
Expand Down

0 comments on commit 955844b

Please sign in to comment.