Skip to content

Commit

Permalink
[MetaSchedule] Support grouping in the cost model (#10811)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Mar 29, 2022
1 parent c2488ac commit ce28068
Show file tree
Hide file tree
Showing 10 changed files with 325 additions and 139 deletions.
243 changes: 153 additions & 90 deletions python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,29 @@
"""
XGBoost-based cost model
"""
from itertools import chain as itertools_chain
import logging
import os
import tempfile
from typing import Any, Callable, Dict, List, NamedTuple, Optional, TYPE_CHECKING, Tuple
from collections import OrderedDict
from itertools import chain as itertools_chain
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple

import numpy as np # type: ignore

from ...contrib.tar import tar, untar
from ...runtime import NDArray
from ..cost_model import PyCostModel
from ..feature_extractor import FeatureExtractor
from ..runner import RunnerResult
from ..search_strategy import MeasureCandidate
from ..utils import cpu_count, derived_object
from ..utils import cpu_count, derived_object, shash2hex
from .metric import max_curve

if TYPE_CHECKING:
from ..tune_context import TuneContext
import xgboost as xgb # type: ignore

from ..tune_context import TuneContext


logger = logging.getLogger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -75,8 +78,8 @@ class PackSum:

def __init__(
self,
xs: List[np.ndarray],
ys: Optional[np.ndarray],
xs: List[np.ndarray], # pylint: disable=invalid-name
ys: Optional[np.ndarray], # pylint: disable=invalid-name
):
"""Create PackSum format given a batch of samples
Expand Down Expand Up @@ -217,23 +220,63 @@ class XGBConfig(NamedTuple):
Default is None, which means to use physical number of cores.
"""

max_depth: int = 10
gamma: float = 0.001
min_child_weight: float = 0
eta: float = 0.2
seed: int = 43
nthread: Optional[int] = None

def to_dict(self):
xgb_params = {
return {
"max_depth": self.max_depth,
"gamma": self.gamma,
"min_child_weight": self.min_child_weight,
"eta": self.eta,
"seed": self.seed,
"nthread": self.nthread,
}
return xgb_params

max_depth: int = 10
gamma: float = 0.001
min_child_weight: float = 0
eta: float = 0.2
seed: int = 43
nthread: Optional[int] = None

class FeatureGroup:
"""Feature group
Parameters
----------
group_hash : str
The hash of the group
features : List[np.ndarray]
The features
costs : List[float]
The costs
min_cost : float
The minimum cost
"""

group_hash: str
features: List[np.ndarray]
costs: np.ndarray
min_cost: float

def __init__(
self,
group_hash: str,
features: List[np.ndarray],
costs: np.ndarray,
) -> None:
self.group_hash = group_hash
self.features = features
self.costs = costs
self.min_cost = np.min(costs)

def append(
self,
features: List[np.ndarray],
costs: np.ndarray,
) -> None:
self.features.extend(features)
self.costs = np.append(self.costs, costs)
self.min_cost = np.min(self.costs)


@derived_object
Expand Down Expand Up @@ -268,9 +311,8 @@ class XGBModel(PyCostModel):
verbose_eval: int
average_peak_n: int
# states
cached_features: List[np.ndarray]
cached_mean_costs: np.ndarray
cached_normalizer: Optional[float]
data: Dict[str, FeatureGroup]
data_size: int
booster: Optional["xgb.Booster"]

def __init__(
Expand All @@ -293,7 +335,7 @@ def __init__(
# model-related
if config.nthread is None:
# use physical core number
config = config._replace(nthread=cpu_count(logical=False))
config = config._replace(nthread=cpu_count(logical=True))
self.config = config
# behavior of randomness
self.num_warmup_samples = num_warmup_samples
Expand All @@ -302,9 +344,8 @@ def __init__(
self.verbose_eval = verbose_eval
self.average_peak_n = average_peak_n
# states
self.cached_features = []
self.cached_mean_costs = np.empty((0,), dtype="float64")
self.cached_normalizer = None
self.data = OrderedDict()
self.data_size = 0
self.booster = None

def load(self, path: str) -> None:
Expand All @@ -324,16 +365,29 @@ def load(self, path: str) -> None:
import xgboost as xgb # pylint: disable=import-outside-toplevel

with tempfile.TemporaryDirectory() as tmp_dir:
model_path = os.path.join(tmp_dir, "model.bin")
data_path = os.path.join(tmp_dir, "data.npy")
# Step 1. Untar
untar(path, tmp_dir)
self.booster = xgb.Booster()
self.booster.load_model(os.path.join(tmp_dir, "model.bin"))
self.cached_features = list(
np.load(os.path.join(tmp_dir, "cached_features.npy"), allow_pickle=True)
)
self.cached_mean_costs = np.load(
os.path.join(tmp_dir, "cached_mean_costs.npy"), allow_pickle=True
)
self._set_cached_normalizer()
# Step 2. Load data
data = OrderedDict()
data_size = 0
for group_hash, features, costs in np.load(data_path, allow_pickle=True):
data[group_hash] = FeatureGroup(
group_hash=group_hash,
features=list(features),
costs=costs,
)
data_size += len(costs)
# Step 3. Load the model
if os.path.exists(model_path):
booster = xgb.Booster()
booster.load_model(model_path)
else:
self.booster = None
self.data = data
self.data_size = data_size
self.booster = booster

def save(self, path: str) -> None:
"""Save the cost model to given file location.
Expand All @@ -349,26 +403,30 @@ def save(self, path: str) -> None:
previously cached feature vectors and results, so that the subsequent training process could
use all the existing data being stored on disk.
"""
import xgboost as xgb # pylint: disable=import-outside-toplevel

if self.booster is None:
# save all the parameters
self.booster = xgb.Booster(self.config.to_dict())
with tempfile.TemporaryDirectory() as tmp_dir:
self.booster.save_model(os.path.join(tmp_dir, "model.bin"))
model_path = os.path.join(tmp_dir, "model.bin")
data_path = os.path.join(tmp_dir, "data.npy")
# Step 1. Save the model
booster = self.booster
if booster is not None:
booster.save_model(model_path)
else:
model_path = None
# Step 2. Save data
data = [
(
g.group_hash,
g.features,
g.costs,
)
for g in self.data.values()
]
np.save(
os.path.join(tmp_dir, "cached_features.npy"),
np.array(self.cached_features, dtype=object),
)
np.save(os.path.join(tmp_dir, "cached_mean_costs.npy"), self.cached_mean_costs)
tar(
path,
[
os.path.join(tmp_dir, "model.bin"),
os.path.join(tmp_dir, "cached_features.npy"),
os.path.join(tmp_dir, "cached_mean_costs.npy"),
],
file=data_path,
arr=np.array(data, dtype=object),
)
# Step 3. Tar it
tar(path, [x for x in [model_path, data_path] if x is not None])
logger.info("Saved XGBModel to %s", path)

def update(
Expand All @@ -391,39 +449,55 @@ def update(
assert len(candidates) == len(results)
if len(candidates) == 0:
return
# extract feature and do validation

# Step 1. Get the feature group
new_group_hash = shash2hex(context.mod)
group = self.data.get(new_group_hash, None)

# Step 2. Extract features
def _feature(x: NDArray) -> np.ndarray:
return x.numpy().astype("float32")

def _mean_cost(x: RunnerResult) -> float:
if not x.run_secs:
return 1e10
return float(np.median([float(s) for s in x.run_secs]))

new_features = [
x.numpy().astype("float32") for x in self.extractor.extract_from(context, candidates)
]
new_mean_costs = np.asarray(
[_mean_cost(x) for x in results],
dtype="float32",
)
if self.booster is not None and self.cached_normalizer is not None:
new_features = [_feature(x) for x in self.extractor.extract_from(context, candidates)]
new_mean_costs = np.array([_mean_cost(x) for x in results]).astype("float32")

# Steps 3. Run validation
if group is not None and self.booster is not None:
logger.debug(
"XGB validation: %s",
"\t".join(
f"{key}: {score:.6f}"
for key, score in self._validate(
xs=new_features,
ys=new_mean_costs,
ys=group.min_cost / new_mean_costs,
)
),
)
# use together with previous features
self.cached_features.extend(new_features)
self.cached_mean_costs = np.append(self.cached_mean_costs, new_mean_costs)
self._set_cached_normalizer()
# train xgb model

# Step 4. Add the features into the data points
if group is None:
group = FeatureGroup(
group_hash=new_group_hash,
features=new_features,
costs=new_mean_costs,
)
else:
group.append(new_features, new_mean_costs)
self.data[new_group_hash] = group
self.data_size += len(new_features)

# Step 5. Re-train the model
self._train(
xs=self.cached_features,
ys=self.cached_mean_costs,
xs=list(itertools_chain.from_iterable([g.features for g in self.data.values()])),
ys=np.concatenate(
[g.min_cost / g.costs for g in self.data.values()],
axis=0,
),
)

def predict(
Expand All @@ -445,10 +519,16 @@ def predict(
result : np.ndarray
The predicted normalized score.
"""
n_measured = len(self.cached_features)
if self.booster is not None and n_measured >= self.num_warmup_samples:
features = self.extractor.extract_from(context, candidates)
ret = self._predict(xs=[x.numpy().astype("float32") for x in features])
if self.data_size >= self.num_warmup_samples and self.booster is not None:
ret = self._predict(
xs=[
x.numpy().astype("float32")
for x in self.extractor.extract_from(
context,
candidates,
)
]
)
else:
ret = np.random.uniform(
low=0,
Expand All @@ -464,20 +544,15 @@ def _train( # type: ignore # pylint: disable=invalid-name
) -> None:
import xgboost as xgb # type: ignore # pylint: disable=import-outside-toplevel

self.d_train = PackSum(
xs=xs,
ys=self.cached_normalizer / ys,
)
self.d_train = PackSum(xs=xs, ys=ys)

def obj(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument
return self.d_train.obj_square_error(ys_pred)

def rmse(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument
return self.d_train.rmse(ys_pred)

def average_peak_score(
ys_pred: np.ndarray, d_train: "xgb.DMatrix" # type: ignore # pylint: disable = unused-argument
):
def avg_peak_score(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument
return self.d_train.average_peak_score(ys_pred, self.average_peak_n)

self.booster = xgb.train(
Expand All @@ -491,7 +566,7 @@ def average_peak_score(
verbose_eval=self.verbose_eval,
fevals=[
rmse,
average_peak_score,
avg_peak_score,
],
evals=[(self.d_train.dmatrix, "tr")],
)
Expand Down Expand Up @@ -528,13 +603,9 @@ def _validate( # type: ignore # pylint: disable=invalid-name
scores: np.ndarray
The predicted result for all inputs.
"""
if self.booster is None or self.cached_normalizer is None:
return []
assert self.booster is not None

d_valid = PackSum(
xs=xs,
ys=self.cached_normalizer / ys,
)
d_valid = PackSum(xs=xs, ys=ys)

def average_peak_score(ys_pred: np.ndarray):
return d_valid.average_peak_score(ys_pred, n=self.average_peak_n)
Expand All @@ -550,14 +621,6 @@ def average_peak_score(ys_pred: np.ndarray):
eval_result.sort(key=make_metric_sorter("p-rmse"))
return eval_result

def _set_cached_normalizer(self) -> None:
filtered = self.cached_mean_costs[self.cached_mean_costs > 0]
if filtered.size == 0:
self.cached_normalizer = 1.0
else:
self.cached_normalizer = np.min(filtered)
assert self.cached_normalizer > 0


def custom_callback(
early_stopping_rounds: int,
Expand Down
Loading

0 comments on commit ce28068

Please sign in to comment.