Skip to content

Commit

Permalink
[MetaSchedule] Distributed Measurement (apache#11683)
Browse files Browse the repository at this point in the history
This PR includes the distributed measurement of tuning candidates using builder and async runner, as well as some auxiliary functions. It enables multiple builders and multiple runners with a tracker connecting in between. The hierarchy of files in the database can be further compacted to make the database more concise.
  • Loading branch information
Kathryn-cat authored and blackkker committed Jul 7, 2022
1 parent 43abf16 commit 340753f
Show file tree
Hide file tree
Showing 9 changed files with 361 additions and 13 deletions.
27 changes: 27 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ struct WorkloadEqual {
}
};

/*! \brief The class of measure candidates. */
class MeasureCandidate;

/*! \brief The class of tuning records. */
class TuningRecordNode : public runtime::Object {
public:
Expand All @@ -123,6 +126,9 @@ class TuningRecordNode : public runtime::Object {
static constexpr const char* _type_key = "meta_schedule.TuningRecord";
TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object);

/*! \brief Construct the measure candidate given the initial IR module and trace
* stored in the tuning record. */
MeasureCandidate AsMeasureCandidate() const;
/*!
* \brief Export the tuning record to a JSON string.
* \return An array containing the trace, running secs, serialized target, and
Expand Down Expand Up @@ -187,6 +193,11 @@ class DatabaseNode : public runtime::Object {
* \return An array of top K tuning records for the given workload.
*/
virtual Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
/*!
* \brief Get all tuning records from the database.
* \return An Array of all the tuning records in the database.
*/
virtual Array<TuningRecord> GetAllTuningRecords() = 0;
/*!
* \brief Get the size of the database.
* \return The size of the database.
Expand Down Expand Up @@ -224,6 +235,11 @@ class PyDatabaseNode : public DatabaseNode {
* \return An array of top K tuning records for the given workload.
*/
using FGetTopK = runtime::TypedPackedFunc<Array<TuningRecord>(const Workload&, int)>;
/*!
* \brief The function type of `GetAllTuningRecords` method.
* \return An Array of all the tuning records in the database.
*/
using FGetAllTuningRecords = runtime::TypedPackedFunc<Array<TuningRecord>()>;
/*!
* \brief The function type of `Size` method.
* \return The size of the database.
Expand All @@ -238,6 +254,8 @@ class PyDatabaseNode : public DatabaseNode {
FCommitTuningRecord f_commit_tuning_record;
/*! \brief The packed function to the `GetTopK` function. */
FGetTopK f_get_top_k;
/*! \brief The packed function to the `GetAllTuningRecords` function. */
FGetAllTuningRecords f_get_all_tuning_records;
/*! \brief The packed function to the `Size` function. */
FSize f_size;

Expand All @@ -249,6 +267,7 @@ class PyDatabaseNode : public DatabaseNode {
// `f_commit_workload` is not visited
// `f_commit_tuning_record` is not visited
// `f_get_top_k` is not visited
// `f_get_all_tuning_records` is not visited
// `f_size` is not visited
}

Expand All @@ -273,6 +292,12 @@ class PyDatabaseNode : public DatabaseNode {
return f_get_top_k(workload, top_k);
}

Array<TuningRecord> GetAllTuningRecords() final {
ICHECK(f_get_all_tuning_records != nullptr)
<< "PyDatabase's GetAllTuningRecords method not implemented!";
return f_get_all_tuning_records();
}

int64_t Size() final {
ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
return f_size();
Expand Down Expand Up @@ -302,13 +327,15 @@ class Database : public runtime::ObjectRef {
* \param f_commit_workload The packed function of `CommitWorkload`.
* \param f_commit_tuning_record The packed function of `CommitTuningRecord`.
* \param f_get_top_k The packed function of `GetTopK`.
* \param f_get_all_tuning_records The packed function of `GetAllTuningRecords`.
* \param f_size The packed function of `Size`.
* \return The created database.
*/
TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
PyDatabaseNode::FCommitWorkload f_commit_workload,
PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
PyDatabaseNode::FSize f_size);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode);
};
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ def __init__( # type: ignore # pylint: disable=too-many-arguments
args_info,
)

