Skip to content

Commit

Permalink
Fix linting & rebase.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Mar 2, 2022
1 parent 98f5c54 commit 7c664b6
Show file tree
Hide file tree
Showing 9 changed files with 18 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,4 @@ def __str__(self) -> str:
result : str
Get the cost model as string with name.
"""
return f"meta_schedule.{self.__class__.__name__}({_get_hex_address(self.handle)})"
return f"meta_schedule.{self.__class__.__name__}({_get_hex_address(self.handle)})" # type: ignore
11 changes: 10 additions & 1 deletion python/tvm/meta_schedule/cost_model/random_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"""
from typing import List, Optional, Tuple, Union

import numpy as np
from tvm.meta_schedule.utils import derived_object # type: ignore

from ..cost_model import PyCostModel
Expand All @@ -46,6 +45,8 @@ class RandomModel(PyCostModel):
https://numpy.org/doc/stable/reference/random/generated/numpy.random.get_state.html
"""

import numpy as np # pylint: disable=import-outside-toplevel

random_state: Union[Tuple[str, np.ndarray, int, int, float], dict]
path: Optional[str]

Expand All @@ -56,6 +57,8 @@ def __init__(
path: Optional[str] = None,
max_range: Optional[int] = 100,
):
import numpy as np # pylint: disable=import-outside-toplevel

super().__init__()
if path is not None:
self.load(path)
Expand All @@ -72,6 +75,8 @@ def load(self, path: str) -> None:
path : str
The file path.
"""
import numpy as np # pylint: disable=import-outside-toplevel

self.random_state = tuple(np.load(path, allow_pickle=True)) # type: ignore

def save(self, path: str) -> None:
Expand All @@ -82,6 +87,8 @@ def save(self, path: str) -> None:
path : str
The file path.
"""
import numpy as np # pylint: disable=import-outside-toplevel

np.save(path, np.array(self.random_state, dtype=object), allow_pickle=True)

def update(
Expand Down Expand Up @@ -117,6 +124,8 @@ def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> n
result : np.ndarray
The predicted running results.
"""
import numpy as np # pylint: disable=import-outside-toplevel

np.random.set_state(self.random_state)
# TODO(@zxybazh): Use numpy's RandState object:
# https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html#numpy.random.RandomState
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,4 @@ def extract_from(
raise NotImplementedError

def __str__(self) -> str:
return f"meta_schedule.{self.__class__.__name__}({_get_hex_address(self.handle)})"
return f"meta_schedule.{self.__class__.__name__}({_get_hex_address(self.handle)})" # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ def apply(
raise NotImplementedError

def __str__(self) -> str:
return f"meta_schedule.{self.__class__.__name__}({_get_hex_address(self.handle)})"
return f"meta_schedule.{self.__class__.__name__}({_get_hex_address(self.handle)})" # type: ignore
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/mutator/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,4 @@ def __str__(self) -> str:
result : str
Get the mutator as string with name.
"""
return f"meta_schedule.{self.__class__.__name__}({_get_hex_address(self.handle)})"
return f"meta_schedule.{self.__class__.__name__}({_get_hex_address(self.handle)})" # type: ignore
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/postproc/postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,4 @@ def __str__(self) -> str:
result : str
Get the post processor as string with name.
"""
return f"meta_schedule.{self.__class__.__name__}({_get_hex_address(self.handle)})"
return f"meta_schedule.{self.__class__.__name__}({_get_hex_address(self.handle)})" # type: ignore
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/schedule_rule/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,4 @@ def __str__(self) -> str:
result : str
Get the schedule rule as string with name.
"""
return f"meta_schedule.{self.__class__.__name__}({_get_hex_address(self.handle)})"
return f"meta_schedule.{self.__class__.__name__}({_get_hex_address(self.handle)})" # type: ignore
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _callbacks(
@staticmethod
def _cost_model(cost_model: Optional[CostModel]) -> CostModel:
if cost_model is None:
return XGBModel(extractor=PerStoreFeature())
return XGBModel(extractor=PerStoreFeature()) # type: ignore
if not isinstance(cost_model, CostModel):
raise TypeError(f"Expected `cost_model` to be CostModel, but gets: {cost_model}")
return cost_model
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __setattr__(self, name, value):
else:
super(TVMDerivedObject, self).__setattr__(name, value)

functools.update_wrapper(TVMDerivedObject.__init__, cls.__init__)
functools.update_wrapper(TVMDerivedObject.__init__, cls.__init__) # type: ignore
TVMDerivedObject.__name__ = cls.__name__
TVMDerivedObject.__doc__ = cls.__doc__
TVMDerivedObject.__module__ = cls.__module__
Expand Down

0 comments on commit 7c664b6

Please sign in to comment.