agsantos | 5aa3965 | 2020-08-11 18:18:04 -0400 | [diff] [blame] | 1 | /** |
Sébastien Blin | cb783e3 | 2021-02-12 11:34:10 -0500 | [diff] [blame] | 2 | * Copyright (C) 2020-2021 Savoir-faire Linux Inc. |
agsantos | 5aa3965 | 2020-08-11 18:18:04 -0400 | [diff] [blame] | 3 | * |
| 4 | * Author: Aline Gondim Santos <aline.gondimsantos@savoirfairelinux.com> |
| 5 | * |
| 6 | * This program is free software; you can redistribute it and/or modify |
| 7 | * it under the terms of the GNU General Public License as published by |
| 8 | * the Free Software Foundation; either version 3 of the License, or |
| 9 | * (at your option) any later version. |
| 10 | * |
| 11 | * This program is distributed in the hope that it will be useful, |
| 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 14 | * GNU General Public License for more details. |
| 15 | * |
| 16 | * You should have received a copy of the GNU General Public License |
| 17 | * along with this program; if not, write to the Free Software |
| 18 | * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. |
| 19 | */ |
| 20 | |
| 21 | #pragma once |
| 22 | |
| 23 | // Library headers |
| 24 | #include "TFModels.h" |
| 25 | |
| 26 | // STL |
| 27 | #include <memory> |
| 28 | #include <string> |
| 29 | #include <vector> |
| 30 | |
| 31 | #ifdef TFLITE |
| 32 | #include <tensorflow/lite/interpreter.h> |
| 33 | #include <tensorflow/lite/delegates/nnapi/nnapi_delegate.h> |
| 34 | |
| 35 | namespace tflite { |
| 36 | class FlatBufferModel; |
| 37 | class Interpreter; |
| 38 | class StatefulNnApiDelegate; |
| 39 | } // namespace tflite |
| 40 | |
| 41 | #else |
| 42 | #ifdef WIN32 |
| 43 | #define NOMINMAX |
| 44 | #undef min |
| 45 | #undef max |
| 46 | #endif |
| 47 | |
| 48 | #include <tensorflow/core/lib/core/status.h> |
| 49 | #include <tensorflow/core/public/session.h> |
| 50 | #include <tensorflow/core/framework/tensor.h> |
| 51 | #include <tensorflow/core/framework/types.pb.h> |
| 52 | #include <tensorflow/core/platform/init_main.h> |
| 53 | #include <tensorflow/core/protobuf/config.pb.h> |
| 54 | |
| 55 | namespace tensorflow { |
| 56 | class Tensor; |
| 57 | class Status; |
| 58 | class GraphDef; |
| 59 | class Session; |
| 60 | struct SessionOptions; |
| 61 | class TensorShape; |
| 62 | class Env; |
agsantos | ac1940d | 2020-09-17 10:18:40 -0400 | [diff] [blame] | 63 | enum DataType : int; |
| 64 | } // namespace tensorflow |
agsantos | 5aa3965 | 2020-08-11 18:18:04 -0400 | [diff] [blame] | 65 | |
| 66 | #endif |
| 67 | |
agsantos | ac1940d | 2020-09-17 10:18:40 -0400 | [diff] [blame] | 68 | namespace jami { |
| 69 | class TensorflowInference |
agsantos | 5aa3965 | 2020-08-11 18:18:04 -0400 | [diff] [blame] | 70 | { |
agsantos | 5aa3965 | 2020-08-11 18:18:04 -0400 | [diff] [blame] | 71 | public: |
| 72 | /** |
| 73 | * @brief TensorflowInference |
| 74 | * Takes a supervised model where the model and labels files are defined |
| 75 | * @param model |
| 76 | */ |
| 77 | TensorflowInference(TFModel model); |
| 78 | ~TensorflowInference(); |
| 79 | |
| 80 | #ifdef TFLITE |
| 81 | /** |
| 82 | * @brief loadModel |
| 83 | * Load the model from the file described in the Supervised Model |
| 84 | */ |
| 85 | void loadModel(); |
| 86 | void buildInterpreter(); |
| 87 | void setInterpreterSettings(); |
| 88 | |
| 89 | /** |
| 90 | * @brief allocateTensors |
| 91 | * Tries to allocate space for the tensors |
| 92 | * In case of success isAllocated() should return true |
| 93 | */ |
| 94 | void allocateTensors(); |
| 95 | |
| 96 | // Debug methods |
| 97 | void describeModelTensors() const; |
| 98 | void describeTensor(std::string prefix, int index) const; |
| 99 | |
| 100 | #else |
| 101 | void LoadGraph(); |
| 102 | tensorflow::Tensor imageTensor; |
| 103 | |
agsantos | ac1940d | 2020-09-17 10:18:40 -0400 | [diff] [blame] | 104 | #endif // TFLITE |
agsantos | 5aa3965 | 2020-08-11 18:18:04 -0400 | [diff] [blame] | 105 | |
| 106 | /** |
| 107 | * @brief runGraph |
| 108 | * runs the underlaying graph model.numberOfRuns times |
| 109 | * Where numberOfRuns is defined in the model |
| 110 | */ |
| 111 | void runGraph(); |
| 112 | |
| 113 | /** |
| 114 | * @brief init |
| 115 | * Inits the model, interpreter, allocates tensors and load the labels |
| 116 | */ |
| 117 | void init(); |
| 118 | // Getters |
| 119 | bool isAllocated() const; |
| 120 | |
| 121 | protected: |
| 122 | #ifdef TFLITE |
| 123 | /** |
| 124 | * @brief getTensorDimensions |
| 125 | * Utility method to get Tensorflow Tensor dimensions |
| 126 | * Given the index of the tensor, the function gives back a vector |
| 127 | * Where each element is the dimension of the vector-space (finite dimension) |
| 128 | * Thus, vector.size() is the number of vector-space used by the tensor |
| 129 | * @param index |
| 130 | * @return |
| 131 | */ |
| 132 | std::vector<int> getTensorDimensions(int index) const; |
| 133 | |
| 134 | // Tensorflow model and interpreter |
| 135 | std::unique_ptr<tflite::FlatBufferModel> flatbufferModel; |
| 136 | std::unique_ptr<tflite::Interpreter> interpreter; |
| 137 | #else |
| 138 | std::unique_ptr<tensorflow::Session> session; |
| 139 | std::vector<tensorflow::Tensor> outputs; |
| 140 | #endif |
| 141 | TFModel tfModel; |
| 142 | std::vector<std::string> labels; |
| 143 | |
| 144 | /** |
| 145 | * @brief nbLabels |
| 146 | * The real number of labels may not match the labels.size() because of padding |
| 147 | */ |
| 148 | size_t nbLabels; |
| 149 | |
agsantos | 9dcf430 | 2020-09-01 18:21:48 -0400 | [diff] [blame] | 150 | bool allocated_ = false; |
agsantos | 5aa3965 | 2020-08-11 18:18:04 -0400 | [diff] [blame] | 151 | }; |
agsantos | ac1940d | 2020-09-17 10:18:40 -0400 | [diff] [blame] | 152 | } // namespace jami |