Skip to content

Commit

Permalink
Add resnet101 to modelzoo (#1272)
Browse files Browse the repository at this point in the history
Change-Id: I39bb5c181d0c2722b524ff87760c4af11777764a
  • Loading branch information
frankfliu committed Oct 6, 2021
1 parent b94f0da commit 7f09958
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,11 @@ protected void configPreProcess(Map<String, ?> arguments) {
String[] tokens = resize.split("\\s*,\\s*");
if (tokens.length > 1) {
addTransform(
new Resize(Integer.parseInt(tokens[0]), Integer.parseInt(tokens[1])));
new Resize(
(int) Double.parseDouble(tokens[0]),
(int) Double.parseDouble(tokens[1])));
} else {
addTransform(new Resize(Integer.parseInt(tokens[0])));
addTransform(new Resize((int) Double.parseDouble(tokens[0])));
}
}
if (getBooleanValue(arguments, "centerCrop", false)) {
Expand Down
2 changes: 2 additions & 0 deletions api/src/main/java/ai/djl/repository/AbstractRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
import org.apache.commons.compress.utils.CloseShieldFilterInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -238,6 +239,7 @@ private void untar(InputStream is, Path dir, boolean gzip) throws IOException {
} else {
bis = new BufferedInputStream(is);
}
bis = new CloseShieldFilterInputStream(bis);
try (TarArchiveInputStream tis = new TarArchiveInputStream(bis)) {
TarArchiveEntry entry;
while ((entry = tis.getNextTarEntry()) != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
"arguments": {
"width": 224,
"height": 224,
"resize": 256,
"centerCrop": true,
"resize": true,
"applySoftmax": true
},
"files": {
Expand Down Expand Up @@ -60,8 +60,8 @@
"arguments": {
"width": 224,
"height": 224,
"resize": 256,
"centerCrop": true,
"resize": true,
"applySoftmax": true
},
"files": {
Expand All @@ -82,6 +82,31 @@
}
}
},
{
"version": "0.0.1",
"snapshot": false,
"name": "resnet101_v1",
"properties": {
"layers": "101",
"dataset": "imagenet"
},
"arguments": {
"width": 224,
"height": 224,
"resize": 256,
"centerCrop": true,
"normalize": true,
"applySoftmax": true
},
"files": {
"model": {
"uri": "0.0.1/resnet101_v1.tar.gz",
"name": "",
"sha1Hash": "3f2f3a216520311388d17d24c798b0c87740024a",
"size": 102734558
}
}
},
{
"version": "0.0.1",
"snapshot": false,
Expand All @@ -94,8 +119,8 @@
"arguments": {
"width": 224,
"height": 224,
"resize": 256,
"centerCrop": true,
"resize": true,
"applySoftmax": true
},
"files": {
Expand Down Expand Up @@ -128,8 +153,8 @@
"arguments": {
"width": 224,
"height": 224,
"resize": 256,
"centerCrop": true,
"resize": true,
"applySoftmax": true
},
"files": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.onnxruntime.zoo;

import ai.djl.Application.CV;
import ai.djl.Application.Tabular;
import ai.djl.onnxruntime.engine.OrtEngine;
import ai.djl.repository.MRL;
Expand All @@ -34,6 +35,9 @@ public class OrtModelZoo extends ModelZoo {
private static final List<ModelLoader> MODEL_LOADERS = new ArrayList<>();

static {
MRL resnet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1");
MODEL_LOADERS.add(new BaseModelLoader(resnet));

MRL irisFlower =
REPOSITORY.model(Tabular.SOFTMAX_REGRESSION, GROUP_ID, "iris_flowers", "0.0.1");
MODEL_LOADERS.add(new BaseModelLoader(irisFlower));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/image_classification",
"groupId": "ai.djl.onnxruntime",
"artifactId": "resnet",
"name": "Resnet",
"description": "Resnet image classification",
"website": "http://www.djl.ai/engines/onnxruntime/model-zoo",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "resnet18_v1-7",
"properties": {
"layers": "18",
"dataset": "imagenet"
},
"arguments": {
"width": 224,
"height": 224,
"resize": 256,
"centerCrop": true,
"normalize": true,
"applySoftmax": true
},
"files": {
"model": {
"uri": "0.0.1/resnet18_v1-7.tar.gz",
"name": "",
"sha1Hash": "12159f62b06d303097bccf4362b0355863b68213",
"size": 43287819
}
}
},
{
"version": "0.0.1",
"snapshot": false,
"name": "resnet101_v1",
"properties": {
"layers": "101",
"dataset": "imagenet"
},
"arguments": {
"width": 224,
"height": 224,
"resize": 256,
"centerCrop": true,
"normalize": true,
"applySoftmax": true
},
"files": {
"model": {
"uri": "0.0.1/resnet101_v1.tar.gz",
"name": "",
"sha1Hash": "7d5a2c748386a94af868a95da38728ff3b6ea8b1",
"size": 165675340
}
}
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
"arguments": {
"width": 224,
"height": 224,
"resize": true,
"resize": 256,
"centerCrop": true,
"normalize": true,
"applySoftmax": true
},
Expand Down Expand Up @@ -53,7 +54,8 @@
"arguments": {
"width": 224,
"height": 224,
"resize": true,
"resize": 256,
"centerCrop": true,
"normalize": true,
"applySoftmax": true
},
Expand All @@ -69,6 +71,31 @@
"size": 43482267
}
}
},
{
"version": "0.0.1",
"snapshot": false,
"name": "resnet101_v1",
"properties": {
"layers": "101",
"dataset": "imagenet"
},
"arguments": {
"width": 224,
"height": 224,
"resize": 256,
"centerCrop": true,
"normalize": true,
"applySoftmax": true
},
"files": {
"model": {
"uri": "0.0.1/resnet101_v1.tar.gz",
"name": "",
"sha1Hash": "e3d85c47c751ab23662b6f3eb04cc6215727994c",
"size": 166376847
}
}
}
]
}

0 comments on commit 7f09958

Please sign in to comment.