Skip to content

Commit

Permalink
[MetaSchedule] Distributed Measurement
Browse files Browse the repository at this point in the history
  • Loading branch information
Kathryn-cat committed Jun 17, 2022
1 parent 5aabeb7 commit 6c8456c
Show file tree
Hide file tree
Showing 9 changed files with 361 additions and 13 deletions.
27 changes: 27 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ struct WorkloadEqual {
}
};

/*! \brief The class of measure candidates. */
class MeasureCandidate;

/*! \brief The class of tuning records. */
class TuningRecordNode : public runtime::Object {
public:
Expand All @@ -123,6 +126,9 @@ class TuningRecordNode : public runtime::Object {
static constexpr const char* _type_key = "meta_schedule.TuningRecord";
TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object);

/*! \brief Construct the measure candidate given the initial IR module and trace
* stored in the tuning record. */
MeasureCandidate AsMeasureCandidate() const;
/*!
* \brief Export the tuning record to a JSON string.
* \return An array containing the trace, running secs, serialized target, and
Expand Down Expand Up @@ -187,6 +193,11 @@ class DatabaseNode : public runtime::Object {
* \return An array of top K tuning records for the given workload.
*/
virtual Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
/*!
* \brief Get all tuning records from the database.
* \return An Array of all the tuning records in the database.
*/
virtual Array<TuningRecord> GetAllTuningRecords() = 0;
/*!
* \brief Get the size of the database.
* \return The size of the database.
Expand Down Expand Up @@ -224,6 +235,11 @@ class PyDatabaseNode : public DatabaseNode {
* \return An array of top K tuning records for the given workload.
*/
using FGetTopK = runtime::TypedPackedFunc<Array<TuningRecord>(const Workload&, int)>;
/*!
* \brief The function type of `GetAllTuningRecords` method.
* \return An Array of all the tuning records in the database.
*/
using FGetAllTuningRecords = runtime::TypedPackedFunc<Array<TuningRecord>()>;
/*!
* \brief The function type of `Size` method.
* \return The size of the database.
Expand All @@ -238,6 +254,8 @@ class PyDatabaseNode : public DatabaseNode {
FCommitTuningRecord f_commit_tuning_record;
/*! \brief The packed function to the `GetTopK` function. */
FGetTopK f_get_top_k;
/*! \brief The packed function to the `GetAllTuningRecords` function. */
FGetAllTuningRecords f_get_all_tuning_records;
/*! \brief The packed function to the `Size` function. */
FSize f_size;

Expand All @@ -249,6 +267,7 @@ class PyDatabaseNode : public DatabaseNode {
// `f_commit_workload` is not visited
// `f_commit_tuning_record` is not visited
// `f_get_top_k` is not visited
// `f_get_all_tuning_records` is not visited
// `f_size` is not visited
}

Expand All @@ -273,6 +292,12 @@ class PyDatabaseNode : public DatabaseNode {
return f_get_top_k(workload, top_k);
}

Array<TuningRecord> GetAllTuningRecords() final {
ICHECK(f_get_all_tuning_records != nullptr)
<< "PyDatabase's GetAllTuningRecords method not implemented!";
return f_get_all_tuning_records();
}

int64_t Size() final {
ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
return f_size();
Expand Down Expand Up @@ -302,13 +327,15 @@ class Database : public runtime::ObjectRef {
* \param f_commit_workload The packed function of `CommitWorkload`.
* \param f_commit_tuning_record The packed function of `CommitTuningRecord`.
* \param f_get_top_k The packed function of `GetTopK`.
* \param f_get_all_tuning_records The packed function of `GetAllTuningRecords`.
* \param f_size The packed function of `Size`.
* \return The created database.
*/
TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
PyDatabaseNode::FCommitWorkload f_commit_workload,
PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
PyDatabaseNode::FSize f_size);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode);
};
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ def __init__( # type: ignore # pylint: disable=too-many-arguments
args_info,
)

