diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h new file mode 100644 index 000000000000..7ba3c207e349 --- /dev/null +++ b/include/tvm/meta_schedule/database.h @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_DATABASE_H_ +#define TVM_META_SCHEDULE_DATABASE_H_ + +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +/*! \brief A workload, i.e. an IRModule and its structural hash. */ +class WorkloadNode : public runtime::Object { + public: + /*! \brief The type of structural hash */ + using THashCode = size_t; + /*! \brief The workload's IRModule. */ + IRModule mod; + /*! \brief The workload's structural hash. */ + THashCode shash; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("mod", &mod); + // `shash` is not visited because TVM FFI doesn't support uint64_t + } + + static constexpr const char* _type_key = "meta_schedule.Workload"; + TVM_DECLARE_FINAL_OBJECT_INFO(WorkloadNode, runtime::Object); + + /*! + * \brief Export the workload to a JSON string. + * \return An array containing the structural hash and the base64 json string. + */ + ObjectRef AsJSON() const; +}; + +/*! + * \brief Managed reference to WorkloadNode. + * \sa WorkloadNode + */ +class Workload : public runtime::ObjectRef { + public: + using THashCode = WorkloadNode::THashCode; + /*! + * \brief Constructor of Workload. + * \param mod The workload's IRModule. + */ + TVM_DLL explicit Workload(IRModule mod); + /*! + * \brief Constructor of Workload. + * \param mod The workload's IRModule. + * \param shash The workload's structural hash. + */ + TVM_DLL explicit Workload(IRModule mod, THashCode shash); + /*! + * \brief Create a workload from a json object. + * \param json_obj The json object. + * \return The created workload. + */ + TVM_DLL static Workload FromJSON(const ObjectRef& json_obj); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Workload, runtime::ObjectRef, WorkloadNode); +}; + +/*! \brief The hash method for Workload */ +struct WorkloadHash { + size_t operator()(const Workload& a) const { return a->shash; } +}; + +/*! \brief The equality check for Workload */ +struct WorkloadEqual { + bool operator()(const Workload& a, const Workload& b) const { + return a->shash == b->shash && tvm::StructuralEqual()(a->mod, b->mod); + } +}; + +/*! \brief The class of tuning records. */ +class TuningRecordNode : public runtime::Object { + public: + /*! \brief The trace tuned. */ + tir::Trace trace; + /*! \brief The profiling result in seconds. */ + Array run_secs; + /*! \brief The workload. */ + Workload workload{nullptr}; + /*! \brief The target for tuning. */ + Target target; + /*! \brief The argument information. */ + Array args_info; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("trace", &trace); + v->Visit("run_secs", &run_secs); + v->Visit("workload", &workload); + v->Visit("target", &target); + v->Visit("args_info", &args_info); + } + + static constexpr const char* _type_key = "meta_schedule.TuningRecord"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object); + + /*! + * \brief Export the tuning record to a JSON string. + * \return An array containing the trace, running secs, serialized target, and + * argument information. + */ + ObjectRef AsJSON() const; +}; + +/*! + * \brief The managed reference of TuningRecordNode. + * \sa TuningRecordNode + */ +class TuningRecord : public runtime::ObjectRef { + public: + /*! + \brief Constructor of a tuning record. + \param trace The trace of the tuning record. + \param run_secs The running time of the tuning record. + \param workload The workload of the tuning record. + \param target The target of the tuning record. + \param args_info The argument information of the tuning record. + */ + TVM_DLL explicit TuningRecord(tir::Trace trace, Array run_secs, Workload workload, + Target target, Array args_info); + /*! + * \brief Create a tuning record from a json object. + * \param json_obj The json object. + * \param workload The workload. + * \return The tuning record created. + */ + TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj, const Workload& workload); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode); +}; + +/* \brief The abstract interface of database. */ +class DatabaseNode : public runtime::Object { + public: + /*! \brief Default destructor */ + virtual ~DatabaseNode() = default; + /*! + * \brief Look up or add workload to the database if missing. + * \param mod The IRModule to be searched for or added. + * \return The workload corresponding to the given IRModule. + */ + virtual Workload CommitWorkload(const IRModule& mod) = 0; + /*! + * \brief Add a tuning record to the database. + * \param record The tuning record to be added. + */ + virtual void CommitTuningRecord(const TuningRecord& record) = 0; + /*! + * \brief Get the top K tuning records of given workload from the database. + * \param workload The workload to be searched for. + * \param top_k The number of top records to be returned. + * \return An array of top K tuning records for the given workload. + */ + virtual Array GetTopK(const Workload& workload, int top_k) = 0; + /*! + * \brief Get the size of the database. + * \return The size of the database. + */ + virtual int64_t Size() = 0; + + static constexpr const char* _type_key = "meta_schedule.Database"; + TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object); +}; + +/*! \brief The database with customized methods on the python-side. */ +class PyDatabaseNode : public DatabaseNode { + public: + /*! + * \brief The function type of `CommitWorkload` method. + * \param mod The IRModule to be searched for or added. + * \return The workload corresponding to the given IRModule. + */ + using FCommitWorkload = runtime::TypedPackedFunc; + /*! + * \brief The function type of `CommitTuningRecord` method. + * \param record The tuning record to be added. + */ + using FCommitTuningRecord = runtime::TypedPackedFunc; + /*! + * \brief The function type of `GetTopK` method. + * \param workload The workload to be searched for. + * \param top_k The number of top records to be returned. + * \return An array of top K tuning records for the given workload. + */ + using FGetTopK = runtime::TypedPackedFunc(const Workload&, int)>; + /*! + * \brief The function type of `Size` method. + * \return The size of the database. + */ + using FSize = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `CommitWorkload` function. */ + FCommitWorkload f_commit_workload; + /*! \brief The packed function to the `CommitTuningRecord` function. */ + FCommitTuningRecord f_commit_tuning_record; + /*! \brief The packed function to the `GetTopK` function. */ + FGetTopK f_get_top_k; + /*! \brief The packed function to the `Size` function. */ + FSize f_size; + + void VisitAttrs(tvm::AttrVisitor* v) { + // PackedFuncs are all not visited, because the reflection system doesn't take care of them, + // so it cannot be accessible on the python side. If there is such need from the future, + // we can then add corresponding accessor methods to help access on python. + // + // `f_commit_workload` is not visited + // `f_commit_tuning_record` is not visited + // `f_get_top_k` is not visited + // `f_size` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.PyDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode); + + Workload CommitWorkload(const IRModule& mod) final { return f_commit_workload(mod); } + + void CommitTuningRecord(const TuningRecord& record) final { f_commit_tuning_record(record); } + + Array GetTopK(const Workload& workload, int top_k) final { + return f_get_top_k(workload, top_k); + } + + int64_t Size() final { return f_size(); } +}; + +/*! + * \brief Managed reference to DatabaseNode. + * \sa DatabaseNode + */ +class Database : public runtime::ObjectRef { + public: + /*! + * \brief Create a default database that uses JSON file for tuning records. + * \param path_workload The path to the workload table. + * \param path_tuning_record The path to the database table. + * \param allow_missing Whether to create new file when the given path is not found. + */ + TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record, + bool allow_missing); + /*! + * \brief Create a database with customized methods on the python-side. + * \param f_commit_workload The packed function of `CommitWorkload`. + * \param f_commit_tuning_record The packed function of `CommitTuningRecord`. + * \param f_get_top_k The packed function of `GetTopK`. + * \param f_size The packed function of `Size`. + * \return The created database. + */ + TVM_DLL static Database PyDatabase(PyDatabaseNode::FCommitWorkload f_commit_workload, + PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, + PyDatabaseNode::FGetTopK f_get_top_k, + PyDatabaseNode::FSize f_size); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_DATABASE_H_ diff --git a/include/tvm/runtime/container/string.h b/include/tvm/runtime/container/string.h index 664d19818be1..bb9e7ff65adc 100644 --- a/include/tvm/runtime/container/string.h +++ b/include/tvm/runtime/container/string.h @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -149,6 +150,12 @@ class String : public ObjectRef { String(const char* other) // NOLINT(*) : String(std::string(other)) {} + /*! + * \brief Construct a new null object + */ + String(std::nullptr_t) // NOLINT(*) + : ObjectRef(nullptr) {} + /*! * \brief Change the value the reference object points to. * diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index c07b28b4fc9f..f8b2b026c83b 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. """Package `tvm.meta_schedule`. The meta schedule infrastructure.""" -from . import builder from . import arg_info +from . import builder +from . import database from . import space_generator +from .database import TuningRecord from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/database/__init__.py b/python/tvm/meta_schedule/database/__init__.py new file mode 100644 index 000000000000..dcd430d39407 --- /dev/null +++ b/python/tvm/meta_schedule/database/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.database package. +The database that stores serialized tuning records and workloads +""" +from .database import Database, PyDatabase, TuningRecord +from .json_database import JSONDatabase diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py new file mode 100644 index 000000000000..3d05441fe22b --- /dev/null +++ b/python/tvm/meta_schedule/database/database.py @@ -0,0 +1,240 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tuning record database""" +from typing import Any, List + +from tvm._ffi import register_object +from tvm.ir.module import IRModule +from tvm.runtime import Object +from tvm.target import Target +from tvm.tir.schedule import Trace + +from .. import _ffi_api +from ..arg_info import ArgInfo +from ..utils import _json_de_tvm + + +@register_object("meta_schedule.Workload") +class Workload(Object): + """A workload, i.e. an IRModule and its structural hash. + + Parameters + ---------- + mod : IRModule + The workload's IRModule + """ + + mod: IRModule + + def __init__(self, mod: IRModule) -> None: + self.__init_handle_by_constructor__( + _ffi_api.Workload, # type: ignore # pylint: disable=no-member + mod, + ) + + def as_json(self) -> Any: + """Export the workload to a JSON string. + + Returns + ------- + json_str : str + The JSON string exported. + """ + return _json_de_tvm(_ffi_api.WorkloadAsJSON(self)) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: Any) -> "Workload": + """Create a workload from a json object. + + Parameters + ---------- + json_obj : Any + The json object to parse. + + Returns + ------- + tuning_record : TuningRecord + The parsed tuning record. + """ + return _ffi_api.WorkloadFromJSON(json_obj) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.TuningRecord") +class TuningRecord(Object): + """The class of tuning records. + + Parameters + ---------- + trace : tvm.ir.Trace + The trace of the tuning record. + run_secs : List[float] + The run time of the tuning record. + workload : Workload + The workload of the tuning record. + target : Target + The target of the tuning record. + args_info : List[ArgInfo] + The argument information of the tuning record. + """ + + trace: Trace + run_secs: List[float] + workload: Workload + target: Target + args_info: List[ArgInfo] + + def __init__( + self, + trace: Trace, + run_secs: List[float], + workload: Workload, + target: Target, + args_info: List[ArgInfo], + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.TuningRecord, # type: ignore # pylint: disable=no-member + trace, + run_secs, + workload, + target, + args_info, + ) + + def as_json(self) -> Any: + """Export the tuning record to a JSON string. + + Returns + ------- + json_str : str + The JSON string exported. + """ + return _json_de_tvm(_ffi_api.TuningRecordAsJSON(self)) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: Any, workload: Workload) -> "TuningRecord": + """Create a tuning record from a json object. + + Parameters + ---------- + json_obj : Any + The json object to parse. + workload : Workload + The workload. + + Returns + ------- + tuning_record : TuningRecord + The parsed tuning record. + """ + return _ffi_api.TuningRecordFromJSON(json_obj, workload) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.Database") +class Database(Object): + """The abstract database interface.""" + + def commit_workload(self, mod: IRModule) -> Workload: + """Commit a workload to the database if missing. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for or added. + + Returns + ------- + workload : Workload + The workload corresponding to the given IRModule. + """ + return _ffi_api.DatabaseCommitWorkload(self, mod) # type: ignore # pylint: disable=no-member + + def commit_tuning_record(self, record: TuningRecord) -> None: + """Commit a tuning record to the database. + + Parameters + ---------- + record : TuningRecord + The tuning record to add. + """ + _ffi_api.DatabaseCommitTuningRecord(self, record) # type: ignore # pylint: disable=no-member + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + """Get the top K tuning records of given workload from the database. + + Parameters + ---------- + workload : Workload + The workload to be searched for. + top_k : int + The number of top records to get. + + Returns + ------- + top_k_records : List[TuningRecord] + The top K records. + """ + return _ffi_api.DatabaseGetTopK(self, workload, top_k) # type: ignore # pylint: disable=no-member + + def __len__(self) -> int: + """Get the number of records in the database. + + Returns + ------- + num_records : int + The number of records in the database + """ + return _ffi_api.DatabaseSize(self) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PyDatabase") +class PyDatabase(Database): + """An abstract Database with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + def f_commit_workload(mod: IRModule) -> Workload: + return self.commit_workload(mod) + + def f_commit_tuning_record(record: TuningRecord) -> None: + self.commit_tuning_record(record) + + def f_get_top_k(workload: Workload, top_k: int) -> List[TuningRecord]: + return self.get_top_k(workload, top_k) + + def f_size() -> int: + return len(self) + + self.__init_handle_by_constructor__( + _ffi_api.DatabasePyDatabase, # type: ignore # pylint: disable=no-member + f_commit_workload, + f_commit_tuning_record, + f_get_top_k, + f_size, + ) + + def commit_workload(self, mod: IRModule) -> Workload: + raise NotImplementedError + + def commit_tuning_record(self, record: TuningRecord) -> None: + raise NotImplementedError + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + raise NotImplementedError + + def __len__(self) -> int: + raise NotImplementedError diff --git a/python/tvm/meta_schedule/database/json_database.py b/python/tvm/meta_schedule/database/json_database.py new file mode 100644 index 000000000000..6897b82d9888 --- /dev/null +++ b/python/tvm/meta_schedule/database/json_database.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The default database that uses a JSON File to store tuning records""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .database import Database + + +@register_object("meta_schedule.JSONDatabase") +class JSONDatabase(Database): + """The class of tuning records. + + Parameters + ---------- + path_workload : str + The path to the workload table. + path_tuning_record : str + The path to the tuning record table. + """ + + path_workload: str + path_tuning_record: str + + def __init__( + self, + path_workload: str, + path_tuning_record: str, + allow_missing: bool = True, + ) -> None: + """Constructor. + + Parameters + ---------- + path_workload : str + The path to the workload table. + path_tuning_record : str + The path to the tuning record table. + allow_missing : bool + Whether to create new file when the given path is not found. + """ + self.__init_handle_by_constructor__( + _ffi_api.DatabaseJSONDatabase, # type: ignore # pylint: disable=no-member + path_workload, + path_tuning_record, + allow_missing, + ) diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index abde198cf6ec..e710b0ed06f3 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. """Utilities for meta schedule""" +import json import os import shutil -from typing import Any, Callable, Union +from typing import Any, Callable, List, Union import psutil @@ -126,3 +127,28 @@ def _json_de_tvm(obj: Any) -> Any: if isinstance(obj, Map): return {_json_de_tvm(k): _json_de_tvm(v) for k, v in obj.items()} raise TypeError("Not supported type: " + str(type(obj))) + + +@register_func("meta_schedule.json_obj2str") +def json_obj2str(json_obj: Any) -> str: + json_obj = _json_de_tvm(json_obj) + return json.dumps(json_obj) + + +@register_func("meta_schedule.batch_json_str2obj") +def batch_json_str2obj(json_strs: List[str]) -> List[Any]: + """Covert a list of JSON strings to a list of json objects. + Parameters + ---------- + json_strs : List[str] + The list of JSON strings + Returns + ------- + result : List[Any] + The list of json objects + """ + return [ + json.loads(json_str) + for json_str in map(str.strip, json_strs) + if json_str and (not json_str.startswith("#")) and (not json_str.startswith("//")) + ] diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc new file mode 100644 index 000000000000..e67b3d1ab9b6 --- /dev/null +++ b/src/meta_schedule/database/database.cc @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/******** Workload ********/ + +Workload::Workload(IRModule mod) { + ObjectPtr n = runtime::make_object(); + n->shash = tvm::StructuralHash()(mod); + n->mod = mod; + data_ = std::move(n); +} + +Workload::Workload(IRModule mod, Workload::THashCode shash) { + ObjectPtr n = runtime::make_object(); + n->mod = mod; + n->shash = shash; + data_ = std::move(n); +} + +ObjectRef WorkloadNode::AsJSON() const { + // Convert `this->mod` to JSON + std::string json_mod = tvm::SaveJSON(this->mod); + // Dump the JSON string to base64 + std::string b64_mod = Base64Encode(json_mod); + // Output + return Array{SHash2Str(this->shash), String(b64_mod)}; +} + +Workload Workload::FromJSON(const ObjectRef& json_obj) { + IRModule mod{nullptr}; + THashCode shash = 0; + try { + const ArrayNode* json_array = json_obj.as(); + CHECK(json_array && json_array->size() == 2); + // Load json[0] => shash + String str_shash = Downcast(json_array->at(0)); + // Load json[1] => mod + { + String b64_mod = Downcast(json_array->at(1)); + std::string json_mod = Base64Decode(b64_mod); + mod = Downcast(LoadJSON(json_mod)); + } + // Verify SHash(mod) == shash + shash = tvm::StructuralHash()(mod); + String recalc_shash = SHash2Str(shash); + CHECK_EQ(recalc_shash, str_shash) << "ValueError: Structural hash changed. Given: " << str_shash + << "; Recalculated: " << recalc_shash; + } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + return Workload(mod, shash); +} + +/******** TuningRecord ********/ + +TuningRecord::TuningRecord(tir::Trace trace, Array run_secs, Workload workload, + Target target, Array args_info) { + ObjectPtr n = make_object(); + n->trace = trace; + n->run_secs = run_secs; + n->workload = workload; + n->target = target; + n->args_info = args_info; + this->data_ = n; +} + +ObjectRef TuningRecordNode::AsJSON() const { + Array json_args_info; + json_args_info.reserve(args_info.size()); + for (const ArgInfo& arg_info : args_info) { + json_args_info.push_back(arg_info->AsJSON()); + } + return Array{trace->AsJSON(false), // + run_secs, // + target->Export(), // + json_args_info}; +} + +TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) { + tir::Trace trace{nullptr}; + Array run_secs{nullptr}; + Target target{nullptr}; + Array args_info; + try { + const ArrayNode* json_array = json_obj.as(); + CHECK(json_array && json_array->size() == 4); + // Load json[1] => run_secs + run_secs = Downcast>(json_array->at(1)); + // Load json[2] => target + target = Target(Downcast>(json_array->at(2))); + // Load json[3] => args_info + { + const ArrayNode* json_args_info = json_array->at(3).as(); + args_info.reserve(json_args_info->size()); + for (const ObjectRef& json_arg_info : *json_args_info) { + args_info.push_back(ArgInfo::FromJSON(json_arg_info)); + } + } + // Load json[0] => trace + { + const ObjectRef& json_trace = json_array->at(0); + tir::Schedule sch = + tir::Schedule::Traced(workload->mod, /*seed=*/-1, /*debug_mask=*/0, + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + tir::Trace::ApplyJSONToSchedule(json_trace, sch); + trace = sch->trace().value(); + } + } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + return TuningRecord(trace, run_secs, workload, target, args_info); +} + +/******** PyDatabase ********/ + +Database Database::PyDatabase(PyDatabaseNode::FCommitWorkload f_commit_workload, + PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, + PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FSize f_size) { + ObjectPtr n = make_object(); + n->f_commit_workload = f_commit_workload; + n->f_commit_tuning_record = f_commit_tuning_record; + n->f_get_top_k = f_get_top_k; + n->f_size = f_size; + return Database(n); +} + +/******** FFI ********/ + +TVM_REGISTER_NODE_TYPE(WorkloadNode); +TVM_REGISTER_NODE_TYPE(TuningRecordNode); +TVM_REGISTER_OBJECT_TYPE(DatabaseNode); +TVM_REGISTER_NODE_TYPE(PyDatabaseNode); +TVM_REGISTER_GLOBAL("meta_schedule.Workload").set_body_typed([](IRModule mod) { + return Workload(mod); +}); +TVM_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON") + .set_body_method(&WorkloadNode::AsJSON); +TVM_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON); +TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord") + .set_body_typed([](tir::Trace trace, Array run_secs, Workload workload, Target target, + Array args_info) { + return TuningRecord(trace, run_secs, workload, target, args_info); + }); +TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON") + .set_body_method(&TuningRecordNode::AsJSON); +TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload") + .set_body_method(&DatabaseNode::CommitWorkload); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord") + .set_body_method(&DatabaseNode::CommitTuningRecord); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK") + .set_body_method(&DatabaseNode::GetTopK); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method(&DatabaseNode::Size); +TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc new file mode 100644 index 000000000000..3efb72e2fa74 --- /dev/null +++ b/src/meta_schedule/database/json_database.cc @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The struct defining comparison function of sorting by mean run seconds. */ +struct SortTuningRecordByMeanRunSecs { + static const constexpr double kMaxMeanTime = 1e10; + + static double Mean(const Array& a) { + if (a.empty()) { + return kMaxMeanTime; + } + double sum = 0.0; + for (const FloatImm& i : a) { + sum += i->value; + } + return sum / a.size(); + } + + bool operator()(const TuningRecord& a, const TuningRecord& b) const { + double a_time = Mean(a->run_secs); + double b_time = Mean(b->run_secs); + return a_time < b_time; + } +}; + +/*! \brief The default database implementation, which mimics two database tables with two files. */ +class JSONDatabaseNode : public DatabaseNode { + public: + /*! \brief The path to the workload table */ + String path_workload; + /*! \brief The path to the tuning record table */ + String path_tuning_record; + /*! \brief All the workloads in the database */ + std::unordered_map workloads2idx_; + /*! \brief All the tuning records in the database */ + std::multiset tuning_records_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("path_workload", &path_workload); + v->Visit("path_tuning_record", &path_tuning_record); + // `workloads2idx_` is not visited + // `tuning_records_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.JSONDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode); + + public: + Workload CommitWorkload(const IRModule& mod) { + // Try to insert `mod` into `workloads_` + decltype(this->workloads2idx_)::iterator it; + bool inserted = false; + std::tie(it, inserted) = + this->workloads2idx_.emplace(Workload(mod, tvm::StructuralHash()(mod)), -1); + Workload workload = it->first; + // If `mod` is new in `workloads2idx_`, append it to the workload file + if (inserted) { + it->second = static_cast(this->workloads2idx_.size()) - 1; + JSONFileAppendLine(this->path_workload, JSONObj2Str(workload->AsJSON())); + } + return it->first; + } + + void CommitTuningRecord(const TuningRecord& record) { + this->tuning_records_.insert(record); + JSONFileAppendLine(this->path_tuning_record, + JSONObj2Str(Array{ + /*workload_index=*/Integer(this->workloads2idx_.at(record->workload)), + /*tuning_record=*/record->AsJSON() // + })); + } + + Array GetTopK(const Workload& workload, int top_k) { + CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; + if (top_k == 0) { + return {}; + } + Array results; + results.reserve(top_k); + int counter = 0; + for (const TuningRecord& record : this->tuning_records_) { + if (WorkloadEqual()(record->workload, workload)) { + results.push_back(record); + if (++counter == top_k) { + break; + } + } + } + return results; + } + + int64_t Size() { return tuning_records_.size(); } +}; + +Database Database::JSONDatabase(String path_workload, String path_tuning_record, + bool allow_missing) { + ObjectPtr n = make_object(); + // Load `n->workloads2idx_` from `path_workload` + std::vector workloads; + { + Array json_objs = JSONStr2Obj(JSONFileReadLines(path_workload, allow_missing)); + int n_objs = json_objs.size(); + n->workloads2idx_.reserve(n_objs); + workloads.reserve(n_objs); + for (int i = 0; i < n_objs; ++i) { + Workload workload = Workload::FromJSON(json_objs[i]); + n->workloads2idx_.emplace(workload, i); + workloads.push_back(workload); + } + } + // Load `n->tuning_records_` from `path_tuning_record` + { + Array json_objs = JSONStr2Obj(JSONFileReadLines(path_tuning_record, allow_missing)); + for (const ObjectRef& json_obj : json_objs) { + int workload_index = -1; + ObjectRef tuning_record{nullptr}; + try { + const ArrayNode* arr = json_obj.as(); + ICHECK_EQ(arr->size(), 2); + workload_index = Downcast(arr->at(0)); + tuning_record = arr->at(1); + } catch (std::runtime_error& e) { + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + n->tuning_records_.insert(TuningRecord::FromJSON(tuning_record, workloads[workload_index])); + } + } + n->path_workload = path_workload; + n->path_tuning_record = path_tuning_record; + return Database(n); +} + +TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseJSONDatabase").set_body_typed(Database::JSONDatabase); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index a2b5ac4d3184..4c9e1e2c10a1 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -19,15 +19,119 @@ #ifndef TVM_META_SCHEDULE_UTILS_H_ #define TVM_META_SCHEDULE_UTILS_H_ +#include #include #include +#include #include #include +#include +#include +#include + +#include #include "../support/array.h" +#include "../support/base64.h" namespace tvm { -namespace meta_schedule {} // namespace meta_schedule +namespace meta_schedule { + +/*! + * \brief Read lines from a json file. + * \param path The path to the json file. + * \param allow_missing Whether to create new file when the given path is not found. + * \return An array containing lines read from the json file. + */ +inline Array JSONFileReadLines(const String& path, bool allow_missing) { + std::ifstream is(path); + if (is.good()) { + Array results; + for (std::string str; std::getline(is, str);) { + results.push_back(str); + } + return results; + } + CHECK(allow_missing) << "ValueError: File doesn't exist: " << path; + std::ofstream os(path); + CHECK(os.good()) << "ValueError: Cannot create new file: " << path; + return {}; +} + +/*! + * \brief Append a line to a json file. + * \param path The path to the json file. + * \param line The line to append. + */ +inline void JSONFileAppendLine(const String& path, const std::string& line) { + std::ofstream os(path, std::ofstream::app); + CHECK(os.good()) << "ValueError: Cannot open the file to write: " << path; + os << line << std::endl; +} + +/*! + * \brief Get the base64 encoded result of a string. + * \param str The string to encode. + * \return The base64 encoded string. + */ +inline std::string Base64Encode(std::string str) { + std::string result; + dmlc::MemoryStringStream m_stream(&result); + support::Base64OutStream b64stream(&m_stream); + static_cast(&b64stream)->Write(str); + b64stream.Finish(); + return result; +} + +/*! + * \brief Get the base64 decoded result of a string. + * \param str The string to decode. + * \return The base64 decoded string. + */ +inline std::string Base64Decode(std::string str) { + std::string result; + dmlc::MemoryStringStream m_stream(&str); + support::Base64InStream b64stream(&m_stream); + b64stream.InitPosition(); + static_cast(&b64stream)->Read(&result); + return result; +} + +/*! + * \brief Parse lines of json string into a json object. + * \param lines The lines of json string. + * \return Array of json objects parsed. + * \note The function calls the python-side json parser in runtime registry. + */ +inline Array JSONStr2Obj(const Array& lines) { + static const runtime::PackedFunc* f_to_obj = + runtime::Registry::Get("meta_schedule.batch_json_str2obj"); + ICHECK(f_to_obj) << "IndexError: Cannot find the packed function " + "`meta_schedule.batch_json_str2obj` in the global registry"; + return (*f_to_obj)(lines); +} + +/*! + * \brief Serialize a json object into a json string. + * \param json_obj The json object to serialize. + * \return A string containing the serialized json object. + * \note The function calls the python-side json obj serializer in runtime registry. + */ +inline String JSONObj2Str(const ObjectRef& json_obj) { + static const runtime::PackedFunc* f_to_str = runtime::Registry::Get("meta_schedule.json_obj2str"); + ICHECK(f_to_str) << "IndexError: Cannot find the packed function " + "`meta_schedule.json_obj2str` in the global registry"; + return (*f_to_str)(json_obj); +} + +/*! + * \brief Converts a structural hash code to string + * \param hash_code The hash code + * \return The string representation of the hash code + */ +inline String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); } + +} // namespace meta_schedule } // namespace tvm #endif // TVM_META_SCHEDULE_UTILS_H_ diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py new file mode 100644 index 000000000000..feef023675b0 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -0,0 +1,274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +"""Test Meta Schedule Database""" +import os.path as osp +import sys +import tempfile +from typing import Callable + +import pytest + +import tvm +from tvm import tir +from tvm.ir.module import IRModule +from tvm.meta_schedule.arg_info import ArgInfo +from tvm.meta_schedule.database import JSONDatabase, TuningRecord +from tvm.script import ty +from tvm.tir import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# fmt: off + +@tvm.script.tir +class Matmul: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + tir.func_attr({"global_symbol": "main"}) + A = tir.match_buffer(a, (1024, 1024), "float32") + B = tir.match_buffer(b, (1024, 1024), "float32") + C = tir.match_buffer(c, (1024, 1024), "float32") + with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.tir +class MatmulRelu: + def main(a: ty.handle, b: ty.handle, d: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + D = tir.match_buffer(d, (16, 16), "float32") + C = tir.alloc_buffer((16, 16), "float32") + with tir.block([16, 16, tir.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + with tir.block([16, 16], "relu") as [vi, vj]: + D[vi, vj] = tir.max(C[vi, vj], 0.0) + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + i_tiles = [1, 1, 2, 512] + j_tiles = [1, 512, 1, 2] + k_tiles = [256, 4] + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=i_tiles) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=j_tiles) + k_0, k_1 = sch.split(loop=k, factors=k_tiles) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def _create_schedule(mod: IRModule, sch_fn: Callable[[Schedule], None]) -> Schedule: + sch = tir.Schedule(mod=mod, debug_mask="all") + sch_fn(sch) + return sch + + +def _create_tmp_database(tmpdir: str) -> JSONDatabase: + path_workload = osp.join(tmpdir, "workloads.json") + path_tuning_record = osp.join(tmpdir, "tuning_records.json") + return JSONDatabase(path_workload, path_tuning_record) + + +def _equal_record(a: TuningRecord, b: TuningRecord): + assert str(a.trace) == str(b.trace) + assert str(a.run_secs) == str(b.run_secs) + # AWAIT(@zxybazh): change to export after fixing "(bool)0" + assert str(a.target) == str(b.target) + assert tvm.ir.structural_equal(a.workload.mod, b.workload.mod) + for arg0, arg1 in zip(a.args_info, b.args_info): + assert str(arg0.as_json()) == str(arg1.as_json()) + + +def test_meta_schedule_tuning_record_round_trip(): + mod: IRModule = Matmul() + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir) + workload = database.commit_workload(mod) + record = TuningRecord( + _create_schedule(mod, _schedule_matmul).trace, + [1.5, 2.5, 1.8], + workload, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ) + database.commit_tuning_record(record) + new_record = TuningRecord.from_json(record.as_json(), workload) + _equal_record(record, new_record) + + +def test_meta_schedule_database_create(): + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir) + assert osp.exists(database.path_workload) + assert osp.exists(database.path_tuning_record) + + +def test_meta_schedule_database_add_entry(): + mod: IRModule = Matmul() + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir) + workload = database.commit_workload(mod) + record = TuningRecord( + _create_schedule(mod, _schedule_matmul).trace, + [1.5, 2.5, 1.8], + workload, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ) + database.commit_tuning_record(record) + assert len(database) == 1 + (ret,) = database.get_top_k(workload, 3) + _equal_record(ret, record) + + +def test_meta_schedule_database_missing(): + mod: IRModule = Matmul() + mod_2: IRModule = MatmulRelu() + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir) + workload = database.commit_workload(mod) + workload_2 = database.commit_workload(mod_2) + record = TuningRecord( + _create_schedule(mod, _schedule_matmul).trace, + [1.5, 2.5, 1.8], + workload, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ) + database.commit_tuning_record(record) + ret = database.get_top_k(workload_2, 3) + assert len(ret) == 0 + + +def test_meta_schedule_database_sorting(): + mod: IRModule = Matmul() + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir) + token = database.commit_workload(mod) + trace = _create_schedule(mod, _schedule_matmul).trace + records = [ + TuningRecord( + trace, + [7.0, 8.0, 9.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [1.0, 2.0, 3.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [4.0, 5.0, 6.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [1.1, 1.2, 600.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [1.0, 100.0, 6.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [4.0, 9.0, 8.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + ] + for record in records: + database.commit_tuning_record(record) + ret = database.get_top_k(token, 2) + assert len(ret) == 2 + try: + _equal_record(ret[0], records[2]) + _equal_record(ret[1], records[1]) + except AssertionError: + _equal_record(ret[0], records[1]) + _equal_record(ret[1], records[2]) + + +def test_meta_schedule_database_reload(): + mod: IRModule = Matmul() + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir) + token = database.commit_workload(mod) + trace = _create_schedule(mod, _schedule_matmul).trace + records = [ + TuningRecord( + trace, + [7.0, 8.0, 9.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [1.0, 2.0, 3.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [4.0, 5.0, 6.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + ] + for record in records: + database.commit_tuning_record(record) + new_database = JSONDatabase( # pylint: disable=unused-variable + path_workload=database.path_workload, + path_tuning_record=database.path_tuning_record, + ) + token = new_database.commit_workload(mod) + ret = new_database.get_top_k(token, 2) + assert len(ret) == 2 + try: + _equal_record(ret[0], records[2]) + _equal_record(ret[1], records[1]) + except AssertionError: + _equal_record(ret[0], records[1]) + _equal_record(ret[1], records[2]) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))