Skip to content

Commit

Permalink
Fix how prefix/posfix works in MultitaskWrapper (#2722)
Browse files Browse the repository at this point in the history
* implementation
* tests
* changelog
* fix mypy
  • Loading branch information
SkafteNicki committed Sep 10, 2024
1 parent 144f6d6 commit eecc55b
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 40 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721))


- Fixed how `prefix`/`postfix` works in `MultitaskWrapper` ([#2722](https://github.com/Lightning-AI/torchmetrics/pull/2722))


## [1.4.1] - 2024-08-02

### Changed
Expand Down
102 changes: 71 additions & 31 deletions src/torchmetrics/wrappers/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,27 @@ class MultitaskWrapper(WrapperMetric):
task_metrics:
Dictionary associating each task to a Metric or a MetricCollection. The keys of the dictionary represent the
names of the tasks, and the values represent the metrics to use for each task.
prefix:
A string to append in front of the metric keys. If not provided, will default to an empty string.
postfix:
A string to append after the keys of the output dict. If not provided, will default to an empty string.
.. note::
The use pre prefix and postfix allows for easily creating task wrappers for training, validation and test.
The arguments are only changing the output keys of the computed metrics and not the input keys. This means
that a ``MultitaskWrapper`` initialized as ``MultitaskWrapper({"task": Metric()}, prefix="train_")`` will
still expect the input to be a dictionary with the key "task", but the output will be a dictionary with the key
"train_task".
Raises:
TypeError:
If argument `task_metrics` is not an dictionary
TypeError:
If not all values in the `task_metrics` dictionary is instances of `Metric` or `MetricCollection`
ValueError:
If `prefix` is not a string
ValueError:
If `postfix` is not a string
Example (with a single metric per class):
>>> import torch
Expand Down Expand Up @@ -91,18 +106,59 @@ class MultitaskWrapper(WrapperMetric):
{'Classification': {'BinaryAccuracy': tensor(0.3333), 'BinaryF1Score': tensor(0.)},
'Regression': {'MeanSquaredError': tensor(0.8333), 'MeanAbsoluteError': tensor(0.6667)}}
Example (with a prefix and postfix):
>>> import torch
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> from torchmetrics.classification import BinaryAccuracy
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
... "Classification": BinaryAccuracy(),
... "Regression": MeanSquaredError()
... }, prefix="train_")
>>> metrics.update(preds, targets)
>>> metrics.compute()
{'train_Classification': tensor(0.3333), 'train_Regression': tensor(0.8333)}
"""

is_differentiable = False
is_differentiable: bool = False

def __init__(
self,
task_metrics: Dict[str, Union[Metric, MetricCollection]],
prefix: Optional[str] = None,
postfix: Optional[str] = None,
) -> None:
self._check_task_metrics_type(task_metrics)
super().__init__()

if not isinstance(task_metrics, dict):
raise TypeError(f"Expected argument `task_metrics` to be a dict. Found task_metrics = {task_metrics}")

for metric in task_metrics.values():
if not (isinstance(metric, (Metric, MetricCollection))):
raise TypeError(
"Expected each task's metric to be a Metric or a MetricCollection. "
f"Found a metric of type {type(metric)}"
)

self.task_metrics = nn.ModuleDict(task_metrics)

if prefix is not None and not isinstance(prefix, str):
raise ValueError(f"Expected argument `prefix` to either be `None` or a string but got {prefix}")
self._prefix = prefix or ""

if postfix is not None and not isinstance(postfix, str):
raise ValueError(f"Expected argument `postfix` to either be `None` or a string but got {postfix}")
self._postfix = postfix or ""

def items(self, flatten: bool = True) -> Iterable[Tuple[str, nn.Module]]:
"""Iterate over task and task metrics.
Expand All @@ -114,9 +170,9 @@ def items(self, flatten: bool = True) -> Iterable[Tuple[str, nn.Module]]:
for task_name, metric in self.task_metrics.items():
if flatten and isinstance(metric, MetricCollection):
for sub_metric_name, sub_metric in metric.items():
yield f"{task_name}_{sub_metric_name}", sub_metric
yield f"{self._prefix}{task_name}_{sub_metric_name}{self._postfix}", sub_metric
else:
yield task_name, metric
yield f"{self._prefix}{task_name}{self._postfix}", metric

def keys(self, flatten: bool = True) -> Iterable[str]:
"""Iterate over task names.
Expand All @@ -129,9 +185,9 @@ def keys(self, flatten: bool = True) -> Iterable[str]:
for task_name, metric in self.task_metrics.items():
if flatten and isinstance(metric, MetricCollection):
for sub_metric_name in metric:
yield f"{task_name}_{sub_metric_name}"
yield f"{self._prefix}{task_name}_{sub_metric_name}{self._postfix}"
else:
yield task_name
yield f"{self._prefix}{task_name}{self._postfix}"

def values(self, flatten: bool = True) -> Iterable[nn.Module]:
"""Iterate over task metrics.
Expand All @@ -147,18 +203,6 @@ def values(self, flatten: bool = True) -> Iterable[nn.Module]:
else:
yield metric

@staticmethod
def _check_task_metrics_type(task_metrics: Dict[str, Union[Metric, MetricCollection]]) -> None:
if not isinstance(task_metrics, dict):
raise TypeError(f"Expected argument `task_metrics` to be a dict. Found task_metrics = {task_metrics}")

for metric in task_metrics.values():
if not (isinstance(metric, (Metric, MetricCollection))):
raise TypeError(
"Expected each task's metric to be a Metric or a MetricCollection. "
f"Found a metric of type {type(metric)}"
)

def update(self, task_preds: Dict[str, Any], task_targets: Dict[str, Any]) -> None:
"""Update each task's metric with its corresponding pred and target.
Expand All @@ -179,20 +223,24 @@ def update(self, task_preds: Dict[str, Any], task_targets: Dict[str, Any]) -> No
target = task_targets[task_name]
metric.update(pred, target)

def _convert_output(self, output: Dict[str, Any]) -> Dict[str, Any]:
"""Convert the output of the underlying metrics to a dictionary with the task names as keys."""
return {f"{self._prefix}{task_name}{self._postfix}": task_output for task_name, task_output in output.items()}

def compute(self) -> Dict[str, Any]:
"""Compute metrics for all tasks."""
return {task_name: metric.compute() for task_name, metric in self.task_metrics.items()}
return self._convert_output({task_name: metric.compute() for task_name, metric in self.task_metrics.items()})

def forward(self, task_preds: Dict[str, Tensor], task_targets: Dict[str, Tensor]) -> Dict[str, Any]:
"""Call underlying forward methods for all tasks and return the result as a dictionary."""
# This method is overridden because we do not need the complex version defined in Metric, that relies on the
# value of full_state_update, and that also accumulates the results. Here, all computations are handled by the
# underlying metrics, which all have their own value of full_state_update, and which all accumulate the results
# by themselves.
return {
return self._convert_output({
task_name: metric(task_preds[task_name], task_targets[task_name])
for task_name, metric in self.task_metrics.items()
}
})

def reset(self) -> None:
"""Reset all underlying metrics."""
Expand All @@ -215,16 +263,8 @@ def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) ->
"""
multitask_copy = deepcopy(self)
if prefix is not None:
prefix = self._check_arg(prefix, "prefix")
multitask_copy.task_metrics = nn.ModuleDict({
prefix + key: value for key, value in multitask_copy.task_metrics.items()
})
if postfix is not None:
postfix = self._check_arg(postfix, "postfix")
multitask_copy.task_metrics = nn.ModuleDict({
key + postfix: value for key, value in multitask_copy.task_metrics.items()
})
multitask_copy._prefix = self._check_arg(prefix, "prefix") or ""
multitask_copy._postfix = self._check_arg(postfix, "prefix") or ""
return multitask_copy

def plot(
Expand Down
28 changes: 19 additions & 9 deletions tests/unittests/wrappers/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,24 @@ def test_key_value_items_method(method, flatten):

def test_clone_with_prefix_and_postfix():
"""Check that the clone method works with prefix and postfix arguments."""
multitask_metrics = MultitaskWrapper({"Classification": BinaryAccuracy(), "Regression": MeanSquaredError()})
cloned_metrics_with_prefix = multitask_metrics.clone(prefix="prefix_")
cloned_metrics_with_postfix = multitask_metrics.clone(postfix="_postfix")
multitask_metrics = MultitaskWrapper(
{"Classification": BinaryAccuracy(), "Regression": MeanSquaredError()},
prefix="prefix_",
postfix="_postfix",
)
assert set(multitask_metrics.keys()) == {"prefix_Classification_postfix", "prefix_Regression_postfix"}

# Check if the cloned metrics have the expected keys
assert set(cloned_metrics_with_prefix.task_metrics.keys()) == {"prefix_Classification", "prefix_Regression"}
assert set(cloned_metrics_with_postfix.task_metrics.keys()) == {"Classification_postfix", "Regression_postfix"}
output = multitask_metrics(
{"Classification": _classification_preds, "Regression": _regression_preds},
{"Classification": _classification_target, "Regression": _regression_target},
)
assert set(output.keys()) == {"prefix_Classification_postfix", "prefix_Regression_postfix"}

# Check if the cloned metrics have the expected values
assert isinstance(cloned_metrics_with_prefix.task_metrics["prefix_Classification"], BinaryAccuracy)
assert isinstance(cloned_metrics_with_prefix.task_metrics["prefix_Regression"], MeanSquaredError)
cloned_metrics = multitask_metrics.clone(prefix="new_prefix_", postfix="_new_postfix")
assert set(cloned_metrics.keys()) == {"new_prefix_Classification_new_postfix", "new_prefix_Regression_new_postfix"}

output = cloned_metrics(
{"Classification": _classification_preds, "Regression": _regression_preds},
{"Classification": _classification_target, "Regression": _regression_target},
)
assert set(output.keys()) == {"new_prefix_Classification_new_postfix", "new_prefix_Regression_new_postfix"}

0 comments on commit eecc55b

Please sign in to comment.