blob: 3589b129c28c0f0264c8600c029a7768d9b407f2 [file] [log] [blame]
agsantos5aa39652020-08-11 18:18:04 -04001/**
Sébastien Blincb783e32021-02-12 11:34:10 -05002 * Copyright (C) 2020-2021 Savoir-faire Linux Inc.
agsantos5aa39652020-08-11 18:18:04 -04003 *
4 * Author: Aline Gondim Santos <aline.gondimsantos@savoirfairelinux.com>
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
19 */
20
21#pragma once
22
23// Library headers
24#include "TFModels.h"
25
26// STL
27#include <memory>
28#include <string>
29#include <vector>
30
31#ifdef TFLITE
32#include <tensorflow/lite/interpreter.h>
33#include <tensorflow/lite/delegates/nnapi/nnapi_delegate.h>
34
35namespace tflite {
36class FlatBufferModel;
37class Interpreter;
38class StatefulNnApiDelegate;
39} // namespace tflite
40
41#else
42#ifdef WIN32
43#define NOMINMAX
44#undef min
45#undef max
46#endif
47
48#include <tensorflow/core/lib/core/status.h>
49#include <tensorflow/core/public/session.h>
50#include <tensorflow/core/framework/tensor.h>
51#include <tensorflow/core/framework/types.pb.h>
52#include <tensorflow/core/platform/init_main.h>
53#include <tensorflow/core/protobuf/config.pb.h>
54
55namespace tensorflow {
56class Tensor;
57class Status;
58class GraphDef;
59class Session;
60struct SessionOptions;
61class TensorShape;
62class Env;
agsantosac1940d2020-09-17 10:18:40 -040063enum DataType : int;
64} // namespace tensorflow
agsantos5aa39652020-08-11 18:18:04 -040065
66#endif
67
agsantosac1940d2020-09-17 10:18:40 -040068namespace jami {
69class TensorflowInference
agsantos5aa39652020-08-11 18:18:04 -040070{
agsantos5aa39652020-08-11 18:18:04 -040071public:
72 /**
73 * @brief TensorflowInference
74 * Takes a supervised model where the model and labels files are defined
75 * @param model
76 */
77 TensorflowInference(TFModel model);
78 ~TensorflowInference();
79
80#ifdef TFLITE
81 /**
82 * @brief loadModel
83 * Load the model from the file described in the Supervised Model
84 */
85 void loadModel();
86 void buildInterpreter();
87 void setInterpreterSettings();
88
89 /**
90 * @brief allocateTensors
91 * Tries to allocate space for the tensors
92 * In case of success isAllocated() should return true
93 */
94 void allocateTensors();
95
96 // Debug methods
97 void describeModelTensors() const;
98 void describeTensor(std::string prefix, int index) const;
99
100#else
101 void LoadGraph();
102 tensorflow::Tensor imageTensor;
103
agsantosac1940d2020-09-17 10:18:40 -0400104#endif // TFLITE
agsantos5aa39652020-08-11 18:18:04 -0400105
106 /**
107 * @brief runGraph
108 * runs the underlaying graph model.numberOfRuns times
109 * Where numberOfRuns is defined in the model
110 */
111 void runGraph();
112
113 /**
114 * @brief init
115 * Inits the model, interpreter, allocates tensors and load the labels
116 */
117 void init();
118 // Getters
119 bool isAllocated() const;
120
121protected:
122#ifdef TFLITE
123 /**
124 * @brief getTensorDimensions
125 * Utility method to get Tensorflow Tensor dimensions
126 * Given the index of the tensor, the function gives back a vector
127 * Where each element is the dimension of the vector-space (finite dimension)
128 * Thus, vector.size() is the number of vector-space used by the tensor
129 * @param index
130 * @return
131 */
132 std::vector<int> getTensorDimensions(int index) const;
133
134 // Tensorflow model and interpreter
135 std::unique_ptr<tflite::FlatBufferModel> flatbufferModel;
136 std::unique_ptr<tflite::Interpreter> interpreter;
137#else
138 std::unique_ptr<tensorflow::Session> session;
139 std::vector<tensorflow::Tensor> outputs;
140#endif
141 TFModel tfModel;
142 std::vector<std::string> labels;
143
144 /**
145 * @brief nbLabels
146 * The real number of labels may not match the labels.size() because of padding
147 */
148 size_t nbLabels;
149
agsantos9dcf4302020-09-01 18:21:48 -0400150 bool allocated_ = false;
agsantos5aa39652020-08-11 18:18:04 -0400151};
agsantosac1940d2020-09-17 10:18:40 -0400152} // namespace jami