diff --git a/examples/onnx/arcface.py b/examples/onnx/arcface.py index e1cfa180e..a89e51601 100644 --- a/examples/onnx/arcface.py +++ b/examples/onnx/arcface.py @@ -24,10 +24,9 @@ from singa import device from singa import tensor -from singa import autograd from singa import sonnx import onnx -from utils import download_model, update_batch_size, check_exist_or_download +from utils import download_model, check_exist_or_download import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') @@ -54,18 +53,17 @@ def get_image(): return img1, img2 -class Infer: +class MyModel(sonnx.SONNXModel): - def __init__(self, sg_ir): - self.sg_ir = sg_ir - for idx, tens in sg_ir.tensor_map.items(): - # allow the tensors to be updated - tens.requires_grad = True - tens.stores_grad = True - sg_ir.tensor_map[idx] = tens + def __init__(self, onnx_model): + super(MyModel, self).__init__(onnx_model) - def forward(self, x): - return sg_ir.run([x])[0] + def forward(self, *x): + y = super(MyModel, self).forward(*x) + return y + + def train_one_batch(self, x, y): + pass if __name__ == "__main__": @@ -78,35 +76,30 @@ def forward(self, x): download_model(url) onnx_model = onnx.load(model_path) - # set batch size - onnx_model = update_batch_size(onnx_model, 2) + # inference demo + logging.info("preprocessing...") + img1, img2 = get_image() + img1 = preprocess(img1) + img2 = preprocess(img2) + # sg_ir = sonnx.prepare(onnx_model) # run without graph + # y = sg_ir.run([img1, img2]) - # prepare the model - logging.info("prepare model...") + logging.info("model compling...") dev = device.create_cuda_gpu() - sg_ir = sonnx.prepare(onnx_model, device=dev) - autograd.training = False - model = Infer(sg_ir) + x = tensor.Tensor(device=dev, data=np.concatenate((img1, img2), axis=0)) + m = MyModel(onnx_model) + m.compile([x], is_train=False, use_graph=True, sequential=True) - # verifty the test dataset + # verifty the test # from utils import load_dataset - # inputs, ref_outputs = load_dataset( - # os.path.join('/tmp', 'resnet100', 'test_data_set_0')) + # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'resnet100', 'test_data_set_0')) # x_batch = tensor.Tensor(device=dev, data=inputs[0]) - # outputs = model.forward(x_batch) + # outputs = sg_ir.run([x_batch]) # for ref_o, o in zip(ref_outputs, outputs): # np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4) - # inference demo - logging.info("preprocessing...") - img1, img2 = get_image() - img1 = preprocess(img1) - img2 = preprocess(img2) - - x_batch = tensor.Tensor(device=dev, - data=np.concatenate((img1, img2), axis=0)) logging.info("model running...") - y = model.forward(x_batch) + y = m.forward(*[x])[0] logging.info("postprocessing...") embedding = tensor.to_numpy(y) @@ -120,4 +113,4 @@ def forward(self, x): sim = np.dot(embedding1, embedding2.T) # logging.info predictions logging.info('Distance = %f' % (dist)) - logging.info('Similarity = %f' % (sim)) + logging.info('Similarity = %f' % (sim)) \ No newline at end of file diff --git a/examples/onnx/bert/bert-squad.py b/examples/onnx/bert/bert-squad.py index e4a848801..936968ead 100644 --- a/examples/onnx/bert/bert-squad.py +++ b/examples/onnx/bert/bert-squad.py @@ -24,14 +24,13 @@ from singa import device from singa import tensor from singa import sonnx -from singa import autograd import onnx import tokenization from run_onnx_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions import sys sys.path.append(os.path.dirname(__file__) + '/..') -from utils import download_model, update_batch_size, check_exist_or_download +from utils import download_model, check_exist_or_download import logging logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s') @@ -54,15 +53,6 @@ def load_vocab(): return filename -class Infer: - - def __init__(self, sg_ir): - self.sg_ir = sg_ir - - def forward(self, x): - return sg_ir.run(x) - - def preprocess(): vocab_file = load_vocab() tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, @@ -96,6 +86,19 @@ def postprocess(eval_examples, extra_data, all_results): print("The result is:", json.dumps(test_data, indent=2)) +class MyModel(sonnx.SONNXModel): + + def __init__(self, onnx_model): + super(MyModel, self).__init__(onnx_model) + + def forward(self, *x): + y = super(MyModel, self).forward(*x) + return y + + def train_one_batch(self, x, y): + pass + + if __name__ == "__main__": url = 'https://media.githubusercontent.com/media/onnx/models/master/text/machine_comprehension/bert-squad/model/bertsquad-10.tar.gz' @@ -107,16 +110,12 @@ def postprocess(eval_examples, extra_data, all_results): download_model(url) onnx_model = onnx.load(model_path) - # set batch size - onnx_model = update_batch_size(onnx_model, batch_size) - dev = device.create_cuda_gpu() - autograd.training = False - # inference logging.info("preprocessing...") input_ids, input_mask, segment_ids, extra_data, eval_examples = preprocess() - sg_ir = None + m = None + dev = device.create_cuda_gpu() n = len(input_ids) bs = batch_size all_results = [] @@ -132,23 +131,20 @@ def postprocess(eval_examples, extra_data, all_results): input_ids[idx:idx + bs].astype(np.int32), ] - if sg_ir is None: - # prepare the model - logging.info("model is none, prepare model...") - sg_ir = sonnx.prepare(onnx_model, - device=dev, - init_inputs=inputs, - keep_initializers_as_inputs=False) - model = Infer(sg_ir) - x_batch = [] for inp in inputs: tmp_tensor = tensor.from_numpy(inp) tmp_tensor.to_device(dev) x_batch.append(tmp_tensor) + # prepare the model + if m is None: + logging.info("model compling...") + m = MyModel(onnx_model) + # m.compile(x_batch, is_train=False, use_graph=True, sequential=True) + logging.info("model running for sample {}...".format(idx)) - outputs = model.forward(x_batch) + outputs = m.forward(*x_batch) logging.info("hanlde the result of sample {}...".format(idx)) result = [] diff --git a/examples/onnx/bert/tokenization.py b/examples/onnx/bert/tokenization.py index 4dd0a3128..09b9b4fd7 100644 --- a/examples/onnx/bert/tokenization.py +++ b/examples/onnx/bert/tokenization.py @@ -86,8 +86,6 @@ def convert_to_unicode(text): elif six.PY2: if isinstance(text, str): return text.decode("utf-8", "ignore") - elif isinstance(text, unicode): - return text else: raise ValueError("Unsupported string type: %s" % (type(text))) else: @@ -109,8 +107,6 @@ def printable_text(text): elif six.PY2: if isinstance(text, str): return text - elif isinstance(text, unicode): - return text.encode("utf-8") else: raise ValueError("Unsupported string type: %s" % (type(text))) else: diff --git a/examples/onnx/fer_emotion.py b/examples/onnx/fer_emotion.py index 46c0142c8..bbd762f3f 100644 --- a/examples/onnx/fer_emotion.py +++ b/examples/onnx/fer_emotion.py @@ -22,10 +22,9 @@ from singa import device from singa import tensor -from singa import autograd from singa import sonnx import onnx -from utils import download_model, update_batch_size, check_exist_or_download +from utils import download_model, check_exist_or_download import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') @@ -51,18 +50,17 @@ def get_image_labe(): return img, labels -class Infer: +class MyModel(sonnx.SONNXModel): - def __init__(self, sg_ir): - self.sg_ir = sg_ir - for idx, tens in sg_ir.tensor_map.items(): - # allow the tensors to be updated - tens.requires_grad = True - tens.stores_grad = True - sg_ir.tensor_map[idx] = tens + def __init__(self, onnx_model): + super(MyModel, self).__init__(onnx_model) - def forward(self, x): - return sg_ir.run([x])[0] + def forward(self, *x): + y = super(MyModel, self).forward(*x) + return y + + def train_one_batch(self, x, y): + pass if __name__ == "__main__": @@ -75,33 +73,30 @@ def forward(self, x): download_model(url) onnx_model = onnx.load(model_path) - # set batch size - onnx_model = update_batch_size(onnx_model, 1) + # inference + logging.info("preprocessing...") + img, labels = get_image_labe() + img = preprocess(img) + # sg_ir = sonnx.prepare(onnx_model) # run without graph + # y = sg_ir.run([img]) - # prepare the model - logging.info("prepare model...") + logging.info("model compling...") dev = device.create_cuda_gpu() - sg_ir = sonnx.prepare(onnx_model, device=dev) - autograd.training = False - model = Infer(sg_ir) + x = tensor.PlaceHolder(img.shape, device=dev) + m = MyModel(onnx_model) + m.compile([x], is_train=False, use_graph=True, sequential=True) # verifty the test # from utils import load_dataset # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'emotion_ferplus', 'test_data_set_0')) # x_batch = tensor.Tensor(device=dev, data=inputs[0]) - # outputs = model.forward(x_batch) + # outputs = sg_ir.run([x_batch]) # for ref_o, o in zip(ref_outputs, outputs): # np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4) - # inference - logging.info("preprocessing...") - img, labels = get_image_labe() - img = preprocess(img) - - x_batch = tensor.Tensor(device=dev, data=img) - logging.info("model running...") - y = model.forward(x_batch) + x = tensor.Tensor(device=dev, data=img) + y = m.forward(*[x])[0] logging.info("postprocessing...") y = tensor.softmax(y) diff --git a/examples/onnx/mnist.py b/examples/onnx/mnist.py deleted file mode 100644 index 80208df1e..000000000 --- a/examples/onnx/mnist.py +++ /dev/null @@ -1,321 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under th - -import os -import gzip -import numpy as np -import codecs - -from singa import device -from singa import tensor -from singa import opt -from singa import autograd -from singa import layer -from singa import sonnx -import onnx -from utils import check_exist_or_download - -import logging -logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s') - - -def load_dataset(): - train_x_url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz' - train_y_url = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz' - valid_x_url = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz' - valid_y_url = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz' - train_x = read_image_file(check_exist_or_download(train_x_url)).astype( - np.float32) - train_y = read_label_file(check_exist_or_download(train_y_url)).astype( - np.float32) - valid_x = read_image_file(check_exist_or_download(valid_x_url)).astype( - np.float32) - valid_y = read_label_file(check_exist_or_download(valid_y_url)).astype( - np.float32) - return train_x, train_y, valid_x, valid_y - - -def read_label_file(path): - with gzip.open(path, 'rb') as f: - data = f.read() - assert get_int(data[:4]) == 2049 - length = get_int(data[4:8]) - parsed = np.frombuffer(data, dtype=np.uint8, offset=8).reshape((length)) - return parsed - - -def get_int(b): - return int(codecs.encode(b, 'hex'), 16) - - -def read_image_file(path): - with gzip.open(path, 'rb') as f: - data = f.read() - assert get_int(data[:4]) == 2051 - length = get_int(data[4:8]) - num_rows = get_int(data[8:12]) - num_cols = get_int(data[12:16]) - parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape( - (length, 1, num_rows, num_cols)) - return parsed - - -def to_categorical(y, num_classes): - y = np.array(y, dtype="int") - n = y.shape[0] - categorical = np.zeros((n, num_classes)) - categorical[np.arange(n), y] = 1 - categorical = categorical.astype(np.float32) - return categorical - - -class CNN: - - def __init__(self): - self.conv1 = layer.Conv2d(1, 20, 5, padding=0) - self.conv2 = layer.Conv2d(20, 50, 5, padding=0) - self.linear1 = layer.Linear(4 * 4 * 50, 500, bias=False) - self.linear2 = layer.Linear(500, 10, bias=False) - self.pooling1 = layer.MaxPool2d(2, 2, padding=0) - self.pooling2 = layer.MaxPool2d(2, 2, padding=0) - - def forward(self, x): - y = self.conv1(x) - y = autograd.relu(y) - y = self.pooling1(y) - y = self.conv2(y) - y = autograd.relu(y) - y = self.pooling2(y) - y = autograd.flatten(y) - y = self.linear1(y) - y = autograd.relu(y) - y = self.linear2(y) - return y - - -def accuracy(pred, target): - y = np.argmax(pred, axis=1) - t = np.argmax(target, axis=1) - a = y == t - return np.array(a, "int").sum() / float(len(t)) - - -def train(model, - x, - y, - epochs=1, - batch_size=64, - dev=device.get_default_device()): - batch_number = x.shape[0] // batch_size - - for i in range(epochs): - for b in range(batch_number): - l_idx = b * batch_size - r_idx = (b + 1) * batch_size - - x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx]) - target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx]) - - output_batch = model.forward(x_batch) - - loss = autograd.softmax_cross_entropy(output_batch, target_batch) - accuracy_rate = accuracy(tensor.to_numpy(output_batch), - tensor.to_numpy(target_batch)) - - sgd = opt.SGD(lr=0.001) - for p, gp in autograd.backward(loss): - sgd.update(p, gp) - sgd.step() - - if b % 1e2 == 0: - logging.info("acc %6.2f loss, %6.2f" % - (accuracy_rate, tensor.to_numpy(loss)[0])) - logging.info("training completed") - return x_batch, output_batch - - -def make_onnx(x, y): - return sonnx.to_onnx([x], [y]) - - -class Infer: - - def __init__(self, sg_ir): - self.sg_ir = sg_ir - for idx, tens in sg_ir.tensor_map.items(): - # allow the tensors to be updated - tens.requires_grad = True - tens.stores_grad = True - - def forward(self, x): - return sg_ir.run([x])[0] - - -def re_train(sg_ir, - x, - y, - epochs=1, - batch_size=64, - dev=device.get_default_device()): - batch_number = x.shape[0] // batch_size - - new_model = Infer(sg_ir) - - for i in range(epochs): - for b in range(batch_number): - l_idx = b * batch_size - r_idx = (b + 1) * batch_size - - x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx]) - target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx]) - - output_batch = new_model.forward(x_batch) - - loss = autograd.softmax_cross_entropy(output_batch, target_batch) - accuracy_rate = accuracy(tensor.to_numpy(output_batch), - tensor.to_numpy(target_batch)) - - sgd = opt.SGD(lr=0.01) - for p, gp in autograd.backward(loss): - sgd.update(p, gp) - sgd.step() - - if b % 1e2 == 0: - logging.info("acc %6.2f loss, %6.2f" % - (accuracy_rate, tensor.to_numpy(loss)[0])) - logging.info("re-training completed") - return new_model - - -class Trans: - - def __init__(self, sg_ir, last_layers): - self.sg_ir = sg_ir - self.last_layers = last_layers - self.append_linear1 = autograd.Linear(500, 128, bias=False) - self.append_linear2 = autograd.Linear(128, 32, bias=False) - self.append_linear3 = autograd.Linear(32, 10, bias=False) - - def forward(self, x): - y = sg_ir.run([x], last_layers=self.last_layers)[0] - y = self.append_linear1(y) - y = autograd.relu(y) - y = self.append_linear2(y) - y = autograd.relu(y) - y = self.append_linear3(y) - y = autograd.relu(y) - return y - - -def transfer_learning(sg_ir, - x, - y, - epochs=1, - batch_size=64, - dev=device.get_default_device()): - batch_number = x.shape[0] // batch_size - - trans_model = Trans(sg_ir, -1) - - for i in range(epochs): - for b in range(batch_number): - l_idx = b * batch_size - r_idx = (b + 1) * batch_size - - x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx]) - target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx]) - output_batch = trans_model.forward(x_batch) - - loss = autograd.softmax_cross_entropy(output_batch, target_batch) - accuracy_rate = accuracy(tensor.to_numpy(output_batch), - tensor.to_numpy(target_batch)) - - sgd = opt.SGD(lr=0.07) - for p, gp in autograd.backward(loss): - sgd.update(p, gp) - sgd.step() - - if b % 1e2 == 0: - logging.info("acc %6.2f loss, %6.2f" % - (accuracy_rate, tensor.to_numpy(loss)[0])) - logging.info("transfer-learning completed") - return trans_model - - -def test(model, x, y, batch_size=64, dev=device.get_default_device()): - batch_number = x.shape[0] // batch_size - - result = 0 - for b in range(batch_number): - l_idx = b * batch_size - r_idx = (b + 1) * batch_size - - x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx]) - target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx]) - - output_batch = model.forward(x_batch) - result += accuracy(tensor.to_numpy(output_batch), - tensor.to_numpy(target_batch)) - - logging.info("testing acc %6.2f" % (result / batch_number)) - - -if __name__ == "__main__": - # create device - dev = device.create_cuda_gpu() - #dev = device.get_default_device() - # create model - model = CNN() - # load data - train_x, train_y, valid_x, valid_y = load_dataset() - # normalization - train_x = train_x / 255 - valid_x = valid_x / 255 - train_y = to_categorical(train_y, 10) - valid_y = to_categorical(valid_y, 10) - # do training - autograd.training = True - x, y = train(model, train_x, train_y, dev=dev) - onnx_model = make_onnx(x, y) - # logging.info('The model is:\n{}'.format(onnx_model)) - - # Save the ONNX model - model_path = os.path.join('/', 'tmp', 'mnist.onnx') - onnx.save(onnx_model, model_path) - logging.info('The model is saved.') - - # load the ONNX model - onnx_model = onnx.load(model_path) - sg_ir = sonnx.prepare(onnx_model, device=dev) - - # inference - autograd.training = False - logging.info('The inference result is:') - test(Infer(sg_ir), valid_x, valid_y, dev=dev) - - # re-training - autograd.training = True - new_model = re_train(sg_ir, train_x, train_y, dev=dev) - autograd.training = False - test(new_model, valid_x, valid_y, dev=dev) - - # transfer-learning - autograd.training = True - new_model = transfer_learning(sg_ir, train_x, train_y, dev=dev) - autograd.training = False - test(new_model, valid_x, valid_y, dev=dev) diff --git a/examples/onnx/mobilenet.py b/examples/onnx/mobilenet.py index 75758f1f2..085a6c3c0 100644 --- a/examples/onnx/mobilenet.py +++ b/examples/onnx/mobilenet.py @@ -22,10 +22,9 @@ from singa import device from singa import tensor -from singa import autograd from singa import sonnx import onnx -from utils import download_model, update_batch_size, check_exist_or_download +from utils import download_model, check_exist_or_download import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') @@ -56,18 +55,17 @@ def get_image_labe(): return img, labels -class Infer: +class MyModel(sonnx.SONNXModel): - def __init__(self, sg_ir): - self.sg_ir = sg_ir - for idx, tens in sg_ir.tensor_map.items(): - # allow the tensors to be updated - tens.requires_grad = True - tens.stores_grad = True - sg_ir.tensor_map[idx] = tens + def __init__(self, onnx_model): + super(MyModel, self).__init__(onnx_model) - def forward(self, x): - return sg_ir.run([x])[0] + def forward(self, *x): + y = super(MyModel, self).forward(*x) + return y + + def train_one_batch(self, x, y): + pass if __name__ == "__main__": @@ -81,32 +79,30 @@ def forward(self, x): download_model(url) onnx_model = onnx.load(model_path) - # set batch size - onnx_model = update_batch_size(onnx_model, 1) + # inference + logging.info("preprocessing...") + img, labels = get_image_labe() + img = preprocess(img) + # sg_ir = sonnx.prepare(onnx_model) # run without graph + # y = sg_ir.run([img]) - # prepare the model - logging.info("prepare model...") + logging.info("model compling...") dev = device.create_cuda_gpu() - sg_ir = sonnx.prepare(onnx_model, device=dev) - autograd.training = False - model = Infer(sg_ir) + x = tensor.PlaceHolder(img.shape, device=dev) + m = MyModel(onnx_model) + m.compile([x], is_train=False, use_graph=True, sequential=True) - # verifty the test dataset + # verifty the test # from utils import load_dataset # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'mobilenetv2-1.0', 'test_data_set_0')) # x_batch = tensor.Tensor(device=dev, data=inputs[0]) - # outputs = model.forward(x_batch) + # outputs = sg_ir.run([x_batch]) # for ref_o, o in zip(ref_outputs, outputs): # np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4) - # inference - logging.info("preprocessing...") - img, labels = get_image_labe() - img = preprocess(img) - logging.info("model running...") - x_batch = tensor.Tensor(device=dev, data=img) - y = model.forward(x_batch) + x = tensor.Tensor(device=dev, data=img) + y = m.forward(*[x])[0] logging.info("postprocessing...") y = tensor.softmax(y) diff --git a/examples/onnx/resnet18.py b/examples/onnx/resnet18.py index b3381c023..c8c4480e0 100644 --- a/examples/onnx/resnet18.py +++ b/examples/onnx/resnet18.py @@ -22,10 +22,9 @@ from singa import device from singa import tensor -from singa import autograd from singa import sonnx import onnx -from utils import download_model, update_batch_size, check_exist_or_download +from utils import download_model, check_exist_or_download import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') @@ -55,19 +54,17 @@ def get_image_labe(): img = Image.open(check_exist_or_download(image_url)) return img, labels +class MyModel(sonnx.SONNXModel): -class Infer: + def __init__(self, onnx_model): + super(MyModel, self).__init__(onnx_model) - def __init__(self, sg_ir): - self.sg_ir = sg_ir - for idx, tens in sg_ir.tensor_map.items(): - # allow the tensors to be updated - tens.requires_grad = True - tens.stores_grad = True - sg_ir.tensor_map[idx] = tens + def forward(self, *x): + y = super(MyModel, self).forward(*x) + return y - def forward(self, x): - return sg_ir.run([x])[0] + def train_one_batch(self, x, y): + pass if __name__ == "__main__": @@ -80,32 +77,30 @@ def forward(self, x): download_model(url) onnx_model = onnx.load(model_path) - # set batch size - onnx_model = update_batch_size(onnx_model, 1) + # inference + logging.info("preprocessing...") + img, labels = get_image_labe() + img = preprocess(img) + # sg_ir = sonnx.prepare(onnx_model) # run without graph + # y = sg_ir.run([img]) - # prepare the model - logging.info("prepare model...") + logging.info("model compling...") dev = device.create_cuda_gpu() - sg_ir = sonnx.prepare(onnx_model, device=dev) - autograd.training = False - model = Infer(sg_ir) + x = tensor.PlaceHolder(img.shape, device=dev) + m = MyModel(onnx_model) + m.compile([x], is_train=False, use_graph=True, sequential=True) # verifty the test # from utils import load_dataset # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'resnet18v1', 'test_data_set_0')) # x_batch = tensor.Tensor(device=dev, data=inputs[0]) - # outputs = model.forward(x_batch) + # outputs = sg_ir.run([x_batch]) # for ref_o, o in zip(ref_outputs, outputs): # np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4) - # inference - logging.info("preprocessing...") - img, labels = get_image_labe() - img = preprocess(img) - logging.info("model running...") - x_batch = tensor.Tensor(device=dev, data=img) - y = model.forward(x_batch) + x = tensor.Tensor(device=dev, data=img) + y = m.forward(*[x])[0] logging.info("postprocessing...") y = tensor.softmax(y) diff --git a/examples/onnx/tiny_yolov2.py b/examples/onnx/tiny_yolov2.py index e883117ae..bac6688cb 100644 --- a/examples/onnx/tiny_yolov2.py +++ b/examples/onnx/tiny_yolov2.py @@ -22,10 +22,9 @@ from singa import device from singa import tensor -from singa import autograd from singa import sonnx import onnx -from utils import download_model, update_batch_size, check_exist_or_download +from utils import download_model, check_exist_or_download import logging logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s') @@ -45,20 +44,6 @@ def get_image(): return img -class Infer: - - def __init__(self, sg_ir): - self.sg_ir = sg_ir - for idx, tens in sg_ir.tensor_map.items(): - # allow the tensors to be updated - tens.requires_grad = True - tens.stores_grad = True - sg_ir.tensor_map[idx] = tens - - def forward(self, x): - return sg_ir.run([x])[0] - - def postprcess(out): numClasses = 20 anchors = [1.08, 1.19, 3.42, 4.41, 6.63, 11.38, 9.42, 5.11, 16.62, 10.52] @@ -124,6 +109,19 @@ def softmax(x): img.save("result.png") +class MyModel(sonnx.SONNXModel): + + def __init__(self, onnx_model): + super(MyModel, self).__init__(onnx_model) + + def forward(self, *x): + y = super(MyModel, self).forward(*x) + return y + + def train_one_batch(self, x, y): + pass + + if __name__ == "__main__": url = 'https://onnxzoo.blob.core.windows.net/models/opset_8/tiny_yolov2/tiny_yolov2.tar.gz' @@ -134,33 +132,31 @@ def softmax(x): download_model(url) onnx_model = onnx.load(model_path) - # set batch size - onnx_model = update_batch_size(onnx_model, 1) + # inference + logging.info("preprocessing...") + img = get_image() + img = preprocess(img) + # sg_ir = sonnx.prepare(onnx_model) # run without graph + # y = sg_ir.run([img]) - # prepare the model - logging.info("prepare model...") + logging.info("model compling...") dev = device.create_cuda_gpu() - sg_ir = sonnx.prepare(onnx_model, device=dev) - autograd.training = False - model = Infer(sg_ir) + x = tensor.PlaceHolder(img.shape, device=dev) + m = MyModel(onnx_model) + m.compile([x], is_train=False, use_graph=True, sequential=True) - # verifty the test dataset + # verifty the test # from utils import load_dataset # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'tiny_yolov2', 'test_data_set_0')) # x_batch = tensor.Tensor(device=dev, data=inputs[0]) - # outputs = model.forward(x_batch) + # outputs = sg_ir.run([x_batch]) # for ref_o, o in zip(ref_outputs, outputs): # np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4) - # inference - logging.info("preprocessing...") - img = get_image() - img = preprocess(img) - logging.info("model running...") - x_batch = tensor.Tensor(device=dev, data=img) - y = model.forward(x_batch) + x = tensor.Tensor(device=dev, data=img) + y = m.forward(*[x])[0] logging.info("postprocessing...") out = tensor.to_numpy(y)[0] - postprcess(out) + postprcess(out) \ No newline at end of file diff --git a/examples/onnx/training/model.json b/examples/onnx/training/model.json new file mode 100644 index 000000000..24d8aae7e --- /dev/null +++ b/examples/onnx/training/model.json @@ -0,0 +1,89 @@ +{ + "mobilenet": { + "name": "MobileNet v2-1.0", + "url": "https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/mobilenetv2-1.0.tar.gz", + "path": "mobilenetv2-1.0/mobilenetv2-1.0.onnx" + }, + "resnet18v1": { + "name": "ResNet-18 Version 1", + "description": "ResNet v1 uses post-activation for the residual blocks", + "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet18v1/resnet18v1.tar.gz", + "path": "resnet18v1/resnet18v1.onnx" + }, + "resnet34v1": { + "name": "ResNet-34 Version 1", + "description": "ResNet v1 uses post-activation for the residual blocks", + "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet34v1/resnet34v1.tar.gz", + "path": "resnet34v1/resnet34v1.onnx" + }, + "resnet50v1": { + "name": "ResNet-50 Version 1", + "description": "ResNet v1 uses post-activation for the residual blocks", + "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v1/resnet50v1.tar.gz", + "path": "resnet50v1/resnet50v1.onnx" + }, + "resnet101v1": { + "name": "ResNet-101 Version 1", + "description": "ResNet v1 uses post-activation for the residual blocks", + "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet101v1/resnet101v1.tar.gz", + "path": "resnet101v1/resnet101v1.onnx" + }, + "resnet152v1": { + "name": "ResNet-152 Version 1", + "description": "ResNet v1 uses post-activation for the residual blocks", + "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet152v1/resnet152v1.tar.gz", + "path": "resnet152v1/resnet152v1.onnx" + }, + "resnet18v2": { + "name": "ResNet-18 Version 2", + "description": "ResNet v2 uses pre-activation for the residual blocks", + "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet18v2/resnet18v2.tar.gz", + "path": "resnet18v2/resnet18v2.onnx" + }, + "resnet34v2": { + "name": "ResNet-34 Version 2", + "description": "ResNet v2 uses pre-activation for the residual blocks", + "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet34v2/resnet34v2.tar.gz", + "path": "resnet34v2/resnet34v2.onnx" + }, + "resnet50v2": { + "name": "ResNet-50 Version 2", + "description": "ResNet v2 uses pre-activation for the residual blocks", + "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.tar.gz", + "path": "resnet50v2/resnet50v2.onnx" + }, + "resnet101v2": { + "name": "ResNet-101 Version 2", + "description": "ResNet v2 uses pre-activation for the residual blocks", + "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet101v2/resnet101v2.tar.gz", + "path": "resnet101v2/resnet101v2.onnx" + }, + "resnet152v2": { + "name": "ResNet-152 Version 2", + "description": "ResNet v2 uses pre-activation for the residual blocks", + "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet152v2/resnet152v2.tar.gz", + "path": "resnet152v2/resnet152v2.onnx" + }, + "vgg16": { + "name": "VGG-16", + "url": "https://github.com/onnx/models/raw/master/vision/classification/vgg/model/vgg16-7.tar.gz", + "path": "vgg16/vgg16.onnx" + }, + "vgg16bn": { + "name": "VGG-16 with batch normalization", + "description": "VGG have batch normalization applied after each convolutional layer", + "url": "https://github.com/onnx/models/raw/master/vision/classification/vgg/model/vgg16-bn-7.tar.gz", + "path": "vgg16-bn/vgg16-bn.onnx" + }, + "vgg19": { + "name": "VGG-19", + "url": "https://github.com/onnx/models/raw/master/vision/classification/vgg/model/vgg16-9.tar.gz", + "path": "vgg19/vgg19.onnx" + }, + "vgg19bn": { + "name": "VGG-19 with batch normalization", + "description": "VGG have batch normalization applied after each convolutional layer", + "url": "https://github.com/onnx/models/raw/master/vision/classification/vgg/model/vgg16-bn-9.tar.gz", + "path": "vgg19-bn/vgg19-bn.onnx" + } +} \ No newline at end of file diff --git a/examples/onnx/training/train.py b/examples/onnx/training/train.py new file mode 100644 index 000000000..6e12d2232 --- /dev/null +++ b/examples/onnx/training/train.py @@ -0,0 +1,348 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys, os +import json +from singa import singa_wrap as singa +from singa import opt +from singa import device +from singa import tensor +from singa import sonnx +from singa import autograd +import numpy as np +import time +import argparse +from PIL import Image +import onnx +import logging +from tqdm import tqdm + +logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') +sys.path.append(os.path.dirname(__file__) + '/../../cnn') +sys.path.append(os.path.dirname(__file__) + '/..') +from utils import download_model + +# Data Augmentation +def augmentation(x, batch_size): + xpad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'symmetric') + for data_num in range(0, batch_size): + offset = np.random.randint(8, size=2) + x[data_num, :, :, :] = xpad[data_num, :, + offset[0]:offset[0] + x.shape[2], + offset[1]:offset[1] + x.shape[2]] + if_flip = np.random.randint(2) + if (if_flip): + x[data_num, :, :, :] = x[data_num, :, :, ::-1] + return x + + +# Calculate Accuracy +def accuracy(pred, target): + # y is network output to be compared with ground truth (int) + y = np.argmax(pred, axis=1) + a = y == target + correct = np.array(a, "int").sum() + # print(correct) + return correct + + +# Data partition according to the rank +def partition(global_rank, world_size, train_x, train_y, val_x, val_y): + # Partition training data + data_per_rank = train_x.shape[0] // world_size + idx_start = global_rank * data_per_rank + idx_end = (global_rank + 1) * data_per_rank + train_x = train_x[idx_start:idx_end] + train_y = train_y[idx_start:idx_end] + # Partition evaluation data + data_per_rank = val_x.shape[0] // world_size + idx_start = global_rank * data_per_rank + idx_end = (global_rank + 1) * data_per_rank + val_x = val_x[idx_start:idx_end] + val_y = val_y[idx_start:idx_end] + return train_x, train_y, val_x, val_y + + +# Function to all reduce NUMPY Accuracy and Loss from Multiple Devices +def reduce_variable(variable, dist_opt, reducer): + reducer.copy_from_numpy(variable) + dist_opt.all_reduce(reducer.data) + dist_opt.wait() + output = tensor.to_numpy(reducer) + return output + + +def resize_dataset(x, image_size): + num_data = x.shape[0] + dim = x.shape[1] + X = np.zeros(shape=(num_data, dim, image_size, image_size), + dtype=np.float32) + for n in range(0, num_data): + for d in range(0, dim): + X[n, d, :, :] = np.array(Image.fromarray(x[n, d, :, :]).resize( + (image_size, image_size), Image.BILINEAR), + dtype=np.float32) + return X + + +class MyModel(sonnx.SONNXModel): + + def __init__(self, onnx_model, num_classes=10, num_channels=3): + super(MyModel, self).__init__(onnx_model) + self.num_classes = num_classes + self.input_size = 224 + self.dimension = 4 + self.num_channels = num_channels + self.num_classes = num_classes + + def forward(self, *x): + y = super(MyModel, self).forward(*x) + return y + + def train_one_batch(self, x, y, dist_option, spars): + out = self.forward(x)[0] + loss = autograd.softmax_cross_entropy(out, y) + if dist_option == 'fp32': + self.optimizer.backward_and_update(loss) + elif dist_option == 'fp16': + self.optimizer.backward_and_update_half(loss) + elif dist_option == 'partialUpdate': + self.optimizer.backward_and_partial_update(loss) + elif dist_option == 'sparseTopK': + self.optimizer.backward_and_sparse_update(loss, + topK=True, + spars=spars) + elif dist_option == 'sparseThreshold': + self.optimizer.backward_and_sparse_update(loss, + topK=False, + spars=spars) + return out, loss + + def set_optimizer(self, optimizer): + self.optimizer = optimizer + + +def run(global_rank, + world_size, + local_rank, + max_epoch, + batch_size, + model_config, + data, + sgd, + graph, + verbosity, + dist_option='fp32', + spars=None): + dev = device.create_cuda_gpu_on(local_rank) + dev.SetRandSeed(0) + np.random.seed(0) + + if data == 'cifar10': + from data import cifar10 + train_x, train_y, val_x, val_y = cifar10.load() + elif data == 'cifar100': + from data import cifar100 + train_x, train_y, val_x, val_y = cifar100.load() + + num_channels = train_x.shape[1] + image_size = train_x.shape[2] + data_size = np.prod(train_x.shape[1:train_x.ndim]).item() + num_classes = (np.max(train_y) + 1).item() + + # read and make onnx model + download_model(model_config['url']) + onnx_model = onnx.load(os.path.join('/tmp', model_config['path'])) + model = MyModel(onnx_model, + num_channels=num_channels, + num_classes=num_classes) + + # For distributed training, sequential gives better performance + if hasattr(sgd, "communicator"): + DIST = True + sequential = True + else: + DIST = False + sequential = False + + if DIST: + train_x, train_y, val_x, val_y = partition(global_rank, world_size, + train_x, train_y, val_x, + val_y) + ''' + # check dataset shape correctness + if global_rank == 0: + print("Check the shape of dataset:") + print(train_x.shape) + print(train_y.shape) + ''' + + if model.dimension == 4: + tx = tensor.Tensor( + (batch_size, num_channels, model.input_size, model.input_size), dev, + tensor.float32) + elif model.dimension == 2: + tx = tensor.Tensor((batch_size, data_size), dev, tensor.float32) + np.reshape(train_x, (train_x.shape[0], -1)) + np.reshape(val_x, (val_x.shape[0], -1)) + + ty = tensor.Tensor((batch_size,), dev, tensor.int32) + num_train_batch = train_x.shape[0] // batch_size + num_val_batch = val_x.shape[0] // batch_size + idx = np.arange(train_x.shape[0], dtype=np.int32) + + # attached model to graph + model.set_optimizer(sgd) + model.compile([tx], is_train=True, use_graph=graph, sequential=sequential) + dev.SetVerbosity(verbosity) + + # Training and Evaluation Loop + for epoch in range(max_epoch): + start_time = time.time() + np.random.shuffle(idx) + + if global_rank == 0: + print('Starting Epoch %d:' % (epoch)) + + # Training Phase + train_correct = np.zeros(shape=[1], dtype=np.float32) + test_correct = np.zeros(shape=[1], dtype=np.float32) + train_loss = np.zeros(shape=[1], dtype=np.float32) + + model.train() + for b in tqdm(range(num_train_batch)): + # Generate the patch data in this iteration + x = train_x[idx[b * batch_size:(b + 1) * batch_size]] + if model.dimension == 4: + x = augmentation(x, batch_size) + if (image_size != model.input_size): + x = resize_dataset(x, model.input_size) + y = train_y[idx[b * batch_size:(b + 1) * batch_size]] + + # Copy the patch data into input tensors + tx.copy_from_numpy(x) + ty.copy_from_numpy(y) + + # Train the model + out, loss = model(tx, ty, dist_option, spars) + train_correct += accuracy(tensor.to_numpy(out), y) + train_loss += tensor.to_numpy(loss)[0] + + if DIST: + # Reduce the Evaluation Accuracy and Loss from Multiple Devices + reducer = tensor.Tensor((1,), dev, tensor.float32) + train_correct = reduce_variable(train_correct, sgd, reducer) + train_loss = reduce_variable(train_loss, sgd, reducer) + + if global_rank == 0: + print('Training loss = %f, training accuracy = %f' % + (train_loss, train_correct / + (num_train_batch * batch_size * world_size)), + flush=True) + + # Evaluation Phase + model.eval() + for b in tqdm(range(num_val_batch)): + x = val_x[b * batch_size:(b + 1) * batch_size] + if model.dimension == 4: + if (image_size != model.input_size): + x = resize_dataset(x, model.input_size) + y = val_y[b * batch_size:(b + 1) * batch_size] + tx.copy_from_numpy(x) + ty.copy_from_numpy(y) + out_test = model(tx)[0] + test_correct += accuracy(tensor.to_numpy(out_test), y) + + if DIST: + # Reduce the Evaulation Accuracy from Multiple Devices + test_correct = reduce_variable(test_correct, sgd, reducer) + + # Output the Evaluation Accuracy + if global_rank == 0: + print('Evaluation accuracy = %f, Elapsed Time = %fs' % + (test_correct / (num_val_batch * batch_size * world_size), + time.time() - start_time), + flush=True) + + dev.PrintTimeProfiling() + + +def loss(out, y): + return autograd.softmax_cross_entropy(out, y) + + +if __name__ == '__main__': + + with open(os.path.join(os.path.dirname(__file__), + 'model.json')) as json_file: + model_config = json.load(json_file) + + # use argparse to get command config: max_epoch, model, data, etc. for single gpu training + parser = argparse.ArgumentParser( + description='Training using the autograd and graph.') + parser.add_argument('--model', + choices=list(model_config.keys()), + help='please refer to the models.json for more details', + default='mobilenet') + parser.add_argument('--data', + choices=['cifar10', 'cifar100'], + default='cifar10') + parser.add_argument('--epoch', + '--max-epoch', + default=10, + type=int, + help='maximum epochs', + dest='max_epoch') + parser.add_argument('--bs', + '--batch-size', + default=32, + type=int, + help='batch size', + dest='batch_size') + parser.add_argument('--lr', + '--learning-rate', + default=0.005, + type=float, + help='initial learning rate', + dest='lr') + # determine which gpu to use + parser.add_argument('--id', + '--device-id', + default=0, + type=int, + help='which GPU to use', + dest='device_id') + parser.add_argument('--no-graph', + '--disable-graph', + default='True', + action='store_false', + help='disable graph', + dest='graph') + parser.add_argument('--verbosity', + '--log-verbosity', + default=1, + type=int, + help='logging verbosity', + dest='verbosity') + + args = parser.parse_args() + + sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5) + run(0, 1, args.device_id, args.max_epoch, args.batch_size, model_config[args.model], + args.data, sgd, args.graph, args.verbosity) diff --git a/examples/onnx/utils.py b/examples/onnx/utils.py index 71d1ef4ee..75b352afa 100644 --- a/examples/onnx/utils.py +++ b/examples/onnx/utils.py @@ -64,10 +64,3 @@ def check_exist_or_download(url): urllib.request.urlretrieve(url, filename) return filename - -def update_batch_size(onnx_model, batch_size): - model_input = onnx_model.graph.input[0] - model_input.type.tensor_type.shape.dim[0].dim_value = batch_size - model_output = onnx_model.graph.output[0] - model_output.type.tensor_type.shape.dim[0].dim_value = batch_size - return onnx_model diff --git a/examples/onnx/vgg16.py b/examples/onnx/vgg16.py index b26ea9477..02a2eb8d0 100644 --- a/examples/onnx/vgg16.py +++ b/examples/onnx/vgg16.py @@ -22,10 +22,9 @@ from singa import device from singa import tensor -from singa import autograd from singa import sonnx import onnx -from utils import download_model, update_batch_size, check_exist_or_download +from utils import download_model, check_exist_or_download import logging logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s') @@ -56,18 +55,17 @@ def get_image_labe(): return img, labels -class Infer: +class MyModel(sonnx.SONNXModel): - def __init__(self, sg_ir): - self.sg_ir = sg_ir - for idx, tens in sg_ir.tensor_map.items(): - # allow the tensors to be updated - tens.requires_grad = True - tens.stores_grad = True - sg_ir.tensor_map[idx] = tens + def __init__(self, onnx_model): + super(MyModel, self).__init__(onnx_model) - def forward(self, x): - return sg_ir.run([x])[0] + def forward(self, *x): + y = super(MyModel, self).forward(*x) + return y + + def train_one_batch(self, x, y): + pass if __name__ == "__main__": @@ -79,32 +77,30 @@ def forward(self, x): download_model(url) onnx_model = onnx.load(model_path) - # set batch size - onnx_model = update_batch_size(onnx_model, 1) + # inference + logging.info("preprocessing...") + img, labels = get_image_labe() + img = preprocess(img) + # sg_ir = sonnx.prepare(onnx_model) # run without graph + # y = sg_ir.run([img]) - # prepare the model - logging.info("prepare model...") + logging.info("model compling...") dev = device.create_cuda_gpu() - sg_ir = sonnx.prepare(onnx_model, device=dev) - autograd.training = False - model = Infer(sg_ir) + x = tensor.PlaceHolder(img.shape, device=dev) + m = MyModel(onnx_model) + m.compile([x], is_train=False, use_graph=True, sequential=True) # verifty the test # from utils import load_dataset # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'vgg16', 'test_data_set_0')) # x_batch = tensor.Tensor(device=dev, data=inputs[0]) - # outputs = model.forward(x_batch) + # outputs = sg_ir.run([x_batch]) # for ref_o, o in zip(ref_outputs, outputs): # np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4) - # inference - logging.info("preprocessing...") - img, labels = get_image_labe() - img = preprocess(img) - logging.info("model running...") - x_batch = tensor.Tensor(device=dev, data=img) - y = model.forward(x_batch) + x = tensor.Tensor(device=dev, data=img) + y = m.forward(*[x])[0] logging.info("postprocessing...") y = tensor.softmax(y)