Skip to content

Commit

Permalink
[PASS] PrecomputePrune, add testcase (apache#14)
Browse files Browse the repository at this point in the history
* [PASS] PrecomputePrune, add testcase

* update comment
  • Loading branch information
tqchen authored and sergei-mironov committed Aug 8, 2018
1 parent 4799a90 commit f691494
Show file tree
Hide file tree
Showing 14 changed files with 312 additions and 50 deletions.
3 changes: 2 additions & 1 deletion nnvm/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ ifneq ($(ADD_CFLAGS), NONE)
endif

ifneq ($(ADD_LDFLAGS), NONE)
LFFLAGS += $(ADD_LDFLAGS)
LDFLAGS += $(ADD_LDFLAGS)
endif

# plugin
Expand All @@ -46,6 +46,7 @@ ifeq ($(UNAME_S), Darwin)
SHARED_LIBRARY_SUFFIX := dylib
WHOLE_ARCH= -all_load
NO_WHOLE_ARCH= -noall_load
LDFLAGS += -undefined dynamic_lookup
else
SHARED_LIBRARY_SUFFIX := so
WHOLE_ARCH= --whole-archive
Expand Down
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tvm

from . import build_module
from . build_module import build
from . build_module import build, precompute_prune, _run_graph

from .. import symbol as _symbol
from .. import graph as _graph
Expand Down
85 changes: 78 additions & 7 deletions nnvm/python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from __future__ import absolute_import as _abs

import tvm
from . import graph_attr
from . import graph_attr, graph_pass
from .. import graph as _graph
from .. import runtime

@tvm.register_func("nnvm.compiler.lower")
def _lower(sch, inputs, func_name):
Expand All @@ -18,9 +19,6 @@ def _build(funcs, target):
return tvm.build(funcs, target=target)


_move_module = tvm.get_global_func("nnvm.compiler._move_module")


def optimize(graph):
"""Perform graph optimization
Expand Down Expand Up @@ -70,10 +68,83 @@ def build(graph, target, shape, dtype="float32"):
raise TypeError("require shape to be dict")

graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
graph = graph_attr.set_shape(graph, shape)
graph = graph_attr.set_dtype(graph, dtype)
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", target, "str")
graph = graph.apply("InferShape").apply("InferType")
graph = graph.apply("GraphFusePartition").apply("GraphFuse")
libmod = _move_module(graph)
libmod = graph_attr._move_out_module(graph, "module")
return graph, libmod


def _run_graph(graph, params):
"""Helper utility to build and run and get outputs, only use cpu mode.
Parameters
----------
graph : Graph
The graph to be executed.
params: dict of str to ndarray
The parameter dictionary.
Returns
-------
out_dict: dict of str to tvm.NDArray
The output dictionaries.
"""
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
shape = {k : v.shape for k, v in params.items()}
dtype = {k : v.dtype for k, v in params.items()}
target = "llvm"
ctx = tvm.cpu(0)
_, oshape = graph_pass.infer_shape(graph, **shape)
_, odtype = graph_pass.infer_dtype(graph, **dtype)
graph, libmod = build(graph, target, shape, dtype)
m = runtime.create(graph, libmod, ctx)
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
for k, v in params.items():
set_input(k, tvm.nd.array(v))
run()
out_data = []
for i, kv in enumerate(zip(oshape, odtype)):
shape, dtype = kv
arr = tvm.nd.empty(shape, dtype, ctx)
get_output(i, arr)
out_data.append(arr)
return out_data


def precompute_prune(graph, params):
"""Precompute the part of graph that can be pre-computed.
This will create a new graph that only contains the ops
that need to be computed depending on input as well as
updated version of param dict that pre-computes some of
intermediate results.
Parameters
----------
graph : Graph
The input graph
params : dict of str -> tvm.NDArray
The parameter dictionary of the graph
Returns
-------
pruned_graph : Graph
The pruned graph
new_params : dict of str-> tvm.NDArray
The updated dictionary of parameters.
"""
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
graph._set_json_attr("param_name_list", list(params.keys()), "list_str")
graph = graph.apply("PrecomputePrune")
pre_graph = graph_attr._move_out_graph(graph, "precompute_graph")
if not pre_graph.symbol.list_output_names():
return graph, params
out_names = pre_graph.json_attr("output_names")
out_arrs = _run_graph(pre_graph, params)
return graph, dict(zip(out_names, out_arrs))
64 changes: 48 additions & 16 deletions nnvm/python/nnvm/compiler/graph_attr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# pylint: disable=invalid-name
"""Utilities to access graph attributes"""
from __future__ import absolute_import as _abs

def set_shape(g, shape):
"""Set the shape of graph nodes in the graph attribute.
import tvm

def set_shape_inputs(g, shape):
"""Set the shape of input graph nodes in the graph attribute.
Parameters
----------
Expand All @@ -17,20 +20,24 @@ def set_shape(g, shape):
g : Graph
The updated graph with updated shape.
"""
index = g.index
list_shape = [[]] * index.num_node_entries
for k, v in shape.items():
list_shape[index.entry_id(k)] = v
g._set_json_attr("shape", list_shape, 'list_shape')
list_shape = [
shape.get(name, ()) for name in g.index.input_names]
g._set_json_attr("shape_inputs", list_shape, 'list_shape')
return g


DTYPE_DICT = {
DTYPE_TO_TCODE = {
"default": -1,
"float32": 0
}

def set_dtype(g, dtype):
"""Set the dtype of graph nodes
TCODE_TO_DTYPE = {
-1: None,
0: "float32"
}

def set_dtype_inputs(g, dtype):
"""Set the dtype inputs of graph nodes
Parameters
----------
Expand All @@ -45,12 +52,37 @@ def set_dtype(g, dtype):
g : Graph
The updated graph with updated dtype.
"""
index = g.index
if isinstance(dtype, dict):
list_dtype = [-1] * index.num_node_entries
for k, v in dtype.items():
list_dtype[index.entry_id(k)] = DTYPE_DICT[v]
list_dtype = [
DTYPE_TO_TCODE[dtype.get(name, "default")]
for name in g.index.input_names]
else:
list_dtype = [DTYPE_DICT[dtype]] * index.num_node_entries
g._set_json_attr("dtype", list_dtype, "list_int")
list_dtype = [DTYPE_TO_TCODE[dtype]] * len(g.index.input_names)
g._set_json_attr("dtype_inputs", list_dtype, "list_int")
return g


def set_layout_inputs(g, layout):
"""Set the layout inputs of graph nodes
Parameters
----------
g : Graph
The input graph
layout : dict of str to str or str
The input layout
Returns
-------
g : Graph
The updated graph with updated dtype.
"""
list_shape = [
layout.get(name, "default") for name in g.index.input_names]
g._set_json_attr("layout_inputs", list_shape, 'list_str')
return g


_move_out_module = tvm.get_global_func("nnvm.graph_attr._move_module")
_move_out_graph = tvm.get_global_func("nnvm.graph_attr._move_graph")
55 changes: 55 additions & 0 deletions nnvm/python/nnvm/compiler/graph_pass.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,62 @@
# pylint: disable=invalid-name
"""Namespace of graph pass.
Principle:
- Graph in, graph out: always takes in graph as first argument and returns a graph
- Composable API: break graph transformation pass as segments of small transformations.
"""
from __future__ import absolute_import as _abs

from . import graph_attr


def infer_shape(graph, **shape):
"""Infer the shape given the shape of inputs.
Parameters
----------
graph : Graph
The graph to perform shape inference from
Returns
-------
in_shape : list of tuple
Shape of inputs
out_shape: list of tuple
Shape of outputs
"""
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
shape = graph.json_attr("shape")
index = graph.index
input_shape = [shape[index.entry_id(x)] for x in index.input_names]
output_shape = [shape[index.entry_id(x)] for x in index.output_entries]
return input_shape, output_shape


def infer_dtype(graph, **dtype):
"""Infer the type given the typeS of inputs.
Parameters
----------
graph : Graph
The graph to perform type inference from
Returns
-------
in_dtype : list of tuple
Dtype of inputs
out_dtype: list of tuple
Dtype of outputs
"""
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph = graph.apply("InferType")
dtype = graph.json_attr("dtype")
index = graph.index
input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.input_names]
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.output_entries]
return input_dtype, output_dtype
6 changes: 6 additions & 0 deletions nnvm/python/nnvm/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(self, graph):
self.nodes = jgraph["nodes"]
self.entry_ptr = jgraph["node_row_ptr"]
self._name2nodeid = {n["name"]: i for i, n in enumerate(self.nodes)}
self.input_names = graph.symbol.list_input_names()
self.output_entries = jgraph["heads"]

@property
def num_nodes(self):
Expand Down Expand Up @@ -66,6 +68,10 @@ def entry_id(self, key, value_index=0):
index : int
The entry index
"""
if isinstance(key, (list, tuple)):
if len(key) != 3:
raise ValueError("Expect entry index to be tuple of 3 elems")
key, value_index, _ = key
idx = self.node_id(key) if isinstance(key, str) else key
assert value_index < self.entry_ptr[idx + 1]
return self.entry_ptr[idx] + value_index
Expand Down
15 changes: 15 additions & 0 deletions nnvm/python/nnvm/top/attr_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ def get_int(self, key):
"""
return int(self[key])

def get_float(self, key):
"""Get float from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
value : float
The result value
"""
return float(self[key])

def get_bool(self, key):
"""Get bool from attr dict
Expand Down
17 changes: 17 additions & 0 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ def _schedule_broadcast(_, outs, target):
tvm.schedule.AutoInlineInjective(s)
return s

def _compute_binary_scalar(f):
"""auxiliary function"""
@tvm.tag_scope("ewise")
def _compute(attrs, x):
x = x[0]
scalar = attrs.get_float("scalar")
scalar = tvm.const(scalar, x.dtype)
return tvm.compute(x.shape, lambda *i: f(x(*i), scalar))
return _compute


_fschedule_broadcast = tvm.convert(_schedule_broadcast)

# exp
Expand All @@ -25,6 +36,12 @@ def _schedule_broadcast(_, outs, target):
reg.register_pattern("exp", OpPattern.ELEM_WISE)
reg.register_schedule("exp", _fschedule_broadcast)

# add scalar
reg.register_compute("__add_scalar__",
_compute_binary_scalar(lambda x, y: x + y))
reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__add_scalar__", _fschedule_broadcast)

# broadcast_add
reg.register_compute("broadcast_add",
lambda _, x: topi.broadcast_add(x[0], x[1]))
Expand Down
14 changes: 14 additions & 0 deletions nnvm/src/compiler/packed_func_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,19 @@ TVM_REGISTER_GLOBAL("nnvm._register_pattern")
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
op.set_attr<TOpPattern>("TOpPattern", args[1].operator int(), args[2]);
});

TVM_REGISTER_GLOBAL("nnvm.graph_attr._move_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
const nnvm::Graph& g = args[0].AsExtension<Graph>();
*rv = const_cast<nnvm::Graph*>(&g)->
MoveCopyAttr<tvm::runtime::Module>(args[1]);
});

TVM_REGISTER_GLOBAL("nnvm.graph_attr._move_graph")
.set_body([](TVMArgs args, TVMRetValue *rv) {
const nnvm::Graph& g = args[0].AsExtension<Graph>();
*rv = const_cast<nnvm::Graph*>(&g)->
MoveCopyAttr<nnvm::Graph>(args[1]);
});
} // namespace compiler
} // namespace nnvm
8 changes: 0 additions & 8 deletions nnvm/src/compiler/pass/graph_fuse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,5 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {

NNVM_REGISTER_PASS(GraphFuse)
.set_body(GraphFuse);


TVM_REGISTER_GLOBAL("nnvm.compiler._move_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
const nnvm::Graph& g = args[0].AsExtension<Graph>();
*rv = const_cast<nnvm::Graph*>(&g)->
MoveCopyAttr<tvm::runtime::Module>("module");
});
} // namespace compiler
} // namespace nnvm
Loading

0 comments on commit f691494

Please sign in to comment.