Skip to content

Commit

Permalink
Addressing comments -3
Browse files Browse the repository at this point in the history
Change-Id: I207872757473210681d9db04bfdcd2c5e6deaa05
  • Loading branch information
Giuseppe Rossini committed Dec 15, 2020
1 parent c59d4b7 commit 881dde0
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 11 deletions.
47 changes: 36 additions & 11 deletions python/tvm/driver/tvmc/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,16 @@ def add_tune_parser(subparsers):
default=0,
help="the thread numbers of a warp",
)
auto_scheduler_group.add_argument(
"--include-simple-tasks",
help="Whether to extract simple tasks that do not include complicated ops",
action="store_true",
)
auto_scheduler_group.add_argument(
"--log-estimated-latency",
help="Whether to log the estimated latency to the file after tuning a task",
action="store_true",
)
auto_tuning_group = parser.add_argument_group(
"Autotuning options",
"Autotuning options, used when the autoscheduler is not enabled",
Expand Down Expand Up @@ -279,6 +289,7 @@ def drive_tune(args):
target_host=args.target_host,
alter_layout=args.desired_layout,
hardware_params=hardware_params,
include_simple_tasks=args.include_simple_tasks,
)

# Create the autoscheduler tuning options
Expand All @@ -291,10 +302,7 @@ def drive_tune(args):

# Schedule the tasks (i.e., produce a schedule for each task)
schedule_tasks(
tasks,
weights,
tuning_options,
args.tuning_records,
tasks, weights, tuning_options, args.tuning_records, args.log_estimated_latency
)
else:
tasks = autotvm_get_tuning_tasks(
Expand Down Expand Up @@ -356,7 +364,13 @@ def autotvm_get_tuning_tasks(mod, params, target, target_host=None, alter_layout


def autoscheduler_get_tuning_tasks(
mod, params, target, target_host=None, alter_layout=None, hardware_params=None
mod,
params,
target,
target_host=None,
alter_layout=None,
hardware_params=None,
include_simple_tasks=False,
):
"""Get the autoscheduler tuning tasks for a given relay module.
Expand Down Expand Up @@ -389,17 +403,19 @@ def autoscheduler_get_tuning_tasks(

# Extract the tasks
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"], params, target=target, target_host=target_host, hardware_params=hardware_params
mod["main"],
params,
target=target,
target_host=target_host,
hardware_params=hardware_params,
include_simple_tasks=include_simple_tasks,
)

return tasks, task_weights


def schedule_tasks(
tasks,
task_weights,
tuning_options,
tuning_records=None,
tasks, task_weights, tuning_options, tuning_records=None, log_estimated_latency=False
):
"""Generate the schedules for the different tasks (i.e., subgraphs) contained in the module.
Store the schedules in a json file that will be used later by the compiler.
Expand All @@ -415,9 +431,18 @@ def schedule_tasks(
tuning_records : str, optional
The json file used to preload the autoscheduler
"""
if not log_estimated_latency:
callbacks = [auto_scheduler.task_scheduler.PrintTableInfo()]
else:
callbacks = [
auto_scheduler.task_scheduler.PrintTableInfo(),
auto_scheduler.task_scheduler.LogEstimatedLatency(("total_latency.tsv")),
]

# Create the scheduler
tuner = auto_scheduler.TaskScheduler(tasks, task_weights, load_log_file=tuning_records)
tuner = auto_scheduler.TaskScheduler(
tasks, task_weights, load_log_file=tuning_records, callbacks=callbacks
)

# Tune the tasks
tuner.tune(tuning_options)
Expand Down
101 changes: 101 additions & 0 deletions tests/python/driver/tvmc/test_autoscheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import pytest
import os
import tarfile

from os import path

from tvm import auto_scheduler
from tvm.driver import tvmc


def _get_tasks(model):
mod, params = tvmc.frontends.load_model(model)
tasks, weights = tvmc.autotuner.autoscheduler_get_tuning_tasks(mod, params, "llvm")
return (tasks, weights)


def _autoscheduler_test_helper(
model, tmpdir_name, tasks_weights=None, early_stopping=1, tuning_records=None
):
tasks, weights = tasks_weights if tasks_weights else _get_tasks(model)
log_file = os.path.join(tmpdir_name, "autoscheduler.json")

tuning_options = auto_scheduler.TuningOptions(
num_measure_trials=1,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
runner="local",
builder="local",
verbose=0,
early_stopping=early_stopping,
)

tvmc.autotuner.schedule_tasks(tasks[:1], weights[:1], tuning_options, tuning_records)

# testing whether the log file was produced
assert path.exists(log_file), "autoscheduler log file should exist"

with auto_scheduler.ApplyHistoryBest(log_file) as best:
assert isinstance(
best, auto_scheduler.dispatcher.ApplyHistoryBest
), "unable to load the best results of tuning"

return log_file


def test_get_tuning_tasks(onnx_resnet50):
pytest.importorskip("onnx")

tasks, weights = _get_tasks(onnx_resnet50)
expected_task_type = auto_scheduler.SearchTask

assert type(tasks) is list
assert len(tasks) > 0
assert all([type(x) is expected_task_type for x in tasks]) is True


def test_tune_tasks(onnx_resnet50, tmpdir_factory):
pytest.importorskip("onnx")

tmpdir_name = tmpdir_factory.mktemp("data")
_autoscheduler_test_helper(onnx_resnet50, tmpdir_name)


def test_tune_tasks__tuning_records(onnx_resnet50, tmpdir_factory):
pytest.importorskip("onnx")

tmpdir_name = tmpdir_factory.mktemp("data")
output_log_phase_1 = _autoscheduler_test_helper(onnx_resnet50, tmpdir_name)

# Exercises transfer learning by making sure a previous log exists
_autoscheduler_test_helper(onnx_resnet50, tmpdir_name, tuning_records=output_log_phase_1)


def test_tune_tasks__no_early_stopping(onnx_resnet50, tmpdir_factory):
pytest.importorskip("onnx")

tmpdir_name = tmpdir_factory.mktemp("data")
_autoscheduler_test_helper(onnx_resnet50, tmpdir_name, tasks_weights=None, early_stopping=None)


def test_tune_tasks__no_tuning_records(onnx_resnet50, tmpdir_factory):
pytest.importorskip("onnx")

tmpdir_name = tmpdir_factory.mktemp("data")
_autoscheduler_test_helper(onnx_resnet50, tmpdir_name, tasks_weights=None, tuning_records=None)

0 comments on commit 881dde0

Please sign in to comment.