Skip to content

Commit

Permalink
Fix validate in Orca PyTorch Estimator (#3012)
Browse files Browse the repository at this point in the history
* fix validate

* rename

* fix

* fix ut

* update

* minor

* fix ut
  • Loading branch information
hkvision authored and yangw1234 committed Sep 27, 2021
1 parent 43cf9ff commit a2e057e
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 19 deletions.
34 changes: 29 additions & 5 deletions python/orca/src/bigdl/orca/learn/pytorch/training_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def validate(self, val_iterator, info):
return metric_meters.summary()

def validate_batch(self, batch, batch_info):
"""Calcuates the loss and accuracy over a given batch.
"""Calculates the loss and accuracy over a given batch.
You can override this method to provide arbitrary metrics.
Expand All @@ -349,20 +349,44 @@ def validate_batch(self, batch, batch_info):
"""
# unpack features into list to support multiple inputs model
*features, target = batch
if len(target.size()) > 1:
# Can't directly call torch.squeeze() in case batch size is 1.
for i in reversed(range(1, len(target.size()))):
target = torch.squeeze(target, i)

if self.use_gpu:
features = [
feature.cuda(non_blocking=True) for feature in features
]
target = target.cuda(non_blocking=True)

# compute output

with self.timers.record("eval_fwd"):
output = self.model(*features)
loss = self.criterion(output, target)
_, predicted = torch.max(output.data, 1)

num_correct = (predicted == target).sum().item()
if len(output.size()) > 1:
# In case there is extra trailing dimensions.
for i in reversed(range(1, len(output.size()))):
output = torch.squeeze(output, i)

np_output = output.detach().numpy()
np_target = target.detach().numpy()
# validate will be called by TCMF to get val_loss for regression tasks.
# In this case, accuracy is calculated but not used and the result is wrong.
# So do not directly raise an Exception here to avoid errors in TCMF.
# TODO: Support other validation metrics.
if len(np_target.shape) != 1 or len(np_output.shape) > 2:
import warnings
warnings.warn("Currently in validate, only accuracy for classification with "
"zero-based label is supported by default. You can override "
"validate_batch in TrainingOperator for other validation metrics")
import numpy as np
if len(np_output.shape) == 1: # Binary classification
np_output = np.round(np_output, 0)
else: # Multi-class classification
np_output = np.argmax(np_output, axis=1)

num_correct = np.sum(np_output == np_target)
num_samples = target.size(0)
return {
"val_loss": loss.item(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tempfile import TemporaryDirectory


class TestPyTorchTrainer(TestCase):
class TestPyTorchEstimator(TestCase):
def test_train(self):
estimator = Estimator.from_torch(
model=model_creator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
class LinearDataset(torch.utils.data.Dataset):
"""y = a * x + b"""

def __init__(self, a, b, size=1000):
x = np.arange(0, 10, 10 / size, dtype=np.float32)
self.x = torch.from_numpy(x)
self.y = torch.from_numpy(a * x + b)
def __init__(self, size=1000):
X1 = torch.randn(size // 2, 50)
X2 = torch.randn(size // 2, 50) + 1.5
self.x = torch.cat([X1, X2], dim=0)
Y1 = torch.zeros(size // 2, 1)
Y2 = torch.ones(size // 2, 1)
self.y = torch.cat([Y1, Y2], dim=0)

def __getitem__(self, index):
return self.x[index, None], self.y[index, None]
Expand All @@ -40,8 +43,30 @@ def __len__(self):
return len(self.x)


class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(50, 50)
self.relu1 = nn.ReLU()
self.dout = nn.Dropout(0.2)
self.fc2 = nn.Linear(50, 100)
self.prelu = nn.PReLU(1)
self.out = nn.Linear(100, 1)
self.out_act = nn.Sigmoid()

def forward(self, input_):
a1 = self.fc1(input_)
h1 = self.relu1(a1)
dout = self.dout(h1)
a2 = self.fc2(dout)
h2 = self.prelu(a2)
a3 = self.out(h2)
y = self.out_act(a3)
return y


def train_data_loader(config):
train_dataset = LinearDataset(2, 5, size=config.get("data_size", 1000))
train_dataset = LinearDataset(size=config.get("data_size", 1000))
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.get("batch_size", 32),
Expand All @@ -50,34 +75,34 @@ def train_data_loader(config):


def val_data_loader(config):
val_dataset = LinearDataset(2, 5, size=config.get("val_size", 400))
val_dataset = LinearDataset(size=config.get("val_size", 400))
validation_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=config.get("batch_size", 32))
return validation_loader


def get_model(config):
return nn.Linear(1, config.get("hidden_size", 1))
return Net()


def get_optimizer(model, config):
return torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-2))


class TestPyTorchTrainer(TestCase):
class TestPyTorchEstimator(TestCase):
def test_linear(self):
estimator = Estimator.from_torch(model=get_model,
optimizer=get_optimizer,
loss=nn.MSELoss(),
config={"lr": 1e-2, "hidden_size": 1,
loss=nn.BCELoss(),
config={"lr": 1e-2,
"batch_size": 128},
backend="pytorch")
train_stats = estimator.fit(train_data_loader, epochs=2)
print(train_stats)
# it seems validate on regression model is not supported
# val_stats = estimator.evaluate(val_data_loader)
# print(val_stats)
val_stats = estimator.evaluate(val_data_loader)
print(val_stats)
assert 0 < val_stats["val_accuracy"] < 1
assert estimator.get_model()
estimator.shutdown()

Expand Down

0 comments on commit a2e057e

Please sign in to comment.