Skip to content

Commit

Permalink
[TIR][Schedule] Transform layout (apache#10538)
Browse files Browse the repository at this point in the history
* [TIR][Schedule] Transform layout

* address commens

* fix

* doc

* Address comments

* remove unused

* Use BufferIndexType enum

* lint

* support *args

* lint

* lint
  • Loading branch information
vinx13 committed Mar 23, 2022
1 parent 9849b89 commit 5679bd2
Show file tree
Hide file tree
Showing 18 changed files with 696 additions and 44 deletions.
16 changes: 16 additions & 0 deletions include/tvm/tir/index_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,34 @@ class IndexMapNode : public Object {
*/
Array<PrimExpr> MapShape(const Array<PrimExpr>& shape) const;

/*!
* \brief Convert to string representation in Python.
* \return The stringified lambda expression in Python.
*/
String ToPythonString() const;

void VisitAttrs(AttrVisitor* v) {
v->Visit("initial_indices", &initial_indices);
v->Visit("final_indices", &final_indices);
}

static constexpr const char* _type_key = "tir.IndexMap";

TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object);
};

class IndexMap : public ObjectRef {
public:
IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices);

/*!
* \brief Create an index map from a packed function
* \param ndim The number of dimensions
* \param func The function to be applied
* \return The created index map
*/
static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func);

/*! \brief Generate the inverse mapping.
*
* The range of the input indices is required in order to ensure
Expand Down
24 changes: 24 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_TIR_SCHEDULE_SCHEDULE_H_

#include <tvm/support/random_engine.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>

Expand All @@ -36,6 +37,14 @@ enum class ScheduleErrorRenderLevel : int32_t {
kNone = 2,
};

/*! \brief Type of buffer index */
enum class BufferIndexType : int32_t {
/*! \brief Index of a read buffer */
kRead = 0,
/*! \brief Index of a written buffer */
kWrite = 1,
};

/**************** Random variable: BlockRV ****************/

/*! \brief A random variable that evaluates to a TensorIR block */
Expand Down Expand Up @@ -521,6 +530,21 @@ class ScheduleNode : public runtime::Object {
*/
virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0;

/******** Schedule: Layout transformation ********/
/*!
* \brief Apply a transformation represented by IndexMap to buffer
* \details The indices and the access region to the target buffer is transformed by the given
* index_map. The index_map is used to infer the new shape of the buffer. Buffer must be either
* a function parameter, or allocated in a block (it cannot be a buffer subregion created via
* 'match_buffer').
* \param block_rv The block that accesses the target buffer.
* \param buffer_index The index of the buffer in block's read or write region.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \param index_map The transformation to apply.
*/
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map) = 0;

/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError, BufferType

from . import schedule
from . import ir_builder
Expand Down
58 changes: 57 additions & 1 deletion python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# under the License.
"""Function data types."""

from typing import Mapping, Union
from typing import Callable, List, Mapping, Optional, Union
import inspect

import tvm._ffi
import tvm.runtime
Expand Down Expand Up @@ -239,3 +240,58 @@ def get(name: str):
The TensorIntrin with the specified name.
"""
return _ffi_api.TensorIntrinGet(name) # pylint: type: ignore


@tvm._ffi.register_object("tir.IndexMap")
class IndexMap(Object):
"""A mapping from multi-dimensional indices to another set of multi-dimensional indices
Parameters
----------
initial_indices : List[Var]
Variables representing the indices prior to remapping.
final_indices : List[PrimExpr]
Expressions defining the indices after remapping.
"""

initial_indices: List[Var]
final_indices: List[PrimExpr]

def __init__(self, initial_indices, final_indices):
self.__init_handle_by_constructor__(_ffi_api.IndexMap, initial_indices, final_indices)

@staticmethod
def from_func(mapping_function: Callable, ndim: Optional[int] = None):
"""Create an index map from a function
Parameters
----------
mapping_function : Callable
The function to map from source indices to target indices
"""
params = inspect.signature(mapping_function).parameters
default_index_dtype = "int32"
args = []
var_arg_name = None
for name, param in params.items():
if param.kind in [
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
]:
args.append(tvm.tir.Var(name, default_index_dtype))
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
var_arg_name = name
else:
raise ValueError("transform_layout mapping may not have *args or **kwargs")

# Now that all the named arguments have been collected,
# everything that remains should go to the *args, if
# specified.
if var_arg_name is not None:
assert ndim is not None, "ndim must be specified when *args is used"
num_var_args = ndim - len(args)
for i in range(num_var_args):
args.append(tvm.tir.Var(f"{var_arg_name}_{i}", default_index_dtype))

final_indices = mapping_function(*args)
return IndexMap(args, final_indices)
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@

from .block_scope import BlockScope, Dependency, DepKind, StmtSRef
from .instruction import Instruction, InstructionKind
from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError
from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError, BufferType
from .state import ScheduleDebugMask, ScheduleState
from .trace import Trace
87 changes: 86 additions & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""The TensorIR schedule class"""
from typing import Dict, List, Optional, Union
import enum
from typing import Callable, Dict, List, Optional, Union

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object, String
from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc
from ..function import IndexMap

