Skip to content

Commit

Permalink
[Blocksparse] Pipeline for lowering dense model to sparse-dense (apac…
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon authored and dhruvaray committed Apr 28, 2020
1 parent f7ca70d commit a52ab12
Show file tree
Hide file tree
Showing 14 changed files with 798 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def get_package_data_files():
zip_safe=False,
install_requires=[
'numpy',
'scipy',
'decorator',
'attrs',
'psutil',
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from . import frontend
from . import backend
from . import quantize
from . import data_dep_optimization

# Dialects
from . import qnn
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@

# Feature
from . import feature
from . import sparse_dense
18 changes: 18 additions & 0 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,21 @@ def extract_fused_functions(mod):
for hash_, func in ret_mod.functions.items():
ret[hash_] = func
return ret


def search_fc_transpose(expr):
"""Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0]))
This function is used in the data_dep_optimization.simplify_fc_transpose method
Parameters
----------
expr : tvm.relay.Expr
Returns
-------
ret : Array[String]
Array of weight variable name in pattern y = nn.dense(x, transpose(w, [1, 0]))
"""
ret = _ffi_api.search_fc_transpose(expr)
return ret
93 changes: 93 additions & 0 deletions python/tvm/relay/analysis/sparse_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
"""
This file contains helper functions for convert dense model
to block sparse model
"""
from collections import namedtuple
import numpy as np
import scipy.sparse as sp
import tvm
from . import _ffi_api


SparseAnalysisResult = namedtuple("SparseAnalysisResult", [
"weight_name",
"weight_shape",
])

def _search_dense_op_weight(expr):
"""Search name of weight in all ```nn.dense``` operator
This is a helpful function to determine which param need
to be converted to sparse
Parameters
----------
expr : relay.Expr
Expr will be searched
Returns
-------
ret : Array[String]
name of weight in all ``nn.dense``` operator
"""
return _ffi_api.search_dense_op_weight(expr)


def process_params(expr, params, block_size, sparsity_threshold):
"""[summary]
Parameters
----------
expr : Relay.Expr
Expr of the network
params : Dict[String, tvm.nd.array]
parameters of the network
block_size : Tuple(int, int)
Blocksize in BSR matrix
sparsity_threshold : float
Minimal sparsity requirement for converting to sparse operation
Returns
-------
ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]]
return names of qualified dense weight and the shape in BSR format
"""
memo = SparseAnalysisResult(weight_name=[], weight_shape=[])
weight_names = _search_dense_op_weight(expr)
for name in weight_names:
name = str(name)
w_np = params[name].asnumpy()
sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size)
if sparsity >= sparsity_threshold:
sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size)
# remove dense weight
del params[name]
memo.weight_name.append(name)
memo.weight_shape.append(list(sparse_weight.data.shape) +
list(sparse_weight.indices.shape) +
list(sparse_weight.indptr.shape))
params[name + ".data"] = tvm.nd.array(sparse_weight.data)
params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)
ret = SparseAnalysisResult(
weight_name=tvm.runtime.convert(memo.weight_name),
weight_shape=tvm.runtime.convert(memo.weight_shape)
)
return ret
21 changes: 21 additions & 0 deletions python/tvm/relay/data_dep_optimization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=unused-argument, not-context-manager
"""Optimizations involves changing of paramters"""

from . import bsr_dense
from . import simplify_fc_transpose
57 changes: 57 additions & 0 deletions python/tvm/relay/data_dep_optimization/bsr_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=unused-argument, not-context-manager
"""Automatic convert model from dense to block sparse"""

from tvm import relay
from tvm.relay.analysis.sparse_dense import process_params

from .utils import _run_opt_pass

