Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
shingjan committed Aug 23, 2021
1 parent be1b456 commit a5bb136
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions python/tvm/autotvm/tuner/xgboost_cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
# pylint: disable=invalid-name
"""XGBoost as cost model"""

import multiprocessing
import logging
import time

import numpy as np

from tvm.contrib.popen_pool import PopenPoolExecutor
from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind

from .. import feature
from ..utils import get_rank
Expand Down Expand Up @@ -146,12 +145,6 @@ def __init__(
self._sample_size = 0
self._reset_pool(self.space, self.target, self.task)

def _initializer(space, target, task):
global _extract_space, _extract_target, _extract_task
_extract_space = space
_extract_target = target
_extract_task = task

def _reset_pool(self, space, target, task):
"""reset processing pool for feature extraction"""

Expand All @@ -163,7 +156,7 @@ def _reset_pool(self, space, target, task):

self.pool = PopenPoolExecutor(
max_workers=self.num_threads,
initializer=XGBoostCostModel._initializer,
initializer=_extract_initializer,
initargs=(space, target, task),
)

Expand Down Expand Up @@ -331,7 +324,7 @@ def _get_feature(self, indexes):
pool = self._get_pool()
feas = pool.map_with_error_catching(self.feature_extract_func, need_extract)
for i, fea in zip(need_extract, feas):
fea_cache[i] = fea.value
fea_cache[i] = fea.value if fea.status == StatusKind.COMPLETE else None

feature_len = None
for idx in indexes:
Expand All @@ -355,6 +348,13 @@ def __del__(self):
_extract_task = None


def _extract_initializer(space, target, task):
global _extract_space, _extract_target, _extract_task
_extract_space = space
_extract_target = target
_extract_task = task


def _extract_itervar_feature_index(args):
"""extract iteration var feature for an index in extract_space"""
try:
Expand Down

0 comments on commit a5bb136

Please sign in to comment.