def as_measure_candidate(self) -> Any:
"""Generate a measure candidate given an initial IR module and a trace
stored in the tuning record.
Returns
-------
candidate : MeasureCandidate
A generated candidate.
"""
return _ffi_api.TuningRecordAsMeasureCandidate(self) # type: ignore # pylint: disable=no-member

def as_json(self) -> Any:
"""Export the tuning record to a JSON string.
Expand Down Expand Up @@ -203,6 +214,16 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
"""
return _ffi_api.DatabaseGetTopK(self, workload, top_k) # type: ignore # pylint: disable=no-member

def get_all_tuning_records(self) -> List[TuningRecord]:
"""Get all the tuning records from the database.
Returns
-------
tuning_records : List[TuningRecord]
All tuning records from the database.
"""
return _ffi_api.DatabaseGetAllTuningRecords(self) # type: ignore # pylint: disable=no-member

def __len__(self) -> int:
"""Get the number of records in the database.
Expand All @@ -229,6 +250,7 @@ def __init__(
f_commit_workload: Callable = None,
f_commit_tuning_record: Callable = None,
f_get_top_k: Callable = None,
f_get_all_tuning_records: Callable = None,
f_size: Callable = None,
):
"""Constructor."""
Expand All @@ -239,6 +261,7 @@ def __init__(
f_commit_workload,
f_commit_tuning_record,
f_get_top_k,
f_get_all_tuning_records,
f_size,
)

Expand All @@ -258,6 +281,7 @@ class PyDatabase:
"commit_workload",
"commit_tuning_record",
"get_top_k",
"get_all_tuning_records",
"__len__",
],
}
Expand Down Expand Up @@ -317,6 +341,16 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
"""
raise NotImplementedError

def get_all_tuning_records(self) -> List[TuningRecord]:
"""Get all the tuning records from the database.
Returns
-------
tuning_records : List[TuningRecord]
All tuning records from the database.
"""
raise NotImplementedError

def __len__(self) -> int:
"""Get the number of records in the database.
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/meta_schedule/database/memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
)
)[: int(top_k)]

def get_all_tuning_records(self) -> List[TuningRecord]:
return self.records

def __len__(self) -> int:
return len(self.records)

Expand Down
23 changes: 12 additions & 11 deletions python/tvm/meta_schedule/testing/dataset_sample_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def sample_candidates(task, task_name, model_name):
-------
None
"""
candidate_path = os.path.join(
args.candidate_cache_dir, model_name, task_name + "_candidates.json"
)
workload_path = os.path.join(args.candidate_cache_dir, model_name, task_name + "_workload.json")
database = ms.database.JSONDatabase(
path_workload=workload_path,
path_tuning_record=candidate_path,
)
sample_init_population = tvm.get_global_func(
"meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation"
)
Expand All @@ -128,7 +136,7 @@ def sample_candidates(task, task_name, model_name):
context.initialize()
context.pre_tuning(
context.generate_design_space(),
database=ms.database.MemoryDatabase(), # type: ignore
database=database,
cost_model=ms.cost_model.RandomModel(), # type: ignore
)

Expand All @@ -148,16 +156,9 @@ def sample_candidates(task, task_name, model_name):
all_states = all_states[: args.num_samples_per_task]

workload = ms.database.Workload(context.mod)
file_path = os.path.join(args.candidate_cache_dir, model_name, task_name + ".json")
with open(file_path, "w", encoding="utf8") as file:
for i, state in enumerate(all_states):
tuning_record = ms.database.TuningRecord(state.trace, workload)
json_str = json.dumps(tuning_record.as_json())
assert "\n" not in json_str, "Failed to generate single line string."
if i == len(all_states) - 1:
file.write(json_str)
else:
file.write(json_str + "\n")
database.commit_workload(context.mod)
for state in all_states:
database.commit_tuning_record(ms.database.TuningRecord(state.trace, workload))


args = _parse_args() # pylint: disable=invalid-name
Expand Down
Loading

0 comments on commit 340753f

Please sign in to comment.