Skip to content

Commit

Permalink
remove exception handling of autotvm xgboost extract functions (#10948)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuanjing Shi committed Apr 12, 2022
1 parent 98fc649 commit 856b5c6
Showing 1 changed file with 49 additions and 69 deletions.
118 changes: 49 additions & 69 deletions python/tvm/autotvm/tuner/xgboost_cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,98 +360,78 @@ def _extract_popen_initializer(space, target, task):

def _extract_itervar_feature_index(args):
"""extract iteration var feature for an index in extract_space"""
try:
config = _extract_space.get(args)
with _extract_target:
sch, fargs = _extract_task.instantiate(config)
config = _extract_space.get(args)
with _extract_target:
sch, fargs = _extract_task.instantiate(config)

fea = feature.get_itervar_feature_flatten(sch, fargs, take_log=True)
fea = np.concatenate((fea, list(config.get_other_option().values())))
return fea
except Exception: # pylint: disable=broad-except
return None
fea = feature.get_itervar_feature_flatten(sch, fargs, take_log=True)
fea = np.concatenate((fea, list(config.get_other_option().values())))
return fea


def _extract_itervar_feature_log(arg):
"""extract iteration var feature for log items"""
try:
inp, res = arg
config = inp.config
with inp.target:
sch, args = inp.task.instantiate(config)
fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
x = np.concatenate((fea, list(config.get_other_option().values())))

if res.error_no == 0:
y = inp.task.flop / np.mean(res.costs)
else:
y = 0.0
return x, y
except Exception: # pylint: disable=broad-except
return None
inp, res = arg
config = inp.config
with inp.target:
sch, args = inp.task.instantiate(config)
fea = feature.get_itervar_feature_flatten(sch, args, take_log=True)
x = np.concatenate((fea, list(config.get_other_option().values())))

if res.error_no == 0:
y = inp.task.flop / np.mean(res.costs)
else:
y = 0.0
return x, y


def _extract_knob_feature_index(args):
"""extract knob feature for an index in extract_space"""
try:

config = _extract_space.get(args)
config = _extract_space.get(args)

return config.get_flatten_feature()
except Exception: # pylint: disable=broad-except
return None
return config.get_flatten_feature()


def _extract_knob_feature_log(arg):
"""extract knob feature for log items"""
try:
inp, res = arg
config = inp.config
x = config.get_flatten_feature()

if res.error_no == 0:
with inp.target: # necessary, for calculating flops of this task
inp.task.instantiate(config)
y = inp.task.flop / np.mean(res.costs)
else:
y = 0.0
return x, y
except Exception: # pylint: disable=broad-except
return None
inp, res = arg
config = inp.config
x = config.get_flatten_feature()

if res.error_no == 0:
with inp.target: # necessary, for calculating flops of this task
inp.task.instantiate(config)
y = inp.task.flop / np.mean(res.costs)
else:
y = 0.0
return x, y


def _extract_curve_feature_index(args):
"""extract sampled curve feature for an index in extract_space"""
try:
config = _extract_space.get(args)
with _extract_target:
sch, fargs = _extract_task.instantiate(config)

config = _extract_space.get(args)
with _extract_target:
sch, fargs = _extract_task.instantiate(config)

fea = feature.get_buffer_curve_sample_flatten(sch, fargs, sample_n=20)
fea = np.concatenate((fea, list(config.get_other_option().values())))
return np.array(fea)
except Exception: # pylint: disable=broad-except
return None
fea = feature.get_buffer_curve_sample_flatten(sch, fargs, sample_n=20)
fea = np.concatenate((fea, list(config.get_other_option().values())))
return np.array(fea)


def _extract_curve_feature_log(arg):
"""extract sampled curve feature for log items"""
try:
inp, res = arg
config = inp.config
with inp.target:
sch, args = inp.task.instantiate(config)
fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
x = np.concatenate((fea, list(config.get_other_option().values())))

if res.error_no == 0:
y = inp.task.flop / np.mean(res.costs)
else:
y = 0.0
return x, y
except Exception: # pylint: disable=broad-except
return None
inp, res = arg
config = inp.config
with inp.target:
sch, args = inp.task.instantiate(config)
fea = feature.get_buffer_curve_sample_flatten(sch, args, sample_n=20)
x = np.concatenate((fea, list(config.get_other_option().values())))

if res.error_no == 0:
y = inp.task.flop / np.mean(res.costs)
else:
y = 0.0
return x, y


def custom_callback(
Expand Down

0 comments on commit 856b5c6

Please sign in to comment.