Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Remove local_variables from on_phase_start (#416)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #416

This is part of a series of diffs to eliminate local_variables (see D20171981).
Proceed removing local_variables from on_phase_start

Reviewed By: mannatsingh

Differential Revision: D20178268

fbshipit-source-id: 09f78810228b2fec9faa2205d92b108aea30aff9
  • Loading branch information
vreis authored and facebook-github-bot committed Mar 6, 2020
1 parent ce31b99 commit 0b2368e
Show file tree
Hide file tree
Showing 13 changed files with 18 additions and 25 deletions.
4 changes: 1 addition & 3 deletions classy_vision/hooks/classy_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ def on_start(self, task: "tasks.ClassyTask") -> None:
pass

@abstractmethod
def on_phase_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
"""Called at the start of each phase."""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def on_start(self, task: ClassyTask) -> None:
self._save_current_model_state(task.base_model, self.state.model_state)
self._save_current_model_state(task.base_model, self.state.ema_model_state)

def on_phase_start(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
def on_phase_start(self, task: ClassyTask) -> None:
# restore the right state depending on the phase type
self.set_model_state(task, use_ema=not task.train)

Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/progress_bar_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def __init__(self) -> None:
self.bar_size: int = 0
self.batches: int = 0

def on_phase_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
"""Create and display a progress bar with 0 progress."""
if not progressbar_available:
raise RuntimeError(
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/tensorboard_plot_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def __init__(self, tb_writer) -> None:
self.wall_times: Optional[List[float]] = None
self.num_steps_global: Optional[List[int]] = None

def on_phase_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
"""Initialize losses and learning_rates."""
self.learning_rates = []
self.wall_times = []
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/time_metrics_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def __init__(self, log_freq: Optional[int] = None) -> None:
self.log_freq: Optional[int] = log_freq
self.start_time: Optional[float] = None

def on_phase_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_start(self, task: "tasks.ClassyTask") -> None:
"""
Initialize start time and reset perf stats
"""
Expand Down
5 changes: 3 additions & 2 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,12 +858,13 @@ def on_start(self):
for hook in self.hooks:
hook.on_start(self)

def on_phase_start(self, local_variables):
def on_phase_start(self):
self.phase_start_time_total = time.perf_counter()

self.advance_phase()

self.run_hooks(local_variables, ClassyHookFunctions.on_phase_start.name)
for hook in self.hooks:
hook.on_phase_start(self)

self.phase_start_time_train = time.perf_counter()

Expand Down
2 changes: 1 addition & 1 deletion classy_vision/tasks/classy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def on_start(self):
pass

@abstractmethod
def on_phase_start(self, local_variables):
def on_phase_start(self):
"""
Epoch start.
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/trainer/classy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def train(self, task: ClassyTask):

task.on_start()
while not task.done_training():
task.on_phase_start(local_variables)
task.on_phase_start()
while True:
try:
task.step(self.use_gpu)
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/trainer/elastic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _run_step(self, state, local_variables, use_gpu):
if state.advance_to_next_phase:
self.elastic_coordinator.barrier()
self.elastic_coordinator._log_event("on_phase_start")
state.task.on_phase_start(local_variables)
state.task.on_phase_start()

state.advance_to_next_phase = False

Expand Down
6 changes: 3 additions & 3 deletions test/hooks_exponential_moving_average_model_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _test_exponential_moving_average_hook(self, model_device, hook_device):
)

exponential_moving_average_hook.on_start(task)
exponential_moving_average_hook.on_phase_start(task, local_variables)
exponential_moving_average_hook.on_phase_start(task)
# set the weights to all ones and simulate 10 updates
task.base_model.update_fc_weight()
fc_weight = model.fc.weight.clone()
Expand All @@ -60,7 +60,7 @@ def _test_exponential_moving_average_hook(self, model_device, hook_device):

# simulate a test phase now
task.train = False
exponential_moving_average_hook.on_phase_start(task, local_variables)
exponential_moving_average_hook.on_phase_start(task)
exponential_moving_average_hook.on_phase_end(task, local_variables)

# the model weights should be updated to the ema weights
Expand All @@ -72,7 +72,7 @@ def _test_exponential_moving_average_hook(self, model_device, hook_device):

# simulate a train phase again
task.train = True
exponential_moving_average_hook.on_phase_start(task, local_variables)
exponential_moving_average_hook.on_phase_start(task)

# the model weights should be back to the old value
self.assertTrue(torch.allclose(model.fc.weight, fc_weight))
Expand Down
2 changes: 1 addition & 1 deletion test/hooks_time_metrics_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_time_metrics(
# on_phase_start() should set the start time and perf_stats
start_time = 1.2
mock_time.return_value = start_time
time_metrics_hook.on_phase_start(task, local_variables)
time_metrics_hook.on_phase_start(task)
self.assertEqual(time_metrics_hook.start_time, start_time)
self.assertTrue(isinstance(task.perf_stats, PerfStats))

Expand Down
4 changes: 2 additions & 2 deletions test/manual/hooks_progress_bar_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_progress_bar(
progress_bar_hook = ProgressBarHook()

# progressbar.ProgressBar should be init-ed with num_batches
progress_bar_hook.on_phase_start(task, local_variables)
progress_bar_hook.on_phase_start(task)
mock_progressbar_pkg.ProgressBar.assert_called_once_with(num_batches)
mock_progress_bar.start.assert_called_once_with()
mock_progress_bar.start.reset_mock()
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_progress_bar(
mock_is_master.return_value = False
progress_bar_hook = ProgressBarHook()
try:
progress_bar_hook.on_phase_start(task, local_variables)
progress_bar_hook.on_phase_start(task)
progress_bar_hook.on_step(task)
progress_bar_hook.on_phase_end(task, local_variables)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion test/manual/hooks_tensorboard_plot_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:
summary_writer.add_scalar.reset_mock()

# run the hook in the correct order
tensorboard_plot_hook.on_phase_start(task, local_variables)
tensorboard_plot_hook.on_phase_start(task)

for loss in losses:
task.losses.append(loss)
Expand Down

0 comments on commit 0b2368e

Please sign in to comment.