Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][M2a] Verification of cached flags #8114

Merged
merged 3 commits into from
May 25, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ using tir::IterVar;
using tir::Var;
using tir::VarNode;

class Analyzer;

//-----------------------------------------------
// Integer set data structure.
//
Expand Down Expand Up @@ -190,6 +192,14 @@ IntSet EvalSet(IntSet s, const std::unordered_map<const VarNode*, IntSet>& dom_m
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Range r, const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes Array<Range>
*
* \param region The range to be evaluated.
* \param dom_map The domain of each variable.
* \return An array of integer sets that can cover all the possible values.
*/
Array<IntSet> EvalSet(const Array<Range>& region, const Map<Var, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectPtrHash, ObjectPtrEqual>;
/*!
Expand All @@ -204,19 +214,53 @@ ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
* \brief Create a union set of all sets, possibly relaxed
* \param sets The sets to be combined
* \return the set after union
*/
IntSet Union(const Array<IntSet>& sets);

/*!
* \brief The union of N-dimensional integer sets
* \param nd_int_sets A list of N-dimensional integer sets
* \return An N-dimensional integer set as the result of union
*/
Array<IntSet> UnionRegion(const Array<Array<IntSet>>& nd_int_sets);

/*!
* \brief Create a lower-bound of union set, where some of the segments may be dropped
* \param sets The sets to be combined
* \return the set after union
*/
IntSet UnionLowerBound(const Array<IntSet>& sets);

/*!
* \brief The union of N-dimensional integer sets
* \param nd_int_sets A list of N-dimensional integer sets
* \return An N-dimensional integer set as the result of union
*/
Array<IntSet> UnionRegionLowerBound(const Array<Array<IntSet>>& nd_int_sets);

/*!
* \brief Create an union set of all sets
* \param sets The sets to be intersected
* \return the set after intersected
*/
IntSet Intersect(const Array<IntSet>& sets);

/*!
* \brief Analyze the region with affine map, given the domain of variables and their predicate
* \param region The region to be analyzed
* \param var_dom The ranges of the variables
* \param predicate The predicate for the affine map
* \param analyzer The analyzer used
* \return NullOpt if the detection fails, or an array of arith::IntSet as the result of analysis
*/
TVM_DLL Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
const Map<Var, Range>& var_dom,
const PrimExpr& predicate,
arith::Analyzer* analyzer);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_SET_H_
4 changes: 3 additions & 1 deletion include/tvm/tir/schedule/block_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ class BlockScopeNode : public Object {
* equivalent to of a stage pipeline. Under the following conditions:
*
* 1) The region cover property holds for every of its child blocks
* 2) No write-after-read dependency
* 2) No write-after-read dependency or opaque dependency, only read-after-write and
* write-after-write are allowed
* 3) All the statements in the scope are schedulable statements, i.e. Block and For
*/
bool stage_pipeline{false};

