Skip to content

Commit

Permalink
[Autoscheduler][Sparse] Add sparse dense end to end model tuning supp…
Browse files Browse the repository at this point in the history
…ort for x86/arm cpu & Some bug fix (apache#7635)

* Add sparse dense end to end model tuning support

* Add sparse tuning for arm network

* Bug fix for tflite frontend dense with layout rewrite

* Move the random_bsr_matrix to sparse.utils
  • Loading branch information
jcf94 authored and Trevor Morris committed May 6, 2021
1 parent 3b0a912 commit a48cb45
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 47 deletions.
6 changes: 3 additions & 3 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def register(myf):
return register


def _prepare_input_map(args):
def prepare_input_map(args):
"""This function deals with special task inputs. Map the input Tensor of a TVM subgraph
to a specific buffer name in the global buffer map.
Expand Down Expand Up @@ -861,7 +861,7 @@ def _timed_eval_func(
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"

tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {}
tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
Expand Down Expand Up @@ -1076,7 +1076,7 @@ def _timed_rpc_run(
random_fill
), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices"

tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {}
tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
Expand Down
30 changes: 28 additions & 2 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ def extract_tasks(
# When auto scheduler is used in end to end network, try to apply layout rewrite
# to improve the overall performance
layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True),
task_inputs=(
env.wkl_key_to_input_names[wkl_key]
if wkl_key in env.wkl_key_to_input_names
else None
),
task_inputs_save_to_file=True,
)
)
weights.append(weight)
Expand All @@ -166,6 +172,7 @@ def __init__(self, tracing_mode):
self.tracing_mode = tracing_mode
self.relay_disable_build_cache = "false"
self.wkl_key_to_weight = {}
self.wkl_key_to_input_names = {}

def __enter__(self):
TracingEnvironment.current = self
Expand All @@ -175,17 +182,30 @@ def __exit__(self, exc_type, exc_val, exc_tb):
TracingEnvironment.current = None

def add_workload_key(self, workload_key):
"""Add the workload key of a search task
"""Add the workload key of a search task.
Parameters
----------
workload_key: str
The workload key of a task
The workload key of a task.
"""
if workload_key not in self.wkl_key_to_weight:
self.wkl_key_to_weight[workload_key] = 0
self.wkl_key_to_weight[workload_key] += 1

def add_workload_input_names(self, workload_key, input_names):
"""Add special task inputs to this workload.
Parameters
----------
workload_key : str
The workload key of a task.
input_names : List[str]
A list of input names.
"""
self.wkl_key_to_input_names[workload_key] = input_names


@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite")
def enter_layout_rewrite():
Expand Down Expand Up @@ -277,6 +297,9 @@ def auto_schedule_topi(func_name, outs):
None in the tracing mode so that the fallback topi schedule will be used.
"""
# pylint: disable=import-outside-toplevel
from tvm.auto_scheduler.measure import (
prepare_input_map,
) # lazily import to avoid recursive dependency

io_tensors, has_layout_free, has_complex_op = traverse_to_get_io_tensors(outs)
if not io_tensors: # The compute includes dynamic shapes which are not supported yet.
Expand Down Expand Up @@ -308,6 +331,9 @@ def auto_schedule_topi(func_name, outs):
# in the task extraction mode
if has_complex_op or env.tracing_mode == TracingMode.EXTRACT_TASK:
env.add_workload_key(key)
input_map = prepare_input_map(io_tensors)
if input_map:
env.add_workload_input_names(key, list(input_map.values()))
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
# in prepare_layout_rewrite mode
if (
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,18 @@ def get_task_input_buffer(workload_key, input_name):
TASK_INPUT_BUFFER_TABLE[workload_key] = {}
input_table = TASK_INPUT_BUFFER_TABLE[workload_key]

if input_name not in input_table.keys():
if input_name not in input_table:
# Try to load buffer data from local file
tensor_from_file = _try_load_buffer_from_file(input_name)
if tensor_from_file:
input_table[input_name] = tensor_from_file

if input_name in input_table.keys():
# Then check for the default table, the input names extracted from a relay model will be
# stored here for we're not able to get the workload_key at that time
if input_name not in input_table:
input_table = TASK_INPUT_BUFFER_TABLE["default"]

if input_name in input_table:
return input_table[input_name]

raise ValueError(
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/analysis/sparse_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def process_params(expr, params, block_size, sparsity_threshold):
ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]]
return names of qualified dense weight and the shape in BSR format
"""

# pylint: disable=import-outside-toplevel
from tvm.auto_scheduler.search_task import (
register_task_input_buffer,
) # lazily import to avoid recursive dependency

memo = SparseAnalysisResult(weight_name=[], weight_shape=[])
weight_names = _search_dense_op_weight(expr)
for name in weight_names:
Expand All @@ -92,6 +98,23 @@ def process_params(expr, params, block_size, sparsity_threshold):
params[name + ".data"] = tvm.nd.array(sparse_weight.data)
params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)

prefix = "sparse_dense_bsr_%d_%d_%d_%d_%.2f_" % (
w_np.shape[0],
w_np.shape[1],
block_size[0],
block_size[1],
1 - sparsity,
)
register_task_input_buffer(
"default", prefix + "W_data", tvm.runtime.ndarray.array(sparse_weight.data)
)
register_task_input_buffer(
"default", prefix + "W_indices", tvm.runtime.ndarray.array(sparse_weight.indices)
)
register_task_input_buffer(
"default", prefix + "W_indptr", tvm.runtime.ndarray.array(sparse_weight.indptr)
)
ret = SparseAnalysisResult(
weight_name=tvm.runtime.convert(memo.weight_name),
weight_shape=tvm.runtime.convert(memo.weight_shape),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,7 +1872,7 @@ def convert_fully_connected(self, op):
out_dtype="int32",
)
else:
out = _op.nn.dense(in_expr, weight_expr)
out = _op.nn.dense(in_expr, weight_expr, units=weight_shape[0])

# if we have bias
if len(input_tensors) == 3:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/nn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def _process_inputs(input_tensors, m, n, prefix_init):
density *= i
density /= k * n
density = density.value
sparse_prefix = "%s_%d_%d_%d_%d_%d_%.2f_" % (prefix_init, m, n, k, bs_r, bs_c, density)
sparse_prefix = "%s_%d_%d_%d_%d_%.2f_" % (prefix_init, n, k, bs_r, bs_c, density)

visited = set()

Expand Down
126 changes: 126 additions & 0 deletions python/tvm/topi/sparse/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Some utils for Sparse operation."""
import tvm
from tvm import relay


def random_bsr_matrix(m, n, bs_r, bs_c, density, dtype):
"""Generate a random sparse matrix in bsr format.
Returns
-------
scipy.sparse.bsr_matrix
"""
# pylint: disable=import-outside-toplevel
import numpy as np
import itertools
import scipy.sparse as sp

y = np.zeros((m, n), dtype=dtype)
assert m % bs_r == 0
assert n % bs_c == 0
nnz = int(density * m * n)
num_blocks = int(nnz / (bs_r * bs_c)) + 1
candidate_blocks = np.asarray(list(itertools.product(range(0, m, bs_r), range(0, n, bs_c))))
assert candidate_blocks.shape[0] == m // bs_r * n // bs_c
chosen_blocks = candidate_blocks[
np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False)
]
# pylint: disable=invalid-name
for (r, c) in chosen_blocks:
y[r : r + bs_r, c : c + bs_c] = np.random.randn(bs_r, bs_c)
s = sp.bsr_matrix(y, blocksize=(bs_r, bs_c))
assert s.data.shape == (num_blocks, bs_r, bs_c)
assert s.indices.shape == (num_blocks,)
assert s.indptr.shape == (m // bs_r + 1,)
return s


def random_sparse_dense_params(func, params, bs_r, bs_c, density):
"""Replace the dense parameters with random sparse parameters. Mainly used for testing.
Parameters
----------
func : tvm.relay.Expr
Expr will be optimized to sparse operation.
params : Dict[Srting, tvm.nd.array]
Parameters of the Expr.
bs_r : int
The row of BSR matrix block.
bs_c : int
The column of BSR matrix block.
density : float
The density of the random sparse parameters.
Returns
-------
Dict[Srting, tvm.nd.array]
The generated random parameters.
"""

def deepcopy(param_dic):
ret = {}
for k, v in param_dic.items():
ret[k] = tvm.nd.array(v.asnumpy())
return ret

new_params = deepcopy(params)
dense_weight_names = relay.analysis.sparse_dense._search_dense_op_weight(func)
for item in dense_weight_names:
name = str(item)
shape = new_params[name].shape
if shape[0] % bs_r == 0 and shape[1] % bs_c == 0:
new_w = random_bsr_matrix(shape[0], shape[1], bs_r, bs_c, density, "float32").todense()
new_params[name] = tvm.nd.array(new_w)
return new_params


def convert_model_dense_to_sparse(mod, params, random_params=False, bs_r=1, bs_c=1, sparsity=0.85):
"""Convert a dense model to sparse model.
Parameters
----------
mod : tvm.Module
The dense model.
params : Dict[Srting, tvm.nd.array]
Parameters of the dense model.
random_params : Bool = False
True to replace the parameters of the dense model with some random sparse tensors.
This is mainly used for testing.
bs_r : int
The row of BSR matrix block.
bs_c : int
The column of BSR matrix block.
sparsity : float
The sparsity of the random sparse parameters.
Returns
-------
tvm.Module
The updated sparse model.
Dict[Srting, tvm.nd.array]
The updated parameters.
"""
mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params)
if random_params:
# Manually replace the parameters of dense model to sparse tensors
params = random_sparse_dense_params(mod, params, bs_r=bs_r, bs_c=bs_c, density=1 - sparsity)
# Currently we only support to conver dense matmul to sparse dense matmul
mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, bs_c), sparsity_threshold=0.8)

return tvm.IRModule.from_expr(mod), params
Loading

0 comments on commit a48cb45

Please sign in to comment.