Skip to content

Commit

Permalink
Fix VTA to fit the new IR Pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Apr 20, 2020
1 parent 016d03a commit fdee3d1
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions vta/python/vta/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def inject_conv2d_transpose_skip(stmt_in):
selects = []

def _find_basics(op):
if isinstance(op, tvm.tir.Call):
if isinstance(op, tvm.tir.BufferLoad):
calls.append(op)
elif isinstance(op, tvm.tir.Select):
selects.append(op)
Expand All @@ -664,7 +664,7 @@ def _do_fold(op):
body = op.body.body
while isinstance(body, tvm.tir.IfThenElse):
body = body.then_case
args = body.args
args = body.indices
res_tensor = body.func.output(0)
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt(
Expand Down Expand Up @@ -696,19 +696,19 @@ def _do_fold(op):
0, 0, 0))
inner = irb.get()

args = conv_call.args
args = conv_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt(
[dout, res_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = kernel_call.args
args = kernel_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt(
[dwgt, kernel_tensor], 'buffer_bind_scope',
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = data_call.args
args = data_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt(
Expand Down

0 comments on commit fdee3d1

Please sign in to comment.