Expand Down
4 changes: 1 addition & 3 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,7 @@ class Schedule : public runtime::ObjectRef {
* \sa ScheduleDebugMask
* \note The checks performed includes:
* 1) VerifySRefTree
* 2) VerifyAffineBinding
* 3) VerifyRegionCover
* 4) VerifyStagePipeline
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
Expand Down
15 changes: 5 additions & 10 deletions include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,11 @@ struct BlockInfo {
* \brief The bitmask of the debug flag in the ScheduleStateNode.
* \sa ScheduleStateNode
*/
enum class ScheduleDebugMask : int32_t {
enum ScheduleDebugMask : uint32_t {
/*! \brief Verify the correctness of the sref tree */
kVerifySRefTree = 1,
/*! \brief Verify the correctness of affine_binding */
kVerifyAffineBinding = 2,
/*! \brief Verify the correctness of region_cover */
kVerifyRegionCover = 4,
/*! \brief Verify the correctness of stage_pipeline */
kVerifyStagePipeline = 8,
/*! \brief Verify the correctness of affine_binding, region_cover and stage_pipeline */
kVerifyCachedFlags = 2,
};

/*!
Expand Down Expand Up @@ -135,9 +131,8 @@ class ScheduleStateNode : public Object {
/*!
* \brief Trigger the verification according to the `debug_mode` bitmask.
* 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree.
* 2) If the bitmask `kVerifyAffineBinding` is on, verify the correctness of `affine_binding`
* 3) If the bitmask `kVerifyRegionCover` is on, verify the correctness of `region_cover`
* 4) If the bitmask `kVerifyStagePipeline` is on, verify the correctness of `stage_pipeline`
* 2) If the bitmask `kVerifyCachedFlags` is on, verify the correctness of `affine_binding`,
* `region_cover` and `stage_pipeline`
*/
TVM_DLL void DebugVerify() const;

Expand Down
8 changes: 8 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,14 @@ TVM_DLL Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var& v
*/
TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(const Var& var)> vmap);

/*!
* \brief Substitute the var specified by vmap.
* \param region The object whose vars are to be substituted
* \param vmap The map of new values.
* \return The result.
*/
TVM_DLL Array<Range> Substitute(const Array<Range>& region, const Map<Var, PrimExpr>& vmap);

/*!
* \brief Sugar for substitute via a given map.
* \param input The input to be updated.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""Integer bound analysis, simplification and pattern detection."""

from .int_set import IntSet, IntervalSet
from .int_set import IntSet, IntervalSet, estimate_region_lower_bound
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
Expand Down
60 changes: 60 additions & 0 deletions python/tvm/arith/int_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,63 @@ class IntervalSet(IntSet):

def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value)


def estimate_region_lower_bound(region, var_dom, predicate):
"""Analyze the region with affine map, given the domain of variables and their predicate

Parameters
----------
region : List[Range]
The region to be analyzed.

var_dom : Dict[Var, Range]
The ranges of the variables

predicate : PrimExpr
The predicate for the affine map

Returns
----------
region_int_set : Optional[List[IntSet]]
None if the detection fails, or an array of IntSets as the result of analysis
"""
return _ffi_api.EstimateRegionLowerBound(region, var_dom, predicate)


def pos_inf():
"""Returns the symbolic positive infinity

Returns
----------
pos_inf : Var
A symbolic var that indicates positive infinity
"""
return _ffi_api.PosInf()


def neg_inf():
"""Returns the symbolic positive infinity

Returns
----------
neg_inf : Var
A symbolic var that indicates positive infinity
"""
return _ffi_api.NegInf()


def union_lower_bound(sets):
"""Create a lower-bound of union set, where some of the segments may be dropped

Parameters
----------
sets : List[IntSet]
The sets to be combined

Returns
----------
union_lower_bound : List[IntSet]
An N-dimensional integer set, the lower bound of the union
"""
return _ffi_api.UnionLowerBound(sets)
4 changes: 1 addition & 3 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def __init__(
----------
The checks performed includes:
1) VerifySRefTree
2) VerifyAffineBinding
3) VerifyRegionCover
4) VerifyStagePipeline
2) VerifyCachedFlags
"""
if isinstance(debug_mode, bool):
if debug_mode:
Expand Down
45 changes: 36 additions & 9 deletions python/tvm/tir/schedule/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""This file defines ScheduleState, the core data structure of TensorIR scheduling."""
from collections import namedtuple
from enum import IntEnum
from typing import Dict, Optional, Union

Expand All @@ -26,6 +27,8 @@
from . import _ffi_api_schedule
from .block_scope import BlockScope, StmtSRef

CachedFlags = namedtuple("CachedFlags", ["affine_binding", "region_cover", "stage_pipeline"])


class ScheduleDebugMask(IntEnum):
"""The bitmask of the `debug_mode` flag in the ScheduleState class.
Expand All @@ -38,18 +41,12 @@ class ScheduleDebugMask(IntEnum):
----------
VERIFY_SREF_TREE : int = 1
Verify the correctness of the sref tree
VERIFY_AFFINE_BINDING : int = 2
Verify the correctness of affine_binding
VERIFY_REGION_COVER : int = 4
Verify the correctness of region_cover
VERIFY_STAGE_PIPELINE: int = 8
Verify the correctness of stage_pipeline
VERIFY_CACHED_FLAGS : int = 2
Verify the correctness of affine_binding, region_cover and stage_pipeline
"""

VERIFY_SREF_TREE = 1
VERIFY_AFFINE_BINDING = 2
VERIFY_REGION_COVER = 4
VERIFY_STAGE_PIPELINE = 8
VERIFY_CACHED_FLAGS = 2


@register_object("tir.ScheduleState")
Expand Down Expand Up @@ -140,6 +137,36 @@ def get_block_scope(self, block_sref: StmtSRef) -> BlockScope:
self, block_sref
)

def _get_cached_flags(self, block_sref: StmtSRef) -> CachedFlags:
"""Get the cached flags of the corresponding block

Parameters
----------
block_sref : StmtSRef
The block sref to be retrieved

Returns
-------
flags : CachedFlags
Three flags: affine_binding, region_cover, stage_pipeline

Note
-------
It is an API intended for internal testing use.
"""
(
affine_binding,
region_cover,
stage_pipeline,
) = _ffi_api_schedule.ScheduleStateGetCachedFlags( # pylint: disable=no-member
self, block_sref
)
return CachedFlags(
affine_binding=bool(affine_binding.value),
region_cover=bool(region_cover.value),
stage_pipeline=bool(stage_pipeline.value),
)

def replace(
self,
src_sref: StmtSRef,
Expand Down
Loading