Skip to content

Commit

Permalink
Pytorch training and inference for model with multiple output and mul…
Browse files Browse the repository at this point in the history
…tiple input (intel-analytics#1544)

* inference with table output

* add unit test

* multi and unit test

* remove duplicate ut

* clear caching data

* support multiple shape

* multi input ut

* release

*  remove empty line

* update so
  • Loading branch information
hhbyyh committed Aug 9, 2019
1 parent f825700 commit 00b11de
Showing 1 changed file with 151 additions and 0 deletions.
151 changes: 151 additions & 0 deletions test_torch_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,157 @@ def lossFunc(input, target):
assert np.allclose(torch_loss.tolist(), az_loss_output)
assert np.allclose(torch_grad, az_grad.tolist(), atol=1.e-5, rtol=1.e-3)

def test_model_inference_with_multiple_output(self):
class TwoOutputModel(nn.Module):
def __init__(self):
super(TwoOutputModel, self).__init__()
self.dense1 = nn.Linear(2, 1)

def forward(self, x):
x1 = self.dense1(x)
return x, x1

torch_model = TwoOutputModel()
az_net = TorchNet.from_pytorch(TwoOutputModel(), [1, 2])

az_input = np.array([[0.5, 1.], [-0.3, 1.2]])
az_output = az_net.forward(az_input)
assert (len(az_output) == 2)
assert (az_output[0].shape == (2, 2))
assert (az_output[1].shape == (2, 1))

def test_model_train_with_multiple_output(self):
class TwoOutputModel(nn.Module):
def __init__(self):
super(TwoOutputModel, self).__init__()
self.dense1 = nn.Linear(2, 1)

def forward(self, x):
x1 = self.dense1(x)
return x, x1

input = [[0.5, 1.], [-0.3, 1.2]]
torch_input = torch.tensor(input)
torch_label = (torch.ones(2, 2), torch.ones(2, 1))

model = TwoOutputModel()
criterion = nn.MSELoss()

def lossFunc(input, label):
loss1 = criterion(input[0], label[0])
loss2 = criterion(input[1], label[1])
loss = loss1 + 0.4 * loss2
return loss

torch_output = model.forward(torch_input)
torch_loss = lossFunc(torch_output, torch_label)
torch_loss.backward()
torch_grad = model.dense1.weight.grad.tolist()[0] + model.dense1.bias.grad.tolist()

az_net = TorchNet.from_pytorch(model, [1, 2])
az_criterion = TorchCriterion.from_pytorch(
loss=lossFunc,
sample_input=(torch.ones(2, 2), torch.ones(2, 1)),
sample_label=(torch.ones(2, 2), torch.ones(2, 1)))

az_input = np.array(input)
az_label = [np.ones([2, 2]), np.ones([2, 1])]
az_output = az_net.forward(az_input)
az_loss_output = az_criterion.forward(az_output, az_label)
az_loss_backward = az_criterion.backward(az_output, az_label)
az_model_backward = az_net.backward(az_input, az_loss_backward)

az_grad = list(az_net.parameters().values())[0]['gradWeight']

assert np.allclose(torch_loss.tolist(), az_loss_output)
assert np.allclose(torch_grad, az_grad.tolist())

def test_torchnet_constructor(self):
class TwoInputModel(nn.Module):
def __init__(self):
super(TwoInputModel, self).__init__()
self.dense1 = nn.Linear(2, 2)
self.dense2 = nn.Linear(3, 1)

def forward(self, x1, x2):
x1 = self.dense1(x1)
x2 = self.dense2(x2)
return x1, x2

az_net = TorchNet.from_pytorch(
TwoInputModel(), sample_input=(torch.ones(2, 2), torch.ones(2, 3)))
az_net = TorchNet.from_pytorch(TwoInputModel(), ([2, 2], [2, 3]))

def test_torchcriterion_constructor(self):
criterion = nn.MSELoss()

def lossFunc(input, label):
loss1 = criterion(input[0], label[0])
loss2 = criterion(input[1], label[1])
loss = loss1 + 0.4 * loss2
return loss

az_criterion = TorchCriterion.from_pytorch(
lossFunc,
sample_input=(torch.ones(2, 2), torch.ones(2, 3)),
sample_label=(torch.ones(2, 2), torch.ones(2, 3)))
az_criterion = TorchCriterion.from_pytorch(lossFunc, ([2, 2], [2, 3]), ([2, 2], [2, 3]))

def test_model_train_with_multiple_input(self):
class TwoInputModel(nn.Module):
def __init__(self):
super(TwoInputModel, self).__init__()
self.dense1 = nn.Linear(2, 2)
self.dense2 = nn.Linear(2, 1)

def forward(self, x1, x2):
x1 = self.dense1(x1)
x2 = self.dense2(x2)
return x1, x2

input = [[0.5, 1.], [-0.3, 1.2]]
torch_input1 = torch.tensor(input, requires_grad=True)
torch_input2 = torch.tensor(input, requires_grad=True)
torch_label = (torch.ones(2, 2), torch.ones(2, 1))

model = TwoInputModel()
criterion = nn.MSELoss()

def lossFunc(input, label):
loss1 = criterion(input[0], label[0])
loss2 = criterion(input[1], label[1])
loss = loss1 + 0.4 * loss2
return loss

torch_output = model.forward(torch_input1, torch_input2)
torch_loss = lossFunc(torch_output, torch_label)
torch_loss.backward()
torch_grad = model.dense1.weight.grad.tolist()[0] + \
model.dense1.weight.grad.tolist()[1] + \
model.dense1.bias.grad.tolist() + \
model.dense2.weight.grad.tolist()[0] + \
model.dense2.bias.grad.tolist()

az_net = TorchNet.from_pytorch(model, sample_input=(torch.ones(2, 2), torch.ones(2, 2)))
az_criterion = TorchCriterion.from_pytorch(
loss=lossFunc,
sample_input=(torch.ones(2, 2), torch.ones(2, 1)),
sample_label=(torch.ones(2, 2), torch.ones(2, 1)))

az_input = [np.array(input), np.array(input)]
az_label = [np.ones([2, 2]), np.ones([2, 1])]
az_output = az_net.forward(az_input)
az_loss_output = az_criterion.forward(az_output, az_label)
az_loss_backward = az_criterion.backward(az_output, az_label)
az_model_backward = az_net.backward(az_input, az_loss_backward)

az_grad = list(az_net.parameters().values())[0]['gradWeight']

assert np.allclose(torch_loss.tolist(), az_loss_output)
assert np.allclose(torch_grad, az_grad.tolist())
assert np.allclose(az_model_backward[0], torch_input1.grad)
assert np.allclose(az_model_backward[1], torch_input2.grad)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 00b11de

Please sign in to comment.