def as_measure_candidate(self) -> Any:
"""Generate a measure candidate given an initial IR module and a trace
stored in the tuning record.
Returns
-------
candidate : MeasureCandidate
A generated candidate.
"""
return _ffi_api.TuningRecordAsMeasureCandidate(self) # type: ignore # pylint: disable=no-member

def as_json(self) -> Any:
"""Export the tuning record to a JSON string.
Expand Down Expand Up @@ -203,6 +214,16 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
"""
return _ffi_api.DatabaseGetTopK(self, workload, top_k) # type: ignore # pylint: disable=no-member

def get_all_tuning_records(self) -> List[TuningRecord]:
"""Get all the tuning records from the database.
Returns
-------
tuning_records : List[TuningRecord]
All tuning records from the database.
"""
return _ffi_api.DatabaseGetAllTuningRecords(self) # type: ignore # pylint: disable=no-member

def __len__(self) -> int:
"""Get the number of records in the database.
Expand All @@ -229,6 +250,7 @@ def __init__(
f_commit_workload: Callable = None,
f_commit_tuning_record: Callable = None,
f_get_top_k: Callable = None,
f_get_all_tuning_records: Callable = None,
f_size: Callable = None,
):
"""Constructor."""
Expand All @@ -239,6 +261,7 @@ def __init__(
f_commit_workload,
f_commit_tuning_record,
f_get_top_k,
f_get_all_tuning_records,
f_size,
)

Expand All @@ -258,6 +281,7 @@ class PyDatabase:
"commit_workload",
"commit_tuning_record",
"get_top_k",
"get_all_tuning_records",
"__len__",
],
}
Expand Down Expand Up @@ -317,6 +341,16 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
"""
raise NotImplementedError

def get_all_tuning_records(self) -> List[TuningRecord]:
"""Get all the tuning records from the database.
Returns
-------
tuning_records : List[TuningRecord]
All tuning records from the database.
"""
raise NotImplementedError

def __len__(self) -> int:
"""Get the number of records in the database.
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/meta_schedule/database/memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
)
)[: int(top_k)]

def get_all_tuning_records(self) -> List[TuningRecord]:
return self.records

def __len__(self) -> int:
return len(self.records)

Expand Down
23 changes: 12 additions & 11 deletions python/tvm/meta_schedule/testing/dataset_sample_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def sample_candidates(task, task_name, model_name):
-------
None
"""
candidate_path = os.path.join(
args.candidate_cache_dir, model_name, task_name + "_candidates.json"
)
workload_path = os.path.join(args.candidate_cache_dir, model_name, task_name + "_workload.json")
database = ms.database.JSONDatabase(
path_workload=workload_path,
path_tuning_record=candidate_path,
)
sample_init_population = tvm.get_global_func(
"meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation"
)
Expand All @@ -128,7 +136,7 @@ def sample_candidates(task, task_name, model_name):
context.initialize()
context.pre_tuning(
context.generate_design_space(),
database=ms.database.MemoryDatabase(), # type: ignore
database=database,
cost_model=ms.cost_model.RandomModel(), # type: ignore
)

Expand All @@ -148,16 +156,9 @@ def sample_candidates(task, task_name, model_name):
all_states = all_states[: args.num_samples_per_task]

workload = ms.database.Workload(context.mod)
file_path = os.path.join(args.candidate_cache_dir, model_name, task_name + ".json")
with open(file_path, "w", encoding="utf8") as file:
for i, state in enumerate(all_states):
tuning_record = ms.database.TuningRecord(state.trace, workload)
json_str = json.dumps(tuning_record.as_json())
assert "\n" not in json_str, "Failed to generate single line string."
if i == len(all_states) - 1:
file.write(json_str)
else:
file.write(json_str + "\n")
database.commit_workload(context.mod)
for state in all_states:
database.commit_tuning_record(ms.database.TuningRecord(state.trace, workload))


args = _parse_args() # pylint: disable=invalid-name
Expand Down
Loading

0 comments on commit 6c8456c

Please sign in to comment.