Skip to content

Commit

Permalink
[TIR][TVMScript] specialize (apache#8354)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy authored and ylc committed Jan 13, 2022
1 parent 1f53b90 commit a7c3b3c
Show file tree
Hide file tree
Showing 6 changed files with 630 additions and 2 deletions.
2 changes: 1 addition & 1 deletion include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);

/*!
* \brief Whether e expression used any var in variable set..
* \brief Whether e expression used any var in variable set.
* \param expr The expression to be checked.
* \param vset_contains The check function to see if var is in the vset.
* \return Whether e uses vset.
Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ class Buffer : public ObjectRef {
TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;

TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode);
};

/*!
Expand Down
38 changes: 38 additions & 0 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,44 @@ class LinkedParam : public ObjectRef {
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

/*!
* \brief Specialize parameters of PrimFunc.
* \param func The PrimFunc to be specialized.
* \param param_map The mapping from function params to the instance.
* \return The new function with parameter specialized.
* \note We can define a Meta TIR function with symbolic shape:
*
* \code
* @tvm.script.tir
* def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None:
* A = tir.match_buffer(a, (m, n), "float32")
* B = tir.match_buffer(b, (m, n), "float32")
*
* with tir.block([m, n], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* \endcode
*
* Then we can make it specialized with given shapes or buffers.
*
* \code
* a, _, m, n = mem_copy.params
* func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
* # or
* func = mem_copy.specialize({n: 16, m: 16})
* \endcode
*
* \code {.language-id}
* @tvm.script.tir
* def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None:
* A = tir.match_buffer(a, (16, 16), "float32")
* B = tir.match_buffer(b, (16, 16), "float32")
*
* with tir.block([16, 16], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* \endcode
*/
PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map);

/*!
* \brief PrimFunc specific attribute names.
*
Expand Down
55 changes: 54 additions & 1 deletion python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
# under the License.
"""Function data types."""

from typing import Mapping, Union

import tvm._ffi
import tvm.runtime
from tvm.runtime import Object
from tvm.ir import BaseFunc
from .buffer import Buffer
from .expr import Var
from .expr import Var, PrimExpr
from . import _ffi_api


Expand Down Expand Up @@ -85,3 +87,54 @@ def with_body(self, new_body, span=None):
The created new function.
"""
return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span)

def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]):
"""Specialize parameters of PrimFunc
Parameters
----------
param_map : Mapping[Var, Union[PrimExpr, Buffer]]
The mapping from function params to the instance
Examples
--------
We can define a Meta TIR function with symbolic shape:
.. code-block:: python
@tvm.script.tir
def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None:
A = tir.match_buffer(a, (m, n), "float32")
B = tir.match_buffer(b, (m, n), "float32")
with tir.block([m, n], "") as [vi, vj]:
B[vi, vj] = A[vi, vj]
Then we can make it specialized with given shapes or buffers.
.. code-block:: python
a, _, m, n = mem_copy.params
func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
# or
func = mem_copy.specialize({n: 16, m: 16})
The specialized function:
.. code-block:: python
@tvm.script.tir
def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (16, 16), "float32")
B = tir.match_buffer(b, (16, 16), "float32")
with tir.block([16, 16], "") as [vi, vj]:
B[vi, vj] = A[vi, vj]
Returns
-------
func : PrimFunc
The new function with parameter specialized
"""
return _ffi_api.Specialize(self, param_map)
Loading

0 comments on commit a7c3b3c

Please sign in to comment.