Skip to content

Commit

Permalink
[TIR][Schedule] Allow named block and buffer arguments in Schedule (#…
Browse files Browse the repository at this point in the history
…11624)

* [Schedule] Allowed string argument as block arg

This has previously been implemented for `Schedule.transform_layout`
in #11296, extending to allow for
block arguments in all `Schedule` methods.

This change was only made for arguments that must be a `BlockRV`.  For
arguments that may be either a `BlockRV` or another
type (e.g. `Schedule.get_child_blocks` accepts either `BlockRV` or
`LoopRV`), this sugar is not implemented, to avoid ambiguity.

* [Schedule] Allowed string argument to Schedule.reindex

Similar to #11269, which added this
functionality to `Schedule.transform_layout`.

* CI test update
  • Loading branch information
Lunderberg committed Jun 9, 2022
1 parent 81b42e6 commit af01281
Show file tree
Hide file tree
Showing 12 changed files with 291 additions and 227 deletions.
112 changes: 76 additions & 36 deletions python/tvm/tir/schedule/schedule.py

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1241,11 +1241,10 @@ struct ReIndexTraits : public UnpackedInstTraits<ReIndexTraits> {
Integer buffer_index_type) {
PythonAPICall py("reindex");
py.Input("block", block);
py.Input("buffer_index", buffer_index);
py.Input("buffer_index_type", '"' +
std::string(BufferIndexType2Str(
static_cast<BufferIndexType>(buffer_index_type->value))) +
'"');
std::ostringstream os;
os << "(\"" << BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
<< "\", " << buffer_index << ")";
py.Input("buffer", os.str());
py.SingleOutput(outputs);
return py.Str();
}
Expand Down
94 changes: 48 additions & 46 deletions tests/python/unittest/test_tir_schedule_cache_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,13 +741,15 @@ def block_predicate_cache_write_output_buf() -> None:

########## Testcases for cache_read ##########

use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True})

def test_cache_read_elementwise():

def test_cache_read_elementwise(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_c = sch.get_block("C")
cached_a = sch.cache_read(block_b, 0, "global")
cached_b = sch.cache_read(block_c, 0, "local")
cached_a = sch.cache_read("B" if use_block_name else block_b, 0, "global")
cached_b = sch.cache_read("C" if use_block_name else block_c, 0, "local")
assert sch.get(cached_a) == sch.get(sch.get_block("A_global"))
assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
assert sch.get(block_b) == sch.get(sch.get_block("B"))
Expand All @@ -756,87 +758,87 @@ def test_cache_read_elementwise():
verify_trace_roundtrip(sch=sch, mod=elementwise)


def test_cache_read_under_scope():
def test_cache_read_under_scope(use_block_name):
sch = tir.Schedule(access_under_scope, debug_mask="all")
block_b = sch.get_block("B")
block_c = sch.get_block("C")
block_b = "B" if use_block_name else sch.get_block("B")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_b, 0, "local")
sch.cache_read(block_c, 0, "global")
tvm.ir.assert_structural_equal(cache_read_under_scope, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=access_under_scope)


def test_cache_read_opaque_access():
def test_cache_read_opaque_access(use_block_name):
sch = tir.Schedule(opaque_access, debug_mask="all")
block = sch.get_block("load_store")
block = "load_store" if use_block_name else sch.get_block("load_store")
sch.cache_read(block, 0, "global")
tvm.ir.assert_structural_equal(cache_read_opaque_access, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=opaque_access)


def test_cache_read_location():
def test_cache_read_location(use_block_name):
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
sch.cache_read(block_b, 0, "global")
tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)


def test_continuous_cache_read():
def test_continuous_cache_read(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_c = sch.get_block("C")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_c, 0, "shared")
sch.cache_read(block_c, 0, "local")
tvm.ir.assert_structural_equal(continuous_cache_read, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise)


def test_cache_read_with_block_predicate():
def test_cache_read_with_block_predicate(use_block_name):
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
block = sch.get_block("consumer")
block = "consumer" if use_block_name else sch.get_block("consumer")
sch.cache_read(block, 0, "shared")
tvm.ir.assert_structural_equal(block_predicate_cache_read, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)


