Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Add hook builder / registry
Browse files Browse the repository at this point in the history
Summary: Add a hook builder / registry, similar to other classy abstractions.

Differential Revision: D19755147

fbshipit-source-id: f963031957395a2c9979da6c98c3d290af6d3e16
  • Loading branch information
Aaron Adcock authored and facebook-github-bot committed Feb 21, 2020
1 parent 79ffd50 commit 04cab0e
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 1 deletion.
62 changes: 62 additions & 0 deletions classy_vision/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from typing import Any, Dict, List

from classy_vision.generic.registry_utils import import_all_modules

Expand Down Expand Up @@ -44,5 +45,66 @@

FILE_ROOT = Path(__file__).parent

HOOK_REGISTRY = {}
HOOK_CLASS_NAMES = set()


def register_hook(name):
"""Registers a :class:`ClassyHook` subclass.
This decorator allows Classy Vision to instantiate a subclass of
:class:`ClassyHook` from a configuration file, even if the class
itself is not part of the base Classy Vision framework. To use it,
apply this decorator to a ClassyHook subclass, like this:
.. code-block:: python
@register_model('resnet')
class CustomHook(ClassyHook):
...
To instantiate a hook from a configuration file, see
:func:`build_model`.
"""

def register_hook_cls(cls):
if name in HOOK_REGISTRY:
raise ValueError("Cannot register duplicate hook ({})".format(name))
if not issubclass(cls, ClassyHook):
raise ValueError(
"Hook ({}: {}) must extend ClassyHook".format(name, cls.__name__)
)
if cls.__name__ in HOOK_CLASS_NAMES:
raise ValueError(
"Cannot register model with duplicate class name ({})".format(
cls.__name__
)
)
HOOK_REGISTRY[name] = cls
HOOK_CLASS_NAMES.add(cls.__name__)
return cls

return register_hook_cls


def build_hooks(hook_configs: List[Dict[str, Any]]):
return [build_hook(config) for config in hook_configs]


def build_hook(hook_config: Dict[str, Any]):
"""Builds a ClassyHook from a config.
This assumes a 'name' key in the config which is used to determine
what model class to instantiate. For instance, a config `{"name":
"my_hook", "foo": "bar"}` will find a class that was registered as
"my_hook" (see :func:`register_hook`) and call .from_config on
it."""
assert hook_config["name"] in HOOK_REGISTRY, (
"Unregistered hook. Did you make sure to use the register_hook decorator "
"AND import the hook file before calling this function??"
)
return HOOK_REGISTRY[hook_config["name"]].from_config(hook_config)


# automatically import any Python files in the hooks/ directory
import_all_modules(FILE_ROOT, "classy_vision.hooks")
22 changes: 21 additions & 1 deletion test/hooks_classy_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy
import unittest

from classy_vision.hooks import ClassyHook
from classy_vision.hooks import ClassyHook, build_hook, build_hooks, register_hook


@register_hook("test_hook")
class TestHook(ClassyHook):
on_rendezvous = ClassyHook._noop
on_start = ClassyHook._noop
Expand All @@ -26,8 +28,26 @@ def __init__(self, a, b):
self.state.a = a
self.state.b = b

@classmethod
def from_config(cls, config):
return TestHook(config["a"], config["b"])


class TestClassyHook(unittest.TestCase):
def test_hook_registry_and_builder(self):
config = {"name": "test_hook", "a": 1, "b": 2}
hook1 = build_hook(hook_config=config)
self.assertTrue(isinstance(hook1, TestHook))
self.assertTrue(hook1.state.a == 1)
self.assertTrue(hook1.state.b == 2)

hook_configs = [copy.deepcopy(config), copy.deepcopy(config)]
hooks = build_hooks(hook_configs=hook_configs)
for hook in hooks:
self.assertTrue(isinstance(hook, TestHook))
self.assertTrue(hook.state.a == 1)
self.assertTrue(hook.state.b == 2)

def test_state_dict(self):
a = 0
b = {1: 2, 3: [4]}
Expand Down

0 comments on commit 04cab0e

Please sign in to comment.