From 340753fa70e3f6bba6dddbb0a08d58ba223ac271 Mon Sep 17 00:00:00 2001 From: "Kathryn (Jinqi) Chen" <65606304+Kathryn-cat@users.noreply.github.com> Date: Fri, 17 Jun 2022 11:55:39 -0700 Subject: [PATCH] [MetaSchedule] Distributed Measurement (#11683) This PR includes the distributed measurement of tuning candidates using builder and async runner, as well as some auxiliary functions. It enables multiple builders and multiple runners with a tracker connecting in between. The hierarchy of files in the database can be further compacted to make the database more concise. --- include/tvm/meta_schedule/database.h | 27 +++ python/tvm/meta_schedule/database/database.py | 34 +++ .../meta_schedule/database/memory_database.py | 3 + .../testing/dataset_sample_candidates.py | 23 +- .../testing/distributed_measure_candidates.py | 198 ++++++++++++++++++ python/tvm/meta_schedule/tune_context.py | 44 ++++ src/meta_schedule/database/database.cc | 22 +- src/meta_schedule/database/json_database.cc | 9 + src/meta_schedule/tune_context.cc | 14 +- 9 files changed, 361 insertions(+), 13 deletions(-) create mode 100644 python/tvm/meta_schedule/testing/distributed_measure_candidates.py diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 37a315bf744e..b22d8beddbab 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -98,6 +98,9 @@ struct WorkloadEqual { } }; +/*! \brief The class of measure candidates. */ +class MeasureCandidate; + /*! \brief The class of tuning records. */ class TuningRecordNode : public runtime::Object { public: @@ -123,6 +126,9 @@ class TuningRecordNode : public runtime::Object { static constexpr const char* _type_key = "meta_schedule.TuningRecord"; TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object); + /*! \brief Construct the measure candidate given the initial IR module and trace + * stored in the tuning record. */ + MeasureCandidate AsMeasureCandidate() const; /*! * \brief Export the tuning record to a JSON string. * \return An array containing the trace, running secs, serialized target, and @@ -187,6 +193,11 @@ class DatabaseNode : public runtime::Object { * \return An array of top K tuning records for the given workload. */ virtual Array GetTopK(const Workload& workload, int top_k) = 0; + /*! + * \brief Get all tuning records from the database. + * \return An Array of all the tuning records in the database. + */ + virtual Array GetAllTuningRecords() = 0; /*! * \brief Get the size of the database. * \return The size of the database. @@ -224,6 +235,11 @@ class PyDatabaseNode : public DatabaseNode { * \return An array of top K tuning records for the given workload. */ using FGetTopK = runtime::TypedPackedFunc(const Workload&, int)>; + /*! + * \brief The function type of `GetAllTuningRecords` method. + * \return An Array of all the tuning records in the database. + */ + using FGetAllTuningRecords = runtime::TypedPackedFunc()>; /*! * \brief The function type of `Size` method. * \return The size of the database. @@ -238,6 +254,8 @@ class PyDatabaseNode : public DatabaseNode { FCommitTuningRecord f_commit_tuning_record; /*! \brief The packed function to the `GetTopK` function. */ FGetTopK f_get_top_k; + /*! \brief The packed function to the `GetAllTuningRecords` function. */ + FGetAllTuningRecords f_get_all_tuning_records; /*! \brief The packed function to the `Size` function. */ FSize f_size; @@ -249,6 +267,7 @@ class PyDatabaseNode : public DatabaseNode { // `f_commit_workload` is not visited // `f_commit_tuning_record` is not visited // `f_get_top_k` is not visited + // `f_get_all_tuning_records` is not visited // `f_size` is not visited } @@ -273,6 +292,12 @@ class PyDatabaseNode : public DatabaseNode { return f_get_top_k(workload, top_k); } + Array GetAllTuningRecords() final { + ICHECK(f_get_all_tuning_records != nullptr) + << "PyDatabase's GetAllTuningRecords method not implemented!"; + return f_get_all_tuning_records(); + } + int64_t Size() final { ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!"; return f_size(); @@ -302,6 +327,7 @@ class Database : public runtime::ObjectRef { * \param f_commit_workload The packed function of `CommitWorkload`. * \param f_commit_tuning_record The packed function of `CommitTuningRecord`. * \param f_get_top_k The packed function of `GetTopK`. + * \param f_get_all_tuning_records The packed function of `GetAllTuningRecords`. * \param f_size The packed function of `Size`. * \return The created database. */ @@ -309,6 +335,7 @@ class Database : public runtime::ObjectRef { PyDatabaseNode::FCommitWorkload f_commit_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, PyDatabaseNode::FGetTopK f_get_top_k, + PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records, PyDatabaseNode::FSize f_size); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode); }; diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 802a739e6958..0c11f77591cc 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -115,6 +115,17 @@ def __init__( # type: ignore # pylint: disable=too-many-arguments args_info, ) + def as_measure_candidate(self) -> Any: + """Generate a measure candidate given an initial IR module and a trace + stored in the tuning record. + + Returns + ------- + candidate : MeasureCandidate + A generated candidate. + """ + return _ffi_api.TuningRecordAsMeasureCandidate(self) # type: ignore # pylint: disable=no-member + def as_json(self) -> Any: """Export the tuning record to a JSON string. @@ -203,6 +214,16 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: """ return _ffi_api.DatabaseGetTopK(self, workload, top_k) # type: ignore # pylint: disable=no-member + def get_all_tuning_records(self) -> List[TuningRecord]: + """Get all the tuning records from the database. + + Returns + ------- + tuning_records : List[TuningRecord] + All tuning records from the database. + """ + return _ffi_api.DatabaseGetAllTuningRecords(self) # type: ignore # pylint: disable=no-member + def __len__(self) -> int: """Get the number of records in the database. @@ -229,6 +250,7 @@ def __init__( f_commit_workload: Callable = None, f_commit_tuning_record: Callable = None, f_get_top_k: Callable = None, + f_get_all_tuning_records: Callable = None, f_size: Callable = None, ): """Constructor.""" @@ -239,6 +261,7 @@ def __init__( f_commit_workload, f_commit_tuning_record, f_get_top_k, + f_get_all_tuning_records, f_size, ) @@ -258,6 +281,7 @@ class PyDatabase: "commit_workload", "commit_tuning_record", "get_top_k", + "get_all_tuning_records", "__len__", ], } @@ -317,6 +341,16 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: """ raise NotImplementedError + def get_all_tuning_records(self) -> List[TuningRecord]: + """Get all the tuning records from the database. + + Returns + ------- + tuning_records : List[TuningRecord] + All tuning records from the database. + """ + raise NotImplementedError + def __len__(self) -> int: """Get the number of records in the database. diff --git a/python/tvm/meta_schedule/database/memory_database.py b/python/tvm/meta_schedule/database/memory_database.py index 6d10e4b5272a..95d937cc77aa 100644 --- a/python/tvm/meta_schedule/database/memory_database.py +++ b/python/tvm/meta_schedule/database/memory_database.py @@ -56,6 +56,9 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: ) )[: int(top_k)] + def get_all_tuning_records(self) -> List[TuningRecord]: + return self.records + def __len__(self) -> int: return len(self.records) diff --git a/python/tvm/meta_schedule/testing/dataset_sample_candidates.py b/python/tvm/meta_schedule/testing/dataset_sample_candidates.py index c80d78173e2e..35b872e7351e 100644 --- a/python/tvm/meta_schedule/testing/dataset_sample_candidates.py +++ b/python/tvm/meta_schedule/testing/dataset_sample_candidates.py @@ -103,6 +103,14 @@ def sample_candidates(task, task_name, model_name): ------- None """ + candidate_path = os.path.join( + args.candidate_cache_dir, model_name, task_name + "_candidates.json" + ) + workload_path = os.path.join(args.candidate_cache_dir, model_name, task_name + "_workload.json") + database = ms.database.JSONDatabase( + path_workload=workload_path, + path_tuning_record=candidate_path, + ) sample_init_population = tvm.get_global_func( "meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation" ) @@ -128,7 +136,7 @@ def sample_candidates(task, task_name, model_name): context.initialize() context.pre_tuning( context.generate_design_space(), - database=ms.database.MemoryDatabase(), # type: ignore + database=database, cost_model=ms.cost_model.RandomModel(), # type: ignore ) @@ -148,16 +156,9 @@ def sample_candidates(task, task_name, model_name): all_states = all_states[: args.num_samples_per_task] workload = ms.database.Workload(context.mod) - file_path = os.path.join(args.candidate_cache_dir, model_name, task_name + ".json") - with open(file_path, "w", encoding="utf8") as file: - for i, state in enumerate(all_states): - tuning_record = ms.database.TuningRecord(state.trace, workload) - json_str = json.dumps(tuning_record.as_json()) - assert "\n" not in json_str, "Failed to generate single line string." - if i == len(all_states) - 1: - file.write(json_str) - else: - file.write(json_str + "\n") + database.commit_workload(context.mod) + for state in all_states: + database.commit_tuning_record(ms.database.TuningRecord(state.trace, workload)) args = _parse_args() # pylint: disable=invalid-name diff --git a/python/tvm/meta_schedule/testing/distributed_measure_candidates.py b/python/tvm/meta_schedule/testing/distributed_measure_candidates.py new file mode 100644 index 000000000000..8e646c484672 --- /dev/null +++ b/python/tvm/meta_schedule/testing/distributed_measure_candidates.py @@ -0,0 +1,198 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +import argparse +import glob +import os + +from tqdm import tqdm # type: ignore +from tvm import meta_schedule as ms +from tvm.target import Target + + +def _parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--candidate_cache_dir", type=str, help="Please provide the full path to the candidates." + ) + parser.add_argument( + "--result_cache_dir", type=str, help="Please provide the full path to the result database." + ) + parser.add_argument( + "--target", + type=str, + default="nvidia/nvidia-v100", + help="Please specify the target hardware for tuning context.", + ) + parser.add_argument( + "--rpc_host", type=str, help="Please provide the private IPv4 address for the tracker." + ) + parser.add_argument( + "--rpc_port", type=int, default=4445, help="Please provide the port for the tracker." + ) + parser.add_argument( + "--rpc_key", + type=str, + default="p3.2xlarge", + help="Please provide the key for the rpc servers.", + ) + parser.add_argument( + "--builder_timeout_sec", + type=int, + default=10, + help="The time for the builder session to time out.", + ) + parser.add_argument( + "--min_repeat_ms", type=int, default=100, help="The time for preheating the gpu." + ) + parser.add_argument( + "--runner_timeout_sec", + type=int, + default=100, + help="The time for the runner session to time out.", + ) + parser.add_argument( + "--cpu_flush", type=bool, default=False, help="Whether to enable cpu cache flush or not." + ) + parser.add_argument( + "--batch_size", + type=int, + default=128, + help="The batch size of candidates sent to builder and runner each time.", + ) + return parser.parse_args() + + +# pylint: disable=too-many-locals +def measure_candidates(database, builder, runner): + """Send the candidates to builder and runner for distributed measurement, + and save the results in a new json database. + + Parameters + ---------- + database : JSONDatabase + The database for candidates to be measured. + builder : Builder + The builder for building the candidates. + runner : Runner + The runner for measuring the candidates. + + Returns + ------- + None + """ + candidates, runner_results, build_fail_indices, run_fail_indices = [], [], [], [] + context = ms.TuneContext(target=Target(args.target)) + tuning_records = database.get_all_tuning_records() + for record in tuning_records: + candidates.append(record.as_measure_candidate()) + with ms.Profiler() as profiler: + for idx in range(0, len(candidates), args.batch_size): + batch_candidates = candidates[idx : idx + args.batch_size] + context._set_measure_candidates(batch_candidates) # pylint: disable=protected-access + with ms.Profiler.timeit("build"): + context._send_to_builder(builder) # pylint: disable=protected-access + with ms.Profiler.timeit("run"): + context._send_to_runner(runner) # pylint: disable=protected-access + batch_runner_results = context._join() # pylint: disable=protected-access + runner_results.extend(batch_runner_results) + for i, result in enumerate(context.builder_results): + if result.error_msg is None: + ms.utils.remove_build_dir(result.artifact_path) + else: + build_fail_indices.append(i + idx) + context._clear_measure_state() # pylint: disable=protected-access + + model_name, workload_name = database.path_workload.split("/")[-2:] + record_name = database.path_tuning_record.split("/")[-1] + new_database = ms.database.JSONDatabase( + path_workload=os.path.join(args.result_cache_dir, model_name, workload_name), + path_tuning_record=os.path.join(args.result_cache_dir, model_name, record_name), + ) + workload = tuning_records[0].workload + new_database.commit_workload(workload.mod) + for i, (record, result) in enumerate(zip(tuning_records, runner_results)): + if result.error_msg is None: + new_database.commit_tuning_record( + ms.database.TuningRecord( + trace=record.trace, + workload=workload, + run_secs=[v.value for v in result.run_secs], + target=Target(args.target), + ) + ) + else: + run_fail_indices.append(i) + fail_indices_name = workload_name.replace("_workload.json", "_failed_indices.txt") + with open( + os.path.join(args.result_cache_dir, model_name, fail_indices_name), "w", encoding="utf8" + ) as file: + file.write(" ".join([str(n) for n in run_fail_indices])) + print( + f"Builder time: {profiler.get()['build']}, Runner time: {profiler.get()['run']}\n\ + Failed number of builds: {len(build_fail_indices)},\ + Failed number of runs: {len(run_fail_indices)}" + ) + + +args = _parse_args() # pylint: disable=invalid-name + + +def main(): + builder = ms.builder.LocalBuilder(timeout_sec=args.builder_timeout_sec) + runner = ms.runner.RPCRunner( + rpc_config=ms.runner.RPCConfig( + tracker_host=args.rpc_host, + tracker_port=args.rpc_port, + tracker_key=args.rpc_key, + session_timeout_sec=args.runner_timeout_sec, + ), + evaluator_config=ms.runner.EvaluatorConfig( + number=3, + repeat=1, + min_repeat_ms=args.min_repeat_ms, + enable_cpu_cache_flush=args.cpu_flush, + ), + max_workers=os.cpu_count(), + ) + if not os.path.isdir(args.candidate_cache_dir): + raise Exception("Please provide a correct candidate cache dir.") + try: + os.makedirs(args.result_cache_dir, exist_ok=True) + except OSError: + print(f"Directory {args.result_cache_dir} cannot be created successfully.") + model_dirs = glob.glob(os.path.join(args.candidate_cache_dir, "*")) + for model_dir in model_dirs: + model_name = model_dir.split("/")[-1] + os.makedirs(os.path.join(args.result_cache_dir, model_name), exist_ok=True) + all_tasks = glob.glob(os.path.join(model_dir, "*.json")) + workload_paths = [] + for path in all_tasks: + if path.endswith("_workload.json"): + workload_paths.append(path) + for workload_path in tqdm(workload_paths): + candidate_path = workload_path.replace("_workload.json", "_candidates.json") + database = ms.database.JSONDatabase( + path_workload=workload_path, + path_tuning_record=candidate_path, + ) + measure_candidates(database, builder, runner) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index b7975e7b2c4e..30c726ded25b 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -171,6 +171,50 @@ def __init__( ) _ffi_api.TuneContextInitialize(self) # type: ignore # pylint: disable=no-member + def _set_measure_candidates(self, candidates): + """Set candidates in a tuning context. + + Parameters + ---------- + candidates : List[MeasureCandidate] + A list of measure candidates for the tuning context. + """ + _ffi_api.TuneContextSetMeasureCandidates(self, candidates) # type: ignore # pylint: disable=no-member + + def _send_to_builder(self, builder): + """Send candidates to builder. + + Parameters + ---------- + builder : Builder + The builder for building the candidates. + """ + _ffi_api.TuneContextSendToBuilder(self, builder) # type: ignore # pylint: disable=no-member + + def _send_to_runner(self, runner): + """Send candidates to runner. + + Parameters + ---------- + runner : Runner + The runner for running the candidates. + """ + _ffi_api.TuneContextSendToRunner(self, runner) # type: ignore # pylint: disable=no-member + + def _join(self): + """Join the runner processes. + + Returns + ------- + result : List[RunnerResult] + The runner results. + """ + return _ffi_api.TuneContextJoin(self) # type: ignore # pylint: disable=no-member + + def _clear_measure_state(self): + """Clear the measure states.""" + _ffi_api.TuneContextClearMeasureState(self) # type: ignore # pylint: disable=no-member + def generate_design_space(self) -> List[Schedule]: """Generate design spaces given a module. diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index 9905ff73c792..5adff4998494 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -85,6 +85,19 @@ TuningRecord::TuningRecord(tir::Trace trace, Workload workload, Optionaldata_ = n; } +MeasureCandidate TuningRecordNode::AsMeasureCandidate() const { + tir::Schedule sch = + tir::Schedule::Traced(workload->mod, -1, 0, tir::ScheduleErrorRenderLevel::kDetail); + trace->ApplyToSchedule(sch, false, nullptr); + tir::PrimFunc func; + for (const auto& kv : sch->mod()->functions) { + func = Downcast(kv.second); + } + Array args_info = ArgInfo::FromPrimFunc(func); + MeasureCandidate candidate = MeasureCandidate(sch, args_info); + return candidate; +} + ObjectRef TuningRecordNode::AsJSON() const { Optional> json_args_info{nullptr}; Optional json_target{nullptr}; @@ -152,12 +165,15 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, PyDatabaseNode::FCommitWorkload f_commit_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, - PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FSize f_size) { + PyDatabaseNode::FGetTopK f_get_top_k, + PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records, + PyDatabaseNode::FSize f_size) { ObjectPtr n = make_object(); n->f_has_workload = f_has_workload; n->f_commit_workload = f_commit_workload; n->f_commit_tuning_record = f_commit_tuning_record; n->f_get_top_k = f_get_top_k; + n->f_get_all_tuning_records = f_get_all_tuning_records; n->f_size = f_size; return Database(n); } @@ -179,6 +195,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord") Optional target, Optional> args_info) { return TuningRecord(trace, workload, run_secs, target, args_info); }); +TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsMeasureCandidate") + .set_body_method(&TuningRecordNode::AsMeasureCandidate); TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON") .set_body_method(&TuningRecordNode::AsJSON); TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); @@ -190,6 +208,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord") .set_body_method(&DatabaseNode::CommitTuningRecord); TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK") .set_body_method(&DatabaseNode::GetTopK); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords") + .set_body_method(&DatabaseNode::GetAllTuningRecords); TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method(&DatabaseNode::Size); TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase); diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 4f5bd9b13613..9bb7ee1027b9 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -156,6 +156,15 @@ class JSONDatabaseNode : public DatabaseNode { return results; } + Array GetAllTuningRecords() { + Array results; + results.reserve(Size()); + for (const TuningRecord& record : this->tuning_records_) { + results.push_back(record); + } + return results; + } + int64_t Size() { return tuning_records_.size(); } }; diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 0c70dcf5c406..57b2344c6f8d 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -142,7 +142,9 @@ Array TuneContextNode::_Join() { results.push_back(future->Result()); } } - this->search_strategy.value()->NotifyRunnerResults(this->measure_candidates.value(), results); + if (this->search_strategy.defined()) { + this->search_strategy.value()->NotifyRunnerResults(this->measure_candidates.value(), results); + } ICHECK(this->measure_candidates.defined()); ICHECK(this->builder_results.defined()); ICHECK_EQ(results.size(), this->measure_candidates.value().size()); @@ -177,6 +179,16 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") TVM_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex); TVM_REGISTER_GLOBAL("meta_schedule.TuneContextInitialize") .set_body_method(&TuneContextNode::Initialize); +TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSetMeasureCandidates") + .set_body_method(&TuneContextNode::_SetMeasureCandidates); +TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSendToBuilder") + .set_body_method(&TuneContextNode::_SendToBuilder); +TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSendToRunner") + .set_body_method(&TuneContextNode::_SendToRunner); +TVM_REGISTER_GLOBAL("meta_schedule.TuneContextJoin") + .set_body_method(&TuneContextNode::_Join); +TVM_REGISTER_GLOBAL("meta_schedule.TuneContextClearMeasureState") + .set_body_method(&TuneContextNode::_ClearMeasureState); } // namespace meta_schedule } // namespace tvm