Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Remove local_variables from on_step #411

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions classy_vision/hooks/classy_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, a, b):
def __init__(self):
self.state = ClassyHookState()

def _noop(self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]) -> None:
def _noop(self, *args, **kwargs) -> None:
"""Derived classes can set their hook functions to this.

This is useful if they want those hook functions to not do anything.
Expand Down Expand Up @@ -79,9 +79,7 @@ def on_phase_start(
pass

@abstractmethod
def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Called each time after parameters have been updated by the optimizer."""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def on_phase_end(self, task: ClassyTask, local_variables: Dict[str, Any]) -> Non
# state in the test phase
self._save_current_model_state(task.base_model, self.state.model_state)

def on_step(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
def on_step(self, task: ClassyTask) -> None:
if not task.train:
return

Expand Down
20 changes: 7 additions & 13 deletions classy_vision/hooks/loss_lr_meter_logging_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,36 +45,30 @@ def on_phase_end(
# trainer to implement an unsynced end of phase meter or
# for meters to not provide a sync function.
logging.info("End of phase metric values:")
self._log_loss_meters(task, local_variables)
self._log_loss_meters(task)
if task.train:
self._log_lr(task, local_variables)
self._log_lr(task)

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""
Log the LR every log_freq batches, if log_freq is not None.
"""
if self.log_freq is None or not task.train:
return
batches = len(task.losses)
if batches and batches % self.log_freq == 0:
self._log_lr(task, local_variables)
self._log_lr(task)
logging.info("Local unsynced metric values:")
self._log_loss_meters(task, local_variables)
self._log_loss_meters(task)

def _log_lr(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_lr(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log the optimizer LR.
"""
optimizer_lr = task.optimizer.parameters.lr
logging.info("Learning Rate: {}\n".format(optimizer_lr))

def _log_loss_meters(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_loss_meters(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log the loss and meters.
"""
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/progress_bar_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def on_phase_start(
self.progress_bar = progressbar.ProgressBar(self.bar_size)
self.progress_bar.start()

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Update the progress bar with the batch size."""
if task.train and is_master() and self.progress_bar is not None:
self.batches += 1
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/tensorboard_plot_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def on_phase_start(
self.wall_times = []
self.num_steps_global = []

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Store the observed learning rates."""
if self.learning_rates is None:
logging.warning("learning_rates is not initialized")
Expand Down
20 changes: 8 additions & 12 deletions classy_vision/hooks/time_metrics_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,17 @@ def on_phase_start(
Initialize start time and reset perf stats
"""
self.start_time = time.time()
local_variables["perf_stats"] = PerfStats()
task.perf_stats = PerfStats()

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""
Log metrics every log_freq batches, if log_freq is not None.
"""
if self.log_freq is None:
return
batches = len(task.losses)
if batches and batches % self.log_freq == 0:
self._log_performance_metrics(task, local_variables)
self._log_performance_metrics(task)

def on_phase_end(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
Expand All @@ -62,11 +60,9 @@ def on_phase_end(
"""
batches = len(task.losses)
if batches:
self._log_performance_metrics(task, local_variables)
self._log_performance_metrics(task)

def _log_performance_metrics(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_performance_metrics(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log performance metrics.
"""
Expand All @@ -85,11 +81,11 @@ def _log_performance_metrics(
)

# Train step time breakdown
if local_variables.get("perf_stats") is None:
logging.warning('"perf_stats" not set in local_variables')
if not hasattr(task, "perf_stats") or task.perf_stats is None:
logging.warning('"perf_stats" not set in task')
elif task.train:
logging.info(
"Train step time breakdown (rank {}):\n{}".format(
get_rank(), local_variables["perf_stats"].report_str()
get_rank(), task.perf_stats.report_str()
)
)
27 changes: 26 additions & 1 deletion classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import enum
import logging
import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, NamedTuple, Optional, Union

import torch
from classy_vision.dataset import ClassyDataset, build_dataset
Expand Down Expand Up @@ -54,6 +54,13 @@ class BroadcastBuffersMode(enum.Enum):
BEFORE_EVAL = enum.auto()


class LastBatchInfo(NamedTuple):
loss: torch.Tensor
output: torch.Tensor
target: torch.Tensor
sample: Dict[str, Any]


@register_task("classification_task")
class ClassificationTask(ClassyTask):
"""Basic classification training task.
Expand Down Expand Up @@ -672,6 +679,14 @@ def eval_step(self, use_gpu, local_variables=None):

self.update_meters(local_variables["output"], local_variables["sample"])

# Move some data to the task so hooks get a chance to access it
self.last_batch = LastBatchInfo(
loss=local_variables["loss"],
output=local_variables["output"],
target=local_variables["target"],
sample=local_variables["sample"],
)

def train_step(self, use_gpu, local_variables=None):
"""Train step to be executed in train loop

Expand All @@ -684,6 +699,8 @@ def train_step(self, use_gpu, local_variables=None):
if local_variables is None:
local_variables = {}

self.last_batch = None

# Process next sample
sample = next(self.get_data_iterator())
local_variables["sample"] = sample
Expand Down Expand Up @@ -738,6 +755,14 @@ def train_step(self, use_gpu, local_variables=None):

self.num_updates += self.get_global_batchsize()

# Move some data to the task so hooks get a chance to access it
self.last_batch = LastBatchInfo(
loss=local_variables["loss"],
output=local_variables["output"],
target=local_variables["target"],
sample=local_variables["sample"],
)

def compute_loss(self, model_output, sample):
return self.loss(model_output, sample["target"])

Expand Down
3 changes: 2 additions & 1 deletion classy_vision/tasks/classy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def step(self, use_gpu, local_variables: Optional[Dict] = None) -> None:
else:
self.eval_step(use_gpu, local_variables)

self.run_hooks(local_variables, ClassyHookFunctions.on_step.name)
for hook in self.hooks:
hook.on_step(self)

def run_hooks(self, local_variables: Dict[str, Any], hook_function: str) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion test/hooks_exponential_moving_average_model_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _test_exponential_moving_average_hook(self, model_device, hook_device):
task.base_model.update_fc_weight()
fc_weight = model.fc.weight.clone()
for _ in range(num_updates):
exponential_moving_average_hook.on_step(task, local_variables)
exponential_moving_average_hook.on_step(task)
exponential_moving_average_hook.on_phase_end(task, local_variables)
# the model weights shouldn't have changed
self.assertTrue(torch.allclose(model.fc.weight, fc_weight))
Expand Down
16 changes: 8 additions & 8 deletions test/hooks_loss_lr_meter_logging_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,27 +51,27 @@ def test_logging(self, mock_get_rank: mock.MagicMock) -> None:

for i in range(num_batches):
task.losses = list(range(i))
loss_lr_meter_hook.on_step(task, local_variables)
loss_lr_meter_hook.on_step(task)
if log_freq is not None and i and i % log_freq == 0:
mock_fn.assert_called_with(task, local_variables)
mock_fn.assert_called_with(task)
mock_fn.reset_mock()
mock_lr_fn.assert_called_with(task, local_variables)
mock_lr_fn.assert_called_with(task)
mock_lr_fn.reset_mock()
continue
mock_fn.assert_not_called()
mock_lr_fn.assert_not_called()

loss_lr_meter_hook.on_phase_end(task, local_variables)
mock_fn.assert_called_with(task, local_variables)
mock_fn.assert_called_with(task)
if task.train:
mock_lr_fn.assert_called_with(task, local_variables)
mock_lr_fn.assert_called_with(task)

# test _log_loss_lr_meters()
task.losses = losses

with self.assertLogs():
loss_lr_meter_hook._log_loss_meters(task, local_variables)
loss_lr_meter_hook._log_lr(task, local_variables)
loss_lr_meter_hook._log_loss_meters(task)
loss_lr_meter_hook._log_lr(task)

task.phase_idx += 1

Expand All @@ -95,7 +95,7 @@ def scheduler_mock(where):
lr_order = [0.0, 1 / 6, 1 / 6, 2 / 6, 3 / 6, 3 / 6, 4 / 6, 5 / 6, 5 / 6]
lr_list = []

def mock_log_lr(task: ClassyTask, local_variables) -> None:
def mock_log_lr(task: ClassyTask) -> None:
lr_list.append(task.optimizer.parameters.lr)

with mock.patch.object(
Expand Down
12 changes: 6 additions & 6 deletions test/hooks_time_metrics_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_time_metrics(
mock_time.return_value = start_time
time_metrics_hook.on_phase_start(task, local_variables)
self.assertEqual(time_metrics_hook.start_time, start_time)
self.assertTrue(isinstance(local_variables.get("perf_stats"), PerfStats))
self.assertTrue(isinstance(task.perf_stats, PerfStats))

# test that the code doesn't raise an exception if losses is empty
try:
Expand All @@ -66,15 +66,15 @@ def test_time_metrics(

for i in range(num_batches):
task.losses = list(range(i))
time_metrics_hook.on_step(task, local_variables)
time_metrics_hook.on_step(task)
if log_freq is not None and i and i % log_freq == 0:
mock_fn.assert_called_with(task, local_variables)
mock_fn.assert_called_with(task)
mock_fn.reset_mock()
continue
mock_fn.assert_not_called()

time_metrics_hook.on_phase_end(task, local_variables)
mock_fn.assert_called_with(task, local_variables)
mock_fn.assert_called_with(task)

task.losses = [0.23, 0.45, 0.34, 0.67]

Expand All @@ -84,7 +84,7 @@ def test_time_metrics(

# test _log_performance_metrics()
with self.assertLogs() as log_watcher:
time_metrics_hook._log_performance_metrics(task, local_variables)
time_metrics_hook._log_performance_metrics(task)

# there should 2 be info logs for train and 1 for test
self.assertEqual(len(log_watcher.output), 2 if train else 1)
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_time_metrics(

# if on_phase_start() is not called, 2 warnings should be logged
# create a new time metrics hook
local_variables = {}
task.perf_stats = None
time_metrics_hook_new = TimeMetricsHook()

with self.assertLogs() as log_watcher:
Expand Down
8 changes: 4 additions & 4 deletions test/manual/hooks_progress_bar_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def test_progress_bar(

# on_step should update the progress bar correctly
for i in range(num_batches):
progress_bar_hook.on_step(task, local_variables)
progress_bar_hook.on_step(task)
mock_progress_bar.update.assert_called_once_with(i + 1)
mock_progress_bar.update.reset_mock()

# check that even if on_step is called again, the progress bar is
# only updated with num_batches
for _ in range(num_batches):
progress_bar_hook.on_step(task, local_variables)
progress_bar_hook.on_step(task)
mock_progress_bar.update.assert_called_once_with(num_batches)
mock_progress_bar.update.reset_mock()

Expand All @@ -68,7 +68,7 @@ def test_progress_bar(
# crash
progress_bar_hook = ProgressBarHook()
try:
progress_bar_hook.on_step(task, local_variables)
progress_bar_hook.on_step(task)
progress_bar_hook.on_phase_end(task, local_variables)
except Exception as e:
self.fail(
Expand All @@ -81,7 +81,7 @@ def test_progress_bar(
progress_bar_hook = ProgressBarHook()
try:
progress_bar_hook.on_phase_start(task, local_variables)
progress_bar_hook.on_step(task, local_variables)
progress_bar_hook.on_step(task)
progress_bar_hook.on_phase_end(task, local_variables)
except Exception as e:
self.fail("Received Exception when is_master() is False: {}".format(e))
Expand Down
4 changes: 2 additions & 2 deletions test/manual/hooks_tensorboard_plot_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:
# the writer if on_phase_start() is not called for initialization
# before on_step() is called.
with self.assertLogs() as log_watcher:
tensorboard_plot_hook.on_step(task, local_variables)
tensorboard_plot_hook.on_step(task)

self.assertTrue(
len(log_watcher.records) == 1
Expand All @@ -88,7 +88,7 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:

for loss in losses:
task.losses.append(loss)
tensorboard_plot_hook.on_step(task, local_variables)
tensorboard_plot_hook.on_step(task)

tensorboard_plot_hook.on_phase_end(task, local_variables)

Expand Down
2 changes: 1 addition & 1 deletion test/optim_param_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class TestHook(ClassyHook):
on_phase_end = ClassyHook._noop
on_end = ClassyHook._noop

def on_step(self, task: ClassyTask, local_variables) -> None:
def on_step(self, task: ClassyTask) -> None:
if not task.train:
return

Expand Down