From d9d498110f8e5d996e89f5611439772d4c655ccb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 19 Sep 2022 13:02:53 -0500 Subject: [PATCH] [TIR][Bugfix] Correct handling of buffer argument when scheduling (#12816) Follow-up from https://github.com/apache/tvm/pull/11269, which allowed schedule arguments of the buffer to be transformed to be specified as a string, or as a `tir::Buffer`. The string handling worked correctly, but the `tir::Buffer` object was handled incorrectly. This commit corrects the handling of `tir::Buffer` arguments when scheduling, and adds a unit test to validate this behavior. --- python/tvm/tir/schedule/schedule.py | 6 +-- .../test_tir_schedule_set_axis_separator.py | 41 +++++++++++++------ 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index b8f696b7a134..27171aca411b 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2390,7 +2390,7 @@ def iter_buffers(): if isinstance(buffer, str): possible_buffers = {} # String lookup requires ensuring that the name is unique - for buffer_index, buffer_index_type, buf in iter_buffers(): + for buffer_index_type, buffer_index, buf in iter_buffers(): if buf.name == buffer: possible_buffers[buf] = (buffer_index_type, buffer_index) @@ -2398,12 +2398,12 @@ def iter_buffers(): assert ( len(possible_buffers) == 1 ), f"Multiple buffers named '{buffer}' in block '{block_name}'" - buffer_obj, (buffer_index, buffer_index_type) = next(iter(possible_buffers.items())) + buffer_obj, (buffer_index_type, buffer_index) = next(iter(possible_buffers.items())) elif isinstance(buffer, Buffer): # Buffer lookup has unique id, can break out early found = False - for buffer_index, buffer_index_type, buffer_obj in iter_buffers(): + for buffer_index_type, buffer_index, buffer_obj in iter_buffers(): if buffer_obj.same_as(buffer): found = True break diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py b/tests/python/unittest/test_tir_schedule_set_axis_separator.py index b432fbb61066..327df33408f2 100644 --- a/tests/python/unittest/test_tir_schedule_set_axis_separator.py +++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py @@ -102,18 +102,25 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer[(128, 128), "flo # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg -use_sugared_transform = tvm.testing.parameter( - by_dict={"set_axis_separators": False, "transform_layout_sugared": True} -) +argument_style = tvm.testing.parameter('set_axis_separators', + 'transform_layout_named', + 'transform_layout_buffer_object', + ) -def test_set_axis_separator(use_sugared_transform): + +def test_set_axis_separator(argument_style): func = element_wise s = tir.Schedule(func, debug_mask='all') - if use_sugared_transform: + if argument_style=='set_axis_separators': s.set_axis_separator(s.get_block("B"), ("write",0), [1]) - else: + elif argument_style=='transform_layout_named': s.transform_layout(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) + elif argument_style =='transform_layout_buffer_object': + B = s.get(s.get_block('B')).writes[0].buffer + s.transform_layout(block='B', buffer=B, index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) + else: + raise ValueError(f'Unexpected argument_style: {argument_style}') tvm.ir.assert_structural_equal(element_wise_set_axis_separator, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) @@ -128,28 +135,38 @@ def test_set_scope_fail_on_index_out_of_bound(): s.set_axis_separator(s.get_block("B"), ("read",-1),[1]) -def test_set_axis_separator_input_buffer(use_sugared_transform): +def test_set_axis_separator_input_buffer(argument_style): func = element_wise s = tir.Schedule(func, debug_mask='all') - if use_sugared_transform: + if argument_style=='set_axis_separators': + s.set_axis_separator(s.get_block("B"), ("read",0), [1]) + elif argument_style=='transform_layout_named': s.transform_layout(block='B', buffer='A', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) + elif argument_style =='transform_layout_buffer_object': + A = s.get(s.get_block('B')).reads[0].buffer + s.transform_layout(block='B', buffer=A, index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) else: - s.set_axis_separator(s.get_block("B"), ("read",0), [1]) + raise ValueError(f'Unexpected argument_style: {argument_style}') tvm.ir.assert_structural_equal(element_wise_set_axis_separator_input_buffer, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) -def test_set_axis_separator_subregion(use_sugared_transform): +def test_set_axis_separator_subregion(argument_style): func = element_wise_subregion_match s = tir.Schedule(func, debug_mask='all') - if use_sugared_transform: + if argument_style=='set_axis_separators': + s.set_axis_separator(s.get_block("B"), ("write",0), [1]) + elif argument_style=='transform_layout_named': s.transform_layout(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) + elif argument_style =='transform_layout_buffer_object': + B = s.get(s.get_block('B')).writes[0].buffer + s.transform_layout(block='B', buffer=B, index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j]) else: - s.set_axis_separator(s.get_block("B"), ("write",0), [1]) + raise ValueError(f'Unexpected argument_style: {argument_style}') tvm.ir.assert_structural_equal(element_wise_subregion_match_set_axis_separator, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func)