Skip to content

Commit

Permalink
[Schedule] Allowed string argument as block arg
Browse files Browse the repository at this point in the history
This has previously been implemented for `Schedule.transform_layout`
in apache#11296, extending to allow for
block arguments in all `Schedule` methods.

This change was only made for arguments that must be a `BlockRV`.  For
arguments that may be either a `BlockRV` or another
type (e.g. `Schedule.get_child_blocks` accepts either `BlockRV` or
`LoopRV`), this sugar is not implemented, to avoid ambiguity.
  • Loading branch information
Lunderberg committed Jun 8, 2022
1 parent 609d6af commit c6a43f9
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 213 deletions.
82 changes: 52 additions & 30 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,14 +373,14 @@ def sample_perfect_tile(
@type_checked
def sample_compute_location(
self,
block: BlockRV,
block: Union[BlockRV, str],
decision: Optional[int] = None,
) -> LoopRV:
"""Sample a compute-at location of the given block
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The block whose compute-at location is to be sampled
decision : Optional[int]
The sampling decision
Expand All @@ -390,6 +390,8 @@ def sample_compute_location(
result : LoopRV
The sampled loop where the input block is to be computed at
"""
block = self._normalize_block_arg(block)

return _ffi_api.ScheduleSampleComputeLocation( # type: ignore # pylint: disable=no-member
self,
block,
Expand Down Expand Up @@ -425,19 +427,20 @@ def get_block(
)

@type_checked
def get_loops(self, block: BlockRV) -> List[LoopRV]:
def get_loops(self, block: Union[BlockRV, str]) -> List[LoopRV]:
"""Get the parent loops of the block in its scope, from outer to inner
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The query block
Returns
-------
loops : List[LoopRV]
A list of loops above the given block in its scope, from outer to inner
"""
block = self._normalize_block_arg(block)
return list(_ffi_api.ScheduleGetLoops(self, block)) # type: ignore # pylint: disable=no-member

@type_checked
Expand All @@ -457,35 +460,37 @@ def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockR
return list(_ffi_api.ScheduleGetChildBlocks(self, block_or_loop)) # type: ignore # pylint: disable=no-member

@type_checked
def get_producers(self, block: BlockRV) -> List[BlockRV]:
def get_producers(self, block: Union[BlockRV, str]) -> List[BlockRV]:
"""Get the producers of a specific block
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The block in the query
Returns
-------
producers : List[BlockRV]
A list of producers of the given block
"""
block = self._normalize_block_arg(block)
return list(_ffi_api.ScheduleGetProducers(self, block)) # type: ignore # pylint: disable=no-member

@type_checked
def get_consumers(self, block: BlockRV) -> List[BlockRV]:
def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]:
"""Get the consumers of a specific block
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The block in the query
Returns
-------
consumers : List[BlockRV]
A list of consumers of the given block
"""
block = self._normalize_block_arg(block)
return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore # pylint: disable=no-member

########## Schedule: Transform loops ##########
Expand Down Expand Up @@ -970,7 +975,9 @@ def after_unroll(a: T.handle, b: T.handle) -> None:
########## Schedule: Insert cache stages ##########

@type_checked
def cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) -> BlockRV:
def cache_read(
self, block: Union[BlockRV, str], read_buffer_index: int, storage_scope: str
) -> BlockRV:
"""Create a block that reads a buffer region into a read cache. It requires:
1) There is at most one block who write the buffer in the scope.
Expand All @@ -979,7 +986,7 @@ def cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str)
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The consumer block of the target buffer.
read_buffer_index: int
Expand Down Expand Up @@ -1036,12 +1043,15 @@ def after_cache_read(a: T.handle, b: T.handle) -> None:
B[vi, vj] = A_local[vi, vj] * 2.0
"""
block = self._normalize_block_arg(block)
return _ffi_api.ScheduleCacheRead( # type: ignore # pylint: disable=no-member
self, block, read_buffer_index, storage_scope
)