from . import _ffi_api
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
Expand Down Expand Up @@ -71,6 +73,13 @@ def __init__(self) -> None:
}


class BufferType(enum.IntEnum):
"""Type of buffer in access regions of a block"""

READ = 0
WRITE = 1


def _parse_error_render_level(error_render_level: str) -> int:
if error_render_level not in _ERROR_RENDER_LEVEL:
raise ValueError(
Expand Down Expand Up @@ -2111,6 +2120,82 @@ def after_unannotate(a: T.handle, b: T.handle) -> None:
self, block_or_loop, ann_key
)

########## Schedule: Layout transformation ##########

@type_checked
def transform_layout(
self,
block: BlockRV,
buffer_index: int,
buffer_type: BufferType,
index_map: Union[IndexMap, Callable],
) -> None:
"""Apply a transformation represented by IndexMap to buffer
Parameters
----------
block_rv : BlockRV
The block that accesses the target buffer
buffer_index: int
The index of the buffer in block's read or write region
buffer_type : BufferType
Type of the buffer, READ or WRITE.
index_map : Union[IndexMap, Callable]
The transformation to apply
Examples
--------
Before transform_layout, in TensorIR, the IR is:
.. code-block:: python
@T.prim_func
def before_transform_layout(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128), "float32")
B = T.alloc_buffer((128, 128), "float32")
C = T.match_buffer(c, (128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
Create the schedule and do transform_layout:
.. code-block:: python
sch = tir.Schedule(before_storage_align)
sch.transform_layout(sch.get_block("B"), buffer_index=0, BufferType.WRITE,
index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16))
print(sch.mod["main"].script())
After applying transform_layout, the IR becomes:
.. code-block:: python
@T.prim_func
def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128), "float32")
B = T.alloc_buffer((8, 8, 16, 16), "float32")
C = T.match_buffer(c, (128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0
"""
if callable(index_map):
index_map = IndexMap.from_func(index_map)
_ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member
self, block, buffer_index, buffer_type, index_map
)

########## Schedule: Misc ##########

@type_checked
Expand Down
51 changes: 50 additions & 1 deletion src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/arith/int_set.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include <sstream>

Expand All @@ -40,6 +41,15 @@ IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices) {
data_ = std::move(n);
}

IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func) {
Array<Var> initial_indices;
initial_indices.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
initial_indices.push_back(Var("i" + std::to_string(i), DataType::Int(32)));
}
return IndexMap(initial_indices, func(initial_indices));
}

IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
// Dummy variables to represent the inverse's inputs.
Array<Var> output_vars;
Expand Down Expand Up @@ -142,13 +152,52 @@ Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape) const {
return output;
}

String IndexMapNode::ToPythonString() const {
std::unordered_set<std::string> used_names;
Map<Var, PrimExpr> var_remap;
for (const Var& initial_index : initial_indices) {
if (used_names.count(initial_index->name_hint)) {
std::string new_name = initial_index->name_hint + std::to_string(used_names.size());
used_names.insert(new_name);
var_remap.Set(initial_index, Var(new_name));
} else {
used_names.insert(initial_index->name_hint);
}
}
std::ostringstream oss;
oss << "lambda ";
for (size_t i = 0; i < initial_indices.size(); ++i) {
if (i != 0) {
oss << ", ";
}
auto it = var_remap.find(initial_indices[i]);
if (it != var_remap.end()) {
oss << (*it).second;
} else {
oss << initial_indices[i];
}
}
oss << ": (";
for (size_t i = 0; i < final_indices.size(); ++i) {
oss << Substitute(final_indices[i], var_remap);
oss << ", ";
}
oss << ")";
return String(oss.str());
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IndexMapNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IndexMapNode*>(node.get());
p->stream << "index_map(" << op->initial_indices << ", " << op->final_indices << ")";
p->stream << "index_map(" << op->ToPythonString() << ")";
});

TVM_REGISTER_NODE_TYPE(IndexMapNode);

TVM_REGISTER_GLOBAL("tir.IndexMap")
.set_body_typed([](Array<Var> initial_indices, Array<PrimExpr> final_indices) {
return IndexMap(initial_indices, final_indices);
});

} // namespace tir
} // namespace tvm
10 changes: 10 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,16 @@ struct ProducerConsumerSplit {
*/
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write);

/*!
* \brief Find the defining site of the buffer in the given block and its ancestors
* \param block_sref The block sref
* \param buffer The buffer
* \return The defining site of the buffer and whether the buffer is allocated (otherwise the
* buffer is from match_buffer).
*/
std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& block_sref,
const Buffer& buffer);

/******** Reduction Block Related ********/

/*!
Expand Down
Loading

0 comments on commit 5679bd2

Please sign in to comment.