Modify GreenScreen

Change-Id: I179896b2414f35f0efc543738e7ecc943d5deb1d
diff --git a/GreenScreen/pluginProcessor.cpp b/GreenScreen/pluginProcessor.cpp
index 3b38830..37db9db 100644
--- a/GreenScreen/pluginProcessor.cpp
+++ b/GreenScreen/pluginProcessor.cpp
@@ -33,20 +33,24 @@
 extern "C" {
 #include <libavutil/display.h>
 }
-
 const char sep = separator();
 
 const std::string TAG = "FORESEG";
 
-PluginParameters* mPluginParameters = getGlobalPluginParameters();
-
 namespace jami {
 
-PluginProcessor::PluginProcessor(const std::string& dataPath)
-    : pluginInference {TFModel {dataPath + sep + "models" + sep + mPluginParameters->model}}
+PluginProcessor::PluginProcessor(const std::string& dataPath, const std::string& model, const std::string& backgroundImage, bool acc)
 {
-    initModel();
-    setBackgroundImage(mPluginParameters->image);
+    activateAcc_ = acc;
+    initModel(dataPath+sep+"model/"+model);
+    setBackgroundImage(backgroundImage);
+}
+
+PluginProcessor::~PluginProcessor()
+{
+    Plog::log(Plog::LogPriority::INFO, TAG, "~pluginprocessor");
+    if (session_)
+        delete session_;
 }
 
 void
@@ -75,37 +79,51 @@
 }
 
 void
-PluginProcessor::initModel()
+PluginProcessor::initModel(const std::string& modelPath)
 {
     try {
-        pluginInference.init();
+        auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+        input_tensor_ = Ort::Value::CreateTensor<float>(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());
+        output_tensor_ = Ort::Value::CreateTensor<float>(allocator_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size());
+        sessOpt_ =  Ort::SessionOptions();
+
+#ifdef NVIDIA
+        if (activateAcc_)
+            Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(sessOpt_, 0));
+#endif
+#ifdef ANDROID
+        if (activateAcc_)
+            Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(sessOpt_, 0));
+#endif
+
+        sessOpt_.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
+#ifdef WIN32
+        std::wstring wsTmp(modelPath.begin(), modelPath.end());
+        session_ = new Ort::Session(env, wsTmp.c_str(), sessOpt_);
+#else
+        session_ = new Ort::Session(env, modelPath.c_str(), sessOpt_);
+#endif
+        isAllocated_ = true;
     } catch (std::exception& e) {
         Plog::log(Plog::LogPriority::ERR, TAG, e.what());
     }
     std::ostringstream oss;
-    oss << "Model is allocated " << pluginInference.isAllocated();
+    oss << "Model is allocated " << isAllocated_;
     Plog::log(Plog::LogPriority::INFO, TAG, oss.str());
 }
 
-#ifdef TFLITE
+bool
+PluginProcessor::isAllocated()
+{
+    return isAllocated_;
+}
+
 void
 PluginProcessor::feedInput(const cv::Mat& frame)
 {
-    auto pair = pluginInference.getInput();
-    uint8_t* inputPointer = pair.first;
-
-    cv::Mat temp(frame.rows, frame.cols, CV_8UC3, inputPointer);
-    frame.convertTo(temp, CV_8UC3);
-
-    inputPointer = nullptr;
+    cv::Mat temp(frame.rows, frame.cols, CV_32FC3, input_image_.data());
+    frame.convertTo(temp, CV_32FC3);
 }
-#else
-void
-PluginProcessor::feedInput(const cv::Mat& frame)
-{
-    pluginInference.ReadTensorFromMat(frame);
-}
-#endif // TFLITE
 
 int
 PluginProcessor::getBackgroundRotation()
@@ -127,11 +145,8 @@
 {
     if (count == 0) {
         // Run the graph
-        pluginInference.runGraph();
-        auto predictions = pluginInference.masksPredictions();
-
-        // Save the predictions
-        computedMask = predictions;
+        session_->Run(Ort::RunOptions{nullptr}, input_names, &input_tensor_, 1, output_names, &output_tensor_, 1);
+        computedMask = std::vector(results_.begin(), results_.end());
     }
 }
 
@@ -188,30 +203,12 @@
         return;
     }
 
