Skip to content

Commit

Permalink
Vta relay bitpack (apache#34)
Browse files Browse the repository at this point in the history
* Add bitpacking

* Fix issue in Python wrapper

* Misc fixes

* Fix some bugs in expr.py
  • Loading branch information
jroesch authored and tmoreau89 committed Mar 21, 2019
1 parent 89c831e commit d16b039
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 0 deletions.
137 changes: 137 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,143 @@ def realize(self):
return _expr.TempExprRealize(self)


class ExprFunctor(object):
"""
An abstract visitor defined over Expr.
Defines the default dispatch over expressions, and
implements memoization.
"""
def __init__(self):
self.memo_map = {}

# pylint: disable=no-else-return
def visit(self, expr):
from .op.op import Op
"""Apply the visitor to an expression."""
found = self.memo_map.get(expr)
if found:
return found

if isinstance(expr, Function):
res = self.visit_function(expr)
elif isinstance(expr, Call):
res = self.visit_call(expr)
elif isinstance(expr, Let):
res = self.visit_let(expr)
elif isinstance(expr, Var):
res = self.visit_var(expr)
elif isinstance(expr, GlobalVar):
res = self.visit_global_var(expr)
elif isinstance(expr, If):
res = self.visit_if(expr)
elif isinstance(expr, Tuple):
res = self.visit_tuple(expr)
elif isinstance(expr, TupleGetItem):
res = self.visit_tuple_getitem(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
elif isinstance(expr, Op):
res = self.visit_op(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))

self.memo_map[expr] = res
return res

def visit_function(self, _):
raise NotImplementedError()

def visit_let(self, _):
raise NotImplementedError()

def visit_call(self, _):
raise NotImplementedError()

def visit_var(self, _):
raise NotImplementedError()

def visit_type(self, typ):
return typ

def visit_if(self, _):
raise NotImplementedError()

def visit_tuple(self, _):
raise NotImplementedError()

def visit_tuple_getitem(self, _):
raise NotImplementedError()

def visit_constant(self, _):
raise NotImplementedError()

def visit_global_var(self, _):
raise NotImplementedError()

def visit_op(self, _):
raise NotImplementedError()


class ExprMutator(ExprFunctor):
"""
A functional visitor over Expr.
The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def visit_function(self, fn):
new_body = self.visit(fn.body)
return Function(
list(fn.params),
new_body,
fn.ret_type,
fn.type_params)

def visit_let(self, let):
new_var = self.visit(let.var)
new_val = self.visit(let.value)
new_body = self.visit(let.body)
return Let(new_var, new_val, new_body)

def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)

def visit_var(self, rvar):
return rvar

def visit_global_id(self, global_var):
return global_var

def visit_if(self, ite):
return If(
self.visit(ite.guard),
self.visit(ite.true_b),
self.visit(ite.false_b))

def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])

def visit_tuple_getitem(self, op):
tuple_value = self.visit(op.tuple_value)
if not tuple_value.same_as(op.tuple_value):
return TupleGetItem(tuple_value, op.index)
return op

def visit_global_var(self, gvar):
return gvar

def visit_constant(self, rconst):
return rconst

def visit_op(self, op):
return op


class TupleWrapper(object):
"""TupleWrapper.
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,7 @@ def layout_transform(data, src_layout, dst_layout):
"""
return _make.layout_transform(data, src_layout, dst_layout)

<<<<<<< HEAD

def reverse_reshape(data, newshape):
"""Reshapes the input array where the special values are inferred from
Expand All @@ -629,10 +630,15 @@ def reverse_reshape(data, newshape):
- data.shape = (10,5,4), newshape = (-1,0), reshape results in (40,5)
- data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5)
=======
def bitpack(data, lanes):
"""Bitpack the innermost dimension of the tensor.
>>>>>>> Vta relay bitpack (#34)

Parameters
----------
data : relay.Expr
<<<<<<< HEAD
The input data to the operator.

newshape : Union[int, Tuple[int], List[int]]
Expand All @@ -646,3 +652,17 @@ def reverse_reshape(data, newshape):
if isinstance(newshape, int):
newshape = [newshape]
return _make._contrib_reverse_reshape(data, list(newshape))
=======
The source tensor to be packed.

lanes : int
The lanes to pack by.

Returns
-------
ret : relay.Expr
The transformed tensor.
"""
return _make.bitpack(data, lanes)

>>>>>>> Vta relay bitpack (#34)
1 change: 1 addition & 0 deletions vta/python/vta/top/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from . import op
from . import relay_op
from . import bitpack
from . import relay_bitpack

0 comments on commit d16b039

Please sign in to comment.