Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Schedule] Transform layout quality of life #11269

Merged
merged 23 commits into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e401037
[TIR][Schedule] Added Schedule.transform_layout_sugared
Lunderberg May 10, 2022
7ef2312
[TE][TIR] Reduced duplication in TE/TIR layout transformations
Lunderberg May 10, 2022
bb215a1
Enabled *args in Schedule.transform_layout_sugared
Lunderberg May 10, 2022
3740d75
Fix lint error
Lunderberg May 11, 2022
b4195cd
Allow Schedule.transform_layout_sugared to set axis separators
Lunderberg May 11, 2022
c048d5c
Merged transform_layout_sugared functionality into transform_layout
Lunderberg May 11, 2022
7a78102
Fix lint errors
Lunderberg May 11, 2022
3229175
Fix lint error
Lunderberg May 12, 2022
ec66ff1
Fixed docstring errors
Lunderberg May 16, 2022
cefee79
Updated/tested TransformatLayoutTraits::UnpackedAsPython
Lunderberg May 17, 2022
db0f5ca
Merge branch 'main' into transform_layout_quality_of_life
Lunderberg May 17, 2022
5caf6f4
Disabled exec-used check for running trace.as_python()
Lunderberg May 17, 2022
2f63e96
Updated SetAxisSeparatorTraits::UnpackedAsPython
Lunderberg May 17, 2022
bd57ea2
Merge branch 'main' into transform_layout_quality_of_life
Lunderberg May 18, 2022
015ce32
Updated unit test that was added in merge commit
Lunderberg May 18, 2022
accf8ff
Fixed the argument name for TensorizeTraits
Lunderberg May 18, 2022
062a0c2
Merge branch 'main' into transform_layout_quality_of_life
Lunderberg May 20, 2022
99fb775
Re-enable type checks of transform_layout/set_axis_separator
Lunderberg May 20, 2022
90e0798
Merge branch 'main' into transform_layout_quality_of_life
Lunderberg May 20, 2022
a513f18
Updated a few additional transform_layout usages from main
Lunderberg May 23, 2022
5cc3e90
Merge branch 'main' into transform_layout_quality_of_life
Lunderberg May 24, 2022
00e7eb6
Merge branch 'main' into transform_layout_quality_of_life
Lunderberg May 25, 2022
d4b03fc
Merge branch 'main' into transform_layout_quality_of_life
Lunderberg May 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 9 additions & 61 deletions python/tvm/te/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from tvm.runtime import Object, convert
from tvm.ir import container as _container
from tvm.tir import IterVar, Buffer, Var
from tvm.tir import IterVar, Buffer, Var, IndexMap

from . import tensor as _tensor
from . import _ffi_api
Expand Down Expand Up @@ -599,65 +599,12 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr

"""

args = []
var_arg_name = None
kwargs = collections.OrderedDict()
default_index_dtype = "int32"

# Make a dummy variable for each explicitly named input index.
# We may have some keyword-only arguments, if the function has
# *args before the last argument.
params = inspect.signature(mapping_function).parameters
for name, param in params.items():
if param.kind in [
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
]:
args.append(tvm.tir.Var(name, default_index_dtype))

elif param.kind == inspect.Parameter.VAR_POSITIONAL:
var_arg_name = name

elif param.kind == inspect.Parameter.KEYWORD_ONLY:
kwargs[name] = tvm.tir.Var(name, default_index_dtype)

elif param.kind in [inspect.Parameter.VAR_KEYWORD]:
raise ValueError("transform_layout mapping may not have **kwargs")

ndim = len(self.op.output(0).shape)
index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim=ndim)

# Now that all the named arguments have been collected,
# everything that remains should go to the *args, if
# specified.
if var_arg_name is not None:
num_var_args = ndim - len(args) - len(kwargs)
for i in range(num_var_args):
args.append(tvm.tir.Var(f"{var_arg_name}[{i}]", default_index_dtype))

initial_indices = args + list(kwargs.values())
if len(initial_indices) != ndim:
raise ValueError(
f"transform_layout mapping accepts {len(params)} initial indices, "
f"but {self.op.name} is {len(self.op.shape)}-dimensional"
)

mapping = mapping_function(*args, **kwargs)

final_indices = []
axis_separators = []
for val in mapping:
if isinstance(val, tvm.ir.PrimExpr):
final_indices.append(val)
elif val is AXIS_SEPARATOR:
axis_separators.append(len(final_indices))
else:
raise TypeError(
"Expected mapping function to return list of "
"either tvm.ir.PrimExpr or tvm.te.AXIS_SEPARATOR. "
"Instead received {val} of type {type(val)}."
)

new_iter_vars = _ffi_api.StageTransformLayout(self, initial_indices, final_indices)
new_iter_vars = _ffi_api.StageTransformLayout(
self, index_map.initial_indices, index_map.final_indices
)
_ffi_api.StageSetAxisSeparators(self, axis_separators)

return new_iter_vars or None
Expand Down Expand Up @@ -700,9 +647,10 @@ def __exit__(self, ptype, value, trace):


# Sentinel value used to indicate which groups of pre-flattening axes
# should be used to post-flattening axes axes. See
# Stage.transform_layout for more details.
AXIS_SEPARATOR = "axis_separator"
# should be used to post-flattening axes axes. Moved from
# te.AXIS_SEPARATOR to tir.IndexMap.AXIS_SEPARATOR for general use,
# maintained here for backwards compatibility.
AXIS_SEPARATOR = IndexMap.AXIS_SEPARATOR


tvm._ffi._init_api("schedule", __name__)
103 changes: 96 additions & 7 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# under the License.
"""Function data types."""

from typing import Callable, List, Mapping, Optional, Union, Tuple
import collections
import inspect
from typing import Callable, List, Mapping, Optional, Union, Tuple

import tvm
import tvm._ffi
Expand Down Expand Up @@ -258,6 +259,11 @@ class IndexMap(Object):
initial_indices: List[Var]
final_indices: List[PrimExpr]

# Sentinel value used to indicate which groups of pre-flattening axes
# should be used to post-flattening axes axes. See
# Stage.transform_layout for more details.
AXIS_SEPARATOR = "axis_separator"

def __init__(self, initial_indices, final_indices):
self.__init_handle_by_constructor__(_ffi_api.IndexMap, initial_indices, final_indices)

Expand All @@ -268,34 +274,117 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None):
Parameters
----------
mapping_function : Callable
The function to map from source indices to target indices

The function to map from source indices to target indices.
The function should accept `tir.Var` parameters and return
a list. Each element of the returned list should be a
`tir.PrimExpr`.

ndim: Optional[int]

The dimensionality of the buffer to which this
transformation should be applied. If mapping_function uses
variadic argument `*args`, `ndim` must be specified. If
mapping_function does not use variadic arguments, ndim is
optional.

Returns
-------
index_map: IndexMap

Returns an IndexMap representing the `mapping_function`.

"""
index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim)
assert not axis_separators, (
"The mapping_function provided to IndexMap.from_func "
"may not return IndexMap.AXIS_SEPARATOR. "
"If required, please use IndexMap.from_func_with_separators instead."
)
return index_map

@staticmethod
def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = None):
"""Create an index map from a function

Parameters
----------
mapping_function : Callable

The function to map from source indices to target indices.
The function should accept tir.Var parameters and return a
list. Each element of the returned list should be either a
`tir.PrimExpr` or the object `IndexMap.AXIS_SEPARATOR`.

ndim: Optional[int]

The dimensionality of the buffer to which this
transformation should be applied. If mapping_function uses
variadic argument `*args`, ndim must be specified. If
mapping_function does not use variadic arguments, ndim is
optional.

Returns
-------
ret: Tuple[IndexMap, List[int]]

Returns a tuple whose first element is an IndexMap
representing the `mapping_function`, and whose second index
is a list of indices at which `IndexMap.AXIS_SEPARATOR`
occurred.

"""
params = inspect.signature(mapping_function).parameters
default_index_dtype = "int32"

args = []
var_arg_name = None
kwargs = collections.OrderedDict()
default_index_dtype = "int32"

for name, param in params.items():
if param.kind in [
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
]:
args.append(tvm.tir.Var(name, default_index_dtype))

elif param.kind == inspect.Parameter.VAR_POSITIONAL:
var_arg_name = name

elif param.kind == inspect.Parameter.KEYWORD_ONLY:
kwargs[name] = tvm.tir.Var(name, default_index_dtype)

else:
raise ValueError("transform_layout mapping may not have *args or **kwargs")
raise ValueError("transform_layout mapping may not have *args")

# Now that all the named arguments have been collected,
# everything that remains should go to the *args, if
# specified.
if var_arg_name is not None:
assert ndim is not None, "ndim must be specified when *args is used"
num_var_args = ndim - len(args)
num_var_args = ndim - len(args) - len(kwargs)
for i in range(num_var_args):
args.append(tvm.tir.Var(f"{var_arg_name}_{i}", default_index_dtype))

final_indices = mapping_function(*args)
return IndexMap(args, final_indices)
mapping = mapping_function(*args, **kwargs)

initial_indices = args + list(kwargs.values())

final_indices = []
axis_separators = []
for val in mapping:
if isinstance(val, tvm.ir.PrimExpr):
final_indices.append(val)
elif val is IndexMap.AXIS_SEPARATOR:
axis_separators.append(len(final_indices))
else:
raise TypeError(
"Expected mapping function to return list of "
"either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. "
"Instead received {val} of type {type(val)}."
)

return IndexMap(initial_indices, final_indices), axis_separators

def is_equivalent_to(self, other_map: "IndexMap") -> bool:
"""Return if the index maps are equivalent.
Expand Down
Loading