Skip to content

Commit

Permalink
Pytorch training and inference for model with multiple output and mul…
Browse files Browse the repository at this point in the history
…tiple input (#1544)

* inference with table output

* add unit test

* multi and unit test

* remove duplicate ut

* clear caching data

* support multiple shape

* multi input ut

* release

*  remove empty line

* update so
  • Loading branch information
hhbyyh committed Aug 9, 2019
1 parent 78afc61 commit 7ce4af6
Show file tree
Hide file tree
Showing 9 changed files with 527 additions and 162 deletions.
151 changes: 151 additions & 0 deletions pyzoo/test/zoo/pipeline/api/test_torch_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,157 @@ def lossFunc(input, target):
assert np.allclose(torch_loss.tolist(), az_loss_output)
assert np.allclose(torch_grad, az_grad.tolist(), atol=1.e-5, rtol=1.e-3)

def test_model_inference_with_multiple_output(self):
class TwoOutputModel(nn.Module):
def __init__(self):
super(TwoOutputModel, self).__init__()
self.dense1 = nn.Linear(2, 1)

def forward(self, x):
x1 = self.dense1(x)
return x, x1

torch_model = TwoOutputModel()
az_net = TorchNet.from_pytorch(TwoOutputModel(), [1, 2])

az_input = np.array([[0.5, 1.], [-0.3, 1.2]])
az_output = az_net.forward(az_input)
assert (len(az_output) == 2)
assert (az_output[0].shape == (2, 2))
assert (az_output[1].shape == (2, 1))

def test_model_train_with_multiple_output(self):
class TwoOutputModel(nn.Module):
def __init__(self):
super(TwoOutputModel, self).__init__()
self.dense1 = nn.Linear(2, 1)

def forward(self, x):
x1 = self.dense1(x)
return x, x1

input = [[0.5, 1.], [-0.3, 1.2]]
torch_input = torch.tensor(input)
torch_label = (torch.ones(2, 2), torch.ones(2, 1))

model = TwoOutputModel()
criterion = nn.MSELoss()

def lossFunc(input, label):
loss1 = criterion(input[0], label[0])
loss2 = criterion(input[1], label[1])
loss = loss1 + 0.4 * loss2
return loss

torch_output = model.forward(torch_input)
torch_loss = lossFunc(torch_output, torch_label)
torch_loss.backward()
torch_grad = model.dense1.weight.grad.tolist()[0] + model.dense1.bias.grad.tolist()

az_net = TorchNet.from_pytorch(model, [1, 2])
az_criterion = TorchCriterion.from_pytorch(
loss=lossFunc,
sample_input=(torch.ones(2, 2), torch.ones(2, 1)),
sample_label=(torch.ones(2, 2), torch.ones(2, 1)))

az_input = np.array(input)
az_label = [np.ones([2, 2]), np.ones([2, 1])]
az_output = az_net.forward(az_input)
az_loss_output = az_criterion.forward(az_output, az_label)
az_loss_backward = az_criterion.backward(az_output, az_label)
az_model_backward = az_net.backward(az_input, az_loss_backward)

az_grad = list(az_net.parameters().values())[0]['gradWeight']

assert np.allclose(torch_loss.tolist(), az_loss_output)
assert np.allclose(torch_grad, az_grad.tolist())

def test_torchnet_constructor(self):
class TwoInputModel(nn.Module):
def __init__(self):
super(TwoInputModel, self).__init__()
self.dense1 = nn.Linear(2, 2)
self.dense2 = nn.Linear(3, 1)

def forward(self, x1, x2):
x1 = self.dense1(x1)
x2 = self.dense2(x2)
return x1, x2

az_net = TorchNet.from_pytorch(
TwoInputModel(), sample_input=(torch.ones(2, 2), torch.ones(2, 3)))
az_net = TorchNet.from_pytorch(TwoInputModel(), ([2, 2], [2, 3]))

def test_torchcriterion_constructor(self):
criterion = nn.MSELoss()

def lossFunc(input, label):
loss1 = criterion(input[0], label[0])
loss2 = criterion(input[1], label[1])
loss = loss1 + 0.4 * loss2
return loss

az_criterion = TorchCriterion.from_pytorch(
lossFunc,
sample_input=(torch.ones(2, 2), torch.ones(2, 3)),
sample_label=(torch.ones(2, 2), torch.ones(2, 3)))
az_criterion = TorchCriterion.from_pytorch(lossFunc, ([2, 2], [2, 3]), ([2, 2], [2, 3]))

def test_model_train_with_multiple_input(self):
class TwoInputModel(nn.Module):
def __init__(self):
super(TwoInputModel, self).__init__()
self.dense1 = nn.Linear(2, 2)
self.dense2 = nn.Linear(2, 1)

def forward(self, x1, x2):
x1 = self.dense1(x1)
x2 = self.dense2(x2)
return x1, x2

input = [[0.5, 1.], [-0.3, 1.2]]
torch_input1 = torch.tensor(input, requires_grad=True)
torch_input2 = torch.tensor(input, requires_grad=True)
torch_label = (torch.ones(2, 2), torch.ones(2, 1))

model = TwoInputModel()
criterion = nn.MSELoss()

def lossFunc(input, label):
loss1 = criterion(input[0], label[0])
loss2 = criterion(input[1], label[1])
loss = loss1 + 0.4 * loss2
return loss

torch_output = model.forward(torch_input1, torch_input2)
torch_loss = lossFunc(torch_output, torch_label)
torch_loss.backward()
torch_grad = model.dense1.weight.grad.tolist()[0] + \
model.dense1.weight.grad.tolist()[1] + \
model.dense1.bias.grad.tolist() + \
model.dense2.weight.grad.tolist()[0] + \
model.dense2.bias.grad.tolist()

az_net = TorchNet.from_pytorch(model, sample_input=(torch.ones(2, 2), torch.ones(2, 2)))
az_criterion = TorchCriterion.from_pytorch(
loss=lossFunc,
sample_input=(torch.ones(2, 2), torch.ones(2, 1)),
sample_label=(torch.ones(2, 2), torch.ones(2, 1)))

az_input = [np.array(input), np.array(input)]
az_label = [np.ones([2, 2]), np.ones([2, 1])]
az_output = az_net.forward(az_input)
az_loss_output = az_criterion.forward(az_output, az_label)
az_loss_backward = az_criterion.backward(az_output, az_label)
az_model_backward = az_net.backward(az_input, az_loss_backward)

az_grad = list(az_net.parameters().values())[0]['gradWeight']

assert np.allclose(torch_loss.tolist(), az_loss_output)
assert np.allclose(torch_grad, az_grad.tolist())
assert np.allclose(az_model_backward[0], torch_input1.grad)
assert np.allclose(az_model_backward[1], torch_input2.grad)


if __name__ == "__main__":
pytest.main([__file__])
22 changes: 18 additions & 4 deletions pyzoo/zoo/pipeline/api/net/torch_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@
#
import torch
import torch.nn as nn
import sys
import os
import tempfile
import shutil
from bigdl.nn.criterion import Criterion
from .torch_net import TorchNet

if sys.version >= '3':
long = int
unicode = str


class LossWrapperModule(nn.Module):
Expand All @@ -44,9 +50,13 @@ def __init__(self, path, bigdl_type="float"):
super(TorchCriterion, self).__init__(None, bigdl_type, path)

@staticmethod
def from_pytorch(loss, input_shape, label_shape=None, sample_input=None, sample_label=None):
def from_pytorch(loss, input_shape=None, label_shape=None,
sample_input=None, sample_label=None):
"""
Create a TorchCriterion directly from PyTorch function
Create a TorchCriterion directly from PyTorch function. We need user to provide a sample
input and label to trace the loss function. User may just specify the input and label shape.
For specific data type or multiple input models, users can send sample_input and
sample_label.
:param loss: this can be a torch loss (e.g. nn.MSELoss()) or
a function that take two Tensor parameters: input and label. E.g.
def lossFunc(input, target):
Expand All @@ -57,14 +67,18 @@ def lossFunc(input, target):
:param sample_input: a sample of input.
:param sample_label: a sample of label.
"""
if not input_shape and not label_shape and not sample_input and not sample_label:
raise Exception("please specify input_shape and label_shape, or sample_input"
" and sample_label")

temp = tempfile.mkdtemp()

# use input_shape as label shape when label_shape is not specified
if not label_shape:
label_shape = input_shape

sample_input = sample_input if sample_input else torch.rand(input_shape)
sample_label = sample_label if sample_label else torch.rand(label_shape)
sample_input = TorchNet.get_sample_input(input_shape, sample_input)
sample_label = TorchNet.get_sample_input(label_shape, sample_label)

traced_script_loss = torch.jit.trace(LossWrapperModule(loss),
(sample_input, sample_label))
Expand Down
31 changes: 27 additions & 4 deletions pyzoo/zoo/pipeline/api/net/torch_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tempfile
import shutil
import numpy as np
import sys

from pyspark import RDD
from bigdl.nn.layer import Layer
Expand All @@ -26,6 +27,10 @@
from bigdl.util.common import callBigDlFunc
from zoo.pipeline.api.net.tfnet import to_sample_rdd

if sys.version >= '3':
long = int
unicode = str


class TorchNet(Layer):
"""
Expand All @@ -38,16 +43,23 @@ def __init__(self, path, bigdl_type="float"):
super(TorchNet, self).__init__(None, bigdl_type, path)

@staticmethod
def from_pytorch(module, input_shape):
def from_pytorch(module, input_shape=None, sample_input=None):
"""
Create a TorchNet directly from PyTorch model, e.g. model in torchvision.models
Create a TorchNet directly from PyTorch model, e.g. model in torchvision.models.
Users need to specify sample_input or input_shape.
:param module: a PyTorch model
:param input_shape: list of integers. E.g. for ResNet, this may be [1, 3, 224, 224]
:param input_shape: list of integers, or tuple of list for multiple inputs models. E.g.
for ResNet, this may be [1, 3, 224, 224]
:param sample_input. A sample of Torch Tensor or tuple to trace the model.
"""
if not input_shape and not sample_input:
raise Exception("please specify input_shape or sample_input")

sample = TorchNet.get_sample_input(input_shape, sample_input)
temp = tempfile.mkdtemp()

# save model
traced_script_module = torch.jit.trace(module, torch.rand(input_shape))
traced_script_module = torch.jit.trace(module, sample)
path = os.path.join(temp, "model.pt")
traced_script_module.save(path)

Expand All @@ -56,6 +68,17 @@ def from_pytorch(module, input_shape):

return net

@staticmethod
def get_sample_input(shape, sample):
if sample:
return sample
elif isinstance(shape, list):
return torch.rand(shape)
elif isinstance(shape, tuple):
return tuple(map(lambda s: torch.rand(s), shape))
else:
raise Exception("please specify shape as list of ints or tuples of int lists")

def predict(self, x, batch_per_thread=1, distributed=True):
"""
Use a model to do prediction.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,17 @@ private static File extract(String path) {

static native long loadLossNative(String lossPath);

static native JTensor modelForwardNative(
long nativeRef, boolean isTraining, float[] storage, int offset, int[] shape);
static native JTensor[] modelForwardNative(
long nativeRef, boolean isTraining, float[][] storage, int[] offset, int[][] shape);

static native JTensor modelBackwardNative(long nativeRef, float[] storage, int offset, int[] shape);
static native JTensor[] modelBackwardNative(
long nativeRef, float[][] storage, int[] offset, int[][] shape);

static native JTensor lossForwardNative(
long nativeRef, float[] input_storage, int input_offset, int[] input_shape,
float[] label_storage, int label_offset, int[] label_shape);
long nativeRef, float[][] input_storage, int[] input_offset, int[][] input_shape,
float[][] label_storage, int[] label_offset, int[][] label_shape);

static native JTensor lossBackwardNative(long nativeRef);
static native JTensor[] lossBackwardNative(long nativeRef);

static native float[] getGradientNative(long nativeRef);

Expand Down
Loading

0 comments on commit 7ce4af6

Please sign in to comment.