Skip to content

Commit

Permalink
[TIR][Schedule] Transform layout quality of life (apache#11269)
Browse files Browse the repository at this point in the history
* [TIR][Schedule] Added Schedule.transform_layout_sugared

* [TE][TIR] Reduced duplication in TE/TIR layout transformations

Previously, the implementations of `tir.IndexMap.from_func` and
`te.Stage.transform_layout` had significant duplication to handle
argument parsing.  This commit extracts the shared logic into
`tir.IndexMap`.

* Enabled *args in Schedule.transform_layout_sugared

* Fix lint error

* Allow Schedule.transform_layout_sugared to set axis separators

* Merged transform_layout_sugared functionality into transform_layout

* Fix lint errors

* Fix lint error

* Fixed docstring errors

* Updated/tested TransformatLayoutTraits::UnpackedAsPython

* Disabled exec-used check for running trace.as_python()

* Updated SetAxisSeparatorTraits::UnpackedAsPython

* Updated unit test that was added in merge commit

* Fixed the argument name for TensorizeTraits

This wasn't checked before, but was the only other issue caught by the
updates to verify_trace_roundtrip.

* Re-enable type checks of transform_layout/set_axis_separator

Disabled while waiting for apache#11289,
which was required for the `Tuple` argument.

* Updated a few additional transform_layout usages from main
  • Loading branch information
Lunderberg authored and juda committed Jun 21, 2022
1 parent e0a4e81 commit d8b1fd1
Show file tree
Hide file tree
Showing 9 changed files with 385 additions and 132 deletions.
70 changes: 9 additions & 61 deletions python/tvm/te/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from tvm.runtime import Object, convert
from tvm.ir import container as _container
from tvm.tir import IterVar, Buffer, Var
from tvm.tir import IterVar, Buffer, Var, IndexMap

from . import tensor as _tensor
from . import _ffi_api
Expand Down Expand Up @@ -599,65 +599,12 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr
"""

args = []
var_arg_name = None
kwargs = collections.OrderedDict()
default_index_dtype = "int32"

# Make a dummy variable for each explicitly named input index.
# We may have some keyword-only arguments, if the function has
# *args before the last argument.
params = inspect.signature(mapping_function).parameters
for name, param in params.items():
if param.kind in [
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
]:
args.append(tvm.tir.Var(name, default_index_dtype))

elif param.kind == inspect.Parameter.VAR_POSITIONAL:
var_arg_name = name

elif param.kind == inspect.Parameter.KEYWORD_ONLY:
kwargs[name] = tvm.tir.Var(name, default_index_dtype)

elif param.kind in [inspect.Parameter.VAR_KEYWORD]:
raise ValueError("transform_layout mapping may not have **kwargs")

ndim = len(self.op.output(0).shape)
index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim=ndim)

# Now that all the named arguments have been collected,
# everything that remains should go to the *args, if
# specified.
if var_arg_name is not None:
num_var_args = ndim - len(args) - len(kwargs)
for i in range(num_var_args):
args.append(tvm.tir.Var(f"{var_arg_name}[{i}]", default_index_dtype))

initial_indices = args + list(kwargs.values())
if len(initial_indices) != ndim:
raise ValueError(
f"transform_layout mapping accepts {len(params)} initial indices, "
f"but {self.op.name} is {len(self.op.shape)}-dimensional"
)

mapping = mapping_function(*args, **kwargs)

final_indices = []
axis_separators = []
for val in mapping:
if isinstance(val, tvm.ir.PrimExpr):
final_indices.append(val)
elif val is AXIS_SEPARATOR:
axis_separators.append(len(final_indices))
else:
raise TypeError(
"Expected mapping function to return list of "
"either tvm.ir.PrimExpr or tvm.te.AXIS_SEPARATOR. "
"Instead received {val} of type {type(val)}."
)

new_iter_vars = _ffi_api.StageTransformLayout(self, initial_indices, final_indices)
new_iter_vars = _ffi_api.StageTransformLayout(
self, index_map.initial_indices, index_map.final_indices
)
_ffi_api.StageSetAxisSeparators(self, axis_separators)

return new_iter_vars or None
Expand Down Expand Up @@ -700,9 +647,10 @@ def __exit__(self, ptype, value, trace):


# Sentinel value used to indicate which groups of pre-flattening axes
# should be used to post-flattening axes axes. See
# Stage.transform_layout for more details.
AXIS_SEPARATOR = "axis_separator"
# should be used to post-flattening axes axes. Moved from
# te.AXIS_SEPARATOR to tir.IndexMap.AXIS_SEPARATOR for general use,
# maintained here for backwards compatibility.
AXIS_SEPARATOR = IndexMap.AXIS_SEPARATOR


tvm._ffi._init_api("schedule", __name__)
103 changes: 96 additions & 7 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# under the License.
"""Function data types."""

from typing import Callable, List, Mapping, Optional, Union, Tuple
import collections
import inspect
from typing import Callable, List, Mapping, Optional, Union, Tuple

import tvm
import tvm._ffi
Expand Down Expand Up @@ -258,6 +259,11 @@ class IndexMap(Object):
initial_indices: List[Var]
final_indices: List[PrimExpr]

# Sentinel value used to indicate which groups of pre-flattening axes
# should be used to post-flattening axes axes. See
# Stage.transform_layout for more details.
AXIS_SEPARATOR = "axis_separator"

def __init__(self, initial_indices, final_indices):
self.__init_handle_by_constructor__(_ffi_api.IndexMap, initial_indices, final_indices)

Expand All @@ -268,34 +274,117 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None):
Parameters
----------
mapping_function : Callable
The function to map from source indices to target indices
The function to map from source indices to target indices.
The function should accept `tir.Var` parameters and return
a list. Each element of the returned list should be a
`tir.PrimExpr`.
ndim: Optional[int]
The dimensionality of the buffer to which this
transformation should be applied. If mapping_function uses
variadic argument `*args`, `ndim` must be specified. If
mapping_function does not use variadic arguments, ndim is
optional.
Returns
-------
index_map: IndexMap
Returns an IndexMap representing the `mapping_function`.
"""
index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim)
assert not axis_separators, (
"The mapping_function provided to IndexMap.from_func "
"may not return IndexMap.AXIS_SEPARATOR. "
"If required, please use IndexMap.from_func_with_separators instead."
)
return index_map

