Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Remove local_variables from on_step (#411)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #411

local_variables makes the code in train_step really hard to read. Killing it
from all hooks will take time, so start from a single hook (on_step).

Reviewed By: mannatsingh

Differential Revision: D20171981

fbshipit-source-id: 5e2a4f9d105ce42cf53e375452a66bb2e747f5a1
  • Loading branch information
vreis authored and facebook-github-bot committed Mar 3, 2020
1 parent 7e51ad9 commit e29cc72
Show file tree
Hide file tree
Showing 14 changed files with 73 additions and 60 deletions.
6 changes: 2 additions & 4 deletions classy_vision/hooks/classy_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, a, b):
def __init__(self):
self.state = ClassyHookState()

def _noop(self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]) -> None:
def _noop(self, *args, **kwargs) -> None:
"""Derived classes can set their hook functions to this.
This is useful if they want those hook functions to not do anything.
Expand Down Expand Up @@ -79,9 +79,7 @@ def on_phase_start(
pass

@abstractmethod
def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Called each time after parameters have been updated by the optimizer."""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def on_phase_end(self, task: ClassyTask, local_variables: Dict[str, Any]) -> Non
# state in the test phase
self._save_current_model_state(task.base_model, self.state.model_state)

def on_step(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
def on_step(self, task: ClassyTask) -> None:
if not task.train:
return

Expand Down
20 changes: 7 additions & 13 deletions classy_vision/hooks/loss_lr_meter_logging_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,36 +45,30 @@ def on_phase_end(
# trainer to implement an unsynced end of phase meter or
# for meters to not provide a sync function.
logging.info("End of phase metric values:")
self._log_loss_meters(task, local_variables)
self._log_loss_meters(task)
if task.train:
self._log_lr(task, local_variables)
self._log_lr(task)

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""
Log the LR every log_freq batches, if log_freq is not None.
"""
if self.log_freq is None or not task.train:
return
batches = len(task.losses)
if batches and batches % self.log_freq == 0:
self._log_lr(task, local_variables)
self._log_lr(task)
logging.info("Local unsynced metric values:")
self._log_loss_meters(task, local_variables)
self._log_loss_meters(task)

def _log_lr(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_lr(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log the optimizer LR.
"""
optimizer_lr = task.optimizer.parameters.lr
logging.info("Learning Rate: {}\n".format(optimizer_lr))

def _log_loss_meters(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_loss_meters(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log the loss and meters.
"""
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/progress_bar_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def on_phase_start(
self.progress_bar = progressbar.ProgressBar(self.bar_size)
self.progress_bar.start()

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Update the progress bar with the batch size."""
if task.train and is_master() and self.progress_bar is not None:
self.batches += 1
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/tensorboard_plot_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def on_phase_start(
self.wall_times = []
self.num_steps_global = []

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Store the observed learning rates."""
if self.learning_rates is None:
logging.warning("learning_rates is not initialized")
Expand Down
20 changes: 8 additions & 12 deletions classy_vision/hooks/time_metrics_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,17 @@ def on_phase_start(
Initialize start time and reset perf stats
"""
self.start_time = time.time()
local_variables["perf_stats"] = PerfStats()
task.perf_stats = PerfStats()

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""
Log metrics every log_freq batches, if log_freq is not None.
"""
if self.log_freq is None:
return
batches = len(task.losses)
if batches and batches % self.log_freq == 0:
self._log_performance_metrics(task, local_variables)
self._log_performance_metrics(task)

def on_phase_end(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
Expand All @@ -62,11 +60,9 @@ def on_phase_end(
"""
batches = len(task.losses)
if batches:
self._log_performance_metrics(task, local_variables)
self._log_performance_metrics(task)

def _log_performance_metrics(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_performance_metrics(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log performance metrics.
"""
Expand All @@ -85,11 +81,11 @@ def _log_performance_metrics(
)

# Train step time breakdown
if local_variables.get("perf_stats") is None:
logging.warning('"perf_stats" not set in local_variables')
if not hasattr(task, "perf_stats") or task.perf_stats is None:
logging.warning('"perf_stats" not set in task')
elif task.train:
logging.info(
"Train step time breakdown (rank {}):\n{}".format(
get_rank(), local_variables["perf_stats"].report_str()
get_rank(), task.perf_stats.report_str()
)
)
30 changes: 29 additions & 1 deletion classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import enum
import logging
import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, NamedTuple, Optional, Union

import torch
from classy_vision.dataset import ClassyDataset, build_dataset
Expand Down Expand Up @@ -54,6 +54,13 @@ class BroadcastBuffersMode(enum.Enum):
BEFORE_EVAL = enum.auto()


class LastBatchInfo(NamedTuple):
loss: torch.Tensor
output: torch.Tensor
target: torch.Tensor
sample: Dict[str, Any]


@register_task("classification_task")
class ClassificationTask(ClassyTask):
"""Basic classification training task.
Expand Down Expand Up @@ -126,6 +133,7 @@ def __init__(self):
)
self.amp_opt_level = None
self.perf_log = []
self.last_batch = None

def set_checkpoint(self, checkpoint):
"""Sets checkpoint on task.
Expand Down Expand Up @@ -634,6 +642,8 @@ def eval_step(self, use_gpu, local_variables=None):
if local_variables is None:
local_variables = {}

self.last_batch = None

# Process next sample
sample = next(self.get_data_iterator())
local_variables["sample"] = sample
Expand Down Expand Up @@ -672,6 +682,14 @@ def eval_step(self, use_gpu, local_variables=None):

self.update_meters(local_variables["output"], local_variables["sample"])

# Move some data to the task so hooks get a chance to access it
self.last_batch = LastBatchInfo(
loss=local_variables["loss"],
output=local_variables["output"],
target=local_variables["target"],
sample=local_variables["sample"],
)

def train_step(self, use_gpu, local_variables=None):
"""Train step to be executed in train loop
Expand All @@ -684,6 +702,8 @@ def train_step(self, use_gpu, local_variables=None):
if local_variables is None:
local_variables = {}

self.last_batch = None

# Process next sample
sample = next(self.get_data_iterator())
local_variables["sample"] = sample
Expand Down Expand Up @@ -738,6 +758,14 @@ def train_step(self, use_gpu, local_variables=None):

self.num_updates += self.get_global_batchsize()

# Move some data to the task so hooks get a chance to access it
self.last_batch = LastBatchInfo(
loss=local_variables["loss"],
output=local_variables["output"],
target=local_variables["target"],
sample=local_variables["sample"],
)

def compute_loss(self, model_output, sample):
return self.loss(model_output, sample["target"])

Expand Down
3 changes: 2 additions & 1 deletion classy_vision/tasks/classy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def step(self, use_gpu, local_variables: Optional[Dict] = None) -> None:
else:
self.eval_step(use_gpu, local_variables)

self.run_hooks(local_variables, ClassyHookFunctions.on_step.name)
for hook in self.hooks:
hook.on_step(self)

def run_hooks(self, local_variables: Dict[str, Any], hook_function: str) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion test/hooks_exponential_moving_average_model_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _test_exponential_moving_average_hook(self, model_device, hook_device):
task.base_model.update_fc_weight()
fc_weight = model.fc.weight.clone()
for _ in range(num_updates):
exponential_moving_average_hook.on_step(task, local_variables)
exponential_moving_average_hook.on_step(task)
exponential_moving_average_hook.on_phase_end(task, local_variables)
# the model weights shouldn't have changed
self.assertTrue(torch.allclose(model.fc.weight, fc_weight))
Expand Down
16 changes: 8 additions & 8 deletions test/hooks_loss_lr_meter_logging_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,27 +51,27 @@ def test_logging(self, mock_get_rank: mock.MagicMock) -> None:

for i in range(num_batches):
task.losses = list(range(i))
loss_lr_meter_hook.on_step(task, local_variables)
loss_lr_meter_hook.on_step(task)
if log_freq is not None and i and i % log_freq == 0:
mock_fn.assert_called_with(task, local_variables)
mock_fn.assert_called_with(task)
mock_fn.reset_mock()
mock_lr_fn.assert_called_with(task, local_variables)
mock_lr_fn.assert_called_with(task)
mock_lr_fn.reset_mock()
continue
mock_fn.assert_not_called()
mock_lr_fn.assert_not_called()

loss_lr_meter_hook.on_phase_end(task, local_variables)
mock_fn.assert_called_with(task, local_variables)
mock_fn.assert_called_with(task)
if task.train:
mock_lr_fn.assert_called_with(task, local_variables)
mock_lr_fn.assert_called_with(task)

# test _log_loss_lr_meters()
task.losses = losses

with self.assertLogs():
loss_lr_meter_hook._log_loss_meters(task, local_variables)
loss_lr_meter_hook._log_lr(task, local_variables)
loss_lr_meter_hook._log_loss_meters(task)
loss_lr_meter_hook._log_lr(task)

task.phase_idx += 1

Expand All @@ -95,7 +95,7 @@ def scheduler_mock(where):
lr_order = [0.0, 1 / 6, 1 / 6, 2 / 6, 3 / 6, 3 / 6, 4 / 6, 5 / 6, 5 / 6]
lr_list = []

def mock_log_lr(task: ClassyTask, local_variables) -> None:
def mock_log_lr(task: ClassyTask) -> None:
lr_list.append(task.optimizer.parameters.lr)

with mock.patch.object(
Expand Down
12 changes: 6 additions & 6 deletions test/hooks_time_metrics_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_time_metrics(
mock_time.return_value = start_time
time_metrics_hook.on_phase_start(task, local_variables)
self.assertEqual(time_metrics_hook.start_time, start_time)
self.assertTrue(isinstance(local_variables.get("perf_stats"), PerfStats))
self.assertTrue(isinstance(task.perf_stats, PerfStats))

# test that the code doesn't raise an exception if losses is empty
try:
Expand All @@ -66,15 +66,15 @@ def test_time_metrics(

for i in range(num_batches):
task.losses = list(range(i))
time_metrics_hook.on_step(task, local_variables)
time_metrics_hook.on_step(task)
if log_freq is not None and i and i % log_freq == 0:
mock_fn.assert_called_with(task, local_variables)
mock_fn.assert_called_with(task)
mock_fn.reset_mock()
continue
mock_fn.assert_not_called()

time_metrics_hook.on_phase_end(task, local_variables)
mock_fn.assert_called_with(task, local_variables)
mock_fn.assert_called_with(task)

task.losses = [0.23, 0.45, 0.34, 0.67]

Expand All @@ -84,7 +84,7 @@ def test_time_metrics(

# test _log_performance_metrics()
with self.assertLogs() as log_watcher:
time_metrics_hook._log_performance_metrics(task, local_variables)
time_metrics_hook._log_performance_metrics(task)

# there should 2 be info logs for train and 1 for test
self.assertEqual(len(log_watcher.output), 2 if train else 1)
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_time_metrics(

# if on_phase_start() is not called, 2 warnings should be logged
# create a new time metrics hook
local_variables = {}
task.perf_stats = None
time_metrics_hook_new = TimeMetricsHook()

with self.assertLogs() as log_watcher:
Expand Down
8 changes: 4 additions & 4 deletions test/manual/hooks_progress_bar_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def test_progress_bar(

# on_step should update the progress bar correctly
for i in range(num_batches):
progress_bar_hook.on_step(task, local_variables)
progress_bar_hook.on_step(task)
mock_progress_bar.update.assert_called_once_with(i + 1)
mock_progress_bar.update.reset_mock()

# check that even if on_step is called again, the progress bar is
# only updated with num_batches
for _ in range(num_batches):
progress_bar_hook.on_step(task, local_variables)
progress_bar_hook.on_step(task)
mock_progress_bar.update.assert_called_once_with(num_batches)
mock_progress_bar.update.reset_mock()

Expand All @@ -68,7 +68,7 @@ def test_progress_bar(
# crash
progress_bar_hook = ProgressBarHook()
try:
progress_bar_hook.on_step(task, local_variables)
progress_bar_hook.on_step(task)
progress_bar_hook.on_phase_end(task, local_variables)
except Exception as e:
self.fail(
Expand All @@ -81,7 +81,7 @@ def test_progress_bar(
progress_bar_hook = ProgressBarHook()
try:
progress_bar_hook.on_phase_start(task, local_variables)
progress_bar_hook.on_step(task, local_variables)
progress_bar_hook.on_step(task)
progress_bar_hook.on_phase_end(task, local_variables)
except Exception as e:
self.fail("Received Exception when is_master() is False: {}".format(e))
Expand Down
4 changes: 2 additions & 2 deletions test/manual/hooks_tensorboard_plot_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:
# the writer if on_phase_start() is not called for initialization
# before on_step() is called.
with self.assertLogs() as log_watcher:
tensorboard_plot_hook.on_step(task, local_variables)
tensorboard_plot_hook.on_step(task)

self.assertTrue(
len(log_watcher.records) == 1
Expand All @@ -88,7 +88,7 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:

for loss in losses:
task.losses.append(loss)
tensorboard_plot_hook.on_step(task, local_variables)
tensorboard_plot_hook.on_step(task)

tensorboard_plot_hook.on_phase_end(task, local_variables)

Expand Down
2 changes: 1 addition & 1 deletion test/optim_param_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class TestHook(ClassyHook):
on_phase_end = ClassyHook._noop
on_end = ClassyHook._noop

def on_step(self, task: ClassyTask, local_variables) -> None:
def on_step(self, task: ClassyTask) -> None:
if not task.train:
return

Expand Down

0 comments on commit e29cc72

Please sign in to comment.