-    int maskSize = static_cast<int>(std::sqrt(computedMask.size()));
-    cv::Mat maskImg(maskSize, maskSize, CV_32FC1, computedMask.data());
-    cv::Mat* applyMask = &frameReduced;
-    cv::Mat output;
-
     if (count == 0) {
+        int maskSize = static_cast<int>(std::sqrt(computedMask.size()));
+        cv::Mat maskImg(maskSize, maskSize, CV_32FC1, computedMask.data());
+        cv::Mat* applyMask = &frameReduced;
+
         rotateFrame(-angle, maskImg);
-#ifdef TFLITE
-        for (int i = 0; i < maskImg.cols; i++) {
-            for (int j = 0; j < maskImg.rows; j++) {
-                if (maskImg.at<float>(j, i) == 15)
-                    maskImg.at<float>(j, i) = 1.;
-                else
-                    maskImg.at<float>(j, i) = smoothFactors[0] * previousMasks[0].at<float>(j, i)
-                                              + smoothFactors[1] * previousMasks[1].at<float>(j, i);
-            }
-        }
-        cv::morphologyEx(maskImg,
-                         maskImg,
-                         cv::MORPH_CLOSE,
-                         cv::getStructuringElement(cv::MORPH_ELLIPSE, kSize),
-                         cv::Point(-1, -1),
-                         4);
-#else
         cv::resize(maskImg, maskImg, cv::Size(frameReduced.cols, frameReduced.rows));
 
         double m, M;
@@ -238,52 +235,39 @@
                 }
             }
         }
-#endif
         if (cv::countNonZero(maskImg) != 0) {
-#ifdef TFLITE
-            cv::Mat tfMask;
-            tfMask = maskImg.clone();
-            tfMask *= 255.;
-            tfMask.convertTo(tfMask, CV_8UC1);
-            cv::threshold(tfMask, tfMask, 127, 255, cv::THRESH_BINARY);
-            if (cv::countNonZero(tfMask) != 0) {
-#endif
-                cv::Mat dilate;
-                cv::dilate(maskImg,
-                           dilate,
-                           cv::getStructuringElement(cv::MORPH_ELLIPSE, kSize),
-                           cv::Point(-1, -1),
-                           2);
-                cv::erode(maskImg,
-                          maskImg,
-                          cv::getStructuringElement(cv::MORPH_ELLIPSE, kSize),
-                          cv::Point(-1, -1),
-                          2);
-                for (int i = 0; i < maskImg.cols; i++) {
-                    for (int j = 0; j < maskImg.rows; j++) {
-                        if (dilate.at<float>(j, i) != maskImg.at<float>(j, i))
-                            maskImg.at<float>(j, i) = grabcutClass;
-                    }
+            cv::Mat dilate;
+            cv::dilate(maskImg,
+                        dilate,
+                        cv::getStructuringElement(cv::MORPH_ELLIPSE, kSize),
+                        cv::Point(-1, -1),
+                        2);
+            cv::erode(maskImg,
+                        maskImg,
+                        cv::getStructuringElement(cv::MORPH_ELLIPSE, kSize),
+                        cv::Point(-1, -1),
+                        2);
+            for (int i = 0; i < maskImg.cols; i++) {
+                for (int j = 0; j < maskImg.rows; j++) {
+                    if (dilate.at<float>(j, i) != maskImg.at<float>(j, i))
+                        maskImg.at<float>(j, i) = grabcutClass;
                 }
-                maskImg.convertTo(maskImg, CV_8UC1);
-                applyMask->convertTo(*applyMask, CV_8UC1);
-                cv::Rect rect(1, 1, maskImg.rows, maskImg.cols);
-                cv::grabCut(*applyMask,
-                            maskImg,
-                            rect,
-                            bgdModel,
-                            fgdModel,
-                            grabCutIterations,
-                            grabCutMode);
-
-                grabCutMode = cv::GC_EVAL;
-                grabCutIterations = 1;
-
-                maskImg = maskImg & 1;
-#ifdef TFLITE
-                cv::bitwise_and(maskImg, tfMask, maskImg);
             }
-#endif
+            maskImg.convertTo(maskImg, CV_8UC1);
+            applyMask->convertTo(*applyMask, CV_8UC1);
+            cv::Rect rect(1, 1, maskImg.rows, maskImg.cols);
+            cv::grabCut(*applyMask,
+                        maskImg,
+                        rect,
+                        bgdModel,
+                        fgdModel,
+                        grabCutIterations,
+                        grabCutMode);
+
+            grabCutMode = cv::GC_EVAL;
+            grabCutIterations = 1;
+
+            maskImg = maskImg & 1;
             maskImg.convertTo(maskImg, CV_32FC1);
             maskImg *= 255.;
             GaussianBlur(maskImg, maskImg, cv::Size(7, 7), 0); // float mask from 0 to 255.
@@ -309,13 +293,11 @@
     cv::merge(channels, roiMaskImg);
     cv::merge(channelsComplementary, roiMaskImgComplementary);
 
-    int origType = frameReduced.type();
-    int roiMaskType = roiMaskImg.type();
-
-    frameReduced.convertTo(output, roiMaskType);
+    cv::Mat output;
+    frameReduced.convertTo(output, roiMaskImg.type());
     output = output.mul(roiMaskImg);
     output += backgroundImage.mul(roiMaskImgComplementary);
-    output.convertTo(output, origType);
+    output.convertTo(output, frameReduced.type());
 
     cv::resize(output, output, cv::Size(frame.cols, frame.rows));