@type_checked
def cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: str) -> BlockRV:
def cache_write(
self, block: Union[BlockRV, str], write_buffer_index: int, storage_scope: str
) -> BlockRV:
"""Create a block that reads a buffer region into a write cache. It requires:
1) There is only one block who write the buffer in the scope.
Expand All @@ -1050,7 +1060,7 @@ def cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: st
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The producer block of the target buffer.
write_buffer_index: int
Expand Down Expand Up @@ -1108,12 +1118,15 @@ def after_cache_write(a: T.handle, b: T.handle) -> None:
B[vi, vj] = B_local[vi, vj]
"""
block = self._normalize_block_arg(block)
return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint: disable=no-member
self, block, write_buffer_index, storage_scope
)

@type_checked
def reindex(self, block: BlockRV, buffer_index: int, buffer_index_type: str) -> BlockRV:
def reindex(
self, block: Union[BlockRV, str], buffer_index: int, buffer_index_type: str
) -> BlockRV:
"""Create a block that read/write a buffer region into a read/write cache with reindexing.
The layout of the cache will be the same as by the iterators of the block that reads/writes
the buffer. It requires:
Expand All @@ -1122,7 +1135,7 @@ def reindex(self, block: BlockRV, buffer_index: int, buffer_index_type: str) ->
Parameters
----------
block: BlockRV
block: Union[BlockRV, str]
The block that accesses the target buffer
buffer_index: int
The index of the buffer in block's read or write region
Expand Down Expand Up @@ -1179,6 +1192,7 @@ def after_reindex(
B[vi, vj] = A_reindex[vi, vj] * 2.0
"""
block = self._normalize_block_arg(block)
assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type"
buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
return _ffi_api.ScheduleReIndex( # type: ignore # pylint: disable=no-member
Expand All @@ -1190,7 +1204,7 @@ def after_reindex(
@type_checked
def compute_at(
self,
block: BlockRV,
block: Union[BlockRV, str],
loop: LoopRV,
preserve_unit_loops: bool = False,
) -> None:
Expand All @@ -1213,7 +1227,7 @@ def compute_at(
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The block to be moved
loop: LoopRV
Expand Down Expand Up @@ -1273,6 +1287,7 @@ def after_compute_at(a: T.handle, c: T.handle) -> None:
C[vi, vj] = B[vi, vj] + 1.0
"""
block = self._normalize_block_arg(block)
_ffi_api.ScheduleComputeAt( # type: ignore # pylint: disable=no-member
self,
block,
Expand All @@ -1283,7 +1298,7 @@ def after_compute_at(a: T.handle, c: T.handle) -> None:
@type_checked
def reverse_compute_at(
self,
block: BlockRV,
block: Union[BlockRV, str],
loop: LoopRV,
preserve_unit_loops: bool = False,
) -> None:
Expand All @@ -1303,7 +1318,7 @@ def reverse_compute_at(
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The block to be moved
loop: LoopRV
Expand Down Expand Up @@ -1363,6 +1378,7 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None:
C[vi, vj] = B[vi, vj] + 1.0
"""
block = self._normalize_block_arg(block)
_ffi_api.ScheduleReverseComputeAt( # type: ignore # pylint: disable=no-member
self,
block,
Expand All @@ -1371,7 +1387,7 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None:
)

@type_checked
def compute_inline(self, block: BlockRV) -> None:
def compute_inline(self, block: Union[BlockRV, str]) -> None:
"""Inline a block into its consumer(s). It requires:
1) The block is a complete non-root block, which only produces one buffer
Expand All @@ -1386,7 +1402,7 @@ def compute_inline(self, block: BlockRV) -> None:
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The block to be inlined to its consumer(s)
Examples
Expand Down Expand Up @@ -1432,10 +1448,11 @@ def after_inline(a: T.handle, c: T.handle) -> None:
C[vi, vj] = A[vi, vj] * 2.0 + 1.0
"""
block = self._normalize_block_arg(block)
_ffi_api.ScheduleComputeInline(self, block) # type: ignore # pylint: disable=no-member

@type_checked
def reverse_compute_inline(self, block: BlockRV) -> None:
def reverse_compute_inline(self, block: Union[BlockRV, str]) -> None:
"""Inline a block into its only producer. It requires:
1) The block is a complete non-root block, which only produces and consumes one buffer
Expand All @@ -1453,7 +1470,7 @@ def reverse_compute_inline(self, block: BlockRV) -> None:
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The block to be inlined to its producer
Examples
Expand Down Expand Up @@ -1499,12 +1516,13 @@ def after_inline(a: T.handle, c: T.handle) -> None:
C[vi, vj] = A[vi, vj] * 2.0 + 1.0
"""
block = self._normalize_block_arg(block)
_ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore # pylint: disable=no-member

########## Schedule: Reduction ##########

@type_checked
def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV:
def decompose_reduction(self, block: Union[BlockRV, str], loop: LoopRV) -> BlockRV:
"""Decompose a reduction block into two separate blocks.
a) The init block, which is translated from the init statement of the reduction block;
Expand All @@ -1523,7 +1541,7 @@ def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV:
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The reduction block to be decomposed
loop : LoopRV
The loop above which the init block is inserted before.
Expand Down Expand Up @@ -1578,6 +1596,7 @@ def after_decompose(a: ty.handle, c: ty.handle) -> None:
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
"""
block = self._normalize_block_arg(block)
return _ffi_api.ScheduleDecomposeReduction(self, block, loop) # type: ignore # pylint: disable=no-member

@type_checked
Expand Down Expand Up @@ -1734,7 +1753,7 @@ def after_rfactor(a: T.handle, b: T.handle) -> None:
@type_checked
def storage_align( # pylint: disable=too-many-arguments
self,
block: BlockRV,
block: Union[BlockRV, str],
buffer_index: int,
axis: int,
factor: int,
Expand All @@ -1747,7 +1766,7 @@ def storage_align( # pylint: disable=too-many-arguments
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The producer block of the buffer.
buffer_index : int
The index of the buffer in block's write region.
Expand Down Expand Up @@ -1812,18 +1831,19 @@ def after_storage_align(a: T.handle, c: T.handle) -> None:
----
Storage_align requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
"""
block = self._normalize_block_arg(block)
_ffi_api.ScheduleStorageAlign( # type: ignore # pylint: disable=no-member
self, block, buffer_index, axis, factor, offset
)

@type_checked
def set_scope(self, block: BlockRV, buffer_index: int, storage_scope: str) -> None:
def set_scope(self, block: Union[BlockRV, str], buffer_index: int, storage_scope: str) -> None:
"""Set the storage scope of a buffer, where the buffer is
specified by the a block and a write-index
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The producer block of the buffer
buffer_index : int
The index of the buffer in block's write region
Expand Down Expand Up @@ -1883,6 +1903,7 @@ def after_set_scope(
----
Set_scope requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
"""
block = self._normalize_block_arg(block)
_ffi_api.ScheduleSetScope( # type: ignore # pylint: disable=no-member
self, block, buffer_index, storage_scope
)
Expand Down Expand Up @@ -2418,14 +2439,14 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
@type_checked
def transform_block_layout(
self,
block: BlockRV,
block: Union[BlockRV, str],
index_map: Union[IndexMap, Callable],
) -> None:
"""Apply a transformation represented by IndexMap to block
Parameters
----------
block : BlockRV
block : Union[BlockRV, str]
The block to be transformed
index_map : Union[IndexMap, Callable]
Expand Down Expand Up @@ -2470,6 +2491,7 @@ def after_transform_block_layout(
vi, = T.axis.remap("S", [i])
B[vi // 16, vi % 16] = A[vi // 16, vi % 16] * 2.0
"""
block = self._normalize_block_arg(block)
if callable(index_map):
index_map = IndexMap.from_func(index_map)
_ffi_api.ScheduleTransformBlockLayout( # type: ignore # pylint: disable=no-member
Expand Down
Loading

0 comments on commit c6a43f9

Please sign in to comment.