def test_cache_read_non_int32_shape():
def test_cache_read_non_int32_shape(use_block_name):
sch = tir.Schedule(elementwise_shape_int64, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
sch.cache_read(block_b, 0, "global")
tvm.ir.assert_structural_equal(cache_read_shape_int64, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64)


def test_cache_read_fail_multi_producer():
def test_cache_read_fail_multi_producer(use_block_name):
sch = tir.Schedule(func_multi_producer, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_read(block_b, 0, "global")


def test_cache_read_fail_index_out_of_bound():
def test_cache_read_fail_index_out_of_bound(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_read(block_b, 1, "global")


def test_cache_read_fail_invalid_storage_scope():
def test_cache_read_fail_invalid_storage_scope(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_read(block_b, 0, "test_scope")


########## Testcases for cache_write ##########


def test_cache_write_elementwise():
def test_cache_write_elementwise(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_c = sch.get_block("C")
cached_b = sch.cache_write(block_b, 0, "local")
cached_c = sch.cache_write(block_c, 0, "global")
cached_b = sch.cache_write("B" if use_block_name else block_b, 0, "local")
cached_c = sch.cache_write("C" if use_block_name else block_c, 0, "global")
assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
assert sch.get(cached_c) == sch.get(sch.get_block("C_global"))
assert sch.get(block_b) == sch.get(sch.get_block("B"))
Expand All @@ -845,10 +847,10 @@ def test_cache_write_elementwise():
verify_trace_roundtrip(sch=sch, mod=elementwise)


def test_cache_write_under_scope():
def test_cache_write_under_scope(use_block_name):
sch = tir.Schedule(access_under_scope, debug_mask="all")
block_a = sch.get_block("A")
block_b = sch.get_block("B")
block_a = "A" if use_block_name else sch.get_block("A")
block_b = "B" if use_block_name else sch.get_block("B")
block_scope = sch.get_block("scope")
sch.cache_write(block_a, 0, "local")
sch.cache_write(block_b, 0, "global")
Expand All @@ -857,70 +859,70 @@ def test_cache_write_under_scope():
verify_trace_roundtrip(sch=sch, mod=access_under_scope)


def test_cache_write_opaque_access():
def test_cache_write_opaque_access(use_block_name):
sch = tir.Schedule(opaque_access, debug_mask="all")
block_store = sch.get_block("load_store")
block_opaque = sch.get_block("opaque")
block_match_buffer = sch.get_block("match_buffer")
block_store = "load_store" if use_block_name else sch.get_block("load_store")
block_opaque = "opaque" if use_block_name else sch.get_block("opaque")
block_match_buffer = "match_buffer" if use_block_name else sch.get_block("match_buffer")
sch.cache_write(block_store, 0, "global")
sch.cache_write(block_opaque, 0, "global")
sch.cache_write(block_match_buffer, 0, "global")
tvm.ir.assert_structural_equal(cache_write_opaque_access, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=opaque_access)


def test_cache_write_location():
def test_cache_write_location(use_block_name):
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_a = sch.get_block("A")
block_a = "A" if use_block_name else sch.get_block("A")
sch.cache_write(block_a, 0, "global")
tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)


def test_continuous_cache_write():
def test_continuous_cache_write(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
sch.cache_write(block_b, 0, "shared")
sch.cache_write(block_b, 0, "local")
tvm.ir.assert_structural_equal(continuous_cache_write, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise)


def test_cache_write_with_block_predicate():
def test_cache_write_with_block_predicate(use_block_name):
# cache write for intermediate buffer
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
block = sch.get_block("producer")
block = "producer" if use_block_name else sch.get_block("producer")
sch.cache_write(block, 0, "shared")
tvm.ir.assert_structural_equal(block_predicate_cache_write_intermediate_buf, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)
# cache write for external buffer
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
block = sch.get_block("consumer")
block = "consumer" if use_block_name else sch.get_block("consumer")
sch.cache_write(block, 0, "shared")
tvm.ir.assert_structural_equal(block_predicate_cache_write_output_buf, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)


def test_cache_write_fail_multi_producer():
def test_cache_write_fail_multi_producer(use_block_name):
sch = tir.Schedule(func_multi_producer, debug_mask="all")
block_a0 = sch.get_block("A0")
block_a1 = sch.get_block("A1")
block_a0 = "A0" if use_block_name else sch.get_block("A0")
block_a1 = "A1" if use_block_name else sch.get_block("A1")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_write(block_a0, 0, "global")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_write(block_a1, 0, "global")


def test_cache_write_fail_index_out_of_bound():
def test_cache_write_fail_index_out_of_bound(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_write(block_b, 1, "global")


def test_cache_write_fail_invalid_storage_scope():
def test_cache_write_fail_invalid_storage_scope(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
block_b = "B" if use_block_name else sch.get_block("B")
with pytest.raises(tvm.tir.ScheduleError):
sch.cache_write(block_b, 0, "test_scope")

Expand Down
Loading

0 comments on commit af01281

Please sign in to comment.