-
Notifications
You must be signed in to change notification settings - Fork 729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add initial version of auto estimator #3731
Changes from 4 commits
83bf8f6
f15d4f9
4b145ae
d65e267
64af414
66890dd
8a275aa
d5058ca
2245741
81a94dc
cf28a38
210954d
3e2fe1a
d1e1f9e
26771e1
e595d6e
f336529
7af862c
5e10f1a
e0b4066
2ea2808
cb834f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,15 +40,28 @@ def __init__(self, model_creator, optimizer_creator, loss_creator, | |
self.onnx_model = None | ||
self.onnx_model_built = False | ||
|
||
def _create_loss(self): | ||
if isinstance(self.loss_creator, torch.nn.modules.loss._Loss): | ||
self.criterion = self.loss_creator | ||
else: | ||
self.criterion = self.loss_creator(self.config) | ||
|
||
def _create_optimizer(self): | ||
if issubclass(self.optimizer_creator, torch.optim.Optimizer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it an instance or class? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may need to be a class, since a torch optimizer is instantiated as below (the parameters need to be passed during instantiation while the parameters for loss is passed after instantiation). |
||
# use torch default parameter values if user pass optimizer name or optimizer class. | ||
self.optimizer = self.optimizer_creator(self.model.parameters()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the user may want to tune learning rate even if he or she specifies, say, "Adam"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
else: | ||
self.optimizer = self.optimizer_creator(self.model, self.config) | ||
|
||
def build(self, config): | ||
# check config and update | ||
self._check_config(**config) | ||
self.config = config | ||
# build model | ||
self.model = self.model_creator(config) | ||
self.model_built = True | ||
self.optimizer = self.optimizer_creator(self.model, self.config) | ||
self.criterion = self.loss_creator(self.config) | ||
self._create_loss() | ||
self._create_optimizer() | ||
|
||
def _reshape_input(self, x): | ||
if x.ndim == 1: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# | ||
# Copyright 2018 Analytics Zoo Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
class AutoEstimator: | ||
def __init__(self, model_builder, searcher): | ||
self.model_builder = model_builder | ||
self.searcher = searcher | ||
|
||
@staticmethod | ||
def from_torch(*, | ||
model_creator, | ||
optimizer, | ||
loss, | ||
logs_dir, | ||
resources_per_trial, | ||
name, | ||
): | ||
""" | ||
Create an AutoEstimator for torch. | ||
|
||
:param model_creator: PyTorch model creator function. | ||
:param optimizer: PyTorch optimizer creator function or pytorch optimizer name (string). | ||
:param loss: PyTorch loss instance or PyTorch loss creator function | ||
or pytorch loss name (string). | ||
:param logs_dir: Local directory to save logs and results. | ||
:param resources_per_trial: Dict. resources for each trial. e.g. {"cpu": 2}. | ||
:param name: Name of the auto estimator. | ||
:return: an AutoEstimator object. | ||
""" | ||
from zoo.orca.automl.pytorch_model_utils import validate_pytorch_loss, \ | ||
validate_pytorch_optim | ||
from zoo.automl.model import ModelBuilder | ||
from zoo.automl.search import SearchEngineFactory | ||
loss = validate_pytorch_loss(loss) | ||
optimizer = validate_pytorch_optim(optimizer) | ||
model_builder = ModelBuilder.from_pytorch(model_creator=model_creator, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we test the behavior if user passes a model_creator that creates a non-pytorch model by mistake? It seems ModelBuilder or PyTorchBaseModel does not check whether it is pytorch model or not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I will add check in |
||
optimizer_creator=optimizer, | ||
loss_creator=loss) | ||
searcher = SearchEngineFactory.create_engine(backend="ray", | ||
logs_dir=logs_dir, | ||
resources_per_trial=resources_per_trial, | ||
name=name) | ||
return AutoEstimator(model_builder=model_builder, searcher=searcher) | ||
|
||
@staticmethod | ||
def from_keras(*, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we use the same name "from_tfkeras" as in ModelBuilder class? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to align the usage in orca tf2 estimator https://github.com/intel-analytics/analytics-zoo/blob/master/pyzoo/zoo/orca/learn/tf2/estimator.py#L36 And I did some refactor in ModelBuilder and removed the factory method ( |
||
model_creator, | ||
logs_dir, | ||
resources_per_trial, | ||
name, | ||
): | ||
""" | ||
Create an AutoEstimator for tensorflow keras. | ||
|
||
:param model_creator: Tensorflow keras model creator function. | ||
:param logs_dir: Local directory to save logs and results. | ||
:param resources_per_trial: Dict. resources for each trial. e.g. {"cpu": 2}. | ||
:param name: Name of the auto estimator. | ||
:return: an AutoEstimator object. | ||
""" | ||
from zoo.automl.model import ModelBuilder | ||
from zoo.automl.search import SearchEngineFactory | ||
model_builder = ModelBuilder.from_tfkeras(model_creator=model_creator) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I will add check in new PR |
||
searcher = SearchEngineFactory.create_engine(backend="ray", | ||
logs_dir=logs_dir, | ||
resources_per_trial=resources_per_trial, | ||
name=name) | ||
return AutoEstimator(model_builder=model_builder, searcher=searcher) | ||
|
||
def fit(self, | ||
data, | ||
recipe=None, | ||
metric=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will add check for the metric value. |
||
search_alg=None, | ||
search_alg_params=None, | ||
scheduler=None, | ||
scheduler_params=None, | ||
): | ||
|
||
self.searcher.compile(data=data, | ||
shanyu-sys marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model_create_func=self.model_builder, | ||
recipe=recipe, | ||
metric=metric, | ||
search_alg=search_alg, | ||
search_alg_params=search_alg_params, | ||
scheduler=scheduler, | ||
scheduler_params=scheduler_params) | ||
analysis = self.searcher.run() | ||
return analysis | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# | ||
# Copyright 2018 Analytics Zoo Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import torch | ||
|
||
PYTORCH_LOSS_NAMES = {s for s in dir(torch.nn.modules) if s.endswith("Loss")} | ||
PYTORCH_OPTIM_NAMES = {s for s in dir(torch.optim) if any(c.isupper() for c in s)} - {'Optimizer'} | ||
|
||
|
||
def validate_pytorch_loss(loss): | ||
import types | ||
if isinstance(loss, str): | ||
if loss in PYTORCH_LOSS_NAMES: | ||
return getattr(torch.nn.modules, loss)() | ||
raise ValueError(f'Must provide a valid torch loss name among {PYTORCH_LOSS_NAMES}') | ||
|
||
if isinstance(loss, torch.nn.modules.loss._Loss) or \ | ||
isinstance(loss, types.FunctionType): | ||
return loss | ||
|
||
raise ValueError("Must provide a valid pytorch loss name or a pytorch loss instance" | ||
"or a pytorch loss creator function ") | ||
|
||
|
||
def validate_pytorch_optim(optim): | ||
import types | ||
if isinstance(optim, str): | ||
if optim in PYTORCH_OPTIM_NAMES: | ||
return getattr(torch.optim, optim) | ||
raise ValueError(f'Must provide a valid torch optimizer name among {PYTORCH_OPTIM_NAMES}') | ||
|
||
if issubclass(optim, torch.optim.Optimizer) or\ | ||
isinstance(optim, types.FunctionType): | ||
return optim | ||
|
||
raise ValueError("Must provide a valid pytorch optimizer name or a pytorch optimizer class" | ||
"or a pytorch optimizer creator function ") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it an instance or class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be an instance, since we may expect user to pass
loss=nn.BCELoss()
as orca estimator does. (orca UT)