Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TIR][Bugfix] Correct handling of buffer argument when scheduling (ap…
Browse files Browse the repository at this point in the history
…ache#12816)

Follow-up from apache#11269, which allowed
schedule arguments of the buffer to be transformed to be specified as
a string, or as a `tir::Buffer`.  The string handling worked
correctly, but the `tir::Buffer` object was handled incorrectly.  This
commit corrects the handling of `tir::Buffer` arguments when
scheduling, and adds a unit test to validate this behavior.
  • Loading branch information
Lunderberg authored and xinetzone committed Nov 25, 2022
1 parent f2d6043 commit d9d4981
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
6 changes: 3 additions & 3 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2390,20 +2390,20 @@ def iter_buffers():
if isinstance(buffer, str):
possible_buffers = {}
# String lookup requires ensuring that the name is unique
for buffer_index, buffer_index_type, buf in iter_buffers():
for buffer_index_type, buffer_index, buf in iter_buffers():
if buf.name == buffer:
possible_buffers[buf] = (buffer_index_type, buffer_index)

assert possible_buffers, f"Could not find buffer '{buffer}' in block '{block_name}'"
assert (
len(possible_buffers) == 1
), f"Multiple buffers named '{buffer}' in block '{block_name}'"
buffer_obj, (buffer_index, buffer_index_type) = next(iter(possible_buffers.items()))
buffer_obj, (buffer_index_type, buffer_index) = next(iter(possible_buffers.items()))

elif isinstance(buffer, Buffer):
# Buffer lookup has unique id, can break out early
found = False
for buffer_index, buffer_index_type, buffer_obj in iter_buffers():
for buffer_index_type, buffer_index, buffer_obj in iter_buffers():
if buffer_obj.same_as(buffer):
found = True
break
Expand Down
41 changes: 29 additions & 12 deletions tests/python/unittest/test_tir_schedule_set_axis_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,25 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer[(128, 128), "flo

# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg

use_sugared_transform = tvm.testing.parameter(
by_dict={"set_axis_separators": False, "transform_layout_sugared": True}
)
argument_style = tvm.testing.parameter('set_axis_separators',
'transform_layout_named',
'transform_layout_buffer_object',
)

def test_set_axis_separator(use_sugared_transform):

def test_set_axis_separator(argument_style):
func = element_wise
s = tir.Schedule(func, debug_mask='all')

if use_sugared_transform:
if argument_style=='set_axis_separators':
s.set_axis_separator(s.get_block("B"), ("write",0), [1])
else:
elif argument_style=='transform_layout_named':
s.transform_layout(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j])
elif argument_style =='transform_layout_buffer_object':
B = s.get(s.get_block('B')).writes[0].buffer
s.transform_layout(block='B', buffer=B, index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j])
else:
raise ValueError(f'Unexpected argument_style: {argument_style}')

tvm.ir.assert_structural_equal(element_wise_set_axis_separator, s.mod["main"])
verify_trace_roundtrip(sch=s, mod=func)
Expand All @@ -128,28 +135,38 @@ def test_set_scope_fail_on_index_out_of_bound():
s.set_axis_separator(s.get_block("B"), ("read",-1),[1])


def test_set_axis_separator_input_buffer(use_sugared_transform):
def test_set_axis_separator_input_buffer(argument_style):
func = element_wise
s = tir.Schedule(func, debug_mask='all')

if use_sugared_transform:
if argument_style=='set_axis_separators':
s.set_axis_separator(s.get_block("B"), ("read",0), [1])
elif argument_style=='transform_layout_named':
s.transform_layout(block='B', buffer='A', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j])
elif argument_style =='transform_layout_buffer_object':
A = s.get(s.get_block('B')).reads[0].buffer
s.transform_layout(block='B', buffer=A, index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j])
else:
s.set_axis_separator(s.get_block("B"), ("read",0), [1])
raise ValueError(f'Unexpected argument_style: {argument_style}')


tvm.ir.assert_structural_equal(element_wise_set_axis_separator_input_buffer, s.mod["main"])
verify_trace_roundtrip(sch=s, mod=func)


def test_set_axis_separator_subregion(use_sugared_transform):
def test_set_axis_separator_subregion(argument_style):
func = element_wise_subregion_match
s = tir.Schedule(func, debug_mask='all')

if use_sugared_transform:
if argument_style=='set_axis_separators':
s.set_axis_separator(s.get_block("B"), ("write",0), [1])
elif argument_style=='transform_layout_named':
s.transform_layout(block='B', buffer='B', index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j])
elif argument_style =='transform_layout_buffer_object':
B = s.get(s.get_block('B')).writes[0].buffer
s.transform_layout(block='B', buffer=B, index_map=lambda i,j: [i,IndexMap.AXIS_SEPARATOR,j])
else:
s.set_axis_separator(s.get_block("B"), ("write",0), [1])
raise ValueError(f'Unexpected argument_style: {argument_style}')

tvm.ir.assert_structural_equal(element_wise_subregion_match_set_axis_separator, s.mod["main"])
verify_trace_roundtrip(sch=s, mod=func)
Expand Down

0 comments on commit d9d4981

Please sign in to comment.