blob: 37436afb4fb8ff9c2bc3342cf772d27c4ed53e14 [file] [log] [blame]
agsantos5aa39652020-08-11 18:18:04 -04001/**
2 * Copyright (C) 2020 Savoir-faire Linux Inc.
3 *
4 * Author: Aline Gondim Santos <aline.gondimsantos@savoirfairelinux.com>
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
19 */
20
21#pragma once
22
23#include "TFInference.h"
24
25// OpenCV headers
26#include <opencv2/core.hpp>
27// STL
28#include <array>
29#include <vector>
30#include <tuple>
31#include <iostream>
32
33namespace jami {
34
35class PluginInference : public TensorflowInference {
36public:
37 /**
38 * @brief PluginInference
39 * Is a type of supervised learning where we detect objects in images
40 * Draw a bounding boxes around them
41 * @param model
42 */
43 PluginInference(TFModel model);
44 ~PluginInference();
45
46#ifdef TFLITE
47 /**
48 * @brief getInput
49 * Returns the input where to fill the data
50 * Use this method if you know what you are doing, all the necessary checks
51 * on dimensions must be done on your part
52 * @return std::tuple<uint8_t *, std::vector<int>>
53 * The first element in the tuple is the pointer to the storage location
54 * The second element is a dimensions vector that will helps you make
55 * The necessary checks to make your data size match the input one
56 */
57 std::pair<uint8_t*, std::vector<int>> getInput();
58
59#else
60 void ReadTensorFromMat(const cv::Mat& image);
61
62#endif //TFLITE
63
64 std::vector<float> masksPredictions() const;
65
66
67 /**
68 * @brief setExpectedImageDimensions
69 * Sets imageWidth and imageHeight from the sources
70 */
71 void setExpectedImageDimensions();
72
73 // Getters
74 int getImageWidth() const;
75 int getImageHeight() const;
76 int getImageNbChannels() const;
77
78
79private:
80 int imageWidth = 0;
81 int imageHeight = 0;
82 int imageNbChannels = 0;
83};
84} // namespace jami