-
Notifications
You must be signed in to change notification settings - Fork 2
/
shufflenetv2.cpp
171 lines (137 loc) · 4.27 KB
/
shufflenetv2.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#include <stdio.h>
#include <algorithm>
#include <vector>
#include <iostream>
#include <fstream>
#include <string>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include "platform.h"
#include "net.h"
#if NCNN_VULKAN
#include "gpu.h"
#endif // NCNN_VULKAN
static bool getFileContent(std::string fileName, std::vector<std::string> & Names)
{
// Open the File
std::ifstream in(fileName.c_str());
// Check if object is valid
if(!in.is_open())
{
return false;
}
std::string str;
// Read the next line from File untill it reaches the end.
while (std::getline(in, str))
{
// Line contains string of length > 0 then save it in vector
if(str.size() > 0)
Names.push_back(str);
}
//Close The File
in.close();
return true;
}
static int detect_shufflenetv2(const cv::Mat& bgr, std::vector<float>& cls_scores)
{
ncnn::Net shufflenetv2;
#if NCNN_VULKAN
shufflenetv2.opt.use_vulkan_compute = true;
#endif // NCNN_VULKAN
// https://github.com/miaow1988/ShuffleNet_V2_pytorch_caffe
// models can be downloaded from https://github.com/miaow1988/ShuffleNet_V2_pytorch_caffe/releases
shufflenetv2.load_param("shufflenet.param");
shufflenetv2.load_model("shufflenet.bin");
ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR, bgr.cols, bgr.rows, 224, 224);
const float norm_vals[3] = {1/255.f, 1/255.f, 1/255.f};
in.substract_mean_normalize(0, norm_vals);
ncnn::Extractor ex = shufflenetv2.create_extractor();
ex.input("data", in);
ncnn::Mat out;
ex.extract("fc", out);
// manually call softmax on the fc output
// convert result into probability
// skip if your model already has softmax operation
{
ncnn::Layer* softmax = ncnn::create_layer("Softmax");
ncnn::ParamDict pd;
softmax->load_param(pd);
softmax->forward_inplace(out);
delete softmax;
}
out = out.reshape(out.w * out.h * out.c);
cls_scores.resize(out.w);
for (int j=0; j<out.w; j++)
{
cls_scores[j] = out[j];
}
return 0;
}
static int print_topk(const std::vector<float>& cls_scores, int topk, std::vector<std::string> & Names)
{
// partial sort topk with index
int size = cls_scores.size();
std::vector< std::pair<float, int> > vec;
vec.resize(size);
for (int i=0; i<size; i++)
{
vec[i] = std::make_pair(cls_scores[i], i);
}
std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(),
std::greater< std::pair<float, int> >());
// print topk and score
for (int i=0; i<topk; i++)
{
float score = vec[i].first;
int index = vec[i].second;
std::string &Nm=Names.at(index);
fprintf(stderr, "%d = %f %s\n", index, score, Nm.c_str());
}
return 0;
}
int main(int argc, char** argv)
{
if (argc != 2)
{
fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
return -1;
}
// Get the names
std::vector<std::string> Names;
bool result = getFileContent("synset_words.txt", Names);
if(!result)
{
fprintf(stderr, "loading names failed\n");
return -1;
}
const char* imagepath = argv[1];
cv::Mat m = cv::imread(imagepath, 1);
if (m.empty())
{
fprintf(stderr, "cv::imread %s failed\n", imagepath);
return -1;
}
#if NCNN_VULKAN
ncnn::create_gpu_instance();
#endif // NCNN_VULKAN
std::vector<float> cls_scores;
detect_shufflenetv2(m, cls_scores);
#if NCNN_VULKAN
ncnn::destroy_gpu_instance();
#endif // NCNN_VULKAN
print_topk(cls_scores, 3, Names);
return 0;
}