@staticmethod
def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = None):
"""Create an index map from a function
Parameters
----------
mapping_function : Callable
The function to map from source indices to target indices.
The function should accept tir.Var parameters and return a
list. Each element of the returned list should be either a
`tir.PrimExpr` or the object `IndexMap.AXIS_SEPARATOR`.
ndim: Optional[int]
The dimensionality of the buffer to which this
transformation should be applied. If mapping_function uses
variadic argument `*args`, ndim must be specified. If
mapping_function does not use variadic arguments, ndim is
optional.
Returns
-------
ret: Tuple[IndexMap, List[int]]
Returns a tuple whose first element is an IndexMap
representing the `mapping_function`, and whose second index
is a list of indices at which `IndexMap.AXIS_SEPARATOR`
occurred.
"""
params = inspect.signature(mapping_function).parameters
default_index_dtype = "int32"

args = []
var_arg_name = None
kwargs = collections.OrderedDict()
default_index_dtype = "int32"

for name, param in params.items():
if param.kind in [
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
]:
args.append(tvm.tir.Var(name, default_index_dtype))

elif param.kind == inspect.Parameter.VAR_POSITIONAL:
var_arg_name = name

elif param.kind == inspect.Parameter.KEYWORD_ONLY:
kwargs[name] = tvm.tir.Var(name, default_index_dtype)

else:
raise ValueError("transform_layout mapping may not have *args or **kwargs")
raise ValueError("transform_layout mapping may not have *args")

# Now that all the named arguments have been collected,
# everything that remains should go to the *args, if
# specified.
if var_arg_name is not None:
assert ndim is not None, "ndim must be specified when *args is used"
num_var_args = ndim - len(args)
num_var_args = ndim - len(args) - len(kwargs)
for i in range(num_var_args):
args.append(tvm.tir.Var(f"{var_arg_name}_{i}", default_index_dtype))

final_indices = mapping_function(*args)
return IndexMap(args, final_indices)
mapping = mapping_function(*args, **kwargs)

initial_indices = args + list(kwargs.values())

final_indices = []
axis_separators = []
for val in mapping:
if isinstance(val, tvm.ir.PrimExpr):
final_indices.append(val)
elif val is IndexMap.AXIS_SEPARATOR:
axis_separators.append(len(final_indices))
else:
raise TypeError(
"Expected mapping function to return list of "
"either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. "
"Instead received {val} of type {type(val)}."
)

return IndexMap(initial_indices, final_indices), axis_separators

def is_equivalent_to(self, other_map: "IndexMap") -> bool:
"""Return if the index maps are equivalent.
Expand Down
Loading

0 comments on commit d8b1fd1

Please sign in to comment.