Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Distributed Measurement #11683

Merged
merged 1 commit into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ struct WorkloadEqual {
}
};

/*! \brief The class of measure candidates. */
class MeasureCandidate;

/*! \brief The class of tuning records. */
class TuningRecordNode : public runtime::Object {
public:
Expand All @@ -123,6 +126,9 @@ class TuningRecordNode : public runtime::Object {
static constexpr const char* _type_key = "meta_schedule.TuningRecord";
TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object);

/*! \brief Construct the measure candidate given the initial IR module and trace
* stored in the tuning record. */
MeasureCandidate AsMeasureCandidate() const;
/*!
* \brief Export the tuning record to a JSON string.
* \return An array containing the trace, running secs, serialized target, and
Expand Down Expand Up @@ -187,6 +193,11 @@ class DatabaseNode : public runtime::Object {
* \return An array of top K tuning records for the given workload.
*/
virtual Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
/*!
* \brief Get all tuning records from the database.
* \return An Array of all the tuning records in the database.
*/
virtual Array<TuningRecord> GetAllTuningRecords() = 0;
/*!
* \brief Get the size of the database.
* \return The size of the database.
Expand Down Expand Up @@ -224,6 +235,11 @@ class PyDatabaseNode : public DatabaseNode {
* \return An array of top K tuning records for the given workload.
*/
using FGetTopK = runtime::TypedPackedFunc<Array<TuningRecord>(const Workload&, int)>;
/*!
* \brief The function type of `GetAllTuningRecords` method.
* \return An Array of all the tuning records in the database.
*/
using FGetAllTuningRecords = runtime::TypedPackedFunc<Array<TuningRecord>()>;
/*!
* \brief The function type of `Size` method.
* \return The size of the database.
Expand All @@ -238,6 +254,8 @@ class PyDatabaseNode : public DatabaseNode {
FCommitTuningRecord f_commit_tuning_record;
/*! \brief The packed function to the `GetTopK` function. */
FGetTopK f_get_top_k;
/*! \brief The packed function to the `GetAllTuningRecords` function. */
FGetAllTuningRecords f_get_all_tuning_records;
/*! \brief The packed function to the `Size` function. */
FSize f_size;

Expand All @@ -249,6 +267,7 @@ class PyDatabaseNode : public DatabaseNode {
// `f_commit_workload` is not visited
// `f_commit_tuning_record` is not visited
// `f_get_top_k` is not visited
// `f_get_all_tuning_records` is not visited
// `f_size` is not visited
}

Expand All @@ -273,6 +292,12 @@ class PyDatabaseNode : public DatabaseNode {
return f_get_top_k(workload, top_k);
}

Array<TuningRecord> GetAllTuningRecords() final {
ICHECK(f_get_all_tuning_records != nullptr)
<< "PyDatabase's GetAllTuningRecords method not implemented!";
return f_get_all_tuning_records();
}

int64_t Size() final {
ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
return f_size();
Expand Down Expand Up @@ -302,13 +327,15 @@ class Database : public runtime::ObjectRef {
* \param f_commit_workload The packed function of `CommitWorkload`.
* \param f_commit_tuning_record The packed function of `CommitTuningRecord`.
* \param f_get_top_k The packed function of `GetTopK`.
* \param f_get_all_tuning_records The packed function of `GetAllTuningRecords`.
* \param f_size The packed function of `Size`.
* \return The created database.
*/
TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
PyDatabaseNode::FCommitWorkload f_commit_workload,
PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
PyDatabaseNode::FSize f_size);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode);
};
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ def __init__( # type: ignore # pylint: disable=too-many-arguments
args_info,
)

def as_measure_candidate(self) -> Any:
"""Generate a measure candidate given an initial IR module and a trace
stored in the tuning record.

Returns
-------
candidate : MeasureCandidate
A generated candidate.
"""
return _ffi_api.TuningRecordAsMeasureCandidate(self) # type: ignore # pylint: disable=no-member

def as_json(self) -> Any:
"""Export the tuning record to a JSON string.

Expand Down Expand Up @@ -203,6 +214,16 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
"""
return _ffi_api.DatabaseGetTopK(self, workload, top_k) # type: ignore # pylint: disable=no-member

def get_all_tuning_records(self) -> List[TuningRecord]:
"""Get all the tuning records from the database.

Returns
-------
tuning_records : List[TuningRecord]
All tuning records from the database.
"""
return _ffi_api.DatabaseGetAllTuningRecords(self) # type: ignore # pylint: disable=no-member

def __len__(self) -> int:
"""Get the number of records in the database.

Expand All @@ -229,6 +250,7 @@ def __init__(
f_commit_workload: Callable = None,
f_commit_tuning_record: Callable = None,
f_get_top_k: Callable = None,
f_get_all_tuning_records: Callable = None,
f_size: Callable = None,
):
"""Constructor."""
Expand All @@ -239,6 +261,7 @@ def __init__(
f_commit_workload,
f_commit_tuning_record,
f_get_top_k,
f_get_all_tuning_records,
f_size,
)

Expand All @@ -258,6 +281,7 @@ class PyDatabase:
"commit_workload",
"commit_tuning_record",
"get_top_k",
"get_all_tuning_records",
"__len__",
],
}
Expand Down Expand Up @@ -317,6 +341,16 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
"""
raise NotImplementedError

def get_all_tuning_records(self) -> List[TuningRecord]:
"""Get all the tuning records from the database.

Returns
-------
tuning_records : List[TuningRecord]
All tuning records from the database.
"""
raise NotImplementedError

def __len__(self) -> int:
"""Get the number of records in the database.

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/meta_schedule/database/memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
)
)[: int(top_k)]

def get_all_tuning_records(self) -> List[TuningRecord]:
return self.records

def __len__(self) -> int:
return len(self.records)

Expand Down
23 changes: 12 additions & 11 deletions python/tvm/meta_schedule/testing/dataset_sample_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def sample_candidates(task, task_name, model_name):
-------
None
"""
candidate_path = os.path.join(
args.candidate_cache_dir, model_name, task_name + "_candidates.json"
)
workload_path = os.path.join(args.candidate_cache_dir, model_name, task_name + "_workload.json")
database = ms.database.JSONDatabase(
path_workload=workload_path,
path_tuning_record=candidate_path,
)
sample_init_population = tvm.get_global_func(
"meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation"
)
Expand All @@ -128,7 +136,7 @@ def sample_candidates(task, task_name, model_name):
context.initialize()
context.pre_tuning(
context.generate_design_space(),
database=ms.database.MemoryDatabase(), # type: ignore
database=database,
cost_model=ms.cost_model.RandomModel(), # type: ignore
)

Expand All @@ -148,16 +156,9 @@ def sample_candidates(task, task_name, model_name):
all_states = all_states[: args.num_samples_per_task]

workload = ms.database.Workload(context.mod)
file_path = os.path.join(args.candidate_cache_dir, model_name, task_name + ".json")
with open(file_path, "w", encoding="utf8") as file:
for i, state in enumerate(all_states):
tuning_record = ms.database.TuningRecord(state.trace, workload)
json_str = json.dumps(tuning_record.as_json())
assert "\n" not in json_str, "Failed to generate single line string."
if i == len(all_states) - 1:
file.write(json_str)
else:
file.write(json_str + "\n")
database.commit_workload(context.mod)
for state in all_states:
database.commit_tuning_record(ms.database.TuningRecord(state.trace, workload))


args = _parse_args() # pylint: disable=invalid-name
Expand Down
Loading