blob: 3b312301e167b76a6659fc2a491eb6ba1da98330 [file] [log] [blame]
agsantos5aa39652020-08-11 18:18:04 -04001/**
Sébastien Blincb783e32021-02-12 11:34:10 -05002 * Copyright (C) 2020-2021 Savoir-faire Linux Inc.
agsantos5aa39652020-08-11 18:18:04 -04003 *
4 * Author: Aline Gondim Santos <aline.gondimsantos@savoirfairelinux.com>
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
agsantosac1940d2020-09-17 10:18:40 -040018 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301
19 * USA.
agsantos5aa39652020-08-11 18:18:04 -040020 */
21
22#include "pluginInference.h"
23// Std libraries
agsantosac1940d2020-09-17 10:18:40 -040024#include "pluglog.h"
agsantos5aa39652020-08-11 18:18:04 -040025#include <cstring>
26#include <numeric>
agsantos5aa39652020-08-11 18:18:04 -040027
28const char sep = separator();
29const std::string TAG = "FORESEG";
30
31namespace jami {
32
agsantosac1940d2020-09-17 10:18:40 -040033PluginInference::PluginInference(TFModel model)
34 : TensorflowInference(model)
agsantos5aa39652020-08-11 18:18:04 -040035{
36#ifndef TFLITE
agsantosac1940d2020-09-17 10:18:40 -040037 // Initialize TENSORFLOW_CC lib
38 static const char* kFakeName = "fake program name";
39 int argc = 1;
40 char* fake_name_copy = strdup(kFakeName);
41 char** argv = &fake_name_copy;
42 tensorflow::port::InitMain(kFakeName, &argc, &argv);
43 if (argc > 1) {
44 Plog::log(Plog::LogPriority::INFO, "TENSORFLOW INIT", "Unknown argument ");
45 }
46 free(fake_name_copy);
47#endif // TFLITE
agsantos5aa39652020-08-11 18:18:04 -040048}
49
agsantosac1940d2020-09-17 10:18:40 -040050PluginInference::~PluginInference() {}
agsantos5aa39652020-08-11 18:18:04 -040051
52#ifdef TFLITE
53std::pair<uint8_t*, std::vector<int>>
54PluginInference::getInput()
55{
agsantosac1940d2020-09-17 10:18:40 -040056 // We assume that we have only one input
57 // Get the input index
58 int input = interpreter->inputs()[0];
agsantos5aa39652020-08-11 18:18:04 -040059
agsantosac1940d2020-09-17 10:18:40 -040060 uint8_t* inputDataPointer = interpreter->typed_tensor<uint8_t>(input);
61 // Get the input dimensions vector
62 std::vector<int> dims = getTensorDimensions(input);
agsantos5aa39652020-08-11 18:18:04 -040063
agsantosac1940d2020-09-17 10:18:40 -040064 return std::make_pair(inputDataPointer, dims);
agsantos5aa39652020-08-11 18:18:04 -040065}
66
67// // Types returned by tensorflow
68// int type = interpreter->tensor(outputIndex)->type
69// typedef enum {
70// kTfLiteNoType = 0,
71// kTfLiteFloat32 = 1, float
72// kTfLiteInt32 = 2, int // int32_t
73// kTfLiteUInt8 = 3, uint8_t
74// kTfLiteInt64 = 4, int64_t
75// kTfLiteString = 5,
76// kTfLiteBool = 6,
77// kTfLiteInt16 = 7, int16_t
78// kTfLiteComplex64 = 8,
79// kTfLiteInt8 = 9, int8_t
80// kTfLiteFloat16 = 10, float16_t
81// } TfLiteType;
82
83std::vector<float>
84PluginInference::masksPredictions() const
85{
agsantosac1940d2020-09-17 10:18:40 -040086 int outputIndex = interpreter->outputs()[0];
87 std::vector<int> dims = getTensorDimensions(outputIndex);
88 int totalDimensions = 1;
89 for (size_t i = 0; i < dims.size(); i++) {
90 totalDimensions *= dims[i];
91 }
92 std::vector<float> out;
agsantos5aa39652020-08-11 18:18:04 -040093
agsantosac1940d2020-09-17 10:18:40 -040094 int type = interpreter->tensor(outputIndex)->type;
95 switch (type) {
96 case 1: {
97 float* outputDataPointer = interpreter->typed_tensor<float>(outputIndex);
98 std::vector<float> output(outputDataPointer, outputDataPointer + totalDimensions);
99 out = std::vector<float>(output.begin(), output.end());
100 break;
101 }
102 case 2: {
103 int* outputDataPointer = interpreter->typed_tensor<int>(outputIndex);
104 std::vector<int> output(outputDataPointer, outputDataPointer + totalDimensions);
105 out = std::vector<float>(output.begin(), output.end());
106 break;
107 }
108 case 4: {
109 int64_t* outputDataPointer = interpreter->typed_tensor<int64_t>(outputIndex);
110 std::vector<int64_t> output(outputDataPointer, outputDataPointer + totalDimensions);
111 out = std::vector<float>(output.begin(), output.end());
112 break;
113 }
114 }
agsantos5aa39652020-08-11 18:18:04 -0400115
agsantosac1940d2020-09-17 10:18:40 -0400116 return out;
agsantos5aa39652020-08-11 18:18:04 -0400117}
118
119void
120PluginInference::setExpectedImageDimensions()
121{
agsantosac1940d2020-09-17 10:18:40 -0400122 // We assume that we have only one input
123 // Get the input index
124 int input = interpreter->inputs()[0];
125 // Get the input dimensions vector
126 std::vector<int> dims = getTensorDimensions(input);
127
128 imageWidth = dims.at(1);
129 imageHeight = dims.at(2);
130 imageNbChannels = dims.at(3);
agsantos5aa39652020-08-11 18:18:04 -0400131}
agsantosac1940d2020-09-17 10:18:40 -0400132#else // TFLITE
agsantos5aa39652020-08-11 18:18:04 -0400133// Given an image file name, read in the data, try to decode it as an image,
134// resize it to the requested size, and then scale the values as desired.
135void
136PluginInference::ReadTensorFromMat(const cv::Mat& image)
137{
agsantosac1940d2020-09-17 10:18:40 -0400138 imageTensor = tensorflow::Tensor(tensorflow::DataType::DT_FLOAT,
139 tensorflow::TensorShape({1, image.cols, image.rows, 3}));
140 float* p = imageTensor.flat<float>().data();
141 cv::Mat temp(image.rows, image.cols, CV_32FC3, p);
142 image.convertTo(temp, CV_32FC3);
agsantos5aa39652020-08-11 18:18:04 -0400143}
144
145std::vector<float>
146PluginInference::masksPredictions() const
147{
agsantosac1940d2020-09-17 10:18:40 -0400148 std::vector<int> dims;
149 int flatSize = 1;
150 int num_dimensions = outputs[0].shape().dims();
151 for (int ii_dim = 0; ii_dim < num_dimensions; ii_dim++) {
152 dims.push_back(outputs[0].shape().dim_size(ii_dim));
153 flatSize *= outputs[0].shape().dim_size(ii_dim);
154 }
agsantos5aa39652020-08-11 18:18:04 -0400155
agsantosac1940d2020-09-17 10:18:40 -0400156 std::vector<float> out;
157 int type = outputs[0].dtype();
agsantos5aa39652020-08-11 18:18:04 -0400158
agsantosac1940d2020-09-17 10:18:40 -0400159 switch (type) {
160 case tensorflow::DataType::DT_FLOAT: {
161 for (int offset = 0; offset < flatSize; offset++) {
162 out.push_back(outputs[0].flat<float>()(offset));
163 }
164 break;
165 }
166 case tensorflow::DataType::DT_INT32: {
167 for (int offset = 0; offset < flatSize; offset++) {
168 out.push_back(static_cast<float>(outputs[0].flat<tensorflow::int32>()(offset)));
169 }
170 break;
171 }
172 case tensorflow::DataType::DT_INT64: {
173 for (int offset = 0; offset < flatSize; offset++) {
174 out.push_back(static_cast<float>(outputs[0].flat<tensorflow::int64>()(offset)));
175 }
176 break;
177 }
178 default: {
179 for (int offset = 0; offset < flatSize; offset++) {
180 out.push_back(0);
181 }
182 break;
183 }
184 }
185 return out;
agsantos5aa39652020-08-11 18:18:04 -0400186}
187
188void
189PluginInference::setExpectedImageDimensions()
190{
agsantosac1940d2020-09-17 10:18:40 -0400191 if (tfModel.dims[1] != 0)
192 imageWidth = tfModel.dims[1];
193 if (tfModel.dims[2] != 0)
194 imageHeight = tfModel.dims[2];
195 if (tfModel.dims[3] != 0)
196 imageNbChannels = tfModel.dims[3];
agsantos5aa39652020-08-11 18:18:04 -0400197}
198#endif
199
200int
201PluginInference::getImageWidth() const
202{
agsantosac1940d2020-09-17 10:18:40 -0400203 return imageWidth;
agsantos5aa39652020-08-11 18:18:04 -0400204}
205
206int
207PluginInference::getImageHeight() const
208{
agsantosac1940d2020-09-17 10:18:40 -0400209 return imageHeight;
agsantos5aa39652020-08-11 18:18:04 -0400210}
211
212int
213PluginInference::getImageNbChannels() const
214{
agsantosac1940d2020-09-17 10:18:40 -0400215 return imageNbChannels;
agsantos5aa39652020-08-11 18:18:04 -0400216}
217} // namespace jami