Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTVM] Use PopenPool in XGBoostCostModel #8820

Merged
merged 7 commits into from
Aug 24, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 33 additions & 41 deletions python/tvm/autotvm/tuner/xgboost_cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
# pylint: disable=invalid-name
"""XGBoost as cost model"""

import multiprocessing
import logging
import time

import numpy as np

from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind

from .. import feature
from ..utils import get_rank
from .metric import max_curve, recall_curve, cover_curve
Expand Down Expand Up @@ -153,20 +154,14 @@ def _reset_pool(self, space, target, task):

self._close_pool()

# Use global variable to pass common arguments. This is only used when
# new processes are started with fork. We have to set the globals
# before we create the pool, so that processes in the pool get the
# correct globals.
global _extract_space, _extract_target, _extract_task
_extract_space = space
_extract_target = target
_extract_task = task
self.pool = multiprocessing.Pool(self.num_threads)
self.pool = PopenPoolExecutor(
max_workers=self.num_threads,
initializer=_extract_popen_initializer,
initargs=(space, target, task),
)

def _close_pool(self):
if self.pool:
self.pool.terminate()
self.pool.join()
self.pool = None

def _get_pool(self):
Expand Down Expand Up @@ -247,13 +242,16 @@ def fit_log(self, records, plan_size, min_seed_records=500):
feature_extract_func = _extract_curve_feature_log
else:
raise RuntimeError("Invalid feature type: " + self.fea_type)
res = pool.map(feature_extract_func, data)
result = pool.map_with_error_catching(feature_extract_func, data)

# filter out feature with different shapes
fea_len = len(self._get_feature([0])[0])

xs, ys = [], []
for x, y in res:
for res in result:
shingjan marked this conversation as resolved.
Show resolved Hide resolved
if res.status != StatusKind.COMPLETE:
continue
x, y = res.value
if len(x) == fea_len:
xs.append(x)
ys.append(y)
Expand Down Expand Up @@ -327,14 +325,9 @@ def _get_feature(self, indexes):

if need_extract:
pool = self._get_pool()
# If we are forking, we can pass arguments in globals for better performance
if multiprocessing.get_start_method(False) == "fork":
feas = pool.map(self.feature_extract_func, need_extract)
else:
args = [(self.space.get(x), self.target, self.task) for x in need_extract]
feas = pool.map(self.feature_extract_func, args)
feas = pool.map_with_error_catching(self.feature_extract_func, need_extract)
for i, fea in zip(need_extract, feas):
fea_cache[i] = fea
fea_cache[i] = fea.value if fea.status == StatusKind.COMPLETE else None

feature_len = None
for idx in indexes:
Expand All @@ -358,17 +351,20 @@ def __del__(self):
_extract_task = None


def _extract_popen_initializer(space, target, task):
global _extract_space, _extract_target, _extract_task
_extract_space = space
_extract_target = target
_extract_task = task


def _extract_itervar_feature_index(args):
"""extract iteration var feature for an index in extract_space"""
try:
if multiprocessing.get_start_method(False) == "fork":
config = _extract_space.get(args)
with _extract_target:
sch, fargs = _extract_task.instantiate(config)
else:
config, target, task = args
with target:
sch, fargs = task.instantiate(config)
config = _extract_space.get(args)
with _extract_target:
sch, fargs = _extract_task.instantiate(config)

fea = feature.get_itervar_feature_flatten(sch, fargs, take_log=True)
fea = np.concatenate((fea, list(config.get_other_option().values())))
return fea
Expand Down Expand Up @@ -398,10 +394,9 @@ def _extract_itervar_feature_log(arg):
def _extract_knob_feature_index(args):
"""extract knob feature for an index in extract_space"""
try:
if multiprocessing.get_start_method(False) == "fork":
config = _extract_space.get(args)
else:
config = args[0]

config = _extract_space.get(args)

return config.get_flatten_feature()
except Exception: # pylint: disable=broad-except
return None
Expand All @@ -428,14 +423,11 @@ def _extract_knob_feature_log(arg):
def _extract_curve_feature_index(args):
"""extract sampled curve feature for an index in extract_space"""
try:
if multiprocessing.get_start_method(False) == "fork":
config = _extract_space.get(args)
with _extract_target:
sch, fargs = _extract_task.instantiate(config)
else:
config, target, task = args
with target:
sch, fargs = task.instantiate(config)

config = _extract_space.get(args)
with _extract_target:
sch, fargs = _extract_task.instantiate(config)

fea = feature.get_buffer_curve_sample_flatten(sch, fargs, sample_n=20)
fea = np.concatenate((fea, list(config.get_other_option().values())))
return np.array(fea)
Expand Down