Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Add from_config to checkpoint hook
Browse files Browse the repository at this point in the history
Differential Revision: D19755146

fbshipit-source-id: dc74d12c1714d80ce0c6adcc3445307aa2187f7c
  • Loading branch information
Aaron Adcock authored and facebook-github-bot committed Feb 21, 2020
1 parent 9c836f8 commit 79ffd50
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
9 changes: 8 additions & 1 deletion classy_vision/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CheckpointHook(ClassyHook):
def __init__(
self,
checkpoint_folder: str,
input_args: Any,
input_args: Any = None,
phase_types: Optional[Collection[str]] = None,
checkpoint_period: int = 1,
) -> None:
Expand Down Expand Up @@ -61,6 +61,13 @@ def __init__(
self.checkpoint_period: int = checkpoint_period
self.phase_counter: int = 0

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CheckpointHook":
assert isinstance(
config["checkpoint_folder"], str
), "checkpoint_folder must be a string specifying the checkpoint directory"
return CheckpointHook(**config)

def _save_checkpoint(self, task, filename):
if getattr(task, "test_only", False):
return
Expand Down
24 changes: 24 additions & 0 deletions test/hooks_checkpoint_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy
import os
import shutil
import tempfile
Expand All @@ -24,6 +25,29 @@ def setUp(self) -> None:
def tearDown(self) -> None:
shutil.rmtree(self.base_dir)

def test_constructors(self) -> None:
"""
Test that the hooks are constructed correctly.
"""
config = {
"checkpoint_folder": "/test/",
"input_args": {"foo": "bar"},
"phase_types": ["train"],
"checkpoint_period": 2,
}

hook1 = CheckpointHook(**config)
hook2 = CheckpointHook.from_config(config)

self.assertTrue(isinstance(hook1, CheckpointHook))
self.assertTrue(isinstance(hook2, CheckpointHook))

# Verify assert logic works correctly
with self.assertRaises(AssertionError):
bad_config = copy.deepcopy(config)
bad_config["checkpoint_folder"] = 12
CheckpointHook.from_config(bad_config)

def test_state_checkpointing(self) -> None:
"""
Test that the state gets checkpointed without any errors, but only on the
Expand Down

0 comments on commit 79ffd50

Please sign in to comment.