def convert(func, params, blocksize, sparsity_threshold):
"""Convert a dense func and according parameters to block sparse
Parameters
----------
func : relay.Expr
Expr will be optimized to sparse operation
params : Dict[Srting, tvm.nd.array]
Parameters of the Expr
blocksize : Tuple(int, int)
Blocksize for BSR matrix
sparsity_threshold : float
Minimal sparsity requirement for converting.
If weight sparsity is lower than this threshold,
the dense operation will be kept.
Returns
-------
new_func: relay.Expr
Mutated Expr with sparse operations
params: Dict[Srting, tvm.nd.array]
New params with BSR matrix for mutated Expr
"""
weight_info = process_params(func, params, blocksize, sparsity_threshold)
new_func = _run_opt_pass(
func,
relay.transform.DenseToSparse(
weight_info.weight_name,
weight_info.weight_shape
)
)
return new_func, params
60 changes: 60 additions & 0 deletions python/tvm/relay/data_dep_optimization/simplify_fc_transpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=unused-argument, not-context-manager
"""Automatic optimize fc tranpose"""
import numpy as np

import tvm
from tvm import relay
from tvm.relay.analysis import search_fc_transpose

from .utils import _run_opt_pass


def convert(func, params):
"""convert all ```y = nn.dense(x, transpose(w, [1, 0]))``` to
```y = nn.dense(x, wt)```
Parameters
----------
func : relay.Expr
Expr will be optimized
params : Dict[String, tvm.nd.array]
Parameters of Expr
Returns
-------
new_func : relay.Expr
Mutated Expr from ```y = nn.dense(x, transpose(w, [1, 0]))``` to
```y = nn.dense(x, wt)```
params: Dict[String, tvm.nd.array]
Parameters of mutated Expr, with weights pre-transposed
"""
weight_info = search_fc_transpose(func)
for item in weight_info:
name = str(item)
w_np = params[name].asnumpy()
new_w = np.transpose(w_np, axes=[1, 0])
params[name + ".T"] = tvm.nd.array(new_w)
del params[name]
new_func = _run_opt_pass(
func,
relay.transform.SimplifyFCTranspose(
weight_info,
)
)
return new_func, params
40 changes: 40 additions & 0 deletions python/tvm/relay/data_dep_optimization/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=unused-argument, not-context-manager
"""Utils functions for optimizations"""

import tvm

def _run_opt_pass(expr, opt_pass):
"""Helper function to run pass
Parameters
----------
expr : relay.Expr
Expr will be optimized
opt_pass : relay.Pass
Optimization pass
Returns
-------
ret: relay.Expr
Optimized Expr by running opt_pass
"""
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
return mod["main"]
40 changes: 40 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,3 +839,43 @@ def visit_var(self, var):
return relay.Var(var.name_hint, relay.TensorType(new_shape, ty.dtype))
return var
return ChangeBatchMutator().visit(func)


def DenseToSparse(weight_name, weight_shape):
"""
Rewrite qualified ```nn.dense operation``` to ```nn.sparse_dense```
This pass is used in ```data_dep_optimization.bsr_dense```
Parameters of this pass is generated by ```analysis.sparse_dense.process_params```
Parameters
----------
weight_name: Array[String]
Names of weights which qualified sparse contrains
weight_shape: Array[Array[IntImm]]
Weights shape in BSR format.
Returns
-------
ret : tvm.transform.Pass
The registered DenseToSparse pass.
"""
return _ffi_api.DenseToSparse(weight_name, weight_shape)

def SimplifyFCTranspose(target_weight_name):
"""
Rewrite ```y = nn.dense(x, transpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)```
This pass is used in ```data_dep_optimization.simplify_fc_transpose```
Parameters
----------
weight_name: Array[String]
Names of weights which qualified ```y = nn.dense(x, transpose(w, [1, 0]))```
This parameter is generated by ```analysis.search_fc_transpose``` function
Returns
-------
ret : tvm.transform.Pass
The registered SimplifyFCTranspose pass.
"""
return _ffi_api.SimplifyFCTranspose(target_weight_name)
Loading

0 comments on commit a52ab12

Please sign in to comment.