Skip to content

Commit

Permalink
add more comment
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 28, 2022
1 parent 183b4cf commit eeb4a6d
Showing 1 changed file with 58 additions and 12 deletions.
70 changes: 58 additions & 12 deletions tests/python/unittest/test_meta_schedule_tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm import relay, tir
from tvm.contrib import graph_executor
from tvm.ir import IRModule
from tvm.tir.schedule import BlockRV, Schedule
from tvm.tir.schedule.trace import Trace
from tvm.meta_schedule import ReplayTraceConfig
from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload, JSONDatabase
Expand Down Expand Up @@ -325,6 +326,10 @@ def get_output(data, lib):
assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4)


# Tensorized intrinsic description and VNNI-specific implementation.
# Equivalent to the ones in topi/x86/tensor_intrin.py


@T.prim_func
def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
Expand All @@ -344,7 +349,7 @@ def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:


@T.prim_func
def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
def dot_product_vnni(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
Expand Down Expand Up @@ -372,10 +377,16 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"


def schedule_dense(block, M, do_tune, sch):
post_blocks = sch.get_consumers(block)
def schedule_dense(dense_block, M, do_tune, sch):
"""
Manually schedule a dense block, created from TE compute op via CreatePrimFunc,
using VNNI instruction.
"""
post_blocks = sch.get_consumers(dense_block)

if len(post_blocks) > 0:
# Fuse all intermediate post ops into the last op.
# This is equivalent to the traverse_inline function used in TE schedules.
while True:
next_post_blocks = []
for post_block in post_blocks:
Expand All @@ -394,8 +405,8 @@ def schedule_dense(block, M, do_tune, sch):

post_blocks = next_post_blocks
else:
a_y, a_x, _ = sch.get_loops(block)[-3:]
outer_block = block
a_y, a_x, _ = sch.get_loops(dense_block)[-3:]
outer_block = dense_block

if do_tune:
y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128)
Expand All @@ -407,16 +418,19 @@ def schedule_dense(block, M, do_tune, sch):
sch.reorder(a_yo, a_xo, a_yi, a_xi)
fused = sch.fuse(a_yo, a_xo)

if outer_block != block:
if outer_block != dense_block:
# Handle the case when dense is fused with post ops.
sch.vectorize(a_xi)
sch.compute_at(block, a_yi)
sch.compute_at(dense_block, a_yi)

a_xi, a_k = sch.get_loops(block)[-2:]
a_xi, a_k = sch.get_loops(dense_block)[-2:]
a_ko, a_ki = sch.split(a_k, factors=[None, 4])
sch.reorder(a_ko, a_xi, a_ki)

# We need to parallelize before decompose_reduction, otherwise the so-called "Compact dataflow"
# condition is violated.
sch.parallel(fused)
dec = sch.decompose_reduction(block, a_ko)
dec = sch.decompose_reduction(dense_block, a_ko)

init_loop = sch.get_loops(dec)[-1]
sch.vectorize(init_loop)
Expand Down Expand Up @@ -462,6 +476,7 @@ def manual_tir_common(do_tune=False):

extracted_tasks = extract_task_from_relay(relay_mod, target, params)

# Filter out tasks that we don't intend to schedule / tune with TIR.
tune_tasks = list(
filter(
lambda task: "dense" in task.task_name,
Expand All @@ -475,6 +490,8 @@ def manual_tir_common(do_tune=False):
num_trials_per_iter=64,
num_trials_total=64,
)
# postprocs=lambda: [] is important to prevent default post processors from
# tampering with the manual schedule.
database = tune_extracted_tasks(
tune_tasks, target, config, work_dir=work_dir, postprocs=lambda: []
)
Expand All @@ -490,11 +507,15 @@ def manual_tir_common(do_tune=False):

sch = tvm.tir.Schedule(mod)
block = sch.get_block("compute")

# Looks up schedule_rule annotation. See the comment in test_tune_relay_manual_tir_vnni().
schedule_rule = sch.get(block).annotations["schedule_rule"]

if "dense_vnni" in schedule_rule:
schedule_dense(block, M, False, sch)

# [0.0] is for dummy measurement. There is only one tuning record so ApplyHistoryBest
# will always have only one option.
tune_rec = TuningRecord(sch.trace, [0.0], workload, tvm.target.Target(target), [])

database.commit_tuning_record(tune_rec)
Expand Down Expand Up @@ -528,12 +549,37 @@ def manual_tir_common(do_tune=False):

@pytest.mark.skip("Requires cascadelake")
def test_tune_relay_manual_tir_vnni():
tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, dot_product_intrin)
# Register a pair of an intrinsic description for 16x4 dot product, and its
# VNNI-specific implementation.
tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, dot_product_vnni)

manual_tir_common(do_tune=False)

def schedule_rule_dense_vnni(sch, block):
schedule_dense(block, None, True, sch)
"""
We can inject and apply a custom TIR scheduling to a TE compute of interest, using
the "schedule_rule" annotation. For example, in topi/x86/dense.py we have the following
declaration for int8 dense targeting the VNNI instruction.
C = te.compute(
...
attrs={"schedule_rule": "meta_schedule.dense_vnni"},
)
When the meta scheduler encounters a TensorIR block with the "schedule_rule" annotation,
it looks up the packed func registry for a function that is associated with the given schedule rule
key ("meta_schedule.dense_vnni" in this example). The signature of such custom schedule functions
must be
(tir.schedule.Schedule, tir.schedule.BlockRV) -> [tir.schedule.Schedule].
The BlockRV argument corresponds to the TE compute annotated with "schedule_rlue".
The relevant code is in meta_schedule/space_generator/post_order_apply.cc.
"""

def schedule_rule_dense_vnni(sch: Schedule, dense_block: BlockRV):
schedule_dense(dense_block, None, True, sch)
return [sch]

register_func("meta_schedule.dense_vnni", schedule_rule_dense_vnni)
Expand Down

0 comments on commit eeb4a6d

Please sign in to comment.