blob: d341671346db9b259f2c4b315c730488f490b403 [file] [log] [blame]
/**
* Copyright (C) 2022 Savoir-faire Linux Inc.
*
* Author: Aline Gondim Santos <aline.gondimsantos@savoirfairelinux.com>
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/
#include "TranscriptAudioSubscriber.h"
#include <pluglog.h>
#include <frameUtils.h>
#include <bitset>
#include <iostream>
const std::string TAG = "Transcript";
const char sep = separator();
namespace jami {
TranscriptAudioSubscriber::TranscriptAudioSubscriber(const std::string& dataPath, TranscriptVideoSubscriber* videoSubscriber, bool acc)
: path_ {dataPath}
, modelProcessor_ {dataPath, acc}
, mVS_ {videoSubscriber}
{
loadMelFilters(path_ + "/assets/mel_filters.bin", modelFilters_);
/**
* Waits for audio samples and then process them
**/
processFrameThread = std::thread([this] {
while (running) {
std::unique_lock<std::mutex> l(inputLock);
inputCv.wait(l, [this] { return not running or newFrame; });
if (not running) {
break;
}
logMelSpectrogram(currentModelInput_.data(), currentModelInput_.size(), 8, modelFilters_, melSpectrogram_);
inputPadTrim(melSpectrogram_);
newFrame = false;
currentModelInput_.clear();
#ifndef __DEBUG__
/** Unlock the mutex, this way we let the other thread
* copy new data while we are processing the old one
**/
l.unlock();
#endif
modelProcessor_.feedInput(melSpectrogram_.data);
auto text = modelProcessor_.getText();
mVS_->setText(text);
}
});
}
TranscriptAudioSubscriber::~TranscriptAudioSubscriber()
{
modelProcessor_.endModels();
formatFilter_.clean();
stop();
processFrameThread.join();
Plog::log(Plog::LogPriority::INFO, TAG, "~TranscriptMediaProcessor");
}
void
TranscriptAudioSubscriber::stop()
{
running = false;
inputCv.notify_all();
}
void
TranscriptAudioSubscriber::update(jami::Observable<AVFrame*>*, AVFrame* const& pluginFrame)
{
if (!pluginFrame || modelFilters_.data.empty())
return;
if (firstRun) {
modelProcessor_.getText();
count_ = 0;
pastModelInput_.clear();
currentModelInput_.clear();
futureModelInput_.clear();
formatFilter_.clean();
AudioFormat afmt = AudioFormat(pluginFrame->sample_rate,
pluginFrame->channels,
static_cast<AVSampleFormat>(pluginFrame->format));
MediaStream ms = MediaStream("input", afmt);
formatFilter_.initialize(filterDescription_, {ms});
firstRun = false;
}
if (!formatFilter_.initialized_)
return;
if (formatFilter_.feedInput(pluginFrame, "input") == 0) {
uniqueFramePtr filteredFrame = {formatFilter_.readOutput(), frameFree};
if (filteredFrame) {
for (size_t i = 0; i < filteredFrame->buf[0]->size; i += 2) {
std::lock_guard<std::mutex> l(inputLock);
int16_t rawValue = (filteredFrame->buf[0]->data[i+1] << 8) | filteredFrame->buf[0]->data[i];
// If not a positive value, perform the 2's complement math on the value
if ((rawValue & 0x8000) != 0) {
rawValue = (~(rawValue - 0x0001)) * -1;
}
futureModelInput_.emplace_back(float(rawValue)/32768.0f);
if (count_++ > WHISPER_STREAM_SAMPLES_CHUNK_STEP)
overlapInput_.emplace_back(float(rawValue)/32768.0f);
count_++;
// Trigger transcription when we have enough samples
if (futureModelInput_.size() == WHISPER_STREAM_SAMPLES_CHUNK && !newFrame) {
pastModelInput_.clear();
std::swap(pastModelInput_, currentModelInput_);
std::swap(currentModelInput_, futureModelInput_);
std::swap(futureModelInput_, overlapInput_);
count_ = 0;
overlapInput_.clear();
newFrame = true;
inputCv.notify_all();
}
}
}
}
// audio returns as is
}
void
TranscriptAudioSubscriber::attached(jami::Observable<AVFrame*>* observable)
{
Plog::log(Plog::LogPriority::INFO, TAG, "::Attached ! ");
observable_ = observable;
}
void
TranscriptAudioSubscriber::detached(jami::Observable<AVFrame*>*)
{
modelProcessor_.getText();
firstRun = true;
observable_ = nullptr;
Plog::log(Plog::LogPriority::INFO, TAG, "::Detached()");
}
void
TranscriptAudioSubscriber::detach()
{
if (observable_) {
firstRun = true;
std::ostringstream oss;
Plog::log(Plog::LogPriority::INFO, TAG, "::Calling detach()");
observable_->detach(this);
}
}
} // namespace jami