blob: 60f034aba06f62d7f94180598a17b8bcbc9b256c [file] [log] [blame]
/**
* Copyright (C) 2022 Savoir-faire Linux Inc.
*
* Author: Aline Gondim Santos <aline.gondimsantos@savoirfairelinux.com>
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301
* USA.
*/
#include "ModelProcessor.h"
#include <pluglog.h>
#include <common.h>
#include <limits.h>
const char sep = separator();
const std::string TAG = "Transcript";
namespace jami {
ModelProcessor::ModelProcessor(const std::string& path, bool acc)
{
loadTokens(path + "/assets/tokenizer.bin", vocab_);
#ifdef __ANDROID__
initModels(path + "/assets/mModelEncoder.ort", path + "/assets/mModelDecoder.ort", path + "/assets/mLogSoftMax.ort", acc);
#else
initModels(path + "/assets/mModelEncoder.onnx", path + "/assets/mModelDecoder.onnx", path + "/assets/mLogSoftMax.onnx", acc);
#endif
}
ModelProcessor::~ModelProcessor()
{
endModels();
Plog::log(Plog::LogPriority::INFO, TAG, "~ModelProcessor");
}
void
ModelProcessor::endModels()
{
if (encoderSession_) {
delete encoderSession_;
encoderSession_ = nullptr;
}
if (decoderSession_) {
delete decoderSession_;
decoderSession_ = nullptr;
}
if (logSoftMaxSession_) {
delete logSoftMaxSession_;
logSoftMaxSession_ = nullptr;
}
if (env_)
env_.release();
env_ = NULL;
}
void
ModelProcessor::initModels(const std::string& encoderModelPath, const std::string& decoderModelPath, const std::string& logSoftMaxModelPath, bool activateAcc)
{
try {
sessOpt_ = Ort::SessionOptions();
#ifdef NVIDIA
if (activateAcc)
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(sessOpt_, 0));
#endif
#ifdef __ANDROID__
if (activateAcc)
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(sessOpt_, 0));
#endif
sessOpt_.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
#ifdef WIN32
encoderSession_ = new Ort::Session(env_, string_utils::to_wstring(encoderModelPath).c_str(), sessOpt_);
decoderSession_ = new Ort::Session(env_, string_utils::to_wstring(decoderModelPath).c_str(), sessOpt_);
logSoftMaxSession_ = new Ort::Session(env_, string_utils::to_wstring(logSoftMaxModelPath).c_str(), sessOpt_);
#else
encoderSession_ = new Ort::Session(env_, encoderModelPath.c_str(), sessOpt_);
decoderSession_ = new Ort::Session(env_, decoderModelPath.c_str(), sessOpt_);
logSoftMaxSession_ = new Ort::Session(env_, logSoftMaxModelPath.c_str(), sessOpt_);
#endif
isAllocated_ = true;
Plog::log(Plog::LogPriority::INFO, TAG, "Model is allocated");
} catch (std::exception& e) {
Plog::log(Plog::LogPriority::ERR, TAG, e.what());
}
}
/* from whisper.cpp */
// the most basic sampling scheme - select the top token
whisperTokenData
ModelProcessor::whisper_sample_best(const float * probs)
{
whisperTokenData result = {
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
int n_logits = vocab_.id_to_token.size();
std::vector<std::pair<double, int64_t>> probs_id;
probs_id.reserve(n_logits);
for (int i = 0; i < n_logits; i++) {
probs_id.emplace_back(std::make_pair(probs[i], i));
}
{
double sum_ts = 0.0;
double max_ts = -1.0;
double max_tx = -1.0;
for (int i = 0; i < vocab_.token_beg; i++) {
max_tx = std::max(max_tx, probs_id[i].first);
}
for (int i = vocab_.token_beg; i < n_logits; i++) {
sum_ts += probs_id[i].first;
if (probs_id[i].first > max_ts) {
max_ts = probs_id[i].first;
result.tid = probs_id[i].second;
}
}
// if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
// timestamp token
if (sum_ts > max_tx) {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for (int i = 0; i < vocab_.token_beg; i++) {
probs_id[i].first = -INT_MAX;
}
}
result.pt = max_ts/(sum_ts + 1e-10);
result.ptsum = sum_ts;
}
// find the top K tokens
const int top_k = 4;
std::partial_sort(
probs_id.begin(),
probs_id.begin() + top_k, probs_id.end(),
[](const std::pair<double, int64_t> & a, const std::pair<double, int64_t> & b) {
return a.first > b.first;
});
probs_id.resize(top_k);
int res = 0;
while ((probs_id[res].second == vocab_.token_sot ||
probs_id[res].second == vocab_.token_solm ||
probs_id[res].second == vocab_.token_beg) &&
res < (int) probs_id.size() - 1) {
res++;
}
result.id = probs_id[res].second;
result.p = probs_id[res].first;
return result;
}
void
ModelProcessor::filterLogits(std::vector<float>& logits, int offset)
{
// Remove all no speech tokens
for (const auto idx : vocab_.noSpeechTokens) {
logits[idx] = (float)-INT_MAX;
}
}
void
ModelProcessor::filterLanguageLogits(std::vector<float>& logits)
{
// Leave only the language tokens
for (size_t i = 0; i < logits.size(); i++) {
if (vocab_.languageId2Tokens[i].empty())
logits[i] = (float)(-INT_MAX);
}
}
whisperTokenData
ModelProcessor::getToken(std::vector<float>& logits)
{
std::vector<Ort::Value> logSoftMaxInputs;
logSoftMaxInputs.emplace_back(Ort::Value::CreateTensor<float>(allocatorInfo_,
logits.data(),
logits.size(),
logitsShape_.data(),
logitsShape_.size()));
auto softmaxOutputs = logSoftMaxSession_->Run(Ort::RunOptions {nullptr},
logSoftMaxInputNames.data(),
logSoftMaxInputs.data(),
logSoftMaxInputNames.size(),
logSoftMaxOutputNames.data(),
logSoftMaxOutputNames.size());
float* probs = softmaxOutputs[1].GetTensorMutableData<float>();
return whisper_sample_best(probs);
}
std::string
ModelProcessor::feedInput(std::vector<float>& melInput, const std::string& preferenceLanguage)
{
std::lock_guard<std::mutex> l(mtx_);
try {
Ort::Value melInputTensor = Ort::Value::CreateTensor<float>(allocatorInfo_,
melInput.data(),
melInput.size(),
melInputShape_.data(),
melInputShape_.size());
audioFeaturesTensor_ = Ort::Value::CreateTensor<float>(allocatorInfo_,
audioFeatures_.data(),
audioFeatures_.size(),
audioFeaturesShape_.data(),
audioFeaturesShape_.size());
// Run the encoder graph
encoderSession_->Run(Ort::RunOptions {nullptr},
encoderInputNames,
&melInputTensor,
1,
encoderOutputNames,
&audioFeaturesTensor_,
1);
} catch(Ort::Exception e) {
Plog::log(Plog::LogPriority::ERR, TAG, e.what());
return "";
} catch (...) { return ""; }
try {
auto isMultilingual = vocab_.is_multilingual();
std::vector<int64_t> currentTokens {};
currentTokens.emplace_back(vocab_.token_sot);
std::array<int64_t, 1> offsetShape {1};
if (isMultilingual) {
if (preferenceLanguage == "auto"
|| vocab_.languageTokens2Id.find(preferenceLanguage) == vocab_.languageTokens2Id.end()) {
std::vector<float> currentKVCache(MODELKVCACHESHAPE * 1 * currentTokens.size() * MODELFEATURESHAPE, 0.0f);
std::array<int64_t, 2> tokenShape {1, 1};
int64_t offset = 0;
std::array<int64_t, 4> kvCacheShape { MODELKVCACHESHAPE, 1, 1, MODELFEATURESHAPE };
std::vector<int64_t> token = { currentTokens.back() };
// Run the decoder graph
std::vector<Ort::Value> inputsVector; // {audioFeaturesTensor_, tokensTensor_, kvCacheTensor_, offsetTensor_};
inputsVector.emplace_back(Ort::Value::CreateTensor<float>(allocatorInfo_,
audioFeatures_.data(),
audioFeatures_.size(),
audioFeaturesShape_.data(),
audioFeaturesShape_.size()));
inputsVector.emplace_back(Ort::Value::CreateTensor<int64_t>(allocatorInfo_,
token.data(),
token.size(),
tokenShape.data(),
tokenShape.size()));
inputsVector.emplace_back(Ort::Value::CreateTensor<float>(allocatorInfo_,
currentKVCache.data(),
currentKVCache.size(),
kvCacheShape.data(),
kvCacheShape.size()));
inputsVector.emplace_back(Ort::Value::CreateTensor<int64_t>(allocatorInfo_,
&offset,
1,
offsetShape.data(),
0));
auto outputs = decoderSession_->Run(Ort::RunOptions {nullptr},
decoderInputNames.data(),
inputsVector.data(),
decoderInputNames.size(),
decoderOutputNames.data(),
decoderOutputNames.size());
auto logitsTensorInfo = outputs[0].GetTensorTypeAndShapeInfo();
auto logitsData = outputs[0].GetTensorMutableData<float>();
{
std::vector<float>logits(logitsData, logitsData + logitsTensorInfo.GetElementCount());
filterLanguageLogits(logits);
auto it = std::max_element(logits.begin(), logits.end());
currentTokens.emplace_back(std::distance(logits.begin(), it));
}
} else
currentTokens.emplace_back(vocab_.languageTokens2Id[preferenceLanguage]);
currentTokens.emplace_back(vocab_.token_transcribe);
}
std::vector<float> currentKVCache(MODELKVCACHESHAPE * 1 * currentTokens.size() * MODELFEATURESHAPE, 0.0f);
std::array<int64_t, 2> tokenShape {1, (long)currentTokens.size()};
for (auto i = 0; i < sampleLen; i++) {
int64_t offset = isMultilingual ? ( i == 0 ? 0 : i + 2 ) : i;
std::array<int64_t, 4> kvCacheShape { MODELKVCACHESHAPE, 1, static_cast<int64_t>(currentTokens.size()), MODELFEATURESHAPE };
std::vector<int64_t> token = { currentTokens.back() };
if (i == 0) {
token = currentTokens;
tokenShape[1] = currentTokens.size();
} else {
tokenShape[1] = 1;
}
// Run the decoder graph
std::vector<Ort::Value> inputsVector; // {audioFeaturesTensor_, tokensTensor_, kvCacheTensor_, offsetTensor_};
inputsVector.emplace_back(Ort::Value::CreateTensor<float>(allocatorInfo_,
audioFeatures_.data(),
audioFeatures_.size(),
audioFeaturesShape_.data(),
audioFeaturesShape_.size()));
inputsVector.emplace_back(Ort::Value::CreateTensor<int64_t>(allocatorInfo_,
token.data(),
token.size(),
tokenShape.data(),
tokenShape.size()));
inputsVector.emplace_back(Ort::Value::CreateTensor<float>(allocatorInfo_,
currentKVCache.data(),
currentKVCache.size(),
kvCacheShape.data(),
kvCacheShape.size()));
inputsVector.emplace_back(Ort::Value::CreateTensor<int64_t>(allocatorInfo_,
&offset,
1,
offsetShape.data(),
0));
auto outputs = decoderSession_->Run(Ort::RunOptions {nullptr},
decoderInputNames.data(),
inputsVector.data(),
decoderInputNames.size(),
decoderOutputNames.data(),
decoderOutputNames.size());
auto logitsTensorInfo = outputs[0].GetTensorTypeAndShapeInfo();
auto logitsData = outputs[0].GetTensorMutableData<float>();
{
std::vector<float>logits(logitsData, logitsData + logitsTensorInfo.GetElementCount());
if (isMultilingual && logits.size() > vocab_.n_vocab) {
std::vector<float>lastLogits;
lastLogits = std::vector<float>(logits.begin() + 2 * vocab_.n_vocab, logits.end());
std::swap(lastLogits, logits);
}
filterLogits(logits, offset);
auto tokenData = getToken(logits);
currentTokens.emplace_back(tokenData.id);
}
// Grab kvCache for next iteration
auto kvCacheTensorInfo = outputs[1].GetTensorTypeAndShapeInfo();
auto nextKVCacheData = outputs[1].GetTensorMutableData<float>();
std::vector<float> nextKVCache;
std::vector<float> zeros(MODELFEATURESHAPE, 0.0f);
int delta = (currentTokens.size() - 1) * MODELFEATURESHAPE;
for (int currentKVIdx = 0; currentKVIdx < MODELKVCACHESHAPE; currentKVIdx++) {
nextKVCache.insert(nextKVCache.end(),
nextKVCacheData + (currentKVIdx * delta),
nextKVCacheData + ((currentKVIdx + 1) * delta));
nextKVCache.insert(nextKVCache.end(), zeros.begin(), zeros.end());
}
std::swap(currentKVCache, nextKVCache);
if (currentTokens.back() == vocab_.token_eot)
break;
}
std::swap(currentTokens, tokensOutput_);
} catch(Ort::Exception e) {
Plog::log(Plog::LogPriority::ERR, TAG, e.what());
return "";
} catch (...) {}
std::ostringstream oss;
for (const auto& token : tokensOutput_) {
if (token >= vocab_.token_eot)
continue;
oss << vocab_.id_to_token[token];
}
tokensOutput_.clear();
return oss.str();
}
} // namespace jami