Skip to content

Commit

Permalink
fix ndim
Browse files Browse the repository at this point in the history
  • Loading branch information
hkvision committed Nov 2, 2020
1 parent 53cd78a commit a9618ea
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions pyzoo/zoo/orca/learn/pytorch/training_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,13 @@ def validate_batch(self, batch, batch_info):
# unpack features into list to support multiple inputs model
*features, target = batch
_target = target
if _target.ndim > 1:
if len(_target.size()) > 1:
# Can't directly call torch.squeeze() in case batch size is 1.
for i in reversed(range(1, _target.ndim)):
for i in reversed(range(1, len(_target.size()))):
_target = torch.squeeze(_target, i)
error_msg = "Currently in validate, only accuracy for classification with zero-based label is supported " \
"by default. You can override validate_batch in TrainingOperator for other validation metrics"
assert _target.ndim == 1, error_msg
assert len(_target.size()) == 1, error_msg

if self.use_gpu:
features = [
Expand All @@ -369,19 +369,20 @@ def validate_batch(self, batch, batch_info):
output = self.model(*features)
loss = self.criterion(output, target)
_output = output
if _output.ndim > 2:
if len(_output.size()) > 2:
# In case there is extra trailing dimensions.
for i in reversed(range(1, _output.ndim)):
for i in reversed(range(1, len(_output.size()))):
_output = torch.squeeze(_output, i)
assert _output.ndim == 1 or _output.ndim == 2, error_msg

np_output = _output.detach().numpy()
np_target = _target.detach().numpy()
import numpy as np
if len(np_output.shape) == 1: # Binary classification
np_output = np.round(np_output, 0)
else: # Multi-class classification
elif len(np_output.shape) == 2: # Multi-class classification
np_output = np.argmax(np_output, axis=1)
else:
raise Exception(error_msg)

num_correct = np.sum((np_output == np_target).astype(np.uint8))
num_samples = target.size(0)
Expand Down

0 comments on commit a9618ea

Please sign in to comment.