add initial project structure
Change-Id: I6a3fb080ff623b312e42d71754480a7ce00b81a0
diff --git a/src/connectionmanager.cpp b/src/connectionmanager.cpp
new file mode 100644
index 0000000..96bd6ba
--- /dev/null
+++ b/src/connectionmanager.cpp
@@ -0,0 +1,1656 @@
+/*
+ * Copyright (C) 2019-2023 Savoir-faire Linux Inc.
+ * Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+#include "connectionmanager.h"
+#include "peer_connection.h"
+#include "upnp/upnp_control.h"
+#include "certstore.h"
+#include "fileutils.h"
+#include "sip_utils.h"
+#include "string_utils.h"
+
+#include <opendht/crypto.h>
+#include <opendht/thread_pool.h>
+#include <opendht/value.h>
+#include <asio.hpp>
+
+#include <algorithm>
+#include <mutex>
+#include <map>
+#include <condition_variable>
+#include <set>
+#include <charconv>
+
+namespace jami {
+static constexpr std::chrono::seconds DHT_MSG_TIMEOUT {30};
+static constexpr uint64_t ID_MAX_VAL = 9007199254740992;
+
+using ValueIdDist = std::uniform_int_distribution<dht::Value::Id>;
+using CallbackId = std::pair<jami::DeviceId, dht::Value::Id>;
+
+struct ConnectionInfo
+{
+ ~ConnectionInfo()
+ {
+ if (socket_)
+ socket_->join();
+ }
+
+ std::mutex mutex_ {};
+ bool responseReceived_ {false};
+ PeerConnectionRequest response_ {};
+ std::unique_ptr<IceTransport> ice_ {nullptr};
+ // Used to store currently non ready TLS Socket
+ std::unique_ptr<TlsSocketEndpoint> tls_ {nullptr};
+ std::shared_ptr<MultiplexedSocket> socket_ {};
+ std::set<CallbackId> cbIds_ {};
+
+ std::function<void(bool)> onConnected_;
+ std::unique_ptr<asio::steady_timer> waitForAnswer_ {};
+};
+
+/**
+ * returns whether or not UPnP is enabled and active_
+ * ie: if it is able to make port mappings
+ */
+bool
+ConnectionManager::Config::getUPnPActive() const
+{
+ if (upnpCtrl)
+ return upnpCtrl->isReady();
+ return false;
+}
+
+class ConnectionManager::Impl : public std::enable_shared_from_this<ConnectionManager::Impl>
+{
+public:
+ explicit Impl(std::shared_ptr<ConnectionManager::Config> config_)
+ : config_ {std::move(config_)}
+ {}
+ ~Impl() {}
+
+ std::shared_ptr<dht::DhtRunner> dht() { return config_->dht; }
+ const dht::crypto::Identity& identity() const { return config_->id; }
+
+ void removeUnusedConnections(const DeviceId& deviceId = {})
+ {
+ std::vector<std::shared_ptr<ConnectionInfo>> unused {};
+
+ {
+ std::lock_guard<std::mutex> lk(infosMtx_);
+ for (auto it = infos_.begin(); it != infos_.end();) {
+ auto& [key, info] = *it;
+ if (info && (!deviceId || key.first == deviceId)) {
+ unused.emplace_back(std::move(info));
+ it = infos_.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ }
+ for (auto& info: unused) {
+ if (info->tls_)
+ info->tls_->shutdown();
+ if (info->socket_)
+ info->socket_->shutdown();
+ if (info->waitForAnswer_)
+ info->waitForAnswer_->cancel();
+ }
+ if (!unused.empty())
+ dht::ThreadPool::io().run([infos = std::move(unused)]() mutable { infos.clear(); });
+ }
+
+ void shutdown()
+ {
+ if (isDestroying_.exchange(true))
+ return;
+ {
+ std::lock_guard<std::mutex> lk(connectCbsMtx_);
+ // Call all pending callbacks that channel is not ready
+ for (auto& [deviceId, pcbs] : pendingCbs_)
+ for (auto& pending : pcbs)
+ pending.cb(nullptr, deviceId);
+ pendingCbs_.clear();
+ }
+ removeUnusedConnections();
+ }
+
+ struct PendingCb
+ {
+ std::string name;
+ ConnectCallback cb;
+ dht::Value::Id vid;
+ };
+
+ void connectDeviceStartIce(const std::shared_ptr<dht::crypto::PublicKey>& devicePk,
+ const dht::Value::Id& vid,
+ const std::string& connType,
+ std::function<void(bool)> onConnected);
+ void onResponse(const asio::error_code& ec, const DeviceId& deviceId, const dht::Value::Id& vid);
+ bool connectDeviceOnNegoDone(const DeviceId& deviceId,
+ const std::string& name,
+ const dht::Value::Id& vid,
+ const std::shared_ptr<dht::crypto::Certificate>& cert);
+ void connectDevice(const DeviceId& deviceId,
+ const std::string& uri,
+ ConnectCallback cb,
+ bool noNewSocket = false,
+ bool forceNewSocket = false,
+ const std::string& connType = "");
+ void connectDevice(const std::shared_ptr<dht::crypto::Certificate>& cert,
+ const std::string& name,
+ ConnectCallback cb,
+ bool noNewSocket = false,
+ bool forceNewSocket = false,
+ const std::string& connType = "");
+ /**
+ * Send a ChannelRequest on the TLS socket. Triggers cb when ready
+ * @param sock socket used to send the request
+ * @param name channel's name
+ * @param vid channel's id
+ * @param deviceId to identify the linked ConnectCallback
+ */
+ void sendChannelRequest(std::shared_ptr<MultiplexedSocket>& sock,
+ const std::string& name,
+ const DeviceId& deviceId,
+ const dht::Value::Id& vid);
+ /**
+ * Triggered when a PeerConnectionRequest comes from the DHT
+ */
+ void answerTo(IceTransport& ice,
+ const dht::Value::Id& id,
+ const std::shared_ptr<dht::crypto::PublicKey>& fromPk);
+ bool onRequestStartIce(const PeerConnectionRequest& req);
+ bool onRequestOnNegoDone(const PeerConnectionRequest& req);
+ void onDhtPeerRequest(const PeerConnectionRequest& req,
+ const std::shared_ptr<dht::crypto::Certificate>& cert);
+
+ void addNewMultiplexedSocket(const CallbackId& id, const std::shared_ptr<ConnectionInfo>& info);
+ void onPeerResponse(const PeerConnectionRequest& req);
+ void onDhtConnected(const dht::crypto::PublicKey& devicePk);
+
+ const std::shared_future<tls::DhParams> dhParams() const;
+ tls::CertificateStore& certStore() const { return *config_->certStore; }
+
+ mutable std::mutex messageMutex_ {};
+ std::set<std::string, std::less<>> treatedMessages_ {};
+
+ void loadTreatedMessages();
+ void saveTreatedMessages() const;
+
+ /// \return true if the given DHT message identifier has been treated
+ /// \note if message has not been treated yet this method st/ore this id and returns true at
+ /// further calls
+ bool isMessageTreated(std::string_view id);
+
+ const std::shared_ptr<dht::log::Logger>& logger() const { return config_->logger; }
+
+ /**
+ * Published IPv4/IPv6 addresses, used only if defined by the user in account
+ * configuration
+ *
+ */
+ IpAddr publishedIp_[2] {};
+
+ // This will be stored in the configuration
+ std::string publishedIpAddress_ {};
+
+ /**
+ * Published port, used only if defined by the user
+ */
+ pj_uint16_t publishedPort_ {sip_utils::DEFAULT_SIP_PORT};
+
+ /**
+ * interface name on which this account is bound
+ */
+ std::string interface_ {"default"};
+
+ /**
+ * Get the local interface name on which this account is bound.
+ */
+ const std::string& getLocalInterface() const { return interface_; }
+
+ /**
+ * Get the published IP address, fallbacks to NAT if family is unspecified
+ * Prefers the usage of IPv4 if possible.
+ */
+ IpAddr getPublishedIpAddress(uint16_t family = PF_UNSPEC) const;
+
+ /**
+ * Set published IP address according to given family
+ */
+ void setPublishedAddress(const IpAddr& ip_addr);
+
+ /**
+ * Store the local/public addresses used to register
+ */
+ void storeActiveIpAddress(std::function<void()>&& cb = {});
+
+ /**
+ * Create and return ICE options.
+ */
+ void getIceOptions(std::function<void(IceTransportOptions&&)> cb) noexcept;
+ IceTransportOptions getIceOptions() const noexcept;
+
+ /**
+ * Inform that a potential peer device have been found.
+ * Returns true only if the device certificate is a valid device certificate.
+ * In that case (true is returned) the account_id parameter is set to the peer account ID.
+ */
+ static bool foundPeerDevice(const std::shared_ptr<dht::crypto::Certificate>& crt,
+ dht::InfoHash& account_id, const std::shared_ptr<Logger>& logger);
+
+ bool findCertificate(const dht::PkId& id,
+ std::function<void(const std::shared_ptr<dht::crypto::Certificate>&)>&& cb);
+
+ /**
+ * returns whether or not UPnP is enabled and active
+ * ie: if it is able to make port mappings
+ */
+ bool getUPnPActive() const;
+
+ /**
+ * Triggered when a new TLS socket is ready to use
+ * @param ok If succeed
+ * @param deviceId Related device
+ * @param vid vid of the connection request
+ * @param name non empty if TLS was created by connectDevice()
+ */
+ void onTlsNegotiationDone(bool ok,
+ const DeviceId& deviceId,
+ const dht::Value::Id& vid,
+ const std::string& name = "");
+
+ std::shared_ptr<ConnectionManager::Config> config_;
+
+ IceTransportFactory iceFactory_ {};
+
+ mutable std::mt19937_64 rand;
+
+ iOSConnectedCallback iOSConnectedCb_ {};
+
+ std::mutex infosMtx_ {};
+ // Note: Someone can ask multiple sockets, so to avoid any race condition,
+ // each device can have multiple multiplexed sockets.
+ std::map<CallbackId, std::shared_ptr<ConnectionInfo>> infos_ {};
+
+ std::shared_ptr<ConnectionInfo> getInfo(const DeviceId& deviceId, const dht::Value::Id& id)
+ {
+ std::lock_guard<std::mutex> lk(infosMtx_);
+ auto it = infos_.find({deviceId, id});
+ if (it != infos_.end())
+ return it->second;
+ return {};
+ }
+
+ std::shared_ptr<ConnectionInfo> getConnectedInfo(const DeviceId& deviceId)
+ {
+ std::lock_guard<std::mutex> lk(infosMtx_);
+ auto it = std::find_if(infos_.begin(), infos_.end(), [&](const auto& item) {
+ auto& [key, value] = item;
+ return key.first == deviceId && value && value->socket_;
+ });
+ if (it != infos_.end())
+ return it->second;
+ return {};
+ }
+
+ ChannelRequestCallback channelReqCb_ {};
+ ConnectionReadyCallback connReadyCb_ {};
+ onICERequestCallback iceReqCb_ {};
+
+ /**
+ * Stores callback from connectDevice
+ * @note: each device needs a vector because several connectDevice can
+ * be done in parallel and we only want one socket
+ */
+ std::mutex connectCbsMtx_ {};
+ std::map<DeviceId, std::vector<PendingCb>> pendingCbs_ {};
+
+ std::vector<PendingCb> extractPendingCallbacks(const DeviceId& deviceId,
+ const dht::Value::Id vid = 0)
+ {
+ std::vector<PendingCb> ret;
+ std::lock_guard<std::mutex> lk(connectCbsMtx_);
+ auto pendingIt = pendingCbs_.find(deviceId);
+ if (pendingIt == pendingCbs_.end())
+ return ret;
+ auto& pendings = pendingIt->second;
+ if (vid == 0) {
+ ret = std::move(pendings);
+ } else {
+ for (auto it = pendings.begin(); it != pendings.end(); ++it) {
+ if (it->vid == vid) {
+ ret.emplace_back(std::move(*it));
+ pendings.erase(it);
+ break;
+ }
+ }
+ }
+ if (pendings.empty())
+ pendingCbs_.erase(pendingIt);
+ return ret;
+ }
+
+ std::vector<PendingCb> getPendingCallbacks(const DeviceId& deviceId,
+ const dht::Value::Id vid = 0)
+ {
+ std::vector<PendingCb> ret;
+ std::lock_guard<std::mutex> lk(connectCbsMtx_);
+ auto pendingIt = pendingCbs_.find(deviceId);
+ if (pendingIt == pendingCbs_.end())
+ return ret;
+ auto& pendings = pendingIt->second;
+ if (vid == 0) {
+ ret = pendings;
+ } else {
+ std::copy_if(pendings.begin(),
+ pendings.end(),
+ std::back_inserter(ret),
+ [&](auto pending) { return pending.vid == vid; });
+ }
+ return ret;
+ }
+
+ std::shared_ptr<ConnectionManager::Impl> shared()
+ {
+ return std::static_pointer_cast<ConnectionManager::Impl>(shared_from_this());
+ }
+ std::shared_ptr<ConnectionManager::Impl const> shared() const
+ {
+ return std::static_pointer_cast<ConnectionManager::Impl const>(shared_from_this());
+ }
+ std::weak_ptr<ConnectionManager::Impl> weak()
+ {
+ return std::static_pointer_cast<ConnectionManager::Impl>(shared_from_this());
+ }
+ std::weak_ptr<ConnectionManager::Impl const> weak() const
+ {
+ return std::static_pointer_cast<ConnectionManager::Impl const>(shared_from_this());
+ }
+
+ std::atomic_bool isDestroying_ {false};
+};
+
+void
+ConnectionManager::Impl::connectDeviceStartIce(
+ const std::shared_ptr<dht::crypto::PublicKey>& devicePk,
+ const dht::Value::Id& vid,
+ const std::string& connType,
+ std::function<void(bool)> onConnected)
+{
+ auto deviceId = devicePk->getLongId();
+ auto info = getInfo(deviceId, vid);
+ if (!info) {
+ onConnected(false);
+ return;
+ }
+
+ std::unique_lock<std::mutex> lk(info->mutex_);
+ auto& ice = info->ice_;
+
+ if (!ice) {
+ if (config_->logger)
+ config_->logger->error("No ICE detected");
+ onConnected(false);
+ return;
+ }
+
+ auto iceAttributes = ice->getLocalAttributes();
+ std::ostringstream icemsg;
+ icemsg << iceAttributes.ufrag << "\n";
+ icemsg << iceAttributes.pwd << "\n";
+ for (const auto& addr : ice->getLocalCandidates(1)) {
+ icemsg << addr << "\n";
+ if (config_->logger)
+ config_->logger->debug("Added local ICE candidate {}", addr);
+ }
+
+ // Prepare connection request as a DHT message
+ PeerConnectionRequest val;
+
+ val.id = vid; /* Random id for the message unicity */
+ val.ice_msg = icemsg.str();
+ val.connType = connType;
+
+ auto value = std::make_shared<dht::Value>(std::move(val));
+ value->user_type = "peer_request";
+
+ // Send connection request through DHT
+ if (config_->logger)
+ config_->logger->debug("Request connection to {}", deviceId);
+ dht()->putEncrypted(dht::InfoHash::get(PeerConnectionRequest::key_prefix
+ + devicePk->getId().toString()),
+ devicePk,
+ value,
+ [l=config_->logger,deviceId](bool ok) {
+ if (l)
+ l->debug("Sent connection request to {:s}. Put encrypted {:s}",
+ deviceId,
+ (ok ? "ok" : "failed"));
+ });
+ // Wait for call to onResponse() operated by DHT
+ if (isDestroying_) {
+ onConnected(true); // This avoid to wait new negotiation when destroying
+ return;
+ }
+
+ info->onConnected_ = std::move(onConnected);
+ info->waitForAnswer_ = std::make_unique<asio::steady_timer>(*config_->ioContext,
+ std::chrono::steady_clock::now()
+ + DHT_MSG_TIMEOUT);
+ info->waitForAnswer_->async_wait(
+ std::bind(&ConnectionManager::Impl::onResponse, this, std::placeholders::_1, deviceId, vid));
+}
+
+void
+ConnectionManager::Impl::onResponse(const asio::error_code& ec,
+ const DeviceId& deviceId,
+ const dht::Value::Id& vid)
+{
+ if (ec == asio::error::operation_aborted)
+ return;
+ auto info = getInfo(deviceId, vid);
+ if (!info)
+ return;
+
+ std::unique_lock<std::mutex> lk(info->mutex_);
+ auto& ice = info->ice_;
+ if (isDestroying_) {
+ info->onConnected_(true); // The destructor can wake a pending wait here.
+ return;
+ }
+ if (!info->responseReceived_) {
+ if (config_->logger)
+ config_->logger->error("no response from DHT to E2E request.");
+ info->onConnected_(false);
+ return;
+ }
+
+ if (!info->ice_) {
+ info->onConnected_(false);
+ return;
+ }
+
+ auto sdp = ice->parseIceCandidates(info->response_.ice_msg);
+
+ if (not ice->startIce({sdp.rem_ufrag, sdp.rem_pwd}, std::move(sdp.rem_candidates))) {
+ if (config_->logger)
+ config_->logger->warn("start ICE failed");
+ info->onConnected_(false);
+ return;
+ }
+ info->onConnected_(true);
+}
+
+bool
+ConnectionManager::Impl::connectDeviceOnNegoDone(
+ const DeviceId& deviceId,
+ const std::string& name,
+ const dht::Value::Id& vid,
+ const std::shared_ptr<dht::crypto::Certificate>& cert)
+{
+ auto info = getInfo(deviceId, vid);
+ if (!info)
+ return false;
+
+ std::unique_lock<std::mutex> lk {info->mutex_};
+ if (info->waitForAnswer_) {
+ // Negotiation is done and connected, go to handshake
+ // and avoid any cancellation at this point.
+ info->waitForAnswer_->cancel();
+ }
+ auto& ice = info->ice_;
+ if (!ice || !ice->isRunning()) {
+ if (config_->logger)
+ config_->logger->error("No ICE detected or not running");
+ return false;
+ }
+
+ // Build socket
+ auto endpoint = std::make_unique<IceSocketEndpoint>(std::shared_ptr<IceTransport>(
+ std::move(ice)),
+ true);
+
+ // Negotiate a TLS session
+ if (config_->logger)
+ config_->logger->debug("Start TLS session - Initied by connectDevice(). Launched by channel: {} - device: {} - vid: {}", name, deviceId, vid);
+ info->tls_ = std::make_unique<TlsSocketEndpoint>(std::move(endpoint),
+ certStore(),
+ identity(),
+ dhParams(),
+ *cert);
+
+ info->tls_->setOnReady(
+ [w = weak(), deviceId = std::move(deviceId), vid = std::move(vid), name = std::move(name)](
+ bool ok) {
+ if (auto shared = w.lock())
+ shared->onTlsNegotiationDone(ok, deviceId, vid, name);
+ });
+ return true;
+}
+
+void
+ConnectionManager::Impl::connectDevice(const DeviceId& deviceId,
+ const std::string& name,
+ ConnectCallback cb,
+ bool noNewSocket,
+ bool forceNewSocket,
+ const std::string& connType)
+{
+ if (!dht()) {
+ cb(nullptr, deviceId);
+ return;
+ }
+ if (deviceId.toString() == identity().second->getLongId().toString()) {
+ cb(nullptr, deviceId);
+ return;
+ }
+ findCertificate(deviceId,
+ [w = weak(),
+ deviceId,
+ name,
+ cb = std::move(cb),
+ noNewSocket,
+ forceNewSocket,
+ connType](const std::shared_ptr<dht::crypto::Certificate>& cert) {
+ if (!cert) {
+ if (auto shared = w.lock())
+ if (shared->config_->logger)
+ shared->config_->logger->error(
+ "No valid certificate found for device {}",
+ deviceId);
+ cb(nullptr, deviceId);
+ return;
+ }
+ if (auto shared = w.lock()) {
+ shared->connectDevice(cert,
+ name,
+ std::move(cb),
+ noNewSocket,
+ forceNewSocket,
+ connType);
+ } else
+ cb(nullptr, deviceId);
+ });
+}
+
+void
+ConnectionManager::Impl::connectDevice(const std::shared_ptr<dht::crypto::Certificate>& cert,
+ const std::string& name,
+ ConnectCallback cb,
+ bool noNewSocket,
+ bool forceNewSocket,
+ const std::string& connType)
+{
+ // Avoid dht operation in a DHT callback to avoid deadlocks
+ dht::ThreadPool::computation().run([w = weak(),
+ name = std::move(name),
+ cert = std::move(cert),
+ cb = std::move(cb),
+ noNewSocket,
+ forceNewSocket,
+ connType] {
+ auto devicePk = cert->getSharedPublicKey();
+ auto deviceId = devicePk->getLongId();
+ auto sthis = w.lock();
+ if (!sthis || sthis->isDestroying_) {
+ cb(nullptr, deviceId);
+ return;
+ }
+ dht::Value::Id vid = ValueIdDist(1, ID_MAX_VAL)(sthis->rand);
+ auto isConnectingToDevice = false;
+ {
+ std::lock_guard<std::mutex> lk(sthis->connectCbsMtx_);
+ auto pendingsIt = sthis->pendingCbs_.find(deviceId);
+ if (pendingsIt != sthis->pendingCbs_.end()) {
+ const auto& pendings = pendingsIt->second;
+ while (std::find_if(pendings.begin(), pendings.end(), [&](const auto& it){ return it.vid == vid; }) != pendings.end()) {
+ vid = ValueIdDist(1, ID_MAX_VAL)(sthis->rand);
+ }
+ }
+ // Check if already connecting
+ isConnectingToDevice = pendingsIt != sthis->pendingCbs_.end();
+ // Save current request for sendChannelRequest.
+ // Note: do not return here, cause we can be in a state where first
+ // socket is negotiated and first channel is pending
+ // so return only after we checked the info
+ if (isConnectingToDevice)
+ pendingsIt->second.emplace_back(PendingCb {name, std::move(cb), vid});
+ else
+ sthis->pendingCbs_[deviceId] = {{name, std::move(cb), vid}};
+ }
+
+ // Check if already negotiated
+ CallbackId cbId(deviceId, vid);
+ if (auto info = sthis->getConnectedInfo(deviceId)) {
+ std::lock_guard<std::mutex> lk(info->mutex_);
+ if (info->socket_) {
+ if (sthis->config_->logger)
+ sthis->config_->logger->debug("Peer already connected to {}. Add a new channel", deviceId);
+ info->cbIds_.emplace(cbId);
+ sthis->sendChannelRequest(info->socket_, name, deviceId, vid);
+ return;
+ }
+ }
+
+ if (isConnectingToDevice && !forceNewSocket) {
+ if (sthis->config_->logger)
+ sthis->config_->logger->debug("Already connecting to {}, wait for the ICE negotiation", deviceId);
+ return;
+ }
+ if (noNewSocket) {
+ // If no new socket is specified, we don't try to generate a new socket
+ for (const auto& pending : sthis->extractPendingCallbacks(deviceId, vid))
+ pending.cb(nullptr, deviceId);
+ return;
+ }
+
+ // Note: used when the ice negotiation fails to erase
+ // all stored structures.
+ auto eraseInfo = [w, cbId] {
+ if (auto shared = w.lock()) {
+ // If no new socket is specified, we don't try to generate a new socket
+ for (const auto& pending : shared->extractPendingCallbacks(cbId.first, cbId.second))
+ pending.cb(nullptr, cbId.first);
+ std::lock_guard<std::mutex> lk(shared->infosMtx_);
+ shared->infos_.erase(cbId);
+ }
+ };
+
+ // If no socket exists, we need to initiate an ICE connection.
+ sthis->getIceOptions([w,
+ deviceId = std::move(deviceId),
+ devicePk = std::move(devicePk),
+ name = std::move(name),
+ cert = std::move(cert),
+ vid,
+ connType,
+ eraseInfo](auto&& ice_config) {
+ auto sthis = w.lock();
+ if (!sthis) {
+ dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
+ return;
+ }
+ ice_config.tcpEnable = true;
+ ice_config.onInitDone = [w,
+ deviceId = std::move(deviceId),
+ devicePk = std::move(devicePk),
+ name = std::move(name),
+ cert = std::move(cert),
+ vid,
+ connType,
+ eraseInfo](bool ok) {
+ dht::ThreadPool::io().run([w = std::move(w),
+ devicePk = std::move(devicePk),
+ vid = std::move(vid),
+ eraseInfo,
+ connType, ok] {
+ auto sthis = w.lock();
+ if (!ok && sthis && sthis->config_->logger)
+ sthis->config_->logger->error("Cannot initialize ICE session.");
+ if (!sthis || !ok) {
+ eraseInfo();
+ return;
+ }
+ sthis->connectDeviceStartIce(devicePk, vid, connType, [=](bool ok) {
+ if (!ok) {
+ dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
+ }
+ });
+ });
+ };
+ ice_config.onNegoDone = [w,
+ deviceId,
+ name,
+ cert = std::move(cert),
+ vid,
+ eraseInfo](bool ok) {
+ dht::ThreadPool::io().run([w = std::move(w),
+ deviceId = std::move(deviceId),
+ name = std::move(name),
+ cert = std::move(cert),
+ vid = std::move(vid),
+ eraseInfo = std::move(eraseInfo),
+ ok] {
+ auto sthis = w.lock();
+ if (!ok && sthis && sthis->config_->logger)
+ sthis->config_->logger->error("ICE negotiation failed.");
+ if (!sthis || !ok || !sthis->connectDeviceOnNegoDone(deviceId, name, vid, cert))
+ eraseInfo();
+ });
+ };
+
+ auto info = std::make_shared<ConnectionInfo>();
+ {
+ std::lock_guard<std::mutex> lk(sthis->infosMtx_);
+ sthis->infos_[{deviceId, vid}] = info;
+ }
+ std::unique_lock<std::mutex> lk {info->mutex_};
+ ice_config.master = false;
+ ice_config.streamsCount = 1;
+ ice_config.compCountPerStream = 1;
+ info->ice_ = sthis->iceFactory_.createUTransport("");
+ if (!info->ice_) {
+ if (sthis->config_->logger)
+ sthis->config_->logger->error("Cannot initialize ICE session.");
+ eraseInfo();
+ return;
+ }
+ // We need to detect any shutdown if the ice session is destroyed before going to the
+ // TLS session;
+ info->ice_->setOnShutdown([eraseInfo]() {
+ dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
+ });
+ info->ice_->initIceInstance(ice_config);
+ });
+ });
+}
+
+void
+ConnectionManager::Impl::sendChannelRequest(std::shared_ptr<MultiplexedSocket>& sock,
+ const std::string& name,
+ const DeviceId& deviceId,
+ const dht::Value::Id& vid)
+{
+ auto channelSock = sock->addChannel(name);
+ channelSock->onShutdown([name, deviceId, vid, w = weak()] {
+ auto shared = w.lock();
+ if (shared)
+ for (const auto& pending : shared->extractPendingCallbacks(deviceId, vid))
+ pending.cb(nullptr, deviceId);
+ });
+ channelSock->onReady(
+ [wSock = std::weak_ptr<ChannelSocket>(channelSock), name, deviceId, vid, w = weak()]() {
+ auto shared = w.lock();
+ auto channelSock = wSock.lock();
+ if (shared)
+ for (const auto& pending : shared->extractPendingCallbacks(deviceId, vid))
+ pending.cb(channelSock, deviceId);
+ });
+
+ ChannelRequest val;
+ val.name = channelSock->name();
+ val.state = ChannelRequestState::REQUEST;
+ val.channel = channelSock->channel();
+ msgpack::sbuffer buffer(256);
+ msgpack::pack(buffer, val);
+
+ std::error_code ec;
+ int res = sock->write(CONTROL_CHANNEL,
+ reinterpret_cast<const uint8_t*>(buffer.data()),
+ buffer.size(),
+ ec);
+ if (res < 0) {
+ // TODO check if we should handle errors here
+ if (config_->logger)
+ config_->logger->error("sendChannelRequest failed - error: {}", ec.message());
+ }
+}
+
+void
+ConnectionManager::Impl::onPeerResponse(const PeerConnectionRequest& req)
+{
+ auto device = req.owner->getLongId();
+ if (config_->logger)
+ config_->logger->debug("New response received from {}", device);
+ if (auto info = getInfo(device, req.id)) {
+ std::lock_guard<std::mutex> lk {info->mutex_};
+ info->responseReceived_ = true;
+ info->response_ = std::move(req);
+ info->waitForAnswer_->expires_at(std::chrono::steady_clock::now());
+ info->waitForAnswer_->async_wait(std::bind(&ConnectionManager::Impl::onResponse,
+ this,
+ std::placeholders::_1,
+ device,
+ req.id));
+ } else {
+ if (config_->logger)
+ config_->logger->warn("Respond received, but cannot find request");
+ }
+}
+
+void
+ConnectionManager::Impl::onDhtConnected(const dht::crypto::PublicKey& devicePk)
+{
+ if (!dht())
+ return;
+ dht()->listen<PeerConnectionRequest>(
+ dht::InfoHash::get(PeerConnectionRequest::key_prefix + devicePk.getId().toString()),
+ [w = weak()](PeerConnectionRequest&& req) {
+ auto shared = w.lock();
+ if (!shared)
+ return false;
+ if (shared->isMessageTreated(to_hex_string(req.id))) {
+ // Message already treated. Just ignore
+ return true;
+ }
+ if (req.isAnswer) {
+ if (shared->config_->logger)
+ shared->config_->logger->debug("Received request answer from {}", req.owner->getLongId());
+ } else {
+ if (shared->config_->logger)
+ shared->config_->logger->debug("Received request from {}", req.owner->getLongId());
+ }
+ if (req.isAnswer) {
+ shared->onPeerResponse(req);
+ } else {
+ // Async certificate checking
+ shared->dht()->findCertificate(
+ req.from,
+ [w, req = std::move(req)](
+ const std::shared_ptr<dht::crypto::Certificate>& cert) mutable {
+ auto shared = w.lock();
+ if (!shared)
+ return;
+ dht::InfoHash peer_h;
+ if (foundPeerDevice(cert, peer_h, shared->config_->logger)) {
+#if TARGET_OS_IOS
+ if (shared->iOSConnectedCb_(req.connType, peer_h))
+ return;
+#endif
+ shared->onDhtPeerRequest(req, cert);
+ } else {
+ if (shared->config_->logger)
+ shared->config_->logger->warn(
+ "Received request from untrusted peer {}",
+ req.owner->getLongId());
+ }
+ });
+ }
+
+ return true;
+ },
+ dht::Value::UserTypeFilter("peer_request"));
+}
+
+void
+ConnectionManager::Impl::onTlsNegotiationDone(bool ok,
+ const DeviceId& deviceId,
+ const dht::Value::Id& vid,
+ const std::string& name)
+{
+ if (isDestroying_)
+ return;
+ // Note: only handle pendingCallbacks here for TLS initied by connectDevice()
+ // Note: if not initied by connectDevice() the channel name will be empty (because no channel
+ // asked yet)
+ auto isDhtRequest = name.empty();
+ if (!ok) {
+ if (isDhtRequest) {
+ if (config_->logger)
+ config_->logger->error("TLS connection failure for peer {} - Initied by DHT request. channel: {} - vid: {}",
+ deviceId,
+ name,
+ vid);
+ if (connReadyCb_)
+ connReadyCb_(deviceId, "", nullptr);
+ } else {
+ if (config_->logger)
+ config_->logger->error("TLS connection failure for peer {} - Initied by connectDevice. channel: {} - vid: {}",
+ deviceId,
+ name,
+ vid);
+ for (const auto& pending : extractPendingCallbacks(deviceId))
+ pending.cb(nullptr, deviceId);
+ }
+ } else {
+ // The socket is ready, store it
+ if (isDhtRequest) {
+ if (config_->logger)
+ config_->logger->debug("Connection to {} is ready - Initied by DHT request. Vid: {}",
+ deviceId,
+ vid);
+ } else {
+ if (config_->logger)
+ config_->logger->debug("Connection to {} is ready - Initied by connectDevice(). channel: {} - vid: {}",
+ deviceId,
+ name,
+ vid);
+ }
+
+ auto info = getInfo(deviceId, vid);
+ addNewMultiplexedSocket({deviceId, vid}, info);
+ // Finally, open the channel and launch pending callbacks
+ if (info->socket_) {
+ // Note: do not remove pending there it's done in sendChannelRequest
+ for (const auto& pending : getPendingCallbacks(deviceId)) {
+ if (config_->logger)
+ config_->logger->debug("Send request on TLS socket for channel {} to {}",
+ pending.name,
+ deviceId);
+ sendChannelRequest(info->socket_, pending.name, deviceId, pending.vid);
+ }
+ }
+ }
+}
+
+void
+ConnectionManager::Impl::answerTo(IceTransport& ice,
+ const dht::Value::Id& id,
+ const std::shared_ptr<dht::crypto::PublicKey>& from)
+{
+ // NOTE: This is a shortest version of a real SDP message to save some bits
+ auto iceAttributes = ice.getLocalAttributes();
+ std::ostringstream icemsg;
+ icemsg << iceAttributes.ufrag << "\n";
+ icemsg << iceAttributes.pwd << "\n";
+ for (const auto& addr : ice.getLocalCandidates(1)) {
+ icemsg << addr << "\n";
+ }
+
+ // Send PeerConnection response
+ PeerConnectionRequest val;
+ val.id = id;
+ val.ice_msg = icemsg.str();
+ val.isAnswer = true;
+ auto value = std::make_shared<dht::Value>(std::move(val));
+ value->user_type = "peer_request";
+
+ if (config_->logger)
+ config_->logger->debug("Connection accepted, DHT reply to {}", from->getLongId());
+ dht()->putEncrypted(dht::InfoHash::get(PeerConnectionRequest::key_prefix
+ + from->getId().toString()),
+ from,
+ value,
+ [from,l=config_->logger](bool ok) {
+ if (l)
+ l->debug("Answer to connection request from {:s}. Put encrypted {:s}",
+ from->getLongId(),
+ (ok ? "ok" : "failed"));
+ });
+}
+
+bool
+ConnectionManager::Impl::onRequestStartIce(const PeerConnectionRequest& req)
+{
+ auto deviceId = req.owner->getLongId();
+ auto info = getInfo(deviceId, req.id);
+ if (!info)
+ return false;
+
+ std::unique_lock<std::mutex> lk {info->mutex_};
+ auto& ice = info->ice_;
+ if (!ice) {
+ if (config_->logger)
+ config_->logger->error("No ICE detected");
+ if (connReadyCb_)
+ connReadyCb_(deviceId, "", nullptr);
+ return false;
+ }
+
+ auto sdp = ice->parseIceCandidates(req.ice_msg);
+ answerTo(*ice, req.id, req.owner);
+ if (not ice->startIce({sdp.rem_ufrag, sdp.rem_pwd}, std::move(sdp.rem_candidates))) {
+ if (config_->logger)
+ config_->logger->error("Start ICE failed - fallback to TURN");
+ ice = nullptr;
+ if (connReadyCb_)
+ connReadyCb_(deviceId, "", nullptr);
+ return false;
+ }
+ return true;
+}
+
+bool
+ConnectionManager::Impl::onRequestOnNegoDone(const PeerConnectionRequest& req)
+{
+ auto deviceId = req.owner->getLongId();
+ auto info = getInfo(deviceId, req.id);
+ if (!info)
+ return false;
+
+ std::unique_lock<std::mutex> lk {info->mutex_};
+ auto& ice = info->ice_;
+ if (!ice) {
+ if (config_->logger)
+ config_->logger->error("No ICE detected");
+ return false;
+ }
+
+ // Build socket
+ auto endpoint = std::make_unique<IceSocketEndpoint>(std::shared_ptr<IceTransport>(
+ std::move(ice)),
+ false);
+
+ // init TLS session
+ auto ph = req.from;
+ if (config_->logger)
+ config_->logger->debug("Start TLS session - Initied by DHT request. Device: {} - vid: {}",
+ req.from,
+ req.id);
+ info->tls_ = std::make_unique<TlsSocketEndpoint>(
+ std::move(endpoint),
+ certStore(),
+ identity(),
+ dhParams(),
+ [ph, w = weak()](const dht::crypto::Certificate& cert) {
+ auto shared = w.lock();
+ if (!shared)
+ return false;
+ auto crt = shared->certStore().getCertificate(cert.getLongId().toString());
+ if (!crt)
+ return false;
+ return crt->getPacked() == cert.getPacked();
+ });
+
+ info->tls_->setOnReady(
+ [w = weak(), deviceId = std::move(deviceId), vid = std::move(req.id)](bool ok) {
+ if (auto shared = w.lock())
+ shared->onTlsNegotiationDone(ok, deviceId, vid);
+ });
+ return true;
+}
+
+void
+ConnectionManager::Impl::onDhtPeerRequest(const PeerConnectionRequest& req,
+ const std::shared_ptr<dht::crypto::Certificate>& /*cert*/)
+{
+ auto deviceId = req.owner->getLongId();
+ if (config_->logger)
+ config_->logger->debug("New connection request from {}", deviceId);
+ if (!iceReqCb_ || !iceReqCb_(deviceId)) {
+ if (config_->logger)
+ config_->logger->debug("Refuse connection from {}", deviceId);
+ return;
+ }
+
+ // Because the connection is accepted, create an ICE socket.
+ getIceOptions([w = weak(), req, deviceId](auto&& ice_config) {
+ auto shared = w.lock();
+ if (!shared)
+ return;
+ // Note: used when the ice negotiation fails to erase
+ // all stored structures.
+ auto eraseInfo = [w, id = req.id, deviceId] {
+ if (auto shared = w.lock()) {
+ // If no new socket is specified, we don't try to generate a new socket
+ for (const auto& pending : shared->extractPendingCallbacks(deviceId, id))
+ pending.cb(nullptr, deviceId);
+ if (shared->connReadyCb_)
+ shared->connReadyCb_(deviceId, "", nullptr);
+ std::lock_guard<std::mutex> lk(shared->infosMtx_);
+ shared->infos_.erase({deviceId, id});
+ }
+ };
+
+ ice_config.tcpEnable = true;
+ ice_config.onInitDone = [w, req, eraseInfo](bool ok) {
+ auto shared = w.lock();
+ if (!shared)
+ return;
+ if (!ok) {
+ if (shared->config_->logger)
+ shared->config_->logger->error("Cannot initialize ICE session.");
+ dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
+ return;
+ }
+
+ dht::ThreadPool::io().run(
+ [w = std::move(w), req = std::move(req), eraseInfo = std::move(eraseInfo)] {
+ auto shared = w.lock();
+ if (!shared)
+ return;
+ if (!shared->onRequestStartIce(req))
+ eraseInfo();
+ });
+ };
+
+ ice_config.onNegoDone = [w, req, eraseInfo](bool ok) {
+ auto shared = w.lock();
+ if (!shared)
+ return;
+ if (!ok) {
+ if (shared->config_->logger)
+ shared->config_->logger->error("ICE negotiation failed.");
+ dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
+ return;
+ }
+
+ dht::ThreadPool::io().run(
+ [w = std::move(w), req = std::move(req), eraseInfo = std::move(eraseInfo)] {
+ if (auto shared = w.lock())
+ if (!shared->onRequestOnNegoDone(req))
+ eraseInfo();
+ });
+ };
+
+ // Negotiate a new ICE socket
+ auto info = std::make_shared<ConnectionInfo>();
+ {
+ std::lock_guard<std::mutex> lk(shared->infosMtx_);
+ shared->infos_[{deviceId, req.id}] = info;
+ }
+ if (shared->config_->logger)
+ shared->config_->logger->debug("Accepting connection from {}", deviceId);
+ std::unique_lock<std::mutex> lk {info->mutex_};
+ ice_config.streamsCount = 1;
+ ice_config.compCountPerStream = 1; // TCP
+ ice_config.master = true;
+ info->ice_ = shared->iceFactory_.createUTransport("");
+ if (not info->ice_) {
+ if (shared->config_->logger)
+ shared->config_->logger->error("Cannot initialize ICE session");
+ eraseInfo();
+ return;
+ }
+ // We need to detect any shutdown if the ice session is destroyed before going to the TLS session;
+ info->ice_->setOnShutdown([eraseInfo]() {
+ dht::ThreadPool::io().run([eraseInfo = std::move(eraseInfo)] { eraseInfo(); });
+ });
+ info->ice_->initIceInstance(ice_config);
+ });
+}
+
+void
+ConnectionManager::Impl::addNewMultiplexedSocket(const CallbackId& id, const std::shared_ptr<ConnectionInfo>& info)
+{
+ info->socket_ = std::make_shared<MultiplexedSocket>(config_->ioContext, id.first, std::move(info->tls_));
+ info->socket_->setOnReady(
+ [w = weak()](const DeviceId& deviceId, const std::shared_ptr<ChannelSocket>& socket) {
+ if (auto sthis = w.lock())
+ if (sthis->connReadyCb_)
+ sthis->connReadyCb_(deviceId, socket->name(), socket);
+ });
+ info->socket_->setOnRequest([w = weak()](const std::shared_ptr<dht::crypto::Certificate>& peer,
+ const uint16_t&,
+ const std::string& name) {
+ if (auto sthis = w.lock())
+ if (sthis->channelReqCb_)
+ return sthis->channelReqCb_(peer, name);
+ return false;
+ });
+ info->socket_->onShutdown([w = weak(), deviceId=id.first, vid=id.second]() {
+ // Cancel current outgoing connections
+ dht::ThreadPool::io().run([w, deviceId, vid] {
+ auto sthis = w.lock();
+ if (!sthis)
+ return;
+
+ std::set<CallbackId> ids;
+ if (auto info = sthis->getInfo(deviceId, vid)) {
+ std::lock_guard<std::mutex> lk(info->mutex_);
+ if (info->socket_) {
+ ids = std::move(info->cbIds_);
+ info->socket_->shutdown();
+ }
+ }
+ for (const auto& cbId : ids)
+ for (const auto& pending : sthis->extractPendingCallbacks(cbId.first, cbId.second))
+ pending.cb(nullptr, deviceId);
+
+ std::lock_guard<std::mutex> lk(sthis->infosMtx_);
+ sthis->infos_.erase({deviceId, vid});
+ });
+ });
+}
+
+const std::shared_future<tls::DhParams>
+ConnectionManager::Impl::dhParams() const
+{
+ return dht::ThreadPool::computation().get<tls::DhParams>(
+ std::bind(tls::DhParams::loadDhParams, config_->cachePath + DIR_SEPARATOR_STR "dhParams"));
+ ;
+}
+
+template<typename ID = dht::Value::Id>
+std::set<ID, std::less<>>
+loadIdList(const std::string& path)
+{
+ std::set<ID, std::less<>> ids;
+ std::ifstream file = fileutils::ifstream(path);
+ if (!file.is_open()) {
+ //JAMI_DBG("Could not load %s", path.c_str());
+ return ids;
+ }
+ std::string line;
+ while (std::getline(file, line)) {
+ if constexpr (std::is_same<ID, std::string>::value) {
+ ids.emplace(std::move(line));
+ } else if constexpr (std::is_integral<ID>::value) {
+ ID vid;
+ if (auto [p, ec] = std::from_chars(line.data(), line.data() + line.size(), vid, 16);
+ ec == std::errc()) {
+ ids.emplace(vid);
+ }
+ }
+ }
+ return ids;
+}
+
+template<typename List = std::set<dht::Value::Id>>
+void
+saveIdList(const std::string& path, const List& ids)
+{
+ std::ofstream file = fileutils::ofstream(path, std::ios::trunc | std::ios::binary);
+ if (!file.is_open()) {
+ //JAMI_ERR("Could not save to %s", path.c_str());
+ return;
+ }
+ for (auto& c : ids)
+ file << std::hex << c << "\n";
+}
+
+void
+ConnectionManager::Impl::loadTreatedMessages()
+{
+ std::lock_guard<std::mutex> lock(messageMutex_);
+ auto path = config_->cachePath + DIR_SEPARATOR_STR "treatedMessages";
+ treatedMessages_ = loadIdList<std::string>(path);
+ if (treatedMessages_.empty()) {
+ auto messages = loadIdList(path);
+ for (const auto& m : messages)
+ treatedMessages_.emplace(to_hex_string(m));
+ }
+}
+
+void
+ConnectionManager::Impl::saveTreatedMessages() const
+{
+ dht::ThreadPool::io().run([w = weak()]() {
+ if (auto sthis = w.lock()) {
+ auto& this_ = *sthis;
+ std::lock_guard<std::mutex> lock(this_.messageMutex_);
+ fileutils::check_dir(this_.config_->cachePath.c_str());
+ saveIdList<decltype(this_.treatedMessages_)>(this_.config_->cachePath
+ + DIR_SEPARATOR_STR "treatedMessages",
+ this_.treatedMessages_);
+ }
+ });
+}
+
+bool
+ConnectionManager::Impl::isMessageTreated(std::string_view id)
+{
+ std::lock_guard<std::mutex> lock(messageMutex_);
+ auto res = treatedMessages_.emplace(id);
+ if (res.second) {
+ saveTreatedMessages();
+ return false;
+ }
+ return true;
+}
+
+/**
+ * returns whether or not UPnP is enabled and active_
+ * ie: if it is able to make port mappings
+ */
+bool
+ConnectionManager::Impl::getUPnPActive() const
+{
+ return config_->getUPnPActive();
+}
+
+IpAddr
+ConnectionManager::Impl::getPublishedIpAddress(uint16_t family) const
+{
+ if (family == AF_INET)
+ return publishedIp_[0];
+ if (family == AF_INET6)
+ return publishedIp_[1];
+
+ assert(family == AF_UNSPEC);
+
+ // If family is not set, prefere IPv4 if available. It's more
+ // likely to succeed behind NAT.
+ if (publishedIp_[0])
+ return publishedIp_[0];
+ if (publishedIp_[1])
+ return publishedIp_[1];
+ return {};
+}
+
+void
+ConnectionManager::Impl::setPublishedAddress(const IpAddr& ip_addr)
+{
+ if (ip_addr.getFamily() == AF_INET) {
+ publishedIp_[0] = ip_addr;
+ } else {
+ publishedIp_[1] = ip_addr;
+ }
+}
+
+void
+ConnectionManager::Impl::storeActiveIpAddress(std::function<void()>&& cb)
+{
+ dht()->getPublicAddress([this, cb = std::move(cb)](std::vector<dht::SockAddr>&& results) {
+ bool hasIpv4 {false}, hasIpv6 {false};
+ for (auto& result : results) {
+ auto family = result.getFamily();
+ if (family == AF_INET) {
+ if (not hasIpv4) {
+ hasIpv4 = true;
+ if (config_->logger)
+ config_->logger->debug("Store DHT public IPv4 address: {}", result);
+ //JAMI_DBG("Store DHT public IPv4 address : %s", result.toString().c_str());
+ setPublishedAddress(*result.get());
+ if (config_->upnpCtrl) {
+ config_->upnpCtrl->setPublicAddress(*result.get());
+ }
+ }
+ } else if (family == AF_INET6) {
+ if (not hasIpv6) {
+ hasIpv6 = true;
+ if (config_->logger)
+ config_->logger->debug("Store DHT public IPv6 address: {}", result);
+ setPublishedAddress(*result.get());
+ }
+ }
+ if (hasIpv4 and hasIpv6)
+ break;
+ }
+ if (cb)
+ cb();
+ });
+}
+
+void
+ConnectionManager::Impl::getIceOptions(std::function<void(IceTransportOptions&&)> cb) noexcept
+{
+ storeActiveIpAddress([this, cb = std::move(cb)] {
+ IceTransportOptions opts = ConnectionManager::Impl::getIceOptions();
+ auto publishedAddr = getPublishedIpAddress();
+
+ if (publishedAddr) {
+ auto interfaceAddr = ip_utils::getInterfaceAddr(getLocalInterface(),
+ publishedAddr.getFamily());
+ if (interfaceAddr) {
+ opts.accountLocalAddr = interfaceAddr;
+ opts.accountPublicAddr = publishedAddr;
+ }
+ }
+ if (cb)
+ cb(std::move(opts));
+ });
+}
+
+IceTransportOptions
+ConnectionManager::Impl::getIceOptions() const noexcept
+{
+ IceTransportOptions opts;
+ opts.upnpEnable = getUPnPActive();
+
+ if (config_->stunEnabled)
+ opts.stunServers.emplace_back(StunServerInfo().setUri(config_->stunServer));
+ if (config_->turnEnabled) {
+ auto cached = false;
+ std::lock_guard<std::mutex> lk(config_->cachedTurnMutex);
+ cached = config_->cacheTurnV4 || config_->cacheTurnV6;
+ if (config_->cacheTurnV4) {
+ opts.turnServers.emplace_back(TurnServerInfo()
+ .setUri(config_->cacheTurnV4.toString())
+ .setUsername(config_->turnServerUserName)
+ .setPassword(config_->turnServerPwd)
+ .setRealm(config_->turnServerRealm));
+ }
+ // NOTE: first test with ipv6 turn was not concluant and resulted in multiple
+ // co issues. So this needs some debug. for now just disable
+ // if (cacheTurnV6 && *cacheTurnV6) {
+ // opts.turnServers.emplace_back(TurnServerInfo()
+ // .setUri(cacheTurnV6->toString(true))
+ // .setUsername(turnServerUserName_)
+ // .setPassword(turnServerPwd_)
+ // .setRealm(turnServerRealm_));
+ //}
+ // Nothing cached, so do the resolution
+ if (!cached) {
+ opts.turnServers.emplace_back(TurnServerInfo()
+ .setUri(config_->turnServer)
+ .setUsername(config_->turnServerUserName)
+ .setPassword(config_->turnServerPwd)
+ .setRealm(config_->turnServerRealm));
+ }
+ }
+ return opts;
+}
+
+bool
+ConnectionManager::Impl::foundPeerDevice(const std::shared_ptr<dht::crypto::Certificate>& crt,
+ dht::InfoHash& account_id,
+ const std::shared_ptr<Logger>& logger)
+{
+ if (not crt)
+ return false;
+
+ auto top_issuer = crt;
+ while (top_issuer->issuer)
+ top_issuer = top_issuer->issuer;
+
+ // Device certificate can't be self-signed
+ if (top_issuer == crt) {
+ if (logger)
+ logger->warn("Found invalid peer device: {}", crt->getLongId());
+ return false;
+ }
+
+ // Check peer certificate chain
+ // Trust store with top issuer as the only CA
+ dht::crypto::TrustList peer_trust;
+ peer_trust.add(*top_issuer);
+ if (not peer_trust.verify(*crt)) {
+ if (logger)
+ logger->warn("Found invalid peer device: {}", crt->getLongId());
+ return false;
+ }
+
+ // Check cached OCSP response
+ if (crt->ocspResponse and crt->ocspResponse->getCertificateStatus() != GNUTLS_OCSP_CERT_GOOD) {
+ if (logger)
+ logger->error("Certificate %s is disabled by cached OCSP response", crt->getLongId());
+ return false;
+ }
+
+ account_id = crt->issuer->getId();
+ if (logger)
+ logger->warn("Found peer device: {} account:{} CA:{}",
+ crt->getLongId(),
+ account_id,
+ top_issuer->getId());
+ return true;
+}
+
+bool
+ConnectionManager::Impl::findCertificate(
+ const dht::PkId& id, std::function<void(const std::shared_ptr<dht::crypto::Certificate>&)>&& cb)
+{
+ if (auto cert = certStore().getCertificate(id.toString())) {
+ if (cb)
+ cb(cert);
+ } else if (cb)
+ cb(nullptr);
+ return true;
+}
+
+ConnectionManager::ConnectionManager(std::shared_ptr<ConnectionManager::Config> config_)
+ : pimpl_ {std::make_shared<Impl>(config_)}
+{}
+
+ConnectionManager::~ConnectionManager()
+{
+ if (pimpl_)
+ pimpl_->shutdown();
+}
+
+void
+ConnectionManager::connectDevice(const DeviceId& deviceId,
+ const std::string& name,
+ ConnectCallback cb,
+ bool noNewSocket,
+ bool forceNewSocket,
+ const std::string& connType)
+{
+ pimpl_->connectDevice(deviceId, name, std::move(cb), noNewSocket, forceNewSocket, connType);
+}
+
+void
+ConnectionManager::connectDevice(const std::shared_ptr<dht::crypto::Certificate>& cert,
+ const std::string& name,
+ ConnectCallback cb,
+ bool noNewSocket,
+ bool forceNewSocket,
+ const std::string& connType)
+{
+ pimpl_->connectDevice(cert, name, std::move(cb), noNewSocket, forceNewSocket, connType);
+}
+
+bool
+ConnectionManager::isConnecting(const DeviceId& deviceId, const std::string& name) const
+{
+ auto pending = pimpl_->getPendingCallbacks(deviceId);
+ return std::find_if(pending.begin(), pending.end(), [&](auto p) { return p.name == name; })
+ != pending.end();
+}
+
+void
+ConnectionManager::closeConnectionsWith(const std::string& peerUri)
+{
+ std::vector<std::shared_ptr<ConnectionInfo>> connInfos;
+ std::set<DeviceId> peersDevices;
+ {
+ std::lock_guard<std::mutex> lk(pimpl_->infosMtx_);
+ for (auto iter = pimpl_->infos_.begin(); iter != pimpl_->infos_.end();) {
+ auto const& [key, value] = *iter;
+ auto deviceId = key.first;
+ auto cert = pimpl_->certStore().getCertificate(deviceId.toString());
+ if (cert && cert->issuer && peerUri == cert->issuer->getId().toString()) {
+ connInfos.emplace_back(value);
+ peersDevices.emplace(deviceId);
+ iter = pimpl_->infos_.erase(iter);
+ } else {
+ iter++;
+ }
+ }
+ }
+ // Stop connections to all peers devices
+ for (const auto& deviceId : peersDevices) {
+ for (const auto& pending : pimpl_->extractPendingCallbacks(deviceId))
+ pending.cb(nullptr, deviceId);
+ // This will close the TLS Session
+ pimpl_->removeUnusedConnections(deviceId);
+ }
+ for (auto& info : connInfos) {
+ if (info->socket_)
+ info->socket_->shutdown();
+ if (info->waitForAnswer_)
+ info->waitForAnswer_->cancel();
+ if (info->ice_) {
+ std::unique_lock<std::mutex> lk {info->mutex_};
+ dht::ThreadPool::io().run(
+ [ice = std::shared_ptr<IceTransport>(std::move(info->ice_))] {});
+ }
+ }
+}
+
+void
+ConnectionManager::onDhtConnected(const dht::crypto::PublicKey& devicePk)
+{
+ pimpl_->onDhtConnected(devicePk);
+}
+
+void
+ConnectionManager::onICERequest(onICERequestCallback&& cb)
+{
+ pimpl_->iceReqCb_ = std::move(cb);
+}
+
+void
+ConnectionManager::onChannelRequest(ChannelRequestCallback&& cb)
+{
+ pimpl_->channelReqCb_ = std::move(cb);
+}
+
+void
+ConnectionManager::onConnectionReady(ConnectionReadyCallback&& cb)
+{
+ pimpl_->connReadyCb_ = std::move(cb);
+}
+
+void
+ConnectionManager::oniOSConnected(iOSConnectedCallback&& cb)
+{
+ pimpl_->iOSConnectedCb_ = std::move(cb);
+}
+
+std::size_t
+ConnectionManager::activeSockets() const
+{
+ std::lock_guard<std::mutex> lk(pimpl_->infosMtx_);
+ return pimpl_->infos_.size();
+}
+
+void
+ConnectionManager::monitor() const
+{
+ std::lock_guard<std::mutex> lk(pimpl_->infosMtx_);
+ auto logger = pimpl_->config_->logger;
+ if (!logger)
+ return;
+ logger->debug("ConnectionManager current status:");
+ for (const auto& [_, ci] : pimpl_->infos_) {
+ if (ci->socket_)
+ ci->socket_->monitor();
+ }
+ logger->debug("ConnectionManager end status.");
+}
+
+void
+ConnectionManager::connectivityChanged()
+{
+ std::lock_guard<std::mutex> lk(pimpl_->infosMtx_);
+ for (const auto& [_, ci] : pimpl_->infos_) {
+ if (ci->socket_)
+ ci->socket_->sendBeacon();
+ }
+}
+
+void
+ConnectionManager::getIceOptions(std::function<void(IceTransportOptions&&)> cb) noexcept
+{
+ return pimpl_->getIceOptions(std::move(cb));
+}
+
+IceTransportOptions
+ConnectionManager::getIceOptions() const noexcept
+{
+ return pimpl_->getIceOptions();
+}
+
+IpAddr
+ConnectionManager::getPublishedIpAddress(uint16_t family) const
+{
+ return pimpl_->getPublishedIpAddress(family);
+}
+
+void
+ConnectionManager::setPublishedAddress(const IpAddr& ip_addr)
+{
+ return pimpl_->setPublishedAddress(ip_addr);
+}
+
+void
+ConnectionManager::storeActiveIpAddress(std::function<void()>&& cb)
+{
+ return pimpl_->storeActiveIpAddress(std::move(cb));
+}
+
+std::shared_ptr<ConnectionManager::Config>
+ConnectionManager::getConfig()
+{
+ return pimpl_->config_;
+}
+
+} // namespace jami
diff --git a/src/fileutils.cpp b/src/fileutils.cpp
new file mode 100644
index 0000000..be911a6
--- /dev/null
+++ b/src/fileutils.cpp
@@ -0,0 +1,878 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Rafaël Carré <rafael.carre@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+//#include "logger.h"
+#include "fileutils.h"
+//#include "archiver.h"
+//#include "compiler_intrinsics.h"
+#include <opendht/crypto.h>
+
+#ifdef RING_UWP
+#include <io.h> // for access and close
+#include "ring_signal.h"
+#endif
+
+#ifdef __APPLE__
+#include <TargetConditionals.h>
+#endif
+
+#if defined(__ANDROID__) || (defined(TARGET_OS_IOS) && TARGET_OS_IOS)
+#include "client/ring_signal.h"
+#endif
+
+#ifdef _WIN32
+#include <windows.h>
+#include "string_utils.h"
+#endif
+
+#include <sys/types.h>
+#include <sys/stat.h>
+
+#ifndef _MSC_VER
+#include <libgen.h>
+#endif
+
+#ifdef _MSC_VER
+#include "windirent.h"
+#else
+#include <dirent.h>
+#endif
+
+#include <signal.h>
+#include <unistd.h>
+#include <fcntl.h>
+#ifndef _WIN32
+#include <pwd.h>
+#else
+#include <shlobj.h>
+#define NAME_MAX 255
+#endif
+#if !defined __ANDROID__ && !defined _WIN32
+#include <wordexp.h>
+#endif
+
+#include <nettle/sha3.h>
+
+#include <sstream>
+#include <fstream>
+#include <iostream>
+#include <stdexcept>
+#include <limits>
+#include <array>
+
+#include <cstdlib>
+#include <cstring>
+#include <cerrno>
+#include <cstddef>
+#include <ciso646>
+
+#include <pj/ctype.h>
+#include <pjlib-util/md5.h>
+
+#include <filesystem>
+
+#define PIDFILE ".ring.pid"
+#define ERASE_BLOCK 4096
+
+namespace jami {
+namespace fileutils {
+
+// returns true if directory exists
+bool
+check_dir(const char* path, [[maybe_unused]] mode_t dirmode, mode_t parentmode)
+{
+ DIR* dir = opendir(path);
+
+ if (!dir) { // doesn't exist
+ if (not recursive_mkdir(path, parentmode)) {
+ perror(path);
+ return false;
+ }
+#ifndef _WIN32
+ if (chmod(path, dirmode) < 0) {
+ //JAMI_ERR("fileutils::check_dir(): chmod() failed on '%s', %s", path, strerror(errno));
+ return false;
+ }
+#endif
+ } else
+ closedir(dir);
+ return true;
+}
+
+std::string
+expand_path(const std::string& path)
+{
+#if defined __ANDROID__ || defined _MSC_VER || defined WIN32 || defined __APPLE__
+ //JAMI_ERR("Path expansion not implemented, returning original");
+ return path;
+#else
+
+ std::string result;
+
+ wordexp_t p;
+ int ret = wordexp(path.c_str(), &p, 0);
+
+ switch (ret) {
+ case WRDE_BADCHAR:
+ JAMI_ERR("Illegal occurrence of newline or one of |, &, ;, <, >, "
+ "(, ), {, }.");
+ return result;
+ case WRDE_BADVAL:
+ JAMI_ERR("An undefined shell variable was referenced");
+ return result;
+ case WRDE_CMDSUB:
+ JAMI_ERR("Command substitution occurred");
+ return result;
+ case WRDE_SYNTAX:
+ JAMI_ERR("Shell syntax error");
+ return result;
+ case WRDE_NOSPACE:
+ JAMI_ERR("Out of memory.");
+ // This is the only error where we must call wordfree
+ break;
+ default:
+ if (p.we_wordc > 0)
+ result = std::string(p.we_wordv[0]);
+ break;
+ }
+
+ wordfree(&p);
+
+ return result;
+#endif
+}
+
+std::mutex&
+getFileLock(const std::string& path)
+{
+ static std::mutex fileLockLock {};
+ static std::map<std::string, std::mutex> fileLocks {};
+
+ std::lock_guard<std::mutex> l(fileLockLock);
+ return fileLocks[path];
+}
+
+bool
+isFile(const std::string& path, bool resolveSymlink)
+{
+ if (path.empty())
+ return false;
+#ifdef _WIN32
+ if (resolveSymlink) {
+ struct _stat64i32 s;
+ if (_wstat(jami::to_wstring(path).c_str(), &s) == 0)
+ return S_ISREG(s.st_mode);
+ } else {
+ DWORD attr = GetFileAttributes(jami::to_wstring(path).c_str());
+ if ((attr != INVALID_FILE_ATTRIBUTES) && !(attr & FILE_ATTRIBUTE_DIRECTORY)
+ && !(attr & FILE_ATTRIBUTE_REPARSE_POINT))
+ return true;
+ }
+#else
+ if (resolveSymlink) {
+ struct stat s;
+ if (stat(path.c_str(), &s) == 0)
+ return S_ISREG(s.st_mode);
+ } else {
+ struct stat s;
+ if (lstat(path.c_str(), &s) == 0)
+ return S_ISREG(s.st_mode);
+ }
+#endif
+
+ return false;
+}
+
+bool
+isDirectory(const std::string& path)
+{
+ struct stat s;
+ if (stat(path.c_str(), &s) == 0)
+ return s.st_mode & S_IFDIR;
+ return false;
+}
+
+bool
+isDirectoryWritable(const std::string& directory)
+{
+ return accessFile(directory, W_OK) == 0;
+}
+
+bool
+hasHardLink(const std::string& path)
+{
+#ifndef _WIN32
+ struct stat s;
+ if (lstat(path.c_str(), &s) == 0)
+ return s.st_nlink > 1;
+#endif
+ return false;
+}
+
+bool
+isSymLink(const std::string& path)
+{
+#ifndef _WIN32
+ struct stat s;
+ if (lstat(path.c_str(), &s) == 0)
+ return S_ISLNK(s.st_mode);
+#elif !defined(_MSC_VER)
+ DWORD attr = GetFileAttributes(jami::to_wstring(path).c_str());
+ if (attr & FILE_ATTRIBUTE_REPARSE_POINT)
+ return true;
+#endif
+ return false;
+}
+
+std::chrono::system_clock::time_point
+writeTime(const std::string& path)
+{
+#ifndef _WIN32
+ struct stat s;
+ auto ret = stat(path.c_str(), &s);
+ if (ret)
+ throw std::runtime_error("Can't check write time for: " + path);
+ return std::chrono::system_clock::from_time_t(s.st_mtime);
+#else
+#if RING_UWP
+ _CREATEFILE2_EXTENDED_PARAMETERS ext_params = {0};
+ ext_params.dwSize = sizeof(CREATEFILE2_EXTENDED_PARAMETERS);
+ ext_params.dwFileAttributes = FILE_ATTRIBUTE_NORMAL;
+ ext_params.dwFileFlags = FILE_FLAG_NO_BUFFERING;
+ ext_params.dwSecurityQosFlags = SECURITY_ANONYMOUS;
+ ext_params.lpSecurityAttributes = nullptr;
+ ext_params.hTemplateFile = nullptr;
+ HANDLE h = CreateFile2(jami::to_wstring(path).c_str(),
+ GENERIC_READ,
+ FILE_SHARE_READ,
+ OPEN_EXISTING,
+ &ext_params);
+#elif _WIN32
+ HANDLE h = CreateFileW(jami::to_wstring(path).c_str(),
+ GENERIC_READ,
+ FILE_SHARE_READ,
+ nullptr,
+ OPEN_EXISTING,
+ FILE_ATTRIBUTE_NORMAL,
+ nullptr);
+#endif
+ if (h == INVALID_HANDLE_VALUE)
+ throw std::runtime_error("Can't open: " + path);
+ FILETIME lastWriteTime;
+ if (!GetFileTime(h, nullptr, nullptr, &lastWriteTime))
+ throw std::runtime_error("Can't check write time for: " + path);
+ CloseHandle(h);
+ SYSTEMTIME sTime;
+ if (!FileTimeToSystemTime(&lastWriteTime, &sTime))
+ throw std::runtime_error("Can't check write time for: " + path);
+ struct tm tm
+ {};
+ tm.tm_year = sTime.wYear - 1900;
+ tm.tm_mon = sTime.wMonth - 1;
+ tm.tm_mday = sTime.wDay;
+ tm.tm_hour = sTime.wHour;
+ tm.tm_min = sTime.wMinute;
+ tm.tm_sec = sTime.wSecond;
+ tm.tm_isdst = -1;
+ return std::chrono::system_clock::from_time_t(mktime(&tm));
+#endif
+}
+
+bool
+createSymlink(const std::string& linkFile, const std::string& target)
+{
+ try {
+ std::filesystem::create_symlink(target, linkFile);
+ } catch (const std::exception& e) {
+ //JAMI_ERR("Couldn't create soft link: %s", e.what());
+ return false;
+ }
+ return true;
+}
+
+bool
+createHardlink(const std::string& linkFile, const std::string& target)
+{
+ try {
+ std::filesystem::create_hard_link(target, linkFile);
+ } catch (const std::exception& e) {
+ //JAMI_ERR("Couldn't create hard link: %s", e.what());
+ return false;
+ }
+ return true;
+}
+
+void
+createFileLink(const std::string& linkFile, const std::string& target, bool hard)
+{
+ if (not hard or not createHardlink(linkFile, target))
+ createSymlink(linkFile, target);
+}
+
+std::string_view
+getFileExtension(std::string_view filename)
+{
+ std::string_view result;
+ auto sep = filename.find_last_of('.');
+ if (sep != std::string_view::npos && sep != filename.size() - 1)
+ result = filename.substr(sep + 1);
+ if (result.size() >= 8)
+ return {};
+ return result;
+}
+
+bool
+isPathRelative(const std::string& path)
+{
+#ifndef _WIN32
+ return not path.empty() and not(path[0] == '/');
+#else
+ return not path.empty() and path.find(":") == std::string::npos;
+#endif
+}
+
+std::string
+getCleanPath(const std::string& base, const std::string& path)
+{
+ if (base.empty() or path.size() < base.size())
+ return path;
+ auto base_sep = base + DIR_SEPARATOR_STR;
+ if (path.compare(0, base_sep.size(), base_sep) == 0)
+ return path.substr(base_sep.size());
+ else
+ return path;
+}
+
+std::string
+getFullPath(const std::string& base, const std::string& path)
+{
+ bool isRelative {not base.empty() and isPathRelative(path)};
+ return isRelative ? base + DIR_SEPARATOR_STR + path : path;
+}
+
+std::vector<uint8_t>
+loadFile(const std::string& path, const std::string& default_dir)
+{
+ std::vector<uint8_t> buffer;
+ std::ifstream file = ifstream(getFullPath(default_dir, path), std::ios::binary);
+ if (!file)
+ throw std::runtime_error("Can't read file: " + path);
+ file.seekg(0, std::ios::end);
+ auto size = file.tellg();
+ if (size > std::numeric_limits<unsigned>::max())
+ throw std::runtime_error("File is too big: " + path);
+ buffer.resize(size);
+ file.seekg(0, std::ios::beg);
+ if (!file.read((char*) buffer.data(), size))
+ throw std::runtime_error("Can't load file: " + path);
+ return buffer;
+}
+
+std::string
+loadTextFile(const std::string& path, const std::string& default_dir)
+{
+ std::string buffer;
+ std::ifstream file = ifstream(getFullPath(default_dir, path));
+ if (!file)
+ throw std::runtime_error("Can't read file: " + path);
+ file.seekg(0, std::ios::end);
+ auto size = file.tellg();
+ if (size > std::numeric_limits<unsigned>::max())
+ throw std::runtime_error("File is too big: " + path);
+ buffer.resize(size);
+ file.seekg(0, std::ios::beg);
+ if (!file.read((char*) buffer.data(), size))
+ throw std::runtime_error("Can't load file: " + path);
+ return buffer;
+}
+
+void
+saveFile(const std::string& path, const uint8_t* data, size_t data_size, [[maybe_unused]] mode_t mode)
+{
+ std::ofstream file = fileutils::ofstream(path, std::ios::trunc | std::ios::binary);
+ if (!file.is_open()) {
+ //JAMI_ERR("Could not write data to %s", path.c_str());
+ return;
+ }
+ file.write((char*) data, data_size);
+#ifndef _WIN32
+ if (chmod(path.c_str(), mode) < 0)
+ /*JAMI_WARN("fileutils::saveFile(): chmod() failed on '%s', %s",
+ path.c_str(),
+ strerror(errno))*/;
+#endif
+}
+
+std::vector<uint8_t>
+loadCacheFile(const std::string& path, std::chrono::system_clock::duration maxAge)
+{
+ // writeTime throws exception if file doesn't exist
+ auto duration = std::chrono::system_clock::now() - writeTime(path);
+ if (duration > maxAge)
+ throw std::runtime_error("file too old");
+
+ //JAMI_DBG("Loading cache file '%.*s'", (int) path.size(), path.c_str());
+ return loadFile(path);
+}
+
+std::string
+loadCacheTextFile(const std::string& path, std::chrono::system_clock::duration maxAge)
+{
+ // writeTime throws exception if file doesn't exist
+ auto duration = std::chrono::system_clock::now() - writeTime(path);
+ if (duration > maxAge)
+ throw std::runtime_error("file too old");
+
+ //JAMI_DBG("Loading cache file '%.*s'", (int) path.size(), path.c_str());
+ return loadTextFile(path);
+}
+
+static size_t
+dirent_buf_size([[maybe_unused]] DIR* dirp)
+{
+ long name_max;
+#if defined(HAVE_FPATHCONF) && defined(HAVE_DIRFD) && defined(_PC_NAME_MAX)
+ name_max = fpathconf(dirfd(dirp), _PC_NAME_MAX);
+ if (name_max == -1)
+#if defined(NAME_MAX)
+ name_max = (NAME_MAX > 255) ? NAME_MAX : 255;
+#else
+ return (size_t) (-1);
+#endif
+#else
+#if defined(NAME_MAX)
+ name_max = (NAME_MAX > 255) ? NAME_MAX : 255;
+#else
+#error "buffer size for readdir_r cannot be determined"
+#endif
+#endif
+ size_t name_end = (size_t) offsetof(struct dirent, d_name) + name_max + 1;
+ return name_end > sizeof(struct dirent) ? name_end : sizeof(struct dirent);
+}
+
+std::vector<std::string>
+readDirectory(const std::string& dir)
+{
+ DIR* dp = opendir(dir.c_str());
+ if (!dp)
+ return {};
+
+ size_t size = dirent_buf_size(dp);
+ if (size == (size_t) (-1))
+ return {};
+ std::vector<uint8_t> buf(size);
+ dirent* entry;
+
+ std::vector<std::string> files;
+#ifndef _WIN32
+ while (!readdir_r(dp, reinterpret_cast<dirent*>(buf.data()), &entry) && entry) {
+#else
+ while ((entry = readdir(dp)) != nullptr) {
+#endif
+ std::string fname {entry->d_name};
+ if (fname == "." || fname == "..")
+ continue;
+ files.emplace_back(std::move(fname));
+ }
+ closedir(dp);
+ return files;
+} // namespace fileutils
+
+/*
+std::vector<uint8_t>
+readArchive(const std::string& path, const std::string& pwd)
+{
+ JAMI_DBG("Reading archive from %s", path.c_str());
+
+ auto isUnencryptedGzip = [](const std::vector<uint8_t>& data) {
+ // NOTE: some webserver modify gzip files and this can end with a gunzip in a gunzip
+ // file. So, to make the readArchive more robust, we can support this case by detecting
+ // gzip header via 1f8b 08
+ // We don't need to support more than 2 level, else somebody may be able to send
+ // gunzip in loops and abuse.
+ return data.size() > 3 && data[0] == 0x1f && data[1] == 0x8b && data[2] == 0x08;
+ };
+
+ auto decompress = [](std::vector<uint8_t>& data) {
+ try {
+ data = archiver::decompress(data);
+ } catch (const std::exception& e) {
+ JAMI_ERR("Error decrypting archive: %s", e.what());
+ throw e;
+ }
+ };
+
+ std::vector<uint8_t> data;
+ // Read file
+ try {
+ data = loadFile(path);
+ } catch (const std::exception& e) {
+ JAMI_ERR("Error loading archive: %s", e.what());
+ throw e;
+ }
+
+ if (isUnencryptedGzip(data)) {
+ if (!pwd.empty())
+ JAMI_WARN() << "A gunzip in a gunzip is detected. A webserver may have a bad config";
+
+ decompress(data);
+ }
+
+ if (!pwd.empty()) {
+ // Decrypt
+ try {
+ data = dht::crypto::aesDecrypt(data, pwd);
+ } catch (const std::exception& e) {
+ JAMI_ERR("Error decrypting archive: %s", e.what());
+ throw e;
+ }
+ decompress(data);
+ } else if (isUnencryptedGzip(data)) {
+ JAMI_WARN() << "A gunzip in a gunzip is detected. A webserver may have a bad config";
+ decompress(data);
+ }
+ return data;
+}
+
+void
+writeArchive(const std::string& archive_str, const std::string& path, const std::string& password)
+{
+ JAMI_DBG("Writing archive to %s", path.c_str());
+
+ if (not password.empty()) {
+ // Encrypt using provided password
+ std::vector<uint8_t> data = dht::crypto::aesEncrypt(archiver::compress(archive_str),
+ password);
+ // Write
+ try {
+ saveFile(path, data);
+ } catch (const std::runtime_error& ex) {
+ JAMI_ERR("Export failed: %s", ex.what());
+ return;
+ }
+ } else {
+ JAMI_WARN("Unsecured archiving (no password)");
+ archiver::compressGzip(archive_str, path);
+ }
+}*/
+
+bool
+recursive_mkdir(const std::string& path, mode_t mode)
+{
+#ifndef _WIN32
+ if (mkdir(path.data(), mode) != 0) {
+#else
+ if (_wmkdir(jami::to_wstring(path.data()).c_str()) != 0) {
+#endif
+ if (errno == ENOENT) {
+ recursive_mkdir(path.substr(0, path.find_last_of(DIR_SEPARATOR_CH)), mode);
+#ifndef _WIN32
+ if (mkdir(path.data(), mode) != 0) {
+#else
+ if (_wmkdir(jami::to_wstring(path.data()).c_str()) != 0) {
+#endif
+ //JAMI_ERR("Could not create directory.");
+ return false;
+ }
+ }
+ } // namespace jami
+ return true;
+}
+
+#ifdef _WIN32
+bool
+eraseFile_win32(const std::string& path, bool dosync)
+{
+ HANDLE h
+ = CreateFileA(path.c_str(), GENERIC_WRITE, 0, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0);
+ if (h == INVALID_HANDLE_VALUE) {
+ JAMI_WARN("Can not open file %s for erasing.", path.c_str());
+ return false;
+ }
+
+ LARGE_INTEGER size;
+ if (!GetFileSizeEx(h, &size)) {
+ JAMI_WARN("Can not erase file %s: GetFileSizeEx() failed.", path.c_str());
+ CloseHandle(h);
+ return false;
+ }
+ if (size.QuadPart == 0) {
+ CloseHandle(h);
+ return false;
+ }
+
+ uint64_t size_blocks = size.QuadPart / ERASE_BLOCK;
+ if (size.QuadPart % ERASE_BLOCK)
+ size_blocks++;
+
+ char* buffer;
+ try {
+ buffer = new char[ERASE_BLOCK];
+ } catch (std::bad_alloc& ba) {
+ JAMI_WARN("Can not allocate buffer for erasing %s.", path.c_str());
+ CloseHandle(h);
+ return false;
+ }
+ memset(buffer, 0x00, ERASE_BLOCK);
+
+ OVERLAPPED ovlp;
+ if (size.QuadPart < (1024 - 42)) { // a small file can be stored in the MFT record
+ ovlp.Offset = 0;
+ ovlp.OffsetHigh = 0;
+ WriteFile(h, buffer, (DWORD) size.QuadPart, 0, &ovlp);
+ FlushFileBuffers(h);
+ }
+ for (uint64_t i = 0; i < size_blocks; i++) {
+ uint64_t offset = i * ERASE_BLOCK;
+ ovlp.Offset = offset & 0x00000000FFFFFFFF;
+ ovlp.OffsetHigh = offset >> 32;
+ WriteFile(h, buffer, ERASE_BLOCK, 0, &ovlp);
+ }
+
+ delete[] buffer;
+
+ if (dosync)
+ FlushFileBuffers(h);
+
+ CloseHandle(h);
+ return true;
+}
+
+#else
+
+bool
+eraseFile_posix(const std::string& path, bool dosync)
+{
+ struct stat st;
+ if (stat(path.c_str(), &st) == -1) {
+ //JAMI_WARN("Can not erase file %s: fstat() failed.", path.c_str());
+ return false;
+ }
+ // Remove read-only flag if possible
+ chmod(path.c_str(), st.st_mode | (S_IWGRP+S_IWUSR) );
+
+ int fd = open(path.c_str(), O_WRONLY);
+ if (fd == -1) {
+ //JAMI_WARN("Can not open file %s for erasing.", path.c_str());
+ return false;
+ }
+
+ if (st.st_size == 0) {
+ close(fd);
+ return false;
+ }
+
+ lseek(fd, 0, SEEK_SET);
+
+ std::array<char, ERASE_BLOCK> buffer;
+ buffer.fill(0);
+ decltype(st.st_size) written(0);
+ while (written < st.st_size) {
+ auto ret = write(fd, buffer.data(), buffer.size());
+ if (ret < 0) {
+ //JAMI_WARNING("Error while overriding file with zeros.");
+ break;
+ } else
+ written += ret;
+ }
+
+ if (dosync)
+ fsync(fd);
+
+ close(fd);
+ return written >= st.st_size;
+}
+#endif
+
+bool
+eraseFile(const std::string& path, bool dosync)
+{
+#ifdef _WIN32
+ return eraseFile_win32(path, dosync);
+#else
+ return eraseFile_posix(path, dosync);
+#endif
+}
+
+int
+remove(const std::string& path, bool erase)
+{
+ if (erase and isFile(path, false) and !hasHardLink(path))
+ eraseFile(path, true);
+
+#ifdef _WIN32
+ // use Win32 api since std::remove will not unlink directory in use
+ if (isDirectory(path))
+ return !RemoveDirectory(jami::to_wstring(path).c_str());
+#endif
+
+ return std::remove(path.c_str());
+}
+
+int
+removeAll(const std::string& path, bool erase)
+{
+ if (path.empty())
+ return -1;
+ if (isDirectory(path) and !isSymLink(path)) {
+ auto dir = path;
+ if (dir.back() != DIR_SEPARATOR_CH)
+ dir += DIR_SEPARATOR_CH;
+ for (auto& entry : fileutils::readDirectory(dir))
+ removeAll(dir + entry, erase);
+ }
+ return remove(path, erase);
+}
+
+void
+openStream(std::ifstream& file, const std::string& path, std::ios_base::openmode mode)
+{
+#ifdef _WIN32
+ file.open(jami::to_wstring(path), mode);
+#else
+ file.open(path, mode);
+#endif
+}
+
+void
+openStream(std::ofstream& file, const std::string& path, std::ios_base::openmode mode)
+{
+#ifdef _WIN32
+ file.open(jami::to_wstring(path), mode);
+#else
+ file.open(path, mode);
+#endif
+}
+
+std::ifstream
+ifstream(const std::string& path, std::ios_base::openmode mode)
+{
+#ifdef _WIN32
+ return std::ifstream(jami::to_wstring(path), mode);
+#else
+ return std::ifstream(path, mode);
+#endif
+}
+
+std::ofstream
+ofstream(const std::string& path, std::ios_base::openmode mode)
+{
+#ifdef _WIN32
+ return std::ofstream(jami::to_wstring(path), mode);
+#else
+ return std::ofstream(path, mode);
+#endif
+}
+
+int64_t
+size(const std::string& path)
+{
+ int64_t size = 0;
+ try {
+ std::ifstream file;
+ openStream(file, path, std::ios::binary | std::ios::in);
+ file.seekg(0, std::ios_base::end);
+ size = file.tellg();
+ file.close();
+ } catch (...) {
+ }
+ return size;
+}
+
+std::string
+sha3File(const std::string& path)
+{
+ sha3_512_ctx ctx;
+ sha3_512_init(&ctx);
+
+ std::ifstream file;
+ try {
+ if (!fileutils::isFile(path))
+ return {};
+ openStream(file, path, std::ios::binary | std::ios::in);
+ if (!file)
+ return {};
+ std::vector<char> buffer(8192, 0);
+ while (!file.eof()) {
+ file.read(buffer.data(), buffer.size());
+ std::streamsize readSize = file.gcount();
+ sha3_512_update(&ctx, readSize, (const uint8_t*) buffer.data());
+ }
+ file.close();
+ } catch (...) {
+ return {};
+ }
+
+ unsigned char digest[SHA3_512_DIGEST_SIZE];
+ sha3_512_digest(&ctx, SHA3_512_DIGEST_SIZE, digest);
+
+ char hash[SHA3_512_DIGEST_SIZE * 2];
+
+ for (int i = 0; i < SHA3_512_DIGEST_SIZE; ++i)
+ pj_val_to_hex_digit(digest[i], &hash[2 * i]);
+
+ return {hash, SHA3_512_DIGEST_SIZE * 2};
+}
+
+std::string
+sha3sum(const std::vector<uint8_t>& buffer)
+{
+ sha3_512_ctx ctx;
+ sha3_512_init(&ctx);
+ sha3_512_update(&ctx, buffer.size(), (const uint8_t*) buffer.data());
+
+ unsigned char digest[SHA3_512_DIGEST_SIZE];
+ sha3_512_digest(&ctx, SHA3_512_DIGEST_SIZE, digest);
+
+ char hash[SHA3_512_DIGEST_SIZE * 2];
+
+ for (int i = 0; i < SHA3_512_DIGEST_SIZE; ++i)
+ pj_val_to_hex_digit(digest[i], &hash[2 * i]);
+
+ return {hash, SHA3_512_DIGEST_SIZE * 2};
+}
+
+int
+accessFile(const std::string& file, int mode)
+{
+#ifdef _WIN32
+ return _waccess(jami::to_wstring(file).c_str(), mode);
+#else
+ return access(file.c_str(), mode);
+#endif
+}
+
+uint64_t
+lastWriteTime(const std::string& p)
+{
+#if USE_STD_FILESYSTEM
+ return std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::filesystem::last_write_time(std::filesystem::path(p)).time_since_epoch())
+ .count();
+#else
+ struct stat result;
+ if (stat(p.c_str(), &result) == 0)
+ return result.st_mtime;
+ return 0;
+#endif
+}
+
+} // namespace fileutils
+} // namespace jami
diff --git a/src/ice_socket.h b/src/ice_socket.h
new file mode 100644
index 0000000..795185d
--- /dev/null
+++ b/src/ice_socket.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include "generic_io.h"
+
+#include <memory>
+#include <functional>
+
+#if defined(_MSC_VER)
+#include <BaseTsd.h>
+using ssize_t = SSIZE_T;
+#endif
+
+namespace jami {
+
+class IceTransport;
+using IceRecvCb = std::function<ssize_t(unsigned char* buf, size_t len)>;
+
+class IceSocket
+{
+private:
+ std::shared_ptr<IceTransport> ice_transport_ {};
+ int compId_ = -1;
+
+public:
+ IceSocket(std::shared_ptr<IceTransport> iceTransport, int compId)
+ : ice_transport_(std::move(iceTransport))
+ , compId_(compId)
+ {}
+
+ void close();
+ ssize_t send(const unsigned char* buf, size_t len);
+ ssize_t waitForData(std::chrono::milliseconds timeout);
+ void setOnRecv(IceRecvCb cb);
+ uint16_t getTransportOverhead();
+ void setDefaultRemoteAddress(const IpAddr& addr);
+ int getCompId() const { return compId_; };
+};
+
+}; // namespace jami
diff --git a/src/ice_transport.cpp b/src/ice_transport.cpp
new file mode 100644
index 0000000..12c0122
--- /dev/null
+++ b/src/ice_transport.cpp
@@ -0,0 +1,1902 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "ice_transport.h"
+#include "ice_socket.h"
+#include "sip_utils.h"
+#include "string_utils.h"
+#include "upnp/upnp_control.h"
+#include "transport/peer_channel.h"
+#include "tracepoint/tracepoint.h"
+
+#include <opendht/logger.h>
+#include <opendht/utils.h>
+
+#include <pjlib.h>
+
+#include <map>
+#include <atomic>
+#include <queue>
+#include <mutex>
+#include <condition_variable>
+#include <thread>
+#include <utility>
+#include <tuple>
+#include <algorithm>
+#include <sstream>
+#include <chrono>
+#include <thread>
+#include <cerrno>
+
+#include "pj/limits.h"
+
+#define TRY(ret) \
+ do { \
+ if ((ret) != PJ_SUCCESS) \
+ throw std::runtime_error(#ret " failed"); \
+ } while (0)
+
+// Validate that the component ID is within the expected range
+#define ASSERT_COMP_ID(compId, compCount) \
+ do { \
+ if ((compId) == 0 or (compId) > (compCount)) \
+ throw std::runtime_error("Invalid component ID " + (std::to_string(compId))); \
+ } while (0)
+
+namespace jami {
+
+static constexpr unsigned STUN_MAX_PACKET_SIZE {8192};
+static constexpr uint16_t IPV6_HEADER_SIZE = 40; ///< Size in bytes of IPV6 packet header
+static constexpr uint16_t IPV4_HEADER_SIZE = 20; ///< Size in bytes of IPV4 packet header
+static constexpr int MAX_CANDIDATES {32};
+static constexpr int MAX_DESTRUCTION_TIMEOUT {3000};
+static constexpr int HANDLE_EVENT_DURATION {500};
+
+//==============================================================================
+
+using MutexGuard = std::lock_guard<std::mutex>;
+using MutexLock = std::unique_lock<std::mutex>;
+using namespace upnp;
+
+//==============================================================================
+
+class IceLock
+{
+ pj_grp_lock_t* lk_;
+
+public:
+ IceLock(pj_ice_strans* strans)
+ : lk_(pj_ice_strans_get_grp_lock(strans))
+ {
+ lock();
+ }
+
+ ~IceLock() { unlock(); }
+
+ void lock() { if (lk_) pj_grp_lock_acquire(lk_); }
+
+ void unlock() { if (lk_) pj_grp_lock_release(lk_); }
+};
+
+class IceTransport::Impl
+{
+public:
+ Impl(std::string_view name);
+ ~Impl();
+
+ void initIceInstance(const IceTransportOptions& options);
+
+ void onComplete(pj_ice_strans* ice_st, pj_ice_strans_op op, pj_status_t status);
+
+ void onReceiveData(unsigned comp_id, void* pkt, pj_size_t size);
+
+ /**
+ * Set/change transport role as initiator.
+ * Should be called before start method.
+ */
+ bool setInitiatorSession();
+
+ /**
+ * Set/change transport role as slave.
+ * Should be called before start method.
+ */
+ bool setSlaveSession();
+ bool createIceSession(pj_ice_sess_role role);
+
+ void getUFragPwd();
+
+ std::string link() const;
+
+ bool _isInitialized() const;
+ bool _isStarted() const;
+ bool _isRunning() const;
+ bool _isFailed() const;
+ bool _waitForInitialization(std::chrono::milliseconds timeout);
+
+ const pj_ice_sess_cand* getSelectedCandidate(unsigned comp_id, bool remote) const;
+ IpAddr getLocalAddress(unsigned comp_id) const;
+ IpAddr getRemoteAddress(unsigned comp_id) const;
+ static const char* getCandidateType(const pj_ice_sess_cand* cand);
+ bool isTcpEnabled() const { return config_.protocol == PJ_ICE_TP_TCP; }
+ bool addStunConfig(int af);
+ void requestUpnpMappings();
+ bool hasUpnp() const;
+ // Take a list of address pairs (local/public) and add them as
+ // reflexive candidates using STUN config.
+ void addServerReflexiveCandidates(const std::vector<std::pair<IpAddr, IpAddr>>& addrList);
+ // Generate server reflexive candidates using the published (DHT/Account) address
+ std::vector<std::pair<IpAddr, IpAddr>> setupGenericReflexiveCandidates();
+ // Generate server reflexive candidates using UPNP mappings.
+ std::vector<std::pair<IpAddr, IpAddr>> setupUpnpReflexiveCandidates();
+ void setDefaultRemoteAddress(unsigned comp_id, const IpAddr& addr);
+ IpAddr getDefaultRemoteAddress(unsigned comp_id) const;
+ bool handleEvents(unsigned max_msec);
+ int flushTimerHeapAndIoQueue();
+ int checkEventQueue(int maxEventToPoll);
+
+ std::shared_ptr<dht::log::Logger> logger_ {};
+
+ std::condition_variable_any iceCV_ {};
+
+ std::string sessionName_ {};
+ std::unique_ptr<pj_pool_t, decltype(&pj_pool_release)> pool_ {nullptr, pj_pool_release};
+ bool isTcp_ {false};
+ bool upnpEnabled_ {false};
+ IceTransportCompleteCb on_initdone_cb_ {};
+ IceTransportCompleteCb on_negodone_cb_ {};
+ pj_ice_strans* icest_ {nullptr};
+ unsigned streamsCount_ {0};
+ unsigned compCountPerStream_ {0};
+ unsigned compCount_ {0};
+ std::string local_ufrag_ {};
+ std::string local_pwd_ {};
+ pj_sockaddr remoteAddr_ {};
+ pj_ice_strans_cfg config_ {};
+ //std::string last_errmsg_ {};
+
+ std::atomic_bool is_stopped_ {false};
+
+ struct Packet
+ {
+ Packet(void* pkt, pj_size_t size)
+ : data {reinterpret_cast<char*>(pkt), reinterpret_cast<char*>(pkt) + size}
+ {}
+ std::vector<char> data {};
+ };
+
+ struct ComponentIO
+ {
+ std::mutex mutex;
+ std::condition_variable cv;
+ std::deque<Packet> queue;
+ IceRecvCb recvCb;
+ };
+
+ // NOTE: Component IDs start from 1, while these three vectors
+ // are indexed from 0. Conversion from ID to vector index must
+ // be done properly.
+ std::vector<ComponentIO> compIO_ {};
+ std::vector<PeerChannel> peerChannels_ {};
+ std::vector<IpAddr> iceDefaultRemoteAddr_;
+
+ // ICE controlling role. True for controller agents and false for
+ // controlled agents
+ std::atomic_bool initiatorSession_ {true};
+
+ // Local/Public addresses used by the account owning the ICE instance.
+ IpAddr accountLocalAddr_ {};
+ IpAddr accountPublicAddr_ {};
+
+ // STUN and TURN servers
+ std::vector<StunServerInfo> stunServers_;
+ std::vector<TurnServerInfo> turnServers_;
+
+ /**
+ * Returns the IP of each candidate for a given component in the ICE session
+ */
+ struct LocalCandidate
+ {
+ IpAddr addr;
+ pj_ice_cand_transport transport;
+ };
+
+ std::shared_ptr<upnp::Controller> upnp_ {};
+ std::mutex upnpMutex_ {};
+ std::map<Mapping::key_t, Mapping> upnpMappings_;
+ std::mutex upnpMappingsMutex_ {};
+
+ bool onlyIPv4Private_ {true};
+
+ // IO/Timer events are handled by following thread
+ std::thread thread_ {};
+ std::atomic_bool threadTerminateFlags_ {false};
+
+ // Wait data on components
+ mutable std::mutex sendDataMutex_ {};
+ std::condition_variable waitDataCv_ = {};
+ pj_size_t lastSentLen_ {0};
+ bool destroying_ {false};
+ onShutdownCb scb {};
+
+ void cancelOperations()
+ {
+ for (auto& c : peerChannels_)
+ c.stop();
+ std::lock_guard<std::mutex> lk(sendDataMutex_);
+ destroying_ = true;
+ waitDataCv_.notify_all();
+ }
+};
+
+//==============================================================================
+
+/**
+ * Add stun/turn configuration or default host as candidates
+ */
+
+static void
+add_stun_server(pj_pool_t& pool, pj_ice_strans_cfg& cfg, const StunServerInfo& info)
+{
+ if (cfg.stun_tp_cnt >= PJ_ICE_MAX_STUN)
+ throw std::runtime_error("Too many STUN configurations");
+
+ IpAddr ip {info.uri};
+
+ // Given URI cannot be DNS resolved or not IPv4 or IPv6?
+ // This prevents a crash into PJSIP when ip.toString() is called.
+ if (ip.getFamily() == AF_UNSPEC) {
+ /*JAMI_DBG("[ice (%s)] STUN server '%s' not used, unresolvable address",
+ (cfg.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"),
+ info.uri.c_str());*/
+ return;
+ }
+
+ auto& stun = cfg.stun_tp[cfg.stun_tp_cnt++];
+ pj_ice_strans_stun_cfg_default(&stun);
+ pj_strdup2_with_null(&pool, &stun.server, ip.toString().c_str());
+ stun.af = ip.getFamily();
+ if (!(stun.port = ip.getPort()))
+ stun.port = PJ_STUN_PORT;
+ stun.cfg.max_pkt_size = STUN_MAX_PACKET_SIZE;
+ stun.conn_type = cfg.stun.conn_type;
+ /*JAMI_DBG("[ice (%s)] added stun server '%s', port %u",
+ (cfg.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"),
+ pj_strbuf(&stun.server),
+ stun.port);*/
+}
+
+static void
+add_turn_server(pj_pool_t& pool, pj_ice_strans_cfg& cfg, const TurnServerInfo& info)
+{
+ if (cfg.turn_tp_cnt >= PJ_ICE_MAX_TURN)
+ throw std::runtime_error("Too many TURN servers");
+
+ IpAddr ip {info.uri};
+
+ // Same comment as add_stun_server()
+ if (ip.getFamily() == AF_UNSPEC) {
+ /*JAMI_DBG("[ice (%s)] TURN server '%s' not used, unresolvable address",
+ (cfg.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"),
+ info.uri.c_str());*/
+ return;
+ }
+
+ auto& turn = cfg.turn_tp[cfg.turn_tp_cnt++];
+ pj_ice_strans_turn_cfg_default(&turn);
+ pj_strdup2_with_null(&pool, &turn.server, ip.toString().c_str());
+ turn.af = ip.getFamily();
+ if (!(turn.port = ip.getPort()))
+ turn.port = PJ_STUN_PORT;
+ turn.cfg.max_pkt_size = STUN_MAX_PACKET_SIZE;
+ turn.conn_type = cfg.turn.conn_type;
+
+ // Authorization (only static plain password supported yet)
+ if (not info.password.empty()) {
+ turn.auth_cred.type = PJ_STUN_AUTH_CRED_STATIC;
+ turn.auth_cred.data.static_cred.data_type = PJ_STUN_PASSWD_PLAIN;
+ pj_strset(&turn.auth_cred.data.static_cred.realm,
+ (char*) info.realm.c_str(),
+ info.realm.size());
+ pj_strset(&turn.auth_cred.data.static_cred.username,
+ (char*) info.username.c_str(),
+ info.username.size());
+ pj_strset(&turn.auth_cred.data.static_cred.data,
+ (char*) info.password.c_str(),
+ info.password.size());
+ }
+
+ /*JAMI_DBG("[ice (%s)] added turn server '%s', port %u",
+ (cfg.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"),
+ pj_strbuf(&turn.server),
+ turn.port);*/
+}
+
+//==============================================================================
+
+IceTransport::Impl::Impl(std::string_view name)
+ : sessionName_(name)
+{
+ if (logger_)
+ logger_->debug("[ice:{}] Creating IceTransport session for \"{:s}\"", fmt::ptr(this), name);
+}
+
+IceTransport::Impl::~Impl()
+{
+ if (logger_)
+ logger_->debug("[ice:{}] destroying {}", fmt::ptr(this), fmt::ptr(icest_));
+
+ threadTerminateFlags_ = true;
+
+ if (thread_.joinable()) {
+ thread_.join();
+ }
+
+ if (icest_) {
+ pj_ice_strans* strans = nullptr;
+
+ std::swap(strans, icest_);
+
+ // must be done before ioqueue/timer destruction
+ if (logger_)
+ logger_->debug("[ice:{}] Destroying ice_strans {}", pj_ice_strans_get_user_data(strans), fmt::ptr(strans));
+
+ pj_ice_strans_stop_ice(strans);
+ pj_ice_strans_destroy(strans);
+
+ // NOTE: This last timer heap and IO queue polling is necessary to close
+ // TURN socket.
+ // Because when destroying the TURN session pjproject creates a pj_timer
+ // to postpone the TURN destruction. This timer is only called if we poll
+ // the event queue.
+
+ int ret = flushTimerHeapAndIoQueue();
+
+ if (ret < 0) {
+ if (logger_)
+ logger_->error("[ice:{}] IO queue polling failed", fmt::ptr(this));
+ } else if (ret > 0) {
+ if (logger_)
+ logger_->error("[ice:{}] Unexpected left timer in timer heap. "
+ "Please report the bug",
+ fmt::ptr(this));
+ }
+
+ if (checkEventQueue(1) > 0) {
+ if (logger_)
+ logger_->warn("[ice:{}] Unexpected left events in IO queue", fmt::ptr(this));
+ }
+
+ if (config_.stun_cfg.ioqueue)
+ pj_ioqueue_destroy(config_.stun_cfg.ioqueue);
+
+ if (config_.stun_cfg.timer_heap)
+ pj_timer_heap_destroy(config_.stun_cfg.timer_heap);
+ }
+
+ if (logger_)
+ logger_->debug("[ice:%p] done destroying", fmt::ptr(this));
+ if (scb)
+ scb();
+}
+
+void
+IceTransport::Impl::initIceInstance(const IceTransportOptions& options)
+{
+ isTcp_ = options.tcpEnable;
+ upnpEnabled_ = options.upnpEnable;
+ on_initdone_cb_ = options.onInitDone;
+ on_negodone_cb_ = options.onNegoDone;
+ streamsCount_ = options.streamsCount;
+ compCountPerStream_ = options.compCountPerStream;
+ compCount_ = streamsCount_ * compCountPerStream_;
+ compIO_ = std::vector<ComponentIO>(compCount_);
+ peerChannels_ = std::vector<PeerChannel>(compCount_);
+ iceDefaultRemoteAddr_.resize(compCount_);
+ initiatorSession_ = options.master;
+ accountLocalAddr_ = std::move(options.accountLocalAddr);
+ accountPublicAddr_ = std::move(options.accountPublicAddr);
+ stunServers_ = std::move(options.stunServers);
+ turnServers_ = std::move(options.turnServers);
+
+ if (logger_)
+ logger_->debug("[ice:{}] Initializing the session - comp count {} - as a {}",
+ fmt::ptr(this),
+ compCount_,
+ initiatorSession_ ? "master" : "slave");
+
+ if (upnpEnabled_)
+ upnp_.reset(new upnp::Controller());
+
+ config_ = options.factory->getIceCfg(); // config copy
+ if (isTcp_) {
+ config_.protocol = PJ_ICE_TP_TCP;
+ config_.stun.conn_type = PJ_STUN_TP_TCP;
+ config_.turn.conn_type = PJ_TURN_TP_TCP;
+ } else {
+ config_.protocol = PJ_ICE_TP_UDP;
+ config_.stun.conn_type = PJ_STUN_TP_UDP;
+ config_.turn.conn_type = PJ_TURN_TP_UDP;
+ }
+
+ pool_.reset(
+ pj_pool_create(options.factory->getPoolFactory(), "IceTransport.pool", 512, 512, NULL));
+ if (not pool_)
+ throw std::runtime_error("pj_pool_create() failed");
+
+ // Note: For server reflexive candidates, UPNP mappings will
+ // be used if available. Then, the public address learnt during
+ // the account registration process will be added only if it
+ // differs from the UPNP public address.
+ // Also note that UPNP candidates should be added first in order
+ // to have a higher priority when performing the connectivity
+ // checks.
+ // STUN configs layout:
+ // - index 0 : host IPv4
+ // - index 1 : host IPv6
+ // - index 2 : upnp/generic srflx IPv4.
+ // - index 3 : generic srflx (if upnp exists and different)
+
+ config_.stun_tp_cnt = 0;
+
+ if (logger_)
+ logger_->debug("[ice:{}] Add host candidates", fmt::ptr(this));
+ addStunConfig(pj_AF_INET());
+ addStunConfig(pj_AF_INET6());
+
+ std::vector<std::pair<IpAddr, IpAddr>> upnpSrflxCand;
+ if (upnp_) {
+ requestUpnpMappings();
+ upnpSrflxCand = setupUpnpReflexiveCandidates();
+ if (not upnpSrflxCand.empty()) {
+ addServerReflexiveCandidates(upnpSrflxCand);
+ if (logger_)
+ logger_->debug("[ice:{}] Added UPNP srflx candidates:", fmt::ptr(this));
+ }
+ }
+
+ auto genericSrflxCand = setupGenericReflexiveCandidates();
+
+ if (not genericSrflxCand.empty()) {
+ // Generic srflx candidates will be added only if different
+ // from upnp candidates.
+ if (upnpSrflxCand.empty()
+ or (upnpSrflxCand[0].second.toString() != genericSrflxCand[0].second.toString())) {
+ addServerReflexiveCandidates(genericSrflxCand);
+ if (logger_)
+ logger_->debug("[ice:{}] Added generic srflx candidates:", fmt::ptr(this));
+ }
+ }
+
+ if (upnpSrflxCand.empty() and genericSrflxCand.empty()) {
+ if (logger_)
+ logger_->warn("[ice:{}] No server reflexive candidates added", fmt::ptr(this));
+ }
+
+ pj_ice_strans_cb icecb;
+ pj_bzero(&icecb, sizeof(icecb));
+
+ icecb.on_rx_data = [](pj_ice_strans* ice_st,
+ unsigned comp_id,
+ void* pkt,
+ pj_size_t size,
+ const pj_sockaddr_t* /*src_addr*/,
+ unsigned /*src_addr_len*/) {
+ if (auto* tr = static_cast<Impl*>(pj_ice_strans_get_user_data(ice_st)))
+ tr->onReceiveData(comp_id, pkt, size);
+ };
+
+ icecb.on_ice_complete = [](pj_ice_strans* ice_st, pj_ice_strans_op op, pj_status_t status) {
+ if (auto* tr = static_cast<Impl*>(pj_ice_strans_get_user_data(ice_st)))
+ tr->onComplete(ice_st, op, status);
+ };
+
+ if (isTcp_) {
+ icecb.on_data_sent = [](pj_ice_strans* ice_st, pj_ssize_t size) {
+ if (auto* tr = static_cast<Impl*>(pj_ice_strans_get_user_data(ice_st))) {
+ std::lock_guard lk(tr->sendDataMutex_);
+ tr->lastSentLen_ += size;
+ tr->waitDataCv_.notify_all();
+ }
+ };
+ }
+
+ icecb.on_destroy = [](pj_ice_strans* ice_st) {
+ if (auto* tr = static_cast<Impl*>(pj_ice_strans_get_user_data(ice_st)))
+ tr->cancelOperations(); // Avoid upper layer to manage this ; Stop read operations
+ };
+
+ // Add STUN servers
+ for (auto& server : stunServers_)
+ add_stun_server(*pool_, config_, server);
+
+ // Add TURN servers
+ for (auto& server : turnServers_)
+ add_turn_server(*pool_, config_, server);
+
+ static constexpr auto IOQUEUE_MAX_HANDLES = std::min(PJ_IOQUEUE_MAX_HANDLES, 64);
+ TRY(pj_timer_heap_create(pool_.get(), 100, &config_.stun_cfg.timer_heap));
+ TRY(pj_ioqueue_create(pool_.get(), IOQUEUE_MAX_HANDLES, &config_.stun_cfg.ioqueue));
+ std::ostringstream sessionName {};
+ // We use the instance pointer as the PJNATH session name in order
+ // to easily identify the logs reported by PJNATH.
+ sessionName << this;
+ pj_status_t status = pj_ice_strans_create(sessionName.str().c_str(),
+ &config_,
+ compCount_,
+ this,
+ &icecb,
+ &icest_);
+
+ if (status != PJ_SUCCESS || icest_ == nullptr) {
+ throw std::runtime_error("pj_ice_strans_create() failed");
+ }
+
+ // Must be created after any potential failure
+ thread_ = std::thread([this] {
+ while (not threadTerminateFlags_) {
+ // NOTE: handleEvents can return false in this case
+ // but here we don't care if there is event or not.
+ handleEvents(HANDLE_EVENT_DURATION);
+ }
+ });
+}
+
+bool
+IceTransport::Impl::_isInitialized() const
+{
+ if (auto *icest = icest_) {
+ auto state = pj_ice_strans_get_state(icest);
+ return state >= PJ_ICE_STRANS_STATE_SESS_READY and state != PJ_ICE_STRANS_STATE_FAILED;
+ }
+ return false;
+}
+
+bool
+IceTransport::Impl::_isStarted() const
+{
+ if (auto *icest = icest_) {
+ auto state = pj_ice_strans_get_state(icest);
+ return state >= PJ_ICE_STRANS_STATE_NEGO and state != PJ_ICE_STRANS_STATE_FAILED;
+ }
+ return false;
+}
+
+bool
+IceTransport::Impl::_isRunning() const
+{
+ if (auto *icest = icest_) {
+ auto state = pj_ice_strans_get_state(icest);
+ return state >= PJ_ICE_STRANS_STATE_RUNNING and state != PJ_ICE_STRANS_STATE_FAILED;
+ }
+ return false;
+}
+
+bool
+IceTransport::Impl::_isFailed() const
+{
+ if (auto *icest = icest_)
+ return pj_ice_strans_get_state(icest) == PJ_ICE_STRANS_STATE_FAILED;
+ return false;
+}
+
+bool
+IceTransport::Impl::handleEvents(unsigned max_msec)
+{
+ // By tests, never seen more than two events per 500ms
+ static constexpr auto MAX_NET_EVENTS = 2;
+
+ pj_time_val max_timeout = {0, static_cast<long>(max_msec)};
+ pj_time_val timeout = {0, 0};
+ unsigned net_event_count = 0;
+
+ pj_timer_heap_poll(config_.stun_cfg.timer_heap, &timeout);
+ auto hasActiveTimer = timeout.sec != PJ_MAXINT32 || timeout.msec != PJ_MAXINT32;
+
+ // timeout limitation
+ if (hasActiveTimer)
+ pj_time_val_normalize(&timeout);
+
+ if (PJ_TIME_VAL_GT(timeout, max_timeout)) {
+ timeout = max_timeout;
+ }
+
+ do {
+ auto n_events = pj_ioqueue_poll(config_.stun_cfg.ioqueue, &timeout);
+
+ // timeout
+ if (not n_events)
+ return hasActiveTimer;
+
+ // error
+ if (n_events < 0) {
+ const auto err = pj_get_os_error();
+ // Kept as debug as some errors are "normal" in regular context
+ if (logger_)
+ logger_->debug("[ice:{}] ioqueue error {:d}: {:s}", fmt::ptr(this), err, sip_utils::sip_strerror(err));
+ std::this_thread::sleep_for(std::chrono::milliseconds(PJ_TIME_VAL_MSEC(timeout)));
+ return hasActiveTimer;
+ }
+
+ net_event_count += n_events;
+ timeout.sec = timeout.msec = 0;
+ } while (net_event_count < MAX_NET_EVENTS);
+ return hasActiveTimer;
+}
+
+int
+IceTransport::Impl::flushTimerHeapAndIoQueue()
+{
+ pj_time_val timerTimeout = {0, 0};
+ pj_time_val defaultWaitTime = {0, HANDLE_EVENT_DURATION};
+ bool hasActiveTimer = false;
+ std::chrono::milliseconds totalWaitTime {0};
+ auto const start = std::chrono::steady_clock::now();
+ // We try to process pending events as fast as possible to
+ // speed-up the release.
+ int maxEventToProcess = 10;
+
+ do {
+ if (checkEventQueue(maxEventToProcess) < 0)
+ return -1;
+
+ pj_timer_heap_poll(config_.stun_cfg.timer_heap, &timerTimeout);
+ hasActiveTimer = !(timerTimeout.sec == PJ_MAXINT32 && timerTimeout.msec == PJ_MAXINT32);
+
+ if (hasActiveTimer) {
+ pj_time_val_normalize(&timerTimeout);
+ auto waitTime = std::chrono::milliseconds(
+ std::min(PJ_TIME_VAL_MSEC(timerTimeout), PJ_TIME_VAL_MSEC(defaultWaitTime)));
+ std::this_thread::sleep_for(waitTime);
+ totalWaitTime += waitTime;
+ }
+ } while (hasActiveTimer && totalWaitTime < std::chrono::milliseconds(MAX_DESTRUCTION_TIMEOUT));
+
+ auto duration = std::chrono::steady_clock::now() - start;
+ if (logger_)
+ logger_->debug("[ice:{}] Timer heap flushed after {}", fmt::ptr(this), dht::print_duration(duration));
+
+ return static_cast<int>(pj_timer_heap_count(config_.stun_cfg.timer_heap));
+}
+
+int
+IceTransport::Impl::checkEventQueue(int maxEventToPoll)
+{
+ pj_time_val timeout = {0, 0};
+ int eventCount = 0;
+ int events = 0;
+
+ do {
+ events = pj_ioqueue_poll(config_.stun_cfg.ioqueue, &timeout);
+ if (events < 0) {
+ const auto err = pj_get_os_error();
+ if (logger_)
+ logger_->error("[ice:{}] ioqueue error {:d}: {:s}", fmt::ptr(this), err, sip_utils::sip_strerror(err));
+ return events;
+ }
+
+ eventCount += events;
+
+ } while (events > 0 && eventCount < maxEventToPoll);
+
+ return eventCount;
+}
+
+void
+IceTransport::Impl::onComplete(pj_ice_strans*, pj_ice_strans_op op, pj_status_t status)
+{
+ const char* opname = op == PJ_ICE_STRANS_OP_INIT ? "initialization"
+ : op == PJ_ICE_STRANS_OP_NEGOTIATION ? "negotiation"
+ : "unknown_op";
+
+ const bool done = status == PJ_SUCCESS;
+ if (done) {
+ if (logger_)
+ logger_->debug("[ice:{}] {:s} {:s} success",
+ fmt::ptr(this),
+ (config_.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"),
+ opname);
+ } else {
+ if (logger_)
+ logger_->error("[ice:{}] {:s} {:s} failed: {:s}",
+ fmt::ptr(this),
+ (config_.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"),
+ opname,
+ sip_utils::sip_strerror(status));
+ }
+
+ if (done and op == PJ_ICE_STRANS_OP_INIT) {
+ if (initiatorSession_)
+ setInitiatorSession();
+ else
+ setSlaveSession();
+ }
+
+ if (op == PJ_ICE_STRANS_OP_INIT and on_initdone_cb_)
+ on_initdone_cb_(done);
+ else if (op == PJ_ICE_STRANS_OP_NEGOTIATION) {
+ if (done) {
+ // Dump of connection pairs
+ if (logger_)
+ logger_->debug("[ice:{}] {:s} connection pairs ([comp id] local [type] <-> remote [type]):\n{:s}",
+ fmt::ptr(this),
+ (config_.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP"),
+ link());
+ }
+ if (on_negodone_cb_)
+ on_negodone_cb_(done);
+ }
+
+ iceCV_.notify_all();
+}
+
+std::string
+IceTransport::Impl::link() const
+{
+ std::ostringstream out;
+ for (unsigned strm = 1; strm <= streamsCount_ * compCountPerStream_; strm++) {
+ auto absIdx = strm;
+ auto comp = (strm + 1) / compCountPerStream_;
+ auto laddr = getLocalAddress(absIdx);
+ auto raddr = getRemoteAddress(absIdx);
+
+ if (laddr and laddr.getPort() != 0 and raddr and raddr.getPort() != 0) {
+ out << " [" << comp << "] " << laddr.toString(true, true) << " ["
+ << getCandidateType(getSelectedCandidate(absIdx, false)) << "] "
+ << " <-> " << raddr.toString(true, true) << " ["
+ << getCandidateType(getSelectedCandidate(absIdx, true)) << "] " << '\n';
+ } else {
+ out << " [" << comp << "] disabled\n";
+ }
+ }
+ return out.str();
+}
+
+bool
+IceTransport::Impl::setInitiatorSession()
+{
+ if (logger_)
+ logger_->debug("[ice:{}] as master", fmt::ptr(this));
+ initiatorSession_ = true;
+ if (_isInitialized()) {
+ auto status = pj_ice_strans_change_role(icest_, PJ_ICE_SESS_ROLE_CONTROLLING);
+ if (status != PJ_SUCCESS) {
+ if (logger_)
+ logger_->error("[ice:{}] role change failed: {:s}", fmt::ptr(this), sip_utils::sip_strerror(status));
+ return false;
+ }
+ return true;
+ }
+ return createIceSession(PJ_ICE_SESS_ROLE_CONTROLLING);
+}
+
+bool
+IceTransport::Impl::setSlaveSession()
+{
+ if (logger_)
+ logger_->debug("[ice:{}] as slave", fmt::ptr(this));
+ initiatorSession_ = false;
+ if (_isInitialized()) {
+ auto status = pj_ice_strans_change_role(icest_, PJ_ICE_SESS_ROLE_CONTROLLED);
+ if (status != PJ_SUCCESS) {
+ if (logger_)
+ logger_->error("[ice:{}] role change failed: {:s}", fmt::ptr(this), sip_utils::sip_strerror(status));
+ return false;
+ }
+ return true;
+ }
+ return createIceSession(PJ_ICE_SESS_ROLE_CONTROLLED);
+}
+
+const pj_ice_sess_cand*
+IceTransport::Impl::getSelectedCandidate(unsigned comp_id, bool remote) const
+{
+ ASSERT_COMP_ID(comp_id, compCount_);
+
+ // Return the selected candidate pair. Might not be the nominated pair if
+ // ICE has not concluded yet, but should be the nominated pair afterwards.
+ if (not _isRunning()) {
+ if (logger_)
+ logger_->error("[ice:{}] ICE transport is not running", fmt::ptr(this));
+ return nullptr;
+ }
+
+ const auto* sess = pj_ice_strans_get_valid_pair(icest_, comp_id);
+ if (sess == nullptr) {
+ if (logger_)
+ logger_->warn("[ice:{}] Component {} has no valid pair (disabled)", fmt::ptr(this), comp_id);
+ return nullptr;
+ }
+
+ if (remote)
+ return sess->rcand;
+ else
+ return sess->lcand;
+}
+
+IpAddr
+IceTransport::Impl::getLocalAddress(unsigned comp_id) const
+{
+ ASSERT_COMP_ID(comp_id, compCount_);
+
+ if (auto cand = getSelectedCandidate(comp_id, false))
+ return cand->addr;
+
+ return {};
+}
+
+IpAddr
+IceTransport::Impl::getRemoteAddress(unsigned comp_id) const
+{
+ ASSERT_COMP_ID(comp_id, compCount_);
+
+ if (auto cand = getSelectedCandidate(comp_id, true))
+ return cand->addr;
+
+ return {};
+}
+
+const char*
+IceTransport::Impl::getCandidateType(const pj_ice_sess_cand* cand)
+{
+ auto name = cand ? pj_ice_get_cand_type_name(cand->type) : nullptr;
+ return name ? name : "?";
+}
+
+void
+IceTransport::Impl::getUFragPwd()
+{
+ if (icest_) {
+ pj_str_t local_ufrag, local_pwd;
+
+ pj_ice_strans_get_ufrag_pwd(icest_, &local_ufrag, &local_pwd, nullptr, nullptr);
+ local_ufrag_.assign(local_ufrag.ptr, local_ufrag.slen);
+ local_pwd_.assign(local_pwd.ptr, local_pwd.slen);
+ }
+}
+
+bool
+IceTransport::Impl::createIceSession(pj_ice_sess_role role)
+{
+ if (not icest_) {
+ return false;
+ }
+
+ if (pj_ice_strans_init_ice(icest_, role, nullptr, nullptr) != PJ_SUCCESS) {
+ if (logger_)
+ logger_->error("[ice:{}] pj_ice_strans_init_ice() failed", fmt::ptr(this));
+ return false;
+ }
+
+ // Fetch some information on local configuration
+ getUFragPwd();
+
+ if (logger_)
+ logger_->debug("[ice:{}] (local) ufrag=%s, pwd=%s", fmt::ptr(this), local_ufrag_.c_str(), local_pwd_.c_str());
+
+ return true;
+}
+
+bool
+IceTransport::Impl::addStunConfig(int af)
+{
+ if (config_.stun_tp_cnt >= PJ_ICE_MAX_STUN) {
+ if (logger_)
+ logger_->error("Max number of STUN configurations reached (%i)", PJ_ICE_MAX_STUN);
+ return false;
+ }
+
+ if (af != pj_AF_INET() and af != pj_AF_INET6()) {
+ if (logger_)
+ logger_->error("Invalid address familly (%i)", af);
+ return false;
+ }
+
+ auto& stun = config_.stun_tp[config_.stun_tp_cnt++];
+
+ pj_ice_strans_stun_cfg_default(&stun);
+ stun.cfg.max_pkt_size = STUN_MAX_PACKET_SIZE;
+ stun.af = af;
+ stun.conn_type = config_.stun.conn_type;
+
+ if (logger_)
+ logger_->debug("[ice:{}] added host stun config for {:s} transport",
+ fmt::ptr(this),
+ config_.protocol == PJ_ICE_TP_TCP ? "TCP" : "UDP");
+
+ return true;
+}
+
+void
+IceTransport::Impl::requestUpnpMappings()
+{
+ // Must be called once !
+
+ std::lock_guard<std::mutex> lock(upnpMutex_);
+
+ if (not upnp_)
+ return;
+
+ auto transport = isTcpEnabled() ? PJ_CAND_TCP_PASSIVE : PJ_CAND_UDP;
+ auto portType = transport == PJ_CAND_UDP ? PortType::UDP : PortType::TCP;
+
+ // Request upnp mapping for each component.
+ for (unsigned id = 1; id <= compCount_; id++) {
+ // Set port number to 0 to get any available port.
+ Mapping requestedMap(portType);
+
+ // Request the mapping
+ Mapping::sharedPtr_t mapPtr = upnp_->reserveMapping(requestedMap);
+
+ // To use a mapping, it must be valid, open and has valid host address.
+ if (mapPtr and mapPtr->getMapKey() and (mapPtr->getState() == MappingState::OPEN)
+ and mapPtr->hasValidHostAddress()) {
+ std::lock_guard<std::mutex> lock(upnpMappingsMutex_);
+ auto ret = upnpMappings_.emplace(mapPtr->getMapKey(), *mapPtr);
+ if (ret.second) {
+ if (logger_)
+ logger_->debug("[ice:{}] UPNP mapping {:s} successfully allocated",
+ fmt::ptr(this),
+ mapPtr->toString(true));
+ } else {
+ if (logger_)
+ logger_->warn("[ice:{}] UPNP mapping {:s} already in the list!",
+ fmt::ptr(this),
+ mapPtr->toString());
+ }
+ } else {
+ if (logger_)
+ logger_->warn("[ice:{}] UPNP mapping request failed!", fmt::ptr(this));
+ upnp_->releaseMapping(requestedMap);
+ }
+ }
+}
+
+bool
+IceTransport::Impl::hasUpnp() const
+{
+ return upnp_ and upnpMappings_.size() == compCount_;
+}
+
+void
+IceTransport::Impl::addServerReflexiveCandidates(
+ const std::vector<std::pair<IpAddr, IpAddr>>& addrList)
+{
+ if (addrList.size() != compCount_) {
+ if (logger_)
+ logger_->warn("[ice:{}] Provided addr list size {} does not match component count {}",
+ fmt::ptr(this),
+ addrList.size(),
+ compCount_);
+ return;
+ }
+ if (compCount_ > PJ_ICE_MAX_COMP) {
+ if (logger_)
+ logger_->error("[ice:{}] Too many components", fmt::ptr(this));
+ return;
+ }
+
+ // Add config for server reflexive candidates (UPNP or from DHT).
+ if (not addStunConfig(pj_AF_INET()))
+ return;
+
+ assert(config_.stun_tp_cnt > 0 && config_.stun_tp_cnt < PJ_ICE_MAX_STUN);
+ auto& stun = config_.stun_tp[config_.stun_tp_cnt - 1];
+
+ for (unsigned id = 1; id <= compCount_; id++) {
+ auto idx = id - 1;
+ auto& localAddr = addrList[idx].first;
+ auto& publicAddr = addrList[idx].second;
+
+ if (logger_)
+ logger_->debug("[ice:{}] Add srflx reflexive candidates [{:s} : {:s}] for comp {:d}",
+ fmt::ptr(this),
+ localAddr.toString(true),
+ publicAddr.toString(true),
+ id);
+
+ pj_sockaddr_cp(&stun.cfg.user_mapping[idx].local_addr, localAddr.pjPtr());
+ pj_sockaddr_cp(&stun.cfg.user_mapping[idx].mapped_addr, publicAddr.pjPtr());
+
+ if (isTcpEnabled()) {
+ if (publicAddr.getPort() == 9) {
+ stun.cfg.user_mapping[idx].tp_type = PJ_CAND_TCP_ACTIVE;
+ } else {
+ stun.cfg.user_mapping[idx].tp_type = PJ_CAND_TCP_PASSIVE;
+ }
+ } else {
+ stun.cfg.user_mapping[idx].tp_type = PJ_CAND_UDP;
+ }
+ }
+
+ stun.cfg.user_mapping_cnt = compCount_;
+}
+
+std::vector<std::pair<IpAddr, IpAddr>>
+IceTransport::Impl::setupGenericReflexiveCandidates()
+{
+ if (not accountLocalAddr_) {
+ if (logger_)
+ logger_->warn("[ice:{}] Missing local address, generic srflx candidates wont be generated!",
+ fmt::ptr(this));
+ return {};
+ }
+
+ if (not accountPublicAddr_) {
+ if (logger_)
+ logger_->warn("[ice:{}] Missing public address, generic srflx candidates wont be generated!",
+ fmt::ptr(this));
+ return {};
+ }
+
+ std::vector<std::pair<IpAddr, IpAddr>> addrList;
+ auto isTcp = isTcpEnabled();
+
+ addrList.reserve(compCount_);
+ for (unsigned id = 1; id <= compCount_; id++) {
+ // For TCP, the type is set to active, because most likely the incoming
+ // connection will be blocked by the NAT.
+ // For UDP use random port number.
+ uint16_t port = isTcp ? 9
+ : upnp::Controller::generateRandomPort(isTcp ? PortType::TCP
+ : PortType::UDP);
+
+ accountLocalAddr_.setPort(port);
+ accountPublicAddr_.setPort(port);
+ addrList.emplace_back(accountLocalAddr_, accountPublicAddr_);
+ }
+
+ return addrList;
+}
+
+std::vector<std::pair<IpAddr, IpAddr>>
+IceTransport::Impl::setupUpnpReflexiveCandidates()
+{
+ // Add UPNP server reflexive candidates if available.
+ if (not hasUpnp())
+ return {};
+
+ std::lock_guard<std::mutex> lock(upnpMappingsMutex_);
+
+ if (upnpMappings_.size() < (size_t)compCount_) {
+ if (logger_)
+ logger_->warn("[ice:{}] Not enough mappings {:d}. Expected {:d}",
+ fmt::ptr(this),
+ upnpMappings_.size(),
+ compCount_);
+ return {};
+ }
+
+ std::vector<std::pair<IpAddr, IpAddr>> addrList;
+
+ addrList.reserve(upnpMappings_.size());
+ for (auto const& [_, map] : upnpMappings_) {
+ assert(map.getMapKey());
+ IpAddr localAddr {map.getInternalAddress()};
+ localAddr.setPort(map.getInternalPort());
+ IpAddr publicAddr {map.getExternalAddress()};
+ publicAddr.setPort(map.getExternalPort());
+ addrList.emplace_back(localAddr, publicAddr);
+ }
+
+ return addrList;
+}
+
+void
+IceTransport::Impl::setDefaultRemoteAddress(unsigned compId, const IpAddr& addr)
+{
+ ASSERT_COMP_ID(compId, compCount_);
+
+ iceDefaultRemoteAddr_[compId - 1] = addr;
+ // The port does not matter. Set it 0 to avoid confusion.
+ iceDefaultRemoteAddr_[compId - 1].setPort(0);
+}
+
+IpAddr
+IceTransport::Impl::getDefaultRemoteAddress(unsigned compId) const
+{
+ ASSERT_COMP_ID(compId, compCount_);
+ return iceDefaultRemoteAddr_[compId - 1];
+}
+
+void
+IceTransport::Impl::onReceiveData(unsigned comp_id, void* pkt, pj_size_t size)
+{
+ ASSERT_COMP_ID(comp_id, compCount_);
+
+ jami_tracepoint_if_enabled(ice_transport_recv,
+ reinterpret_cast<uint64_t>(this),
+ comp_id,
+ size,
+ getRemoteAddress(comp_id).toString().c_str());
+ if (size == 0)
+ return;
+
+ {
+ auto& io = compIO_[comp_id - 1];
+ std::lock_guard<std::mutex> lk(io.mutex);
+
+ if (io.recvCb) {
+ io.recvCb((uint8_t*) pkt, size);
+ return;
+ }
+ }
+
+ std::error_code ec;
+ auto err = peerChannels_.at(comp_id - 1).write((const char*) pkt, size, ec);
+ if (err < 0) {
+ if (logger_)
+ logger_->error("[ice:{}] rx: channel is closed", fmt::ptr(this));
+ }
+}
+
+bool
+IceTransport::Impl::_waitForInitialization(std::chrono::milliseconds timeout)
+{
+ IceLock lk(icest_);
+
+ if (not iceCV_.wait_for(lk, timeout, [this] {
+ return threadTerminateFlags_ or _isInitialized() or _isFailed();
+ })) {
+ if (logger_)
+ logger_->warn("[ice:{}] waitForInitialization: timeout", fmt::ptr(this));
+ return false;
+ }
+
+ return _isInitialized();
+}
+
+//==============================================================================
+
+IceTransport::IceTransport(std::string_view name)
+ : pimpl_ {std::make_unique<Impl>(name)}
+{}
+
+IceTransport::~IceTransport()
+{
+ cancelOperations();
+}
+
+const std::shared_ptr<dht::log::Logger>&
+IceTransport::logger() const
+{
+ return pimpl_->logger_;
+}
+
+void
+IceTransport::initIceInstance(const IceTransportOptions& options)
+{
+ pimpl_->initIceInstance(options);
+ jami_tracepoint(ice_transport_context, reinterpret_cast<uint64_t>(this));
+}
+
+bool
+IceTransport::isInitialized() const
+{
+ IceLock lk(pimpl_->icest_);
+ return pimpl_->_isInitialized();
+}
+
+bool
+IceTransport::isStarted() const
+{
+ IceLock lk(pimpl_->icest_);
+ return pimpl_->_isStarted();
+}
+
+bool
+IceTransport::isRunning() const
+{
+ if (!pimpl_->icest_)
+ return false;
+ IceLock lk(pimpl_->icest_);
+ return pimpl_->_isRunning();
+}
+
+bool
+IceTransport::isFailed() const
+{
+ return pimpl_->_isFailed();
+}
+
+unsigned
+IceTransport::getComponentCount() const
+{
+ return pimpl_->compCount_;
+}
+
+bool
+IceTransport::setSlaveSession()
+{
+ return pimpl_->setSlaveSession();
+}
+bool
+IceTransport::setInitiatorSession()
+{
+ return pimpl_->setInitiatorSession();
+}
+
+bool
+IceTransport::isInitiator() const
+{
+ if (isInitialized()) {
+ return pj_ice_strans_get_role(pimpl_->icest_) == PJ_ICE_SESS_ROLE_CONTROLLING;
+ }
+ return pimpl_->initiatorSession_;
+}
+
+bool
+IceTransport::startIce(const Attribute& rem_attrs, std::vector<IceCandidate>&& rem_candidates)
+{
+ if (not isInitialized()) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->error("[ice:{}] not initialized transport", fmt::ptr(pimpl_.get()));
+ pimpl_->is_stopped_ = true;
+ return false;
+ }
+
+ // pj_ice_strans_start_ice crashes if remote candidates array is empty
+ if (rem_candidates.empty()) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->error("[ice:{}] start failed: no remote candidates", fmt::ptr(pimpl_.get()));
+ pimpl_->is_stopped_ = true;
+ return false;
+ }
+
+ auto comp_cnt = std::max(1u, getComponentCount());
+ if (rem_candidates.size() / comp_cnt > PJ_ICE_ST_MAX_CAND - 1) {
+ std::vector<IceCandidate> rcands;
+ rcands.reserve(PJ_ICE_ST_MAX_CAND - 1);
+ if (pimpl_->logger_)
+ pimpl_->logger_->warn("[ice:{}] too much candidates detected, trim list.", fmt::ptr(pimpl_.get()));
+ // Just trim some candidates. To avoid to only take host candidates, iterate
+ // through the whole list and select some host, some turn and peer reflexives
+ // It should give at least enough infos to negotiate.
+ auto maxHosts = 8;
+ auto maxRelays = PJ_ICE_MAX_TURN;
+ for (auto& c : rem_candidates) {
+ if (c.type == PJ_ICE_CAND_TYPE_HOST) {
+ if (maxHosts == 0)
+ continue;
+ maxHosts -= 1;
+ } else if (c.type == PJ_ICE_CAND_TYPE_RELAYED) {
+ if (maxRelays == 0)
+ continue;
+ maxRelays -= 1;
+ }
+ if (rcands.size() == PJ_ICE_ST_MAX_CAND - 1)
+ break;
+ rcands.emplace_back(std::move(c));
+ }
+ rem_candidates = std::move(rcands);
+ }
+
+ pj_str_t ufrag, pwd;
+ if (pimpl_->logger_)
+ pimpl_->logger_->debug("[ice:{}] negotiation starting ({:d} remote candidates)",
+ fmt::ptr(pimpl_),
+ rem_candidates.size());
+
+ auto status = pj_ice_strans_start_ice(pimpl_->icest_,
+ pj_strset(&ufrag,
+ (char*) rem_attrs.ufrag.c_str(),
+ rem_attrs.ufrag.size()),
+ pj_strset(&pwd,
+ (char*) rem_attrs.pwd.c_str(),
+ rem_attrs.pwd.size()),
+ rem_candidates.size(),
+ rem_candidates.data());
+ if (status != PJ_SUCCESS) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->error("[ice:{}] start failed: {:s}", fmt::ptr(pimpl_.get()), sip_utils::sip_strerror(status));
+ pimpl_->is_stopped_ = true;
+ return false;
+ }
+
+ return true;
+}
+
+bool
+IceTransport::startIce(const SDP& sdp)
+{
+ if (pimpl_->streamsCount_ != 1) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->error(FMT_STRING("Expected exactly one stream per SDP (found {:u} streams)"), pimpl_->streamsCount_);
+ return false;
+ }
+
+ if (not isInitialized()) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->error(FMT_STRING("[ice:{}] not initialized transport"), fmt::ptr(pimpl_));
+ pimpl_->is_stopped_ = true;
+ return false;
+ }
+
+ for (unsigned id = 1; id <= getComponentCount(); id++) {
+ auto candVec = getLocalCandidates(id);
+ for (auto const& cand : candVec) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->debug("[ice:{}] Using local candidate {:s} for comp {:d}",
+ fmt::ptr(pimpl_), cand, id);
+ }
+ }
+
+ if (pimpl_->logger_)
+ pimpl_->logger_->debug("[ice:{}] negotiation starting ({:u} remote candidates)",
+ fmt::ptr(pimpl_), sdp.candidates.size());
+ pj_str_t ufrag, pwd;
+
+ std::vector<IceCandidate> rem_candidates;
+ rem_candidates.reserve(sdp.candidates.size());
+ IceCandidate cand;
+ for (const auto& line : sdp.candidates) {
+ if (parseIceAttributeLine(0, line, cand))
+ rem_candidates.emplace_back(cand);
+ }
+
+ auto status = pj_ice_strans_start_ice(pimpl_->icest_,
+ pj_strset(&ufrag,
+ (char*) sdp.ufrag.c_str(),
+ sdp.ufrag.size()),
+ pj_strset(&pwd, (char*) sdp.pwd.c_str(), sdp.pwd.size()),
+ rem_candidates.size(),
+ rem_candidates.data());
+ if (status != PJ_SUCCESS) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->error("[ice:{}] start failed: {:s}", fmt::ptr(pimpl_), sip_utils::sip_strerror(status));
+ pimpl_->is_stopped_ = true;
+ return false;
+ }
+
+ return true;
+}
+
+void
+IceTransport::cancelOperations()
+{
+ pimpl_->cancelOperations();
+}
+
+IpAddr
+IceTransport::getLocalAddress(unsigned comp_id) const
+{
+ return pimpl_->getLocalAddress(comp_id);
+}
+
+IpAddr
+IceTransport::getRemoteAddress(unsigned comp_id) const
+{
+ // Return the default remote address if set.
+ // Note that the default remote addresses are the addresses
+ // set in the 'c=' and 'a=rtcp' lines of the received SDP.
+ // See pj_ice_strans_sendto2() for more details.
+ if (auto defaultAddr = pimpl_->getDefaultRemoteAddress(comp_id)) {
+ return defaultAddr;
+ }
+
+ return pimpl_->getRemoteAddress(comp_id);
+}
+
+const IceTransport::Attribute
+IceTransport::getLocalAttributes() const
+{
+ return {pimpl_->local_ufrag_, pimpl_->local_pwd_};
+}
+
+std::vector<std::string>
+IceTransport::getLocalCandidates(unsigned comp_id) const
+{
+ ASSERT_COMP_ID(comp_id, getComponentCount());
+ std::vector<std::string> res;
+ pj_ice_sess_cand cand[MAX_CANDIDATES];
+ unsigned cand_cnt = PJ_ARRAY_SIZE(cand);
+
+ if (!isInitialized()) {
+ return res;
+ }
+
+ if (pj_ice_strans_enum_cands(pimpl_->icest_, comp_id, &cand_cnt, cand) != PJ_SUCCESS) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->error("[ice:{}] pj_ice_strans_enum_cands() failed", fmt::ptr(pimpl_));
+ return res;
+ }
+
+ res.reserve(cand_cnt);
+ for (unsigned i = 0; i < cand_cnt; ++i) {
+ /** Section 4.5, RFC 6544 (https://tools.ietf.org/html/rfc6544)
+ * candidate-attribute = "candidate" ":" foundation SP component-id
+ * SP "TCP" SP priority SP connection-address SP port SP cand-type [SP
+ * rel-addr] [SP rel-port] SP tcp-type-ext
+ * *(SP extension-att-name SP
+ * extension-att-value)
+ *
+ * tcp-type-ext = "tcptype" SP tcp-type
+ * tcp-type = "active" / "passive" / "so"
+ */
+ char ipaddr[PJ_INET6_ADDRSTRLEN];
+ std::string tcp_type;
+ if (cand[i].transport != PJ_CAND_UDP) {
+ tcp_type += " tcptype";
+ switch (cand[i].transport) {
+ case PJ_CAND_TCP_ACTIVE:
+ tcp_type += " active";
+ break;
+ case PJ_CAND_TCP_PASSIVE:
+ tcp_type += " passive";
+ break;
+ case PJ_CAND_TCP_SO:
+ default:
+ tcp_type += " so";
+ break;
+ }
+ }
+ res.emplace_back(
+ fmt::format("{} {} {} {} {} {} typ {}{}",
+ sip_utils::as_view(cand[i].foundation),
+ cand[i].comp_id,
+ (cand[i].transport == PJ_CAND_UDP ? "UDP" : "TCP"),
+ cand[i].prio,
+ pj_sockaddr_print(&cand[i].addr, ipaddr, sizeof(ipaddr), 0),
+ pj_sockaddr_get_port(&cand[i].addr),
+ pj_ice_get_cand_type_name(cand[i].type),
+ tcp_type));
+ }
+
+ return res;
+}
+std::vector<std::string>
+IceTransport::getLocalCandidates(unsigned streamIdx, unsigned compId) const
+{
+ ASSERT_COMP_ID(compId, getComponentCount());
+
+ std::vector<std::string> res;
+ pj_ice_sess_cand cand[MAX_CANDIDATES];
+ unsigned cand_cnt = MAX_CANDIDATES;
+
+ if (not isInitialized()) {
+ return res;
+ }
+
+ // In the implementation, the component IDs are enumerated globally
+ // (per SDP: 1, 2, 3, 4, ...). This is simpler because we create
+ // only one pj_ice_strans instance. However, the component IDs are
+ // enumerated per stream in the generated SDP (1, 2, 1, 2, ...) in
+ // order to be compliant with the spec.
+
+ auto globalCompId = streamIdx * 2 + compId;
+ if (pj_ice_strans_enum_cands(pimpl_->icest_, globalCompId, &cand_cnt, cand) != PJ_SUCCESS) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->error("[ice:{}] pj_ice_strans_enum_cands() failed", fmt::ptr(pimpl_));
+ return res;
+ }
+
+ res.reserve(cand_cnt);
+ // Build ICE attributes according to RFC 6544, section 4.5.
+ for (unsigned i = 0; i < cand_cnt; ++i) {
+ char ipaddr[PJ_INET6_ADDRSTRLEN];
+ std::string tcp_type;
+ if (cand[i].transport != PJ_CAND_UDP) {
+ tcp_type += " tcptype";
+ switch (cand[i].transport) {
+ case PJ_CAND_TCP_ACTIVE:
+ tcp_type += " active";
+ break;
+ case PJ_CAND_TCP_PASSIVE:
+ tcp_type += " passive";
+ break;
+ case PJ_CAND_TCP_SO:
+ default:
+ tcp_type += " so";
+ break;
+ }
+ }
+ res.emplace_back(
+ fmt::format("{} {} {} {} {} {} typ {}{}",
+ sip_utils::as_view(cand[i].foundation),
+ compId,
+ (cand[i].transport == PJ_CAND_UDP ? "UDP" : "TCP"),
+ cand[i].prio,
+ pj_sockaddr_print(&cand[i].addr, ipaddr, sizeof(ipaddr), 0),
+ pj_sockaddr_get_port(&cand[i].addr),
+ pj_ice_get_cand_type_name(cand[i].type),
+ tcp_type));
+ }
+
+ return res;
+}
+
+bool
+IceTransport::parseIceAttributeLine(unsigned streamIdx,
+ const std::string& line,
+ IceCandidate& cand) const
+{
+ // Silently ignore empty lines
+ if (line.empty())
+ return false;
+
+ if (streamIdx >= pimpl_->streamsCount_) {
+ throw std::runtime_error(fmt::format("Stream index {:d} is invalid!", streamIdx));
+ }
+
+ int af, cnt;
+ char foundation[32], transport[12], ipaddr[80], type[32], tcp_type[32];
+ pj_str_t tmpaddr;
+ unsigned comp_id, prio, port;
+ pj_status_t status;
+ pj_bool_t is_tcp = PJ_FALSE;
+
+ // Parse ICE attribute line according to RFC-6544 section 4.5.
+ // TODO/WARNING: There is no fail-safe in case of malformed attributes.
+ cnt = sscanf(line.c_str(),
+ "%31s %u %11s %u %79s %u typ %31s tcptype %31s\n",
+ foundation,
+ &comp_id,
+ transport,
+ &prio,
+ ipaddr,
+ &port,
+ type,
+ tcp_type);
+ if (cnt != 7 && cnt != 8) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->error("[ice:{}] Invalid ICE candidate line: {:s}", fmt::ptr(pimpl_), line);
+ return false;
+ }
+
+ if (strcmp(transport, "TCP") == 0) {
+ is_tcp = PJ_TRUE;
+ }
+
+ pj_bzero(&cand, sizeof(IceCandidate));
+
+ if (strcmp(type, "host") == 0)
+ cand.type = PJ_ICE_CAND_TYPE_HOST;
+ else if (strcmp(type, "srflx") == 0)
+ cand.type = PJ_ICE_CAND_TYPE_SRFLX;
+ else if (strcmp(type, "prflx") == 0)
+ cand.type = PJ_ICE_CAND_TYPE_PRFLX;
+ else if (strcmp(type, "relay") == 0)
+ cand.type = PJ_ICE_CAND_TYPE_RELAYED;
+ else {
+ if (pimpl_->logger_)
+ pimpl_->logger_->warn("[ice:{}] invalid remote candidate type '{:s}'", fmt::ptr(pimpl_), type);
+ return false;
+ }
+
+ if (is_tcp) {
+ if (strcmp(tcp_type, "active") == 0)
+ cand.transport = PJ_CAND_TCP_ACTIVE;
+ else if (strcmp(tcp_type, "passive") == 0)
+ cand.transport = PJ_CAND_TCP_PASSIVE;
+ else if (strcmp(tcp_type, "so") == 0)
+ cand.transport = PJ_CAND_TCP_SO;
+ else {
+ if (pimpl_->logger_)
+ pimpl_->logger_->warn("[ice:{}] invalid transport type type '{:s}'", fmt::ptr(pimpl_), tcp_type);
+ return false;
+ }
+ } else {
+ cand.transport = PJ_CAND_UDP;
+ }
+
+ // If the component Id is enumerated relative to media, convert
+ // it to absolute enumeration.
+ if (comp_id <= pimpl_->compCountPerStream_) {
+ comp_id += pimpl_->compCountPerStream_ * streamIdx;
+ }
+ cand.comp_id = (pj_uint8_t) comp_id;
+
+ cand.prio = prio;
+
+ if (strchr(ipaddr, ':'))
+ af = pj_AF_INET6();
+ else {
+ af = pj_AF_INET();
+ pimpl_->onlyIPv4Private_ &= IpAddr(ipaddr).isPrivate();
+ }
+
+ tmpaddr = pj_str(ipaddr);
+ pj_sockaddr_init(af, &cand.addr, NULL, 0);
+ status = pj_sockaddr_set_str_addr(af, &cand.addr, &tmpaddr);
+ if (status != PJ_SUCCESS) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->warn("[ice:{}] invalid IP address '{:s}'", fmt::ptr(pimpl_), ipaddr);
+ return false;
+ }
+
+ pj_sockaddr_set_port(&cand.addr, (pj_uint16_t) port);
+ pj_strdup2(pimpl_->pool_.get(), &cand.foundation, foundation);
+
+ return true;
+}
+
+ssize_t
+IceTransport::recv(unsigned compId, unsigned char* buf, size_t len, std::error_code& ec)
+{
+ ASSERT_COMP_ID(compId, getComponentCount());
+ auto& io = pimpl_->compIO_[compId - 1];
+ std::lock_guard<std::mutex> lk(io.mutex);
+
+ if (io.queue.empty()) {
+ ec = std::make_error_code(std::errc::resource_unavailable_try_again);
+ return -1;
+ }
+
+ auto& packet = io.queue.front();
+ const auto count = std::min(len, packet.data.size());
+ std::copy_n(packet.data.begin(), count, buf);
+ if (count == packet.data.size()) {
+ io.queue.pop_front();
+ } else {
+ packet.data.erase(packet.data.begin(), packet.data.begin() + count);
+ }
+
+ ec.clear();
+ return count;
+}
+
+ssize_t
+IceTransport::recvfrom(unsigned compId, char* buf, size_t len, std::error_code& ec)
+{
+ ASSERT_COMP_ID(compId, getComponentCount());
+ return pimpl_->peerChannels_.at(compId - 1).read(buf, len, ec);
+}
+
+void
+IceTransport::setOnRecv(unsigned compId, IceRecvCb cb)
+{
+ ASSERT_COMP_ID(compId, getComponentCount());
+
+ auto& io = pimpl_->compIO_[compId - 1];
+ std::lock_guard<std::mutex> lk(io.mutex);
+ io.recvCb = std::move(cb);
+
+ if (io.recvCb) {
+ // Flush existing queue using the callback
+ for (const auto& packet : io.queue)
+ io.recvCb((uint8_t*) packet.data.data(), packet.data.size());
+ io.queue.clear();
+ }
+}
+
+void
+IceTransport::setOnShutdown(onShutdownCb&& cb)
+{
+ pimpl_->scb = cb;
+}
+
+ssize_t
+IceTransport::send(unsigned compId, const unsigned char* buf, size_t len)
+{
+ ASSERT_COMP_ID(compId, getComponentCount());
+
+ auto remote = getRemoteAddress(compId);
+
+ if (!remote) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->error("[ice:{}] can't find remote address for component {:d}", fmt::ptr(pimpl_), compId);
+ errno = EINVAL;
+ return -1;
+ }
+
+ std::unique_lock dlk(pimpl_->sendDataMutex_, std::defer_lock);
+ if (isTCPEnabled())
+ dlk.lock();
+
+ jami_tracepoint(ice_transport_send,
+ reinterpret_cast<uint64_t>(this),
+ compId,
+ len,
+ remote.toString().c_str());
+
+ auto status = pj_ice_strans_sendto2(pimpl_->icest_,
+ compId,
+ buf,
+ len,
+ remote.pjPtr(),
+ remote.getLength());
+
+ jami_tracepoint(ice_transport_send_status, status);
+
+ if (status == PJ_EPENDING && isTCPEnabled()) {
+ // NOTE; because we are in TCP, the sent size will count the header (2
+ // bytes length).
+ pimpl_->waitDataCv_.wait(dlk, [&] {
+ return pimpl_->lastSentLen_ >= static_cast<pj_size_t>(len) or pimpl_->destroying_;
+ });
+ pimpl_->lastSentLen_ = 0;
+ } else if (status != PJ_SUCCESS && status != PJ_EPENDING) {
+ if (status == PJ_EBUSY) {
+ errno = EAGAIN;
+ } else {
+ if (pimpl_->logger_)
+ pimpl_->logger_->error("[ice:{}] ice send failed: {:s}", fmt::ptr(pimpl_), sip_utils::sip_strerror(status));
+ errno = EIO;
+ }
+ return -1;
+ }
+
+ return len;
+}
+
+bool
+IceTransport::waitForInitialization(std::chrono::milliseconds timeout)
+{
+ return pimpl_->_waitForInitialization(timeout);
+}
+
+ssize_t
+IceTransport::waitForData(unsigned compId, std::chrono::milliseconds timeout, std::error_code& ec)
+{
+ ASSERT_COMP_ID(compId, getComponentCount());
+ return pimpl_->peerChannels_.at(compId - 1).wait(timeout, ec);
+}
+
+bool
+IceTransport::isTCPEnabled()
+{
+ return pimpl_->isTcpEnabled();
+}
+
+ICESDP
+IceTransport::parseIceCandidates(std::string_view sdp_msg)
+{
+ if (pimpl_->streamsCount_ != 1) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->error("Expected exactly one stream per SDP (found %u streams)", pimpl_->streamsCount_);
+ return {};
+ }
+
+ ICESDP res;
+ int nr = 0;
+ for (std::string_view line; jami::getline(sdp_msg, line); nr++) {
+ if (nr == 0) {
+ res.rem_ufrag = line;
+ } else if (nr == 1) {
+ res.rem_pwd = line;
+ } else {
+ IceCandidate cand;
+ if (parseIceAttributeLine(0, std::string(line), cand)) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->debug("[ice:{}] Add remote candidate: {}",
+ fmt::ptr(pimpl_),
+ line);
+ res.rem_candidates.emplace_back(cand);
+ }
+ }
+ }
+ return res;
+}
+
+void
+IceTransport::setDefaultRemoteAddress(unsigned comp_id, const IpAddr& addr)
+{
+ pimpl_->setDefaultRemoteAddress(comp_id, addr);
+}
+
+std::string
+IceTransport::link() const
+{
+ return pimpl_->link();
+}
+
+//==============================================================================
+
+IceTransportFactory::IceTransportFactory()
+ : cp_(new pj_caching_pool(),
+ [](pj_caching_pool* p) {
+ pj_caching_pool_destroy(p);
+ delete p;
+ })
+ , ice_cfg_()
+{
+ pj_caching_pool_init(cp_.get(), NULL, 0);
+
+ pj_ice_strans_cfg_default(&ice_cfg_);
+ ice_cfg_.stun_cfg.pf = &cp_->factory;
+
+ // v2.4.5 of PJNATH has a default of 100ms but RFC 5389 since version 14 requires
+ // a minimum of 500ms on fixed-line links. Our usual case is wireless links.
+ // This solves too long ICE exchange by DHT.
+ // Using 500ms with default PJ_STUN_MAX_TRANSMIT_COUNT (7) gives around 33s before timeout.
+ ice_cfg_.stun_cfg.rto_msec = 500;
+
+ // See https://tools.ietf.org/html/rfc5245#section-8.1.1.2
+ // If enabled, it may help speed-up the connectivity, but may cause
+ // the nomination of sub-optimal pairs.
+ ice_cfg_.opt.aggressive = PJ_FALSE;
+}
+
+IceTransportFactory::~IceTransportFactory() {}
+
+std::shared_ptr<IceTransport>
+IceTransportFactory::createTransport(std::string_view name)
+{
+ try {
+ return std::make_shared<IceTransport>(name);
+ } catch (const std::exception& e) {
+ //JAMI_ERR("%s", e.what());
+ return nullptr;
+ }
+}
+
+std::unique_ptr<IceTransport>
+IceTransportFactory::createUTransport(std::string_view name)
+{
+ try {
+ return std::make_unique<IceTransport>(name);
+ } catch (const std::exception& e) {
+ //JAMI_ERR("%s", e.what());
+ return nullptr;
+ }
+}
+
+//==============================================================================
+
+void
+IceSocket::close()
+{
+ if (ice_transport_)
+ ice_transport_->setOnRecv(compId_, {});
+ ice_transport_.reset();
+}
+
+ssize_t
+IceSocket::send(const unsigned char* buf, size_t len)
+{
+ if (not ice_transport_)
+ return -1;
+ return ice_transport_->send(compId_, buf, len);
+}
+
+ssize_t
+IceSocket::waitForData(std::chrono::milliseconds timeout)
+{
+ if (not ice_transport_)
+ return -1;
+
+ std::error_code ec;
+ return ice_transport_->waitForData(compId_, timeout, ec);
+}
+
+void
+IceSocket::setOnRecv(IceRecvCb cb)
+{
+ if (ice_transport_)
+ ice_transport_->setOnRecv(compId_, cb);
+}
+
+uint16_t
+IceSocket::getTransportOverhead()
+{
+ if (not ice_transport_)
+ return 0;
+
+ return (ice_transport_->getRemoteAddress(compId_).getFamily() == AF_INET) ? IPV4_HEADER_SIZE
+ : IPV6_HEADER_SIZE;
+}
+
+void
+IceSocket::setDefaultRemoteAddress(const IpAddr& addr)
+{
+ if (ice_transport_)
+ ice_transport_->setDefaultRemoteAddress(compId_, addr);
+}
+
+} // namespace jami
diff --git a/src/ice_transport.h b/src/ice_transport.h
new file mode 100644
index 0000000..0bf6432
--- /dev/null
+++ b/src/ice_transport.h
@@ -0,0 +1,219 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+
+#include "ice_options.h"
+#include "ice_socket.h"
+#include "ip_utils.h"
+
+#include <pjnath.h>
+#include <pjlib.h>
+#include <pjlib-util.h>
+
+#include <functional>
+#include <memory>
+#include <msgpack.hpp>
+#include <vector>
+
+namespace dht {
+namespace log {
+class Logger;
+}
+}
+
+namespace jami {
+
+using Logger = dht::log::Logger;
+
+namespace upnp {
+class Controller;
+}
+
+class IceTransport;
+
+using IceRecvCb = std::function<ssize_t(unsigned char* buf, size_t len)>;
+using IceCandidate = pj_ice_sess_cand;
+using onShutdownCb = std::function<void(void)>;
+
+struct ICESDP
+{
+ std::vector<IceCandidate> rem_candidates;
+ std::string rem_ufrag;
+ std::string rem_pwd;
+};
+
+struct SDP
+{
+ std::string ufrag;
+ std::string pwd;
+
+ std::vector<std::string> candidates;
+ MSGPACK_DEFINE(ufrag, pwd, candidates)
+};
+
+class IceTransport
+{
+public:
+ using Attribute = struct
+ {
+ std::string ufrag;
+ std::string pwd;
+ };
+
+ /**
+ * Constructor
+ */
+ IceTransport(std::string_view name);
+ ~IceTransport();
+
+ const std::shared_ptr<Logger>& logger() const;
+
+ void initIceInstance(const IceTransportOptions& options);
+
+ /**
+ * Get current state
+ */
+ bool isInitiator() const;
+
+ /**
+ * Start transport negotiation between local candidates and given remote
+ * to find the right candidate pair.
+ * This function doesn't block, the callback on_negodone_cb will be called
+ * with the negotiation result when operation is really done.
+ * Return false if negotiation cannot be started else true.
+ */
+ bool startIce(const Attribute& rem_attrs, std::vector<IceCandidate>&& rem_candidates);
+ bool startIce(const SDP& sdp);
+
+ /**
+ * Cancel operations
+ */
+ void cancelOperations();
+
+ /**
+ * Returns true if ICE transport has been initialized
+ * [mutex protected]
+ */
+ bool isInitialized() const;
+
+ /**
+ * Returns true if ICE negotiation has been started
+ * [mutex protected]
+ */
+ bool isStarted() const;
+
+ /**
+ * Returns true if ICE negotiation has completed with success
+ * [mutex protected]
+ */
+ bool isRunning() const;
+
+ /**
+ * Returns true if ICE transport is in failure state
+ * [mutex protected]
+ */
+ bool isFailed() const;
+
+ IpAddr getLocalAddress(unsigned comp_id) const;
+
+ IpAddr getRemoteAddress(unsigned comp_id) const;
+
+ IpAddr getDefaultLocalAddress() const { return getLocalAddress(1); }
+
+ /**
+ * Return ICE session attributes
+ */
+ const Attribute getLocalAttributes() const;
+
+ /**
+ * Return ICE session attributes
+ */
+ std::vector<std::string> getLocalCandidates(unsigned comp_id) const;
+
+ /**
+ * Return ICE session attributes
+ */
+ std::vector<std::string> getLocalCandidates(unsigned streamIdx, unsigned compId) const;
+
+ bool parseIceAttributeLine(unsigned streamIdx,
+ const std::string& line,
+ IceCandidate& cand) const;
+
+ bool getCandidateFromSDP(const std::string& line, IceCandidate& cand) const;
+
+ // I/O methods
+
+ void setOnRecv(unsigned comp_id, IceRecvCb cb);
+ void setOnShutdown(onShutdownCb&& cb);
+
+ ssize_t recv(unsigned comp_id, unsigned char* buf, size_t len, std::error_code& ec);
+ ssize_t recvfrom(unsigned comp_id, char* buf, size_t len, std::error_code& ec);
+
+ ssize_t send(unsigned comp_id, const unsigned char* buf, size_t len);
+
+ bool waitForInitialization(std::chrono::milliseconds timeout);
+
+ int waitForNegotiation(std::chrono::milliseconds timeout);
+
+ ssize_t waitForData(unsigned comp_id, std::chrono::milliseconds timeout, std::error_code& ec);
+
+ unsigned getComponentCount() const;
+
+ // Set session state
+ bool setSlaveSession();
+ bool setInitiatorSession();
+
+ bool isTCPEnabled();
+
+ ICESDP parseIceCandidates(std::string_view sdp_msg);
+
+ void setDefaultRemoteAddress(unsigned comp_id, const IpAddr& addr);
+
+ std::string link() const;
+
+private:
+ class Impl;
+ std::unique_ptr<Impl> pimpl_;
+};
+
+class IceTransportFactory
+{
+public:
+ IceTransportFactory();
+ ~IceTransportFactory();
+
+ std::shared_ptr<IceTransport> createTransport(std::string_view name);
+
+ std::unique_ptr<IceTransport> createUTransport(std::string_view name);
+
+ /**
+ * PJSIP specifics
+ */
+ pj_ice_strans_cfg getIceCfg() const { return ice_cfg_; }
+ pj_pool_factory* getPoolFactory() { return &cp_->factory; }
+ std::shared_ptr<pj_caching_pool> getPoolCaching() { return cp_; }
+
+private:
+ std::shared_ptr<pj_caching_pool> cp_;
+ pj_ice_strans_cfg ice_cfg_;
+};
+
+}; // namespace jami
diff --git a/src/ip_utils.cpp b/src/ip_utils.cpp
new file mode 100644
index 0000000..494dc9b
--- /dev/null
+++ b/src/ip_utils.cpp
@@ -0,0 +1,501 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "ip_utils.h"
+#include "logger.h"
+
+#include "connectivity/sip_utils.h"
+
+#include <sys/types.h>
+#include <unistd.h>
+#include <limits.h>
+
+#ifdef _WIN32
+#define InetPtonA inet_pton
+WINSOCK_API_LINKAGE INT WSAAPI InetPtonA(INT Family, LPCSTR pStringBuf, PVOID pAddr);
+#else
+#include <arpa/inet.h>
+#include <arpa/nameser.h>
+#include <resolv.h>
+#include <netdb.h>
+#include <netinet/ip.h>
+#include <net/if.h>
+#include <ifaddrs.h>
+#include <sys/ioctl.h>
+#endif
+
+#ifndef HOST_NAME_MAX
+#ifdef MAX_COMPUTERNAME_LENGTH
+#define HOST_NAME_MAX MAX_COMPUTERNAME_LENGTH
+#else
+// Max 255 chars as per RFC 1035
+#define HOST_NAME_MAX 255
+#endif
+#endif
+
+namespace jami {
+
+std::string_view
+sip_strerror(pj_status_t code)
+{
+ thread_local char err_msg[PJ_ERR_MSG_SIZE];
+ return as_view(pj_strerror(code, err_msg, sizeof err_msg));
+}
+
+
+std::string
+ip_utils::getHostname()
+{
+ char hostname[HOST_NAME_MAX];
+ if (gethostname(hostname, HOST_NAME_MAX))
+ return {};
+ return hostname;
+}
+
+int
+ip_utils::getHostName(char* out, size_t out_len)
+{
+ char tempstr[INET_ADDRSTRLEN];
+ const char* p = NULL;
+#ifdef _WIN32
+ struct hostent* h = NULL;
+ struct sockaddr_in localAddr;
+ memset(&localAddr, 0, sizeof(localAddr));
+ gethostname(out, out_len);
+ h = gethostbyname(out);
+ if (h != NULL) {
+ memcpy(&localAddr.sin_addr, h->h_addr_list[0], 4);
+ p = inet_ntop(AF_INET, &localAddr.sin_addr, tempstr, sizeof(tempstr));
+ if (p)
+ strncpy(out, p, out_len);
+ else
+ return -1;
+ } else {
+ return -1;
+ }
+#elif (defined(BSD) && BSD >= 199306) || defined(__FreeBSD_kernel__)
+ int retVal = 0;
+ struct ifaddrs* ifap;
+ struct ifaddrs* ifa;
+ if (getifaddrs(&ifap) != 0)
+ return -1;
+ // Cycle through available interfaces.
+ for (ifa = ifap; ifa != NULL; ifa = ifa->ifa_next) {
+ // Skip loopback, point-to-point and down interfaces.
+ // except don't skip down interfaces if we're trying to get
+ // a list of configurable interfaces.
+ if ((ifa->ifa_flags & IFF_LOOPBACK) || (!(ifa->ifa_flags & IFF_UP)))
+ continue;
+ if (ifa->ifa_addr->sa_family == AF_INET) {
+ if (((struct sockaddr_in*) (ifa->ifa_addr))->sin_addr.s_addr == htonl(INADDR_LOOPBACK)) {
+ // We don't want the loopback interface. Go to next one.
+ continue;
+ }
+ p = inet_ntop(AF_INET,
+ &((struct sockaddr_in*) (ifa->ifa_addr))->sin_addr,
+ tempstr,
+ sizeof(tempstr));
+ if (p)
+ strncpy(out, p, out_len);
+ else
+ retVal = -1;
+ break;
+ }
+ }
+ freeifaddrs(ifap);
+ retVal = ifa ? 0 : -1;
+ return retVal;
+#else
+ struct ifconf ifConf;
+ struct ifreq ifReq;
+ struct sockaddr_in localAddr;
+ char szBuffer[MAX_INTERFACE * sizeof(struct ifreq)];
+ int nResult;
+ int localSock;
+ memset(&ifConf, 0, sizeof(ifConf));
+ memset(&ifReq, 0, sizeof(ifReq));
+ memset(szBuffer, 0, sizeof(szBuffer));
+ memset(&localAddr, 0, sizeof(localAddr));
+ // Create an unbound datagram socket to do the SIOCGIFADDR ioctl on.
+ localSock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
+ if (localSock == INVALID_SOCKET)
+ return -1;
+ /* Get the interface configuration information... */
+ ifConf.ifc_len = (int) sizeof szBuffer;
+ ifConf.ifc_ifcu.ifcu_buf = (caddr_t) szBuffer;
+ nResult = ioctl(localSock, SIOCGIFCONF, &ifConf);
+ if (nResult < 0) {
+ close(localSock);
+ return -1;
+ }
+ unsigned int i;
+ unsigned int j = 0;
+ // Cycle through the list of interfaces looking for IP addresses.
+ for (i = 0u; i < (unsigned int) ifConf.ifc_len && j < MIN_INTERFACE;) {
+ struct ifreq* pifReq = (struct ifreq*) ((caddr_t) ifConf.ifc_req + i);
+ i += sizeof *pifReq;
+ // See if this is the sort of interface we want to deal with.
+ memset(ifReq.ifr_name, 0, sizeof(ifReq.ifr_name));
+ strncpy(ifReq.ifr_name, pifReq->ifr_name, sizeof(ifReq.ifr_name));
+ ioctl(localSock, SIOCGIFFLAGS, &ifReq);
+ // Skip loopback, point-to-point and down interfaces.
+ // except don't skip down interfaces if we're trying to get
+ // a list of configurable interfaces.
+ if ((ifReq.ifr_flags & IFF_LOOPBACK) || (!(ifReq.ifr_flags & IFF_UP)))
+ continue;
+ if (pifReq->ifr_addr.sa_family == AF_INET) {
+ memcpy(&localAddr, &pifReq->ifr_addr, sizeof pifReq->ifr_addr);
+ if (localAddr.sin_addr.s_addr == htonl(INADDR_LOOPBACK)) {
+ // We don't want the loopback interface. Go to the next one.
+ continue;
+ }
+ }
+ j++; // Increment j if we found an address which is not loopback and is up.
+ }
+ close(localSock);
+ p = inet_ntop(AF_INET, &localAddr.sin_addr, tempstr, sizeof(tempstr));
+ if (p)
+ strncpy(out, p, out_len);
+ else
+ return -1;
+#endif
+ return 0;
+}
+std::string
+ip_utils::getGateway(char* localHost, ip_utils::subnet_mask prefix)
+{
+ std::string_view localHostStr(localHost);
+ if (prefix == ip_utils::subnet_mask::prefix_32bit)
+ return std::string(localHostStr);
+ std::string defaultGw {};
+ // Make a vector of each individual number in the ip address.
+ std::vector<std::string_view> tokens = split_string(localHostStr, '.');
+ // Build a gateway address from the individual ip components.
+ for (unsigned i = 0; i <= (unsigned) prefix; i++)
+ defaultGw += tokens[i] + ".";
+ for (unsigned i = (unsigned) ip_utils::subnet_mask::prefix_32bit;
+ i > (unsigned) prefix + 1;
+ i--)
+ defaultGw += "0.";
+ defaultGw += "1";
+ return defaultGw;
+}
+
+IpAddr
+ip_utils::getLocalGateway()
+{
+ char localHostBuf[INET_ADDRSTRLEN];
+ if (ip_utils::getHostName(localHostBuf, INET_ADDRSTRLEN) < 0) {
+ JAMI_WARN("Couldn't find local host");
+ return {};
+ } else {
+ return IpAddr(ip_utils::getGateway(localHostBuf, ip_utils::subnet_mask::prefix_24bit));
+ }
+}
+
+std::vector<IpAddr>
+ip_utils::getAddrList(std::string_view name, pj_uint16_t family)
+{
+ std::vector<IpAddr> ipList;
+ if (name.empty())
+ return ipList;
+ if (IpAddr::isValid(name, family)) {
+ ipList.emplace_back(name);
+ return ipList;
+ }
+
+ static constexpr unsigned MAX_ADDR_NUM = 128;
+ pj_addrinfo res[MAX_ADDR_NUM];
+ unsigned addr_num = MAX_ADDR_NUM;
+ const pj_str_t pjname(sip_utils::CONST_PJ_STR(name));
+ auto status = pj_getaddrinfo(family, &pjname, &addr_num, res);
+ if (status != PJ_SUCCESS) {
+ JAMI_ERR("Error resolving %.*s : %s",
+ (int) name.size(),
+ name.data(),
+ sip_utils::sip_strerror(status).c_str());
+ return ipList;
+ }
+
+ for (unsigned i = 0; i < addr_num; i++) {
+ bool found = false;
+ for (const auto& ip : ipList)
+ if (!pj_sockaddr_cmp(&ip, &res[i].ai_addr)) {
+ found = true;
+ break;
+ }
+ if (!found)
+ ipList.emplace_back(res[i].ai_addr);
+ }
+
+ return ipList;
+}
+
+bool
+ip_utils::haveCommonAddr(const std::vector<IpAddr>& a, const std::vector<IpAddr>& b)
+{
+ for (const auto& i : a) {
+ for (const auto& j : b) {
+ if (i == j)
+ return true;
+ }
+ }
+ return false;
+}
+
+IpAddr
+ip_utils::getLocalAddr(pj_uint16_t family)
+{
+ IpAddr ip_addr {};
+ pj_status_t status = pj_gethostip(family, ip_addr.pjPtr());
+ if (status == PJ_SUCCESS) {
+ return ip_addr;
+ }
+ JAMI_WARN("Could not get preferred address familly (%s)",
+ (family == pj_AF_INET6()) ? "IPv6" : "IPv4");
+ family = (family == pj_AF_INET()) ? pj_AF_INET6() : pj_AF_INET();
+ status = pj_gethostip(family, ip_addr.pjPtr());
+ if (status == PJ_SUCCESS) {
+ return ip_addr;
+ }
+ JAMI_ERR("Could not get local IP");
+ return ip_addr;
+}
+
+IpAddr
+ip_utils::getInterfaceAddr(const std::string& interface, pj_uint16_t family)
+{
+ if (interface == DEFAULT_INTERFACE)
+ return getLocalAddr(family);
+
+ IpAddr addr {};
+
+#ifndef _WIN32
+ const auto unix_family = family == pj_AF_INET() ? AF_INET : AF_INET6;
+
+ int fd = socket(unix_family, SOCK_DGRAM, 0);
+ if (fd < 0) {
+ JAMI_ERR("Could not open socket: %m");
+ return addr;
+ }
+
+ if (unix_family == AF_INET6) {
+ int val = family != pj_AF_UNSPEC();
+ if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, (void*) &val, sizeof(val)) < 0) {
+ JAMI_ERR("Could not setsockopt: %m");
+ close(fd);
+ return addr;
+ }
+ }
+
+ ifreq ifr;
+ strncpy(ifr.ifr_name, interface.c_str(), sizeof ifr.ifr_name);
+ // guarantee that ifr_name is NULL-terminated
+ ifr.ifr_name[sizeof(ifr.ifr_name) - 1] = '\0';
+
+ memset(&ifr.ifr_addr, 0, sizeof(ifr.ifr_addr));
+ ifr.ifr_addr.sa_family = unix_family;
+
+ ioctl(fd, SIOCGIFADDR, &ifr);
+ close(fd);
+
+ addr = ifr.ifr_addr;
+ if (addr.isUnspecified())
+ return getLocalAddr(addr.getFamily());
+#else // _WIN32
+ struct addrinfo hints;
+ struct addrinfo* result = NULL;
+ struct sockaddr_in* sockaddr_ipv4;
+ struct sockaddr_in6* sockaddr_ipv6;
+
+ ZeroMemory(&hints, sizeof(hints));
+
+ DWORD dwRetval = getaddrinfo(interface.c_str(), "0", &hints, &result);
+ if (dwRetval != 0) {
+ JAMI_ERR("getaddrinfo failed with error: %lu", dwRetval);
+ return addr;
+ }
+
+ switch (result->ai_family) {
+ sockaddr_ipv4 = (struct sockaddr_in*) result->ai_addr;
+ addr = sockaddr_ipv4->sin_addr;
+ break;
+ case AF_INET6:
+ sockaddr_ipv6 = (struct sockaddr_in6*) result->ai_addr;
+ addr = sockaddr_ipv6->sin6_addr;
+ break;
+ default:
+ break;
+ }
+
+ if (addr.isUnspecified())
+ return getLocalAddr(addr.getFamily());
+#endif // !_WIN32
+
+ return addr;
+}
+
+std::vector<std::string>
+ip_utils::getAllIpInterfaceByName()
+{
+ std::vector<std::string> ifaceList;
+ ifaceList.push_back("default");
+#ifndef _WIN32
+ static ifreq ifreqs[20];
+ ifconf ifconf;
+
+ ifconf.ifc_buf = (char*) (ifreqs);
+ ifconf.ifc_len = sizeof(ifreqs);
+
+ int sock = socket(AF_INET6, SOCK_STREAM, 0);
+
+ if (sock >= 0) {
+ if (ioctl(sock, SIOCGIFCONF, &ifconf) >= 0)
+ for (unsigned i = 0; i < ifconf.ifc_len / sizeof(ifreq); ++i)
+ ifaceList.push_back(std::string(ifreqs[i].ifr_name));
+
+ close(sock);
+ }
+
+#else
+ JAMI_ERR("Not implemented yet. (iphlpapi.h problem)");
+#endif
+ return ifaceList;
+}
+
+std::vector<std::string>
+ip_utils::getAllIpInterface()
+{
+ pj_sockaddr addrList[16];
+ unsigned addrCnt = PJ_ARRAY_SIZE(addrList);
+
+ std::vector<std::string> ifaceList;
+
+ if (pj_enum_ip_interface(pj_AF_UNSPEC(), &addrCnt, addrList) == PJ_SUCCESS) {
+ for (unsigned i = 0; i < addrCnt; i++) {
+ char addr[PJ_INET6_ADDRSTRLEN];
+ pj_sockaddr_print(&addrList[i], addr, sizeof(addr), 0);
+ ifaceList.push_back(std::string(addr));
+ }
+ }
+
+ return ifaceList;
+}
+
+std::vector<IpAddr>
+ip_utils::getLocalNameservers()
+{
+ std::vector<IpAddr> res;
+#if defined __ANDROID__ || defined _WIN32 || TARGET_OS_IPHONE
+#ifdef _MSC_VER
+#pragma message(__FILE__ "(" STR2(__LINE__) ") : -NOTE- " \
+ "Not implemented")
+#else
+#warning "Not implemented"
+#endif
+#else
+ if (not(_res.options & RES_INIT))
+ res_init();
+ res.insert(res.end(), _res.nsaddr_list, _res.nsaddr_list + _res.nscount);
+#endif
+ return res;
+}
+
+bool
+IpAddr::isValid(std::string_view address, pj_uint16_t family)
+{
+ const pj_str_t pjstring(sip_utils::CONST_PJ_STR(address));
+ pj_str_t ret_str;
+ pj_uint16_t ret_port;
+ int ret_family;
+ auto status = pj_sockaddr_parse2(pj_AF_UNSPEC(), 0, &pjstring, &ret_str, &ret_port, &ret_family);
+ if (status != PJ_SUCCESS || (family != pj_AF_UNSPEC() && ret_family != family))
+ return false;
+
+ char buf[PJ_INET6_ADDRSTRLEN];
+ pj_str_t addr_with_null = {buf, 0};
+ pj_strncpy_with_null(&addr_with_null, &ret_str, sizeof(buf));
+ struct sockaddr sa;
+ return inet_pton(ret_family == pj_AF_INET6() ? AF_INET6 : AF_INET, buf, &(sa.sa_data)) == 1;
+}
+
+bool
+IpAddr::isUnspecified() const
+{
+ switch (addr.addr.sa_family) {
+ case AF_INET:
+ return IN_IS_ADDR_UNSPECIFIED(&addr.ipv4.sin_addr);
+ case AF_INET6:
+ return IN6_IS_ADDR_UNSPECIFIED(reinterpret_cast<const in6_addr*>(&addr.ipv6.sin6_addr));
+ default:
+ return true;
+ }
+}
+
+bool
+IpAddr::isLoopback() const
+{
+ switch (addr.addr.sa_family) {
+ case AF_INET: {
+ auto addr_host = ntohl(addr.ipv4.sin_addr.s_addr);
+ uint8_t b1 = (uint8_t)(addr_host >> 24);
+ return b1 == 127;
+ }
+ case AF_INET6:
+ return IN6_IS_ADDR_LOOPBACK(reinterpret_cast<const in6_addr*>(&addr.ipv6.sin6_addr));
+ default:
+ return false;
+ }
+}
+
+bool
+IpAddr::isPrivate() const
+{
+ if (isLoopback()) {
+ return true;
+ }
+ switch (addr.addr.sa_family) {
+ case AF_INET: {
+ auto addr_host = ntohl(addr.ipv4.sin_addr.s_addr);
+ uint8_t b1, b2;
+ b1 = (uint8_t)(addr_host >> 24);
+ b2 = (uint8_t)((addr_host >> 16) & 0x0ff);
+ // 10.x.y.z
+ if (b1 == 10)
+ return true;
+ // 172.16.0.0 - 172.31.255.255
+ if ((b1 == 172) && (b2 >= 16) && (b2 <= 31))
+ return true;
+ // 192.168.0.0 - 192.168.255.255
+ if ((b1 == 192) && (b2 == 168))
+ return true;
+ return false;
+ }
+ case AF_INET6: {
+ const pj_uint8_t* addr6 = reinterpret_cast<const pj_uint8_t*>(&addr.ipv6.sin6_addr);
+ if (addr6[0] == 0xfc)
+ return true;
+ return false;
+ }
+ default:
+ return false;
+ }
+}
+} // namespace jami
diff --git a/src/multiplexed_socket.cpp b/src/multiplexed_socket.cpp
new file mode 100644
index 0000000..5abf742
--- /dev/null
+++ b/src/multiplexed_socket.cpp
@@ -0,0 +1,1208 @@
+/*
+ * Copyright (C) 2019-2023 Savoir-faire Linux Inc.
+ * Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include "multiplexed_socket.h"
+#include "peer_connection.h"
+#include "ice_transport.h"
+#include "certstore.h"
+
+#include <opendht/logger.h>
+#include <opendht/thread_pool.h>
+
+#include <asio/io_context.hpp>
+#include <asio/steady_timer.hpp>
+
+#include <deque>
+
+static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations
+static constexpr int MULTIPLEXED_SOCKET_VERSION {1};
+
+struct ChanneledMessage
+{
+ uint16_t channel;
+ std::vector<uint8_t> data;
+ MSGPACK_DEFINE(channel, data)
+};
+
+struct BeaconMsg
+{
+ bool p;
+ MSGPACK_DEFINE_MAP(p)
+};
+
+struct VersionMsg
+{
+ int v;
+ MSGPACK_DEFINE_MAP(v)
+};
+
+namespace jami {
+
+using clock = std::chrono::steady_clock;
+using time_point = clock::time_point;
+
+class MultiplexedSocket::Impl
+{
+public:
+ Impl(MultiplexedSocket& parent,
+ std::shared_ptr<asio::io_context> ctx,
+ const DeviceId& deviceId,
+ std::unique_ptr<TlsSocketEndpoint> endpoint)
+ : parent_(parent)
+ , deviceId(deviceId)
+ , ctx_(std::move(ctx))
+ , beaconTimer_(*ctx_)
+ , endpoint(std::move(endpoint))
+ , eventLoopThread_ {[this] {
+ try {
+ eventLoop();
+ } catch (const std::exception& e) {
+ if (logger_)
+ logger_->error("[CNX] peer connection event loop failure: {}", e.what());
+ shutdown();
+ }
+ }}
+ {}
+
+ ~Impl() {}
+
+ void join()
+ {
+ if (!isShutdown_) {
+ if (endpoint)
+ endpoint->setOnStateChange({});
+ shutdown();
+ } else {
+ clearSockets();
+ }
+ if (eventLoopThread_.joinable())
+ eventLoopThread_.join();
+ }
+
+ void clearSockets()
+ {
+ decltype(sockets) socks;
+ {
+ std::lock_guard<std::mutex> lkSockets(socketsMutex);
+ socks = std::move(sockets);
+ }
+ for (auto& socket : socks) {
+ // Just trigger onShutdown() to make client know
+ // No need to write the EOF for the channel, the write will fail because endpoint is
+ // already shutdown
+ if (socket.second)
+ socket.second->stop();
+ }
+ }
+
+ void shutdown()
+ {
+ if (isShutdown_)
+ return;
+ stop.store(true);
+ isShutdown_ = true;
+ beaconTimer_.cancel();
+ if (onShutdown_)
+ onShutdown_();
+ if (endpoint) {
+ std::unique_lock<std::mutex> lk(writeMtx);
+ endpoint->shutdown();
+ }
+ clearSockets();
+ }
+
+ std::shared_ptr<ChannelSocket> makeSocket(const std::string& name,
+ uint16_t channel,
+ bool isInitiator = false)
+ {
+ auto& channelSocket = sockets[channel];
+ if (not channelSocket)
+ channelSocket = std::make_shared<ChannelSocket>(
+ parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() {
+ // Remove socket in another thread to avoid any lock
+ dht::ThreadPool::io().run([w, channel]() {
+ if (auto shared = w.lock()) {
+ shared->eraseChannel(channel);
+ }
+ });
+ });
+ else {
+ if (logger_)
+ logger_->warn("A channel is already present on that socket, accepting "
+ "the request will close the previous one {}", name);
+ }
+ return channelSocket;
+ }
+
+ /**
+ * Handle packets on the TLS endpoint and parse RTP
+ */
+ void eventLoop();
+ /**
+ * Triggered when a new control packet is received
+ */
+ void handleControlPacket(std::vector<uint8_t>&& pkt);
+ void handleProtocolPacket(std::vector<uint8_t>&& pkt);
+ bool handleProtocolMsg(const msgpack::object& o);
+ /**
+ * Triggered when a new packet on a channel is received
+ */
+ void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt);
+ void onRequest(const std::string& name, uint16_t channel);
+ void onAccept(const std::string& name, uint16_t channel);
+
+ void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); }
+ void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); }
+
+ // Beacon
+ void sendBeacon(const std::chrono::milliseconds& timeout);
+ void handleBeaconRequest();
+ void handleBeaconResponse();
+ std::atomic_int beaconCounter_ {0};
+
+ bool writeProtocolMessage(const msgpack::sbuffer& buffer);
+
+ msgpack::unpacker pac_ {};
+
+ MultiplexedSocket& parent_;
+
+ std::shared_ptr<Logger> logger_;
+ std::shared_ptr<asio::io_context> ctx_;
+
+ OnConnectionReadyCb onChannelReady_ {};
+ OnConnectionRequestCb onRequest_ {};
+ OnShutdownCb onShutdown_ {};
+
+ DeviceId deviceId {};
+ // Main socket
+ std::unique_ptr<TlsSocketEndpoint> endpoint {};
+
+ std::mutex socketsMutex {};
+ std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {};
+
+ // Main loop to parse incoming packets
+ std::atomic_bool stop {false};
+ std::thread eventLoopThread_ {};
+
+ std::atomic_bool isShutdown_ {false};
+
+ std::mutex writeMtx {};
+
+ time_point start_ {clock::now()};
+ //std::shared_ptr<Task> beaconTask_ {};
+ asio::steady_timer beaconTimer_;
+
+ // version related stuff
+ void sendVersion();
+ void onVersion(int version);
+ std::atomic_bool canSendBeacon_ {false};
+ std::atomic_bool answerBeacon_ {true};
+ int version_ {MULTIPLEXED_SOCKET_VERSION};
+ std::function<void(bool)> onBeaconCb_ {};
+ std::function<void(int)> onVersionCb_ {};
+};
+
+void
+MultiplexedSocket::Impl::eventLoop()
+{
+ endpoint->setOnStateChange([this](tls::TlsSessionState state) {
+ if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
+ if (logger_)
+ logger_->debug("Tls endpoint is down, shutdown multiplexed socket");
+ shutdown();
+ return false;
+ }
+ return true;
+ });
+ sendVersion();
+ std::error_code ec;
+ while (!stop) {
+ if (!endpoint) {
+ shutdown();
+ return;
+ }
+ pac_.reserve_buffer(IO_BUFFER_SIZE);
+ int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec);
+ if (size < 0) {
+ if (ec && logger_)
+ logger_->error("Read error detected: {}", ec.message());
+ break;
+ }
+ if (size == 0) {
+ // We can close the socket
+ shutdown();
+ break;
+ }
+
+ pac_.buffer_consumed(size);
+ msgpack::object_handle oh;
+ while (pac_.next(oh) && !stop) {
+ try {
+ auto msg = oh.get().as<ChanneledMessage>();
+ if (msg.channel == CONTROL_CHANNEL)
+ handleControlPacket(std::move(msg.data));
+ else if (msg.channel == PROTOCOL_CHANNEL)
+ handleProtocolPacket(std::move(msg.data));
+ else
+ handleChannelPacket(msg.channel, std::move(msg.data));
+ } catch (const std::exception& e) {
+ if (logger_)
+ logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what());
+ } catch (...) {
+ if (logger_)
+ logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size);
+ }
+ }
+ }
+}
+
+void
+MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
+{
+ std::lock_guard<std::mutex> lkSockets(socketsMutex);
+ auto& socket = sockets[channel];
+ if (!socket) {
+ if (logger_)
+ logger_->error("Receiving an answer for a non existing channel. This is a bug.");
+ return;
+ }
+
+ onChannelReady_(deviceId, socket);
+ socket->ready();
+ // Due to the callbacks that can take some time, onAccept can arrive after
+ // receiving all the data. In this case, the socket should be removed here
+ // as handle by onChannelReady_
+ if (socket->isRemovable())
+ sockets.erase(channel);
+ else
+ socket->answered();
+}
+
+void
+MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout)
+{
+ if (!canSendBeacon_)
+ return;
+ beaconCounter_++;
+ if (logger_)
+ logger_->debug("Send beacon to peer {}", deviceId);
+
+ msgpack::sbuffer buffer(8);
+ msgpack::packer<msgpack::sbuffer> pk(&buffer);
+ pk.pack(BeaconMsg {true});
+ if (!writeProtocolMessage(buffer))
+ return;
+ beaconTimer_.expires_after(timeout);
+ beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) {
+ if (ec == asio::error::operation_aborted)
+ return;
+ if (auto shared = w.lock()) {
+ if (shared->pimpl_->beaconCounter_ != 0) {
+ if (shared->pimpl_->logger_)
+ shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket");
+ shared->shutdown();
+ }
+ }
+ });
+}
+
+void
+MultiplexedSocket::Impl::handleBeaconRequest()
+{
+ if (!answerBeacon_)
+ return;
+ // Run this on dedicated thread because some callbacks can take time
+ dht::ThreadPool::io().run([w = parent_.weak()]() {
+ if (auto shared = w.lock()) {
+ msgpack::sbuffer buffer(8);
+ msgpack::packer<msgpack::sbuffer> pk(&buffer);
+ pk.pack(BeaconMsg {false});
+ if (shared->pimpl_->logger_)
+ shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId());
+ shared->pimpl_->writeProtocolMessage(buffer);
+ }
+ });
+}
+
+void
+MultiplexedSocket::Impl::handleBeaconResponse()
+{
+ if (logger_)
+ logger_->debug("Get beacon response from peer {}", deviceId);
+ beaconCounter_--;
+}
+
+bool
+MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer)
+{
+ std::error_code ec;
+ int wr = parent_.write(PROTOCOL_CHANNEL,
+ (const unsigned char*) buffer.data(),
+ buffer.size(),
+ ec);
+ return wr > 0;
+}
+
+void
+MultiplexedSocket::Impl::sendVersion()
+{
+ dht::ThreadPool::io().run([w = parent_.weak()]() {
+ if (auto shared = w.lock()) {
+ auto version = shared->pimpl_->version_;
+ msgpack::sbuffer buffer(8);
+ msgpack::packer<msgpack::sbuffer> pk(&buffer);
+ pk.pack(VersionMsg {version});
+ shared->pimpl_->writeProtocolMessage(buffer);
+ }
+ });
+}
+
+void
+MultiplexedSocket::Impl::onVersion(int version)
+{
+ // Check if version > 1
+ if (version >= 1) {
+ if (logger_)
+ logger_->debug("Peer {} supports beacon", deviceId);
+ canSendBeacon_ = true;
+ } else {
+ if (logger_)
+ logger_->warn("Peer {} uses version {:d} which doesn't support beacon",
+ deviceId,
+ version);
+ canSendBeacon_ = false;
+ }
+}
+
+void
+MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
+{
+ auto accept = onRequest_(endpoint->peerCertificate(), channel, name);
+ std::shared_ptr<ChannelSocket> channelSocket;
+ if (accept) {
+ std::lock_guard<std::mutex> lkSockets(socketsMutex);
+ channelSocket = makeSocket(name, channel);
+ }
+
+ // Answer to ChannelRequest if accepted
+ ChannelRequest val;
+ val.channel = channel;
+ val.name = name;
+ val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE;
+ msgpack::sbuffer buffer(512);
+ msgpack::pack(buffer, val);
+ std::error_code ec;
+ int wr = parent_.write(CONTROL_CHANNEL,
+ reinterpret_cast<const uint8_t*>(buffer.data()),
+ buffer.size(),
+ ec);
+ if (wr < 0) {
+ if (ec && logger_)
+ logger_->error("The write operation failed with error: {:s}", ec.message());
+ stop.store(true);
+ return;
+ }
+
+ if (accept) {
+ onChannelReady_(deviceId, channelSocket);
+ channelSocket->ready();
+ }
+}
+
+void
+MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
+{
+ // Run this on dedicated thread because some callbacks can take time
+ dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
+ auto shared = w.lock();
+ if (!shared)
+ return;
+ auto& pimpl = *shared->pimpl_;
+ try {
+ size_t off = 0;
+ while (off != pkt.size()) {
+ msgpack::unpacked result;
+ msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
+ auto object = result.get();
+ if (pimpl.handleProtocolMsg(object))
+ continue;
+ auto req = object.as<ChannelRequest>();
+ if (req.state == ChannelRequestState::ACCEPT) {
+ pimpl.onAccept(req.name, req.channel);
+ } else if (req.state == ChannelRequestState::DECLINE) {
+ std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex);
+ auto channel = pimpl.sockets.find(req.channel);
+ if (channel != pimpl.sockets.end()) {
+ channel->second->stop();
+ pimpl.sockets.erase(channel);
+ }
+ } else if (pimpl.onRequest_) {
+ pimpl.onRequest(req.name, req.channel);
+ }
+ }
+ } catch (const std::exception& e) {
+ if (pimpl.logger_)
+ pimpl.logger_->error("Error on the control channel: {}", e.what());
+ }
+ });
+}
+
+void
+MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt)
+{
+ std::lock_guard<std::mutex> lkSockets(socketsMutex);
+ auto sockIt = sockets.find(channel);
+ if (channel > 0 && sockIt != sockets.end() && sockIt->second) {
+ if (pkt.size() == 0) {
+ sockIt->second->stop();
+ if (sockIt->second->isAnswered())
+ sockets.erase(sockIt);
+ else
+ sockIt->second->removable(); // This means that onAccept didn't happen yet, will be
+ // removed later.
+ } else {
+ sockIt->second->onRecv(std::move(pkt));
+ }
+ } else if (pkt.size() != 0) {
+ if (logger_)
+ logger_->warn("Non existing channel: {}", channel);
+ }
+}
+
+bool
+MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o)
+{
+ try {
+ if (o.type == msgpack::type::MAP && o.via.map.size > 0) {
+ auto key = o.via.map.ptr[0].key.as<std::string_view>();
+ if (key == "p") {
+ auto msg = o.as<BeaconMsg>();
+ if (msg.p)
+ handleBeaconRequest();
+ else
+ handleBeaconResponse();
+ if (onBeaconCb_)
+ onBeaconCb_(msg.p);
+ return true;
+ } else if (key == "v") {
+ auto msg = o.as<VersionMsg>();
+ onVersion(msg.v);
+ if (onVersionCb_)
+ onVersionCb_(msg.v);
+ return true;
+ } else {
+ if (logger_)
+ logger_->warn("Unknown message type");
+ }
+ }
+ } catch (const std::exception& e) {
+ if (logger_)
+ logger_->error("Error on the protocol channel: {}", e.what());
+ }
+ return false;
+}
+
+void
+MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt)
+{
+ // Run this on dedicated thread because some callbacks can take time
+ dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
+ auto shared = w.lock();
+ if (!shared)
+ return;
+ try {
+ size_t off = 0;
+ while (off != pkt.size()) {
+ msgpack::unpacked result;
+ msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
+ auto object = result.get();
+ if (shared->pimpl_->handleProtocolMsg(object))
+ return;
+ }
+ } catch (const std::exception& e) {
+ if (shared->pimpl_->logger_)
+ shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what());
+ }
+ });
+}
+
+MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId,
+ std::unique_ptr<TlsSocketEndpoint> endpoint)
+ : pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint)))
+{}
+
+MultiplexedSocket::~MultiplexedSocket() {}
+
+std::shared_ptr<ChannelSocket>
+MultiplexedSocket::addChannel(const std::string& name)
+{
+ // Note: because both sides can request the same channel number at the same time
+ // it's better to use a random channel number instead of just incrementing the request.
+ thread_local dht::crypto::random_device rd;
+ std::uniform_int_distribution<uint16_t> dist;
+ auto offset = dist(rd);
+ std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
+ for (int i = 1; i < UINT16_MAX; ++i) {
+ auto c = (offset + i) % UINT16_MAX;
+ if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL
+ || pimpl_->sockets.find(c) != pimpl_->sockets.end())
+ continue;
+ auto channel = pimpl_->makeSocket(name, c, true);
+ return channel;
+ }
+ return {};
+}
+
+DeviceId
+MultiplexedSocket::deviceId() const
+{
+ return pimpl_->deviceId;
+}
+
+void
+MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb)
+{
+ pimpl_->onChannelReady_ = std::move(cb);
+}
+
+void
+MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb)
+{
+ pimpl_->onRequest_ = std::move(cb);
+}
+
+bool
+MultiplexedSocket::isReliable() const
+{
+ return true;
+}
+
+bool
+MultiplexedSocket::isInitiator() const
+{
+ if (!pimpl_->endpoint) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->warn("No endpoint found for socket");
+ return false;
+ }
+ return pimpl_->endpoint->isInitiator();
+}
+
+int
+MultiplexedSocket::maxPayload() const
+{
+ if (!pimpl_->endpoint) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->warn("No endpoint found for socket");
+ return 0;
+ }
+ return pimpl_->endpoint->maxPayload();
+}
+
+std::size_t
+MultiplexedSocket::write(const uint16_t& channel,
+ const uint8_t* buf,
+ std::size_t len,
+ std::error_code& ec)
+{
+ assert(nullptr != buf);
+
+ if (pimpl_->isShutdown_) {
+ ec = std::make_error_code(std::errc::broken_pipe);
+ return -1;
+ }
+ if (len > UINT16_MAX) {
+ ec = std::make_error_code(std::errc::message_size);
+ return -1;
+ }
+ bool oneShot = len < 8192;
+ msgpack::sbuffer buffer(oneShot ? 16 + len : 16);
+ msgpack::packer<msgpack::sbuffer> pk(&buffer);
+ pk.pack_array(2);
+ pk.pack(channel);
+ pk.pack_bin(len);
+ if (oneShot)
+ pk.pack_bin_body((const char*) buf, len);
+
+ std::unique_lock<std::mutex> lk(pimpl_->writeMtx);
+ if (!pimpl_->endpoint) {
+ if (pimpl_->logger_)
+ pimpl_->logger_->warn("No endpoint found for socket");
+ ec = std::make_error_code(std::errc::broken_pipe);
+ return -1;
+ }
+ int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec);
+ if (not oneShot and res >= 0)
+ res = pimpl_->endpoint->write(buf, len, ec);
+ lk.unlock();
+ if (res < 0) {
+ if (ec && pimpl_->logger_)
+ pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message());
+ shutdown();
+ }
+ return res;
+}
+
+void
+MultiplexedSocket::shutdown()
+{
+ pimpl_->shutdown();
+}
+
+void
+MultiplexedSocket::join()
+{
+ pimpl_->join();
+}
+
+void
+MultiplexedSocket::onShutdown(OnShutdownCb&& cb)
+{
+ pimpl_->onShutdown_ = std::move(cb);
+ if (pimpl_->isShutdown_)
+ pimpl_->onShutdown_();
+}
+
+const std::shared_ptr<Logger>&
+MultiplexedSocket::logger()
+{
+ return pimpl_->logger_;
+}
+
+void
+MultiplexedSocket::monitor() const
+{
+ auto cert = peerCertificate();
+ if (!cert || !cert->issuer)
+ return;
+ auto now = clock::now();
+ if (!pimpl_->logger_)
+ return;
+ pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId());
+ pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_));
+ pimpl_->endpoint->monitor();
+ std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
+ for (const auto& [_, channel] : pimpl_->sockets) {
+ if (channel)
+ pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}",
+ fmt::ptr(channel.get()),
+ channel.use_count(),
+ channel->name(),
+ channel->isInitiator());
+ }
+}
+
+void
+MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout)
+{
+ pimpl_->sendBeacon(timeout);
+}
+
+std::shared_ptr<dht::crypto::Certificate>
+MultiplexedSocket::peerCertificate() const
+{
+ return pimpl_->endpoint->peerCertificate();
+}
+
+#ifdef LIBJAMI_TESTABLE
+bool
+MultiplexedSocket::canSendBeacon() const
+{
+ return pimpl_->canSendBeacon_;
+}
+
+void
+MultiplexedSocket::answerToBeacon(bool value)
+{
+ pimpl_->answerBeacon_ = value;
+}
+
+void
+MultiplexedSocket::setVersion(int version)
+{
+ pimpl_->version_ = version;
+}
+
+void
+MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb)
+{
+ pimpl_->onBeaconCb_ = cb;
+}
+
+void
+MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb)
+{
+ pimpl_->onVersionCb_ = cb;
+}
+
+void
+MultiplexedSocket::sendVersion()
+{
+ pimpl_->sendVersion();
+}
+
+IpAddr
+MultiplexedSocket::getLocalAddress() const
+{
+ return pimpl_->endpoint->getLocalAddress();
+}
+
+IpAddr
+MultiplexedSocket::getRemoteAddress() const
+{
+ return pimpl_->endpoint->getRemoteAddress();
+}
+
+#endif
+
+void
+MultiplexedSocket::eraseChannel(uint16_t channel)
+{
+ std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
+ auto itSocket = pimpl_->sockets.find(channel);
+ if (pimpl_->sockets.find(channel) != pimpl_->sockets.end())
+ pimpl_->sockets.erase(itSocket);
+}
+
+////////////////////////////////////////////////////////////////
+
+class ChannelSocket::Impl
+{
+public:
+ Impl(std::weak_ptr<MultiplexedSocket> endpoint,
+ const std::string& name,
+ const uint16_t& channel,
+ bool isInitiator,
+ std::function<void()> rmFromMxSockCb)
+ : name(name)
+ , channel(channel)
+ , endpoint(std::move(endpoint))
+ , isInitiator_(isInitiator)
+ , rmFromMxSockCb_(std::move(rmFromMxSockCb))
+ {}
+
+ ~Impl() {}
+
+ ChannelReadyCb readyCb_ {};
+ OnShutdownCb shutdownCb_ {};
+ std::atomic_bool isShutdown_ {false};
+ std::string name {};
+ uint16_t channel {};
+ std::weak_ptr<MultiplexedSocket> endpoint {};
+ bool isInitiator_ {false};
+ std::function<void()> rmFromMxSockCb_;
+
+ bool isAnswered_ {false};
+ bool isRemovable_ {false};
+
+ std::vector<uint8_t> buf {};
+ std::mutex mutex {};
+ std::condition_variable cv {};
+ GenericSocket<uint8_t>::RecvCb cb {};
+};
+
+ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx,
+ const DeviceId& deviceId,
+ const std::string& name,
+ const uint16_t& channel)
+ : pimpl_deviceId(deviceId)
+ , pimpl_name(name)
+ , pimpl_channel(channel)
+ , ioCtx_(*ctx)
+{}
+
+ChannelSocketTest::~ChannelSocketTest() {}
+
+void
+ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1,
+ const std::shared_ptr<ChannelSocketTest>& socket2)
+{
+ socket1->remote = socket2;
+ socket2->remote = socket1;
+}
+
+DeviceId
+ChannelSocketTest::deviceId() const
+{
+ return pimpl_deviceId;
+}
+
+std::string
+ChannelSocketTest::name() const
+{
+ return pimpl_name;
+}
+
+uint16_t
+ChannelSocketTest::channel() const
+{
+ return pimpl_channel;
+}
+
+void
+ChannelSocketTest::shutdown()
+{
+ {
+ std::unique_lock<std::mutex> lk {mutex};
+ if (!isShutdown_.exchange(true)) {
+ lk.unlock();
+ shutdownCb_();
+ }
+ cv.notify_all();
+ }
+
+ if (auto peer = remote.lock()) {
+ if (!peer->isShutdown_.exchange(true)) {
+ peer->shutdownCb_();
+ }
+ peer->cv.notify_all();
+ }
+}
+
+std::size_t
+ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec)
+{
+ std::size_t size = std::min(len, this->rx_buf.size());
+
+ for (std::size_t i = 0; i < size; ++i)
+ buf[i] = this->rx_buf[i];
+
+ if (size == this->rx_buf.size()) {
+ this->rx_buf.clear();
+ } else
+ this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size);
+ return size;
+}
+
+std::size_t
+ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec)
+{
+ if (isShutdown_) {
+ ec = std::make_error_code(std::errc::broken_pipe);
+ return -1;
+ }
+ ec = {};
+ dht::ThreadPool::computation().run(
+ [r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable {
+ if (auto peer = r.lock())
+ peer->onRecv(std::move(data));
+ });
+ return len;
+}
+
+int
+ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
+{
+ std::unique_lock<std::mutex> lk {mutex};
+ cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; });
+ return rx_buf.size();
+}
+
+void
+ChannelSocketTest::setOnRecv(RecvCb&& cb)
+{
+ std::lock_guard<std::mutex> lkSockets(mutex);
+ this->cb = std::move(cb);
+ if (!rx_buf.empty() && this->cb) {
+ this->cb(rx_buf.data(), rx_buf.size());
+ rx_buf.clear();
+ }
+}
+
+void
+ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt)
+{
+ std::lock_guard<std::mutex> lkSockets(mutex);
+ if (cb) {
+ cb(pkt.data(), pkt.size());
+ return;
+ }
+ rx_buf.insert(rx_buf.end(),
+ std::make_move_iterator(pkt.begin()),
+ std::make_move_iterator(pkt.end()));
+ cv.notify_all();
+}
+
+void
+ChannelSocketTest::onReady(ChannelReadyCb&& cb)
+{}
+
+void
+ChannelSocketTest::onShutdown(OnShutdownCb&& cb)
+{
+ std::unique_lock<std::mutex> lk {mutex};
+ shutdownCb_ = std::move(cb);
+
+ if (isShutdown_) {
+ lk.unlock();
+ shutdownCb_();
+ }
+}
+
+ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
+ const std::string& name,
+ const uint16_t& channel,
+ bool isInitiator,
+ std::function<void()> rmFromMxSockCb)
+ : pimpl_ {
+ std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))}
+{}
+
+ChannelSocket::~ChannelSocket() {}
+
+DeviceId
+ChannelSocket::deviceId() const
+{
+ if (auto ep = pimpl_->endpoint.lock()) {
+ return ep->deviceId();
+ }
+ return {};
+}
+
+std::string
+ChannelSocket::name() const
+{
+ return pimpl_->name;
+}
+
+uint16_t
+ChannelSocket::channel() const
+{
+ return pimpl_->channel;
+}
+
+bool
+ChannelSocket::isReliable() const
+{
+ if (auto ep = pimpl_->endpoint.lock()) {
+ return ep->isReliable();
+ }
+ return false;
+}
+
+bool
+ChannelSocket::isInitiator() const
+{
+ // Note. Is initiator here as not the same meaning of MultiplexedSocket.
+ // because a multiplexed socket can have sockets from accepted requests
+ // or made via connectDevice(). Here, isInitiator_ return if the socket
+ // is from connectDevice.
+ return pimpl_->isInitiator_;
+}
+
+int
+ChannelSocket::maxPayload() const
+{
+ if (auto ep = pimpl_->endpoint.lock()) {
+ return ep->maxPayload();
+ }
+ return -1;
+}
+
+void
+ChannelSocket::setOnRecv(RecvCb&& cb)
+{
+ std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
+ pimpl_->cb = std::move(cb);
+ if (!pimpl_->buf.empty() && pimpl_->cb) {
+ pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size());
+ pimpl_->buf.clear();
+ }
+}
+
+void
+ChannelSocket::onRecv(std::vector<uint8_t>&& pkt)
+{
+ std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
+ if (pimpl_->cb) {
+ pimpl_->cb(&pkt[0], pkt.size());
+ return;
+ }
+ pimpl_->buf.insert(pimpl_->buf.end(),
+ std::make_move_iterator(pkt.begin()),
+ std::make_move_iterator(pkt.end()));
+ pimpl_->cv.notify_all();
+}
+
+#ifdef LIBJAMI_TESTABLE
+std::shared_ptr<MultiplexedSocket>
+ChannelSocket::underlyingSocket() const
+{
+ if (auto mtx = pimpl_->endpoint.lock())
+ return mtx;
+ return {};
+}
+#endif
+
+void
+ChannelSocket::answered()
+{
+ pimpl_->isAnswered_ = true;
+}
+
+void
+ChannelSocket::removable()
+{
+ pimpl_->isRemovable_ = true;
+}
+
+bool
+ChannelSocket::isRemovable() const
+{
+ return pimpl_->isRemovable_;
+}
+
+bool
+ChannelSocket::isAnswered() const
+{
+ return pimpl_->isAnswered_;
+}
+
+void
+ChannelSocket::ready()
+{
+ if (pimpl_->readyCb_)
+ pimpl_->readyCb_();
+}
+
+void
+ChannelSocket::stop()
+{
+ if (pimpl_->isShutdown_)
+ return;
+ pimpl_->isShutdown_ = true;
+ if (pimpl_->shutdownCb_)
+ pimpl_->shutdownCb_();
+ pimpl_->cv.notify_all();
+ // stop() can be called by ChannelSocket::shutdown()
+ // In this case, the eventLoop is not used, but MxSock
+ // must remove the channel from its list (so that the
+ // channel can be destroyed and its shared_ptr invalidated).
+ if (pimpl_->rmFromMxSockCb_)
+ pimpl_->rmFromMxSockCb_();
+}
+
+void
+ChannelSocket::shutdown()
+{
+ if (pimpl_->isShutdown_)
+ return;
+ stop();
+ if (auto ep = pimpl_->endpoint.lock()) {
+ std::error_code ec;
+ const uint8_t dummy = '\0';
+ ep->write(pimpl_->channel, &dummy, 0, ec);
+ }
+}
+
+std::size_t
+ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec)
+{
+ std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
+ std::size_t size = std::min(len, pimpl_->buf.size());
+
+ for (std::size_t i = 0; i < size; ++i)
+ outBuf[i] = pimpl_->buf[i];
+
+ pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size);
+ return size;
+}
+
+std::size_t
+ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
+{
+ if (pimpl_->isShutdown_) {
+ ec = std::make_error_code(std::errc::broken_pipe);
+ return -1;
+ }
+ if (auto ep = pimpl_->endpoint.lock()) {
+ std::size_t sent = 0;
+ do {
+ std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent);
+ auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec);
+ if (ec) {
+ if (ep->logger())
+ ep->logger()->error("Error when writing on channel: {}", ec.message());
+ return res;
+ }
+ sent += toSend;
+ } while (sent < len);
+ return sent;
+ }
+ ec = std::make_error_code(std::errc::broken_pipe);
+ return -1;
+}
+
+int
+ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
+{
+ std::unique_lock<std::mutex> lk {pimpl_->mutex};
+ pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; });
+ return pimpl_->buf.size();
+}
+
+void
+ChannelSocket::onShutdown(OnShutdownCb&& cb)
+{
+ pimpl_->shutdownCb_ = std::move(cb);
+ if (pimpl_->isShutdown_) {
+ pimpl_->shutdownCb_();
+ }
+}
+
+void
+ChannelSocket::onReady(ChannelReadyCb&& cb)
+{
+ pimpl_->readyCb_ = std::move(cb);
+}
+
+void
+ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout)
+{
+ if (auto ep = pimpl_->endpoint.lock()) {
+ ep->sendBeacon(timeout);
+ } else {
+ shutdown();
+ }
+}
+
+std::shared_ptr<dht::crypto::Certificate>
+ChannelSocket::peerCertificate() const
+{
+ if (auto ep = pimpl_->endpoint.lock())
+ return ep->peerCertificate();
+ return {};
+}
+
+IpAddr
+ChannelSocket::getLocalAddress() const
+{
+ if (auto ep = pimpl_->endpoint.lock())
+ return ep->getLocalAddress();
+ return {};
+}
+
+IpAddr
+ChannelSocket::getRemoteAddress() const
+{
+ if (auto ep = pimpl_->endpoint.lock())
+ return ep->getRemoteAddress();
+ return {};
+}
+
+} // namespace jami
diff --git a/src/peer_connection.cpp b/src/peer_connection.cpp
new file mode 100644
index 0000000..0b4ede5
--- /dev/null
+++ b/src/peer_connection.cpp
@@ -0,0 +1,452 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com>
+ * Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "peer_connection.h"
+#include "tls_session.h"
+
+#include <opendht/thread_pool.h>
+#include <opendht/logger.h>
+
+#include <algorithm>
+#include <chrono>
+#include <future>
+#include <vector>
+#include <atomic>
+#include <stdexcept>
+#include <istream>
+#include <ostream>
+#include <unistd.h>
+#include <cstdio>
+
+#ifdef _WIN32
+#include <winsock2.h>
+#include <ws2tcpip.h>
+#else
+#include <sys/select.h>
+#endif
+
+#ifndef _MSC_VER
+#include <sys/time.h>
+#endif
+
+static constexpr int ICE_COMP_ID_SIP_TRANSPORT {1};
+
+namespace jami {
+
+int
+init_crt(gnutls_session_t session, dht::crypto::Certificate& crt)
+{
+ // Support only x509 format
+ if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509) {
+ return GNUTLS_E_CERTIFICATE_ERROR;
+ }
+
+ // Store verification status
+ unsigned int status = 0;
+ auto ret = gnutls_certificate_verify_peers2(session, &status);
+ if (ret < 0 or (status & GNUTLS_CERT_SIGNATURE_FAILURE) != 0) {
+ return GNUTLS_E_CERTIFICATE_ERROR;
+ }
+
+ unsigned int cert_list_size = 0;
+ auto cert_list = gnutls_certificate_get_peers(session, &cert_list_size);
+ if (cert_list == nullptr) {
+ return GNUTLS_E_CERTIFICATE_ERROR;
+ }
+
+ // Check if received peer certificate is awaited
+ std::vector<std::pair<uint8_t*, uint8_t*>> crt_data;
+ crt_data.reserve(cert_list_size);
+ for (unsigned i = 0; i < cert_list_size; i++)
+ crt_data.emplace_back(cert_list[i].data, cert_list[i].data + cert_list[i].size);
+ crt = dht::crypto::Certificate {crt_data};
+
+ return GNUTLS_E_SUCCESS;
+}
+
+using lock = std::lock_guard<std::mutex>;
+
+//==============================================================================
+
+IceSocketEndpoint::IceSocketEndpoint(std::shared_ptr<IceTransport> ice, bool isSender)
+ : ice_(std::move(ice))
+ , iceIsSender(isSender)
+{}
+
+IceSocketEndpoint::~IceSocketEndpoint()
+{
+ shutdown();
+ if (ice_)
+ dht::ThreadPool::io().run([ice = std::move(ice_)] {});
+}
+
+void
+IceSocketEndpoint::shutdown()
+{
+ // Sometimes the other peer never send any packet
+ // So, we cancel pending read to avoid to have
+ // any blocking operation.
+ if (ice_)
+ ice_->cancelOperations();
+}
+
+int
+IceSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
+{
+ if (ice_) {
+ if (!ice_->isRunning())
+ return -1;
+ return ice_->waitForData(compId_, timeout, ec);
+ }
+ return -1;
+}
+
+std::size_t
+IceSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec)
+{
+ if (ice_) {
+ if (!ice_->isRunning())
+ return 0;
+ try {
+ auto res = ice_->recvfrom(compId_, reinterpret_cast<char*>(buf), len, ec);
+ if (res < 0)
+ shutdown();
+ return res;
+ } catch (const std::exception& e) {
+ if (auto logger = ice_->logger())
+ logger->error("IceSocketEndpoint::read exception: %s", e.what());
+ }
+ return 0;
+ }
+ return -1;
+}
+
+std::size_t
+IceSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec)
+{
+ if (ice_) {
+ if (!ice_->isRunning())
+ return 0;
+ auto res = 0;
+ res = ice_->send(compId_, reinterpret_cast<const unsigned char*>(buf), len);
+ if (res < 0) {
+ ec.assign(errno, std::generic_category());
+ shutdown();
+ } else {
+ ec.clear();
+ }
+ return res;
+ }
+ return -1;
+}
+
+//==============================================================================
+
+class TlsSocketEndpoint::Impl
+{
+public:
+ static constexpr auto TLS_TIMEOUT = std::chrono::seconds(40);
+
+ Impl(std::unique_ptr<IceSocketEndpoint>&& ep,
+ tls::CertificateStore& certStore,
+ const dht::crypto::Certificate& peer_cert,
+ const Identity& local_identity,
+ const std::shared_future<tls::DhParams>& dh_params)
+ : peerCertificate {peer_cert}
+ , ep_ {ep.get()}
+ {
+ tls::TlsSession::TlsSessionCallbacks tls_cbs
+ = {/*.onStateChange = */ [this](tls::TlsSessionState state) { onTlsStateChange(state); },
+ /*.onRxData = */ [this](std::vector<uint8_t>&& buf) { onTlsRxData(std::move(buf)); },
+ /*.onCertificatesUpdate = */
+ [this](const gnutls_datum_t* l, const gnutls_datum_t* r, unsigned int n) {
+ onTlsCertificatesUpdate(l, r, n);
+ },
+ /*.verifyCertificate = */
+ [this](gnutls_session_t session) {
+ return verifyCertificate(session);
+ }};
+ tls::TlsParams tls_param = {
+ /*.ca_list = */ "",
+ /*.peer_ca = */ nullptr,
+ /*.cert = */ local_identity.second,
+ /*.cert_key = */ local_identity.first,
+ /*.dh_params = */ dh_params,
+ /*.certStore = */ certStore,
+ /*.timeout = */ TLS_TIMEOUT,
+ /*.cert_check = */ nullptr,
+ };
+ tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs);
+ }
+
+ Impl(std::unique_ptr<IceSocketEndpoint>&& ep,
+ tls::CertificateStore& certStore,
+ std::function<bool(const dht::crypto::Certificate&)>&& cert_check,
+ const Identity& local_identity,
+ const std::shared_future<tls::DhParams>& dh_params)
+ : peerCertificateCheckFunc {std::move(cert_check)}
+ , peerCertificate {null_cert}
+ , ep_ {ep.get()}
+ {
+ tls::TlsSession::TlsSessionCallbacks tls_cbs
+ = {/*.onStateChange = */ [this](tls::TlsSessionState state) { onTlsStateChange(state); },
+ /*.onRxData = */ [this](std::vector<uint8_t>&& buf) { onTlsRxData(std::move(buf)); },
+ /*.onCertificatesUpdate = */
+ [this](const gnutls_datum_t* l, const gnutls_datum_t* r, unsigned int n) {
+ onTlsCertificatesUpdate(l, r, n);
+ },
+ /*.verifyCertificate = */
+ [this](gnutls_session_t session) {
+ return verifyCertificate(session);
+ }};
+ tls::TlsParams tls_param = {
+ /*.ca_list = */ "",
+ /*.peer_ca = */ nullptr,
+ /*.cert = */ local_identity.second,
+ /*.cert_key = */ local_identity.first,
+ /*.dh_params = */ dh_params,
+ /*.certStore = */ certStore,
+ /*.timeout = */ std::chrono::duration_cast<decltype(tls::TlsParams::timeout)>(TLS_TIMEOUT),
+ /*.cert_check = */ nullptr,
+ };
+ tls = std::make_unique<tls::TlsSession>(std::move(ep), tls_param, tls_cbs);
+ }
+
+ ~Impl()
+ {
+ {
+ std::lock_guard<std::mutex> lk(cbMtx_);
+ onStateChangeCb_ = {};
+ onReadyCb_ = {};
+ }
+ tls.reset();
+ }
+
+ std::shared_ptr<IceTransport> underlyingICE() const
+ {
+ if (ep_)
+ if (const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(ep_))
+ return iceSocket->underlyingICE();
+ return {};
+ }
+
+ // TLS callbacks
+ int verifyCertificate(gnutls_session_t);
+ void onTlsStateChange(tls::TlsSessionState);
+ void onTlsRxData(std::vector<uint8_t>&&);
+ void onTlsCertificatesUpdate(const gnutls_datum_t*, const gnutls_datum_t*, unsigned int);
+
+ std::mutex cbMtx_ {};
+ OnStateChangeCb onStateChangeCb_;
+ dht::crypto::Certificate null_cert;
+ std::function<bool(const dht::crypto::Certificate&)> peerCertificateCheckFunc;
+ const dht::crypto::Certificate& peerCertificate;
+ std::atomic_bool isReady_ {false};
+ OnReadyCb onReadyCb_;
+ std::unique_ptr<tls::TlsSession> tls;
+ const IceSocketEndpoint* ep_;
+};
+
+int
+TlsSocketEndpoint::Impl::verifyCertificate(gnutls_session_t session)
+{
+ dht::crypto::Certificate crt;
+ auto verified = init_crt(session, crt);
+ if (verified != GNUTLS_E_SUCCESS)
+ return verified;
+ if (peerCertificateCheckFunc) {
+ if (!peerCertificateCheckFunc(crt)) {
+ if (const auto& logger = tls->logger())
+ logger->error("[TLS-SOCKET] Refusing peer certificate");
+ return GNUTLS_E_CERTIFICATE_ERROR;
+ }
+
+ null_cert = std::move(crt);
+ } else {
+ if (crt.getPacked() != peerCertificate.getPacked()) {
+ if (const auto& logger = tls->logger())
+ logger->error("[TLS-SOCKET] Unexpected peer certificate");
+ return GNUTLS_E_CERTIFICATE_ERROR;
+ }
+ }
+
+ return GNUTLS_E_SUCCESS;
+}
+
+void
+TlsSocketEndpoint::Impl::onTlsStateChange(tls::TlsSessionState state)
+{
+ std::lock_guard<std::mutex> lk(cbMtx_);
+ if ((state == tls::TlsSessionState::SHUTDOWN || state == tls::TlsSessionState::ESTABLISHED)
+ && !isReady_) {
+ isReady_ = true;
+ if (onReadyCb_)
+ onReadyCb_(state == tls::TlsSessionState::ESTABLISHED);
+ }
+ if (onStateChangeCb_ && !onStateChangeCb_(state))
+ onStateChangeCb_ = {};
+}
+
+void
+TlsSocketEndpoint::Impl::onTlsRxData([[maybe_unused]] std::vector<uint8_t>&& buf)
+{}
+
+void
+TlsSocketEndpoint::Impl::onTlsCertificatesUpdate([[maybe_unused]] const gnutls_datum_t* local_raw,
+ [[maybe_unused]] const gnutls_datum_t* remote_raw,
+ [[maybe_unused]] unsigned int remote_count)
+{}
+
+TlsSocketEndpoint::TlsSocketEndpoint(std::unique_ptr<IceSocketEndpoint>&& tr,
+ tls::CertificateStore& certStore,
+ const Identity& local_identity,
+ const std::shared_future<tls::DhParams>& dh_params,
+ const dht::crypto::Certificate& peer_cert)
+ : pimpl_ {std::make_unique<Impl>(std::move(tr), certStore, peer_cert, local_identity, dh_params)}
+{}
+
+TlsSocketEndpoint::TlsSocketEndpoint(
+ std::unique_ptr<IceSocketEndpoint>&& tr,
+ tls::CertificateStore& certStore,
+ const Identity& local_identity,
+ const std::shared_future<tls::DhParams>& dh_params,
+ std::function<bool(const dht::crypto::Certificate&)>&& cert_check)
+ : pimpl_ {
+ std::make_unique<Impl>(std::move(tr), certStore, std::move(cert_check), local_identity, dh_params)}
+{}
+
+TlsSocketEndpoint::~TlsSocketEndpoint() {}
+
+bool
+TlsSocketEndpoint::isInitiator() const
+{
+ if (!pimpl_->tls) {
+ return false;
+ }
+ return pimpl_->tls->isInitiator();
+}
+
+int
+TlsSocketEndpoint::maxPayload() const
+{
+ if (!pimpl_->tls) {
+ return -1;
+ }
+ return pimpl_->tls->maxPayload();
+}
+
+std::size_t
+TlsSocketEndpoint::read(ValueType* buf, std::size_t len, std::error_code& ec)
+{
+ if (!pimpl_->tls) {
+ ec = std::make_error_code(std::errc::broken_pipe);
+ return -1;
+ }
+ return pimpl_->tls->read(buf, len, ec);
+}
+
+std::size_t
+TlsSocketEndpoint::write(const ValueType* buf, std::size_t len, std::error_code& ec)
+{
+ if (!pimpl_->tls) {
+ ec = std::make_error_code(std::errc::broken_pipe);
+ return -1;
+ }
+ return pimpl_->tls->write(buf, len, ec);
+}
+
+std::shared_ptr<dht::crypto::Certificate>
+TlsSocketEndpoint::peerCertificate() const
+{
+ if (!pimpl_->tls)
+ return {};
+ return pimpl_->tls->peerCertificate();
+}
+
+void
+TlsSocketEndpoint::waitForReady(const std::chrono::milliseconds& timeout)
+{
+ if (!pimpl_->tls) {
+ return;
+ }
+ pimpl_->tls->waitForReady(timeout);
+}
+
+int
+TlsSocketEndpoint::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
+{
+ if (!pimpl_->tls) {
+ ec = std::make_error_code(std::errc::broken_pipe);
+ return -1;
+ }
+ return pimpl_->tls->waitForData(timeout, ec);
+}
+
+void
+TlsSocketEndpoint::setOnStateChange(std::function<bool(tls::TlsSessionState state)>&& cb)
+{
+ std::lock_guard<std::mutex> lk(pimpl_->cbMtx_);
+ pimpl_->onStateChangeCb_ = std::move(cb);
+}
+
+void
+TlsSocketEndpoint::setOnReady(std::function<void(bool ok)>&& cb)
+{
+ std::lock_guard<std::mutex> lk(pimpl_->cbMtx_);
+ pimpl_->onReadyCb_ = std::move(cb);
+}
+
+void
+TlsSocketEndpoint::shutdown()
+{
+ pimpl_->tls->shutdown();
+ if (pimpl_->ep_) {
+ const auto* iceSocket = reinterpret_cast<const IceSocketEndpoint*>(pimpl_->ep_);
+ if (iceSocket && iceSocket->underlyingICE())
+ iceSocket->underlyingICE()->cancelOperations();
+ }
+}
+
+void
+TlsSocketEndpoint::monitor() const
+{
+ if (auto ice = pimpl_->underlyingICE())
+ if (auto logger = ice->logger())
+ logger->debug("\t- Ice connection: {}", ice->link());
+}
+
+IpAddr
+TlsSocketEndpoint::getLocalAddress() const
+{
+ if (auto ice = pimpl_->underlyingICE())
+ return ice->getLocalAddress(ICE_COMP_ID_SIP_TRANSPORT);
+ return {};
+}
+
+IpAddr
+TlsSocketEndpoint::getRemoteAddress() const
+{
+ if (auto ice = pimpl_->underlyingICE())
+ return ice->getRemoteAddress(ICE_COMP_ID_SIP_TRANSPORT);
+ return {};
+}
+
+} // namespace jami
diff --git a/src/peer_connection.h b/src/peer_connection.h
new file mode 100644
index 0000000..3798f0c
--- /dev/null
+++ b/src/peer_connection.h
@@ -0,0 +1,142 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com>
+ * Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+
+#include "ip_utils.h"
+#include "certstore.h"
+#include "opendht/crypto.h"
+#include "ice_transport.h"
+#include "tls_session.h"
+
+#include <functional>
+#include <future>
+#include <limits>
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace dht {
+namespace crypto {
+struct PrivateKey;
+struct Certificate;
+} // namespace crypto
+} // namespace dht
+
+namespace jami {
+namespace tls {
+class DhParams;
+}
+
+using OnStateChangeCb = std::function<bool(tls::TlsSessionState state)>;
+using OnReadyCb = std::function<void(bool ok)>;
+using onShutdownCb = std::function<void(void)>;
+
+//==============================================================================
+
+class IceSocketEndpoint : public GenericSocket<uint8_t>
+{
+public:
+ using SocketType = GenericSocket<uint8_t>;
+ explicit IceSocketEndpoint(std::shared_ptr<IceTransport> ice, bool isSender);
+ ~IceSocketEndpoint();
+
+ void shutdown() override;
+ bool isReliable() const override { return ice_ ? ice_->isRunning() : false; }
+ bool isInitiator() const override { return ice_ ? ice_->isInitiator() : true; }
+ int maxPayload() const override
+ {
+ return 65536 /* The max for a RTP packet used to wrap data here */;
+ }
+ int waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const override;
+ std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override;
+ std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override;
+
+ std::shared_ptr<IceTransport> underlyingICE() const { return ice_; }
+
+ void setOnRecv(RecvCb&& cb) override
+ {
+ if (ice_)
+ ice_->setOnRecv(compId_, cb);
+ }
+
+private:
+ std::shared_ptr<IceTransport> ice_ {nullptr};
+ std::atomic_bool iceStopped {false};
+ std::atomic_bool iceIsSender {false};
+ uint8_t compId_ {1};
+};
+
+//==============================================================================
+
+/// Implement a TLS session IO over a system socket
+class TlsSocketEndpoint : public GenericSocket<uint8_t>
+{
+public:
+ using SocketType = GenericSocket<uint8_t>;
+ using Identity = std::pair<std::shared_ptr<dht::crypto::PrivateKey>,
+ std::shared_ptr<dht::crypto::Certificate>>;
+
+ TlsSocketEndpoint(std::unique_ptr<IceSocketEndpoint>&& tr,
+ tls::CertificateStore& certStore,
+ const Identity& local_identity,
+ const std::shared_future<tls::DhParams>& dh_params,
+ const dht::crypto::Certificate& peer_cert);
+ TlsSocketEndpoint(std::unique_ptr<IceSocketEndpoint>&& tr,
+ tls::CertificateStore& certStore,
+ const Identity& local_identity,
+ const std::shared_future<tls::DhParams>& dh_params,
+ std::function<bool(const dht::crypto::Certificate&)>&& cert_check);
+ ~TlsSocketEndpoint();
+
+ bool isReliable() const override { return true; }
+ bool isInitiator() const override;
+ int maxPayload() const override;
+ void shutdown() override;
+ std::size_t read(ValueType* buf, std::size_t len, std::error_code& ec) override;
+ std::size_t write(const ValueType* buf, std::size_t len, std::error_code& ec) override;
+
+ std::shared_ptr<dht::crypto::Certificate> peerCertificate() const;
+
+ void setOnRecv(RecvCb&&) override
+ {
+ throw std::logic_error("TlsSocketEndpoint::setOnRecv not implemented");
+ }
+ int waitForData(std::chrono::milliseconds timeout, std::error_code&) const override;
+
+ void waitForReady(const std::chrono::milliseconds& timeout = {});
+
+ void setOnStateChange(OnStateChangeCb&& cb);
+ void setOnReady(OnReadyCb&& cb);
+
+ IpAddr getLocalAddress() const;
+ IpAddr getRemoteAddress() const;
+
+ void monitor() const;
+
+private:
+ class Impl;
+ std::unique_ptr<Impl> pimpl_;
+};
+
+} // namespace jami
diff --git a/src/security/certstore.cpp b/src/security/certstore.cpp
new file mode 100644
index 0000000..acaa07d
--- /dev/null
+++ b/src/security/certstore.cpp
@@ -0,0 +1,673 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com>
+ * Author: Vsevolod Ivanov <vsevolod.ivanov@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "certstore.h"
+#include "security_const.h"
+
+#include "fileutils.h"
+
+#include <opendht/thread_pool.h>
+#include <opendht/logger.h>
+
+#include <gnutls/ocsp.h>
+
+#include <thread>
+#include <sstream>
+#include <fmt/format.h>
+
+namespace jami {
+namespace tls {
+
+CertificateStore::CertificateStore(const std::string& path, std::shared_ptr<Logger> logger)
+ : logger_(std::move(logger))
+ , certPath_(fmt::format("{}/certificates", path))
+ , crlPath_(fmt::format("{}/crls", path))
+ , ocspPath_(fmt::format("{}/oscp", path))
+{
+ fileutils::check_dir(certPath_.c_str());
+ fileutils::check_dir(crlPath_.c_str());
+ fileutils::check_dir(ocspPath_.c_str());
+ loadLocalCertificates();
+}
+
+unsigned
+CertificateStore::loadLocalCertificates()
+{
+ std::lock_guard<std::mutex> l(lock_);
+
+ auto dir_content = fileutils::readDirectory(certPath_);
+ unsigned n = 0;
+ for (const auto& f : dir_content) {
+ try {
+ auto crt = std::make_shared<crypto::Certificate>(
+ fileutils::loadFile(certPath_ + DIR_SEPARATOR_CH + f));
+ auto id = crt->getId().toString();
+ auto longId = crt->getLongId().toString();
+ if (id != f && longId != f)
+ throw std::logic_error("Certificate id mismatch");
+ while (crt) {
+ id = crt->getId().toString();
+ longId = crt->getLongId().toString();
+ certs_.emplace(std::move(id), crt);
+ certs_.emplace(std::move(longId), crt);
+ loadRevocations(*crt);
+ crt = crt->issuer;
+ ++n;
+ }
+ } catch (const std::exception& e) {
+ if (logger_)
+ logger_->warn("Remove cert. {}", e.what());
+ remove(fmt::format("{}/{}", certPath_, f).c_str());
+ }
+ }
+ if (logger_)
+ logger_->debug("CertificateStore: loaded {} local certificates.", n);
+ return n;
+}
+
+void
+CertificateStore::loadRevocations(crypto::Certificate& crt) const
+{
+ auto dir = fmt::format("{:s}/{:s}", crlPath_, crt.getId().toString());
+ for (const auto& crl : fileutils::readDirectory(dir)) {
+ try {
+ crt.addRevocationList(std::make_shared<crypto::RevocationList>(
+ fileutils::loadFile(fmt::format("{}/{}", dir, crl))));
+ } catch (const std::exception& e) {
+ if (logger_)
+ logger_->warn("Can't load revocation list: %s", e.what());
+ }
+ }
+ auto ocsp_dir = ocspPath_ + DIR_SEPARATOR_CH + crt.getId().toString();
+ for (const auto& ocsp : fileutils::readDirectory(ocsp_dir)) {
+ try {
+ auto ocsp_filepath = fmt::format("{}/{}", ocsp_dir, ocsp);
+ if (logger_) logger_->debug("Found {:s}", ocsp_filepath);
+ auto serial = crt.getSerialNumber();
+ if (dht::toHex(serial.data(), serial.size()) != ocsp)
+ continue;
+ // Save the response
+ auto ocspBlob = fileutils::loadFile(ocsp_filepath);
+ crt.ocspResponse = std::make_shared<dht::crypto::OcspResponse>(ocspBlob.data(),
+ ocspBlob.size());
+ unsigned int status = crt.ocspResponse->getCertificateStatus();
+ if (status == GNUTLS_OCSP_CERT_GOOD) {
+ if (logger_) logger_->debug("Certificate {:s} has good OCSP status", crt.getId());
+ } else if (status == GNUTLS_OCSP_CERT_REVOKED) {
+ if (logger_) logger_->error("Certificate {:s} has revoked OCSP status", crt.getId());
+ } else if (status == GNUTLS_OCSP_CERT_UNKNOWN) {
+ if (logger_) logger_->error("Certificate {:s} has unknown OCSP status", crt.getId());
+ } else {
+ if (logger_) logger_->error("Certificate {:s} has invalid OCSP status", crt.getId());
+ }
+ } catch (const std::exception& e) {
+ if (logger_)
+ logger_->warn("Can't load OCSP revocation status: {:s}", e.what());
+ }
+ }
+}
+
+std::vector<std::string>
+CertificateStore::getPinnedCertificates() const
+{
+ std::lock_guard<std::mutex> l(lock_);
+
+ std::vector<std::string> certIds;
+ certIds.reserve(certs_.size());
+ for (const auto& crt : certs_)
+ certIds.emplace_back(crt.first);
+ return certIds;
+}
+
+std::shared_ptr<crypto::Certificate>
+CertificateStore::getCertificate(const std::string& k)
+{
+ auto getCertificate_ = [this](const std::string& k) -> std::shared_ptr<crypto::Certificate> {
+ auto cit = certs_.find(k);
+ if (cit == certs_.cend())
+ return {};
+ return cit->second;
+ };
+ std::unique_lock<std::mutex> l(lock_);
+ auto crt = getCertificate_(k);
+ // Check if certificate is complete
+ // If the certificate has been splitted, reconstruct it
+ auto top_issuer = crt;
+ while (top_issuer && top_issuer->getUID() != top_issuer->getIssuerUID()) {
+ if (top_issuer->issuer) {
+ top_issuer = top_issuer->issuer;
+ } else if (auto cert = getCertificate_(top_issuer->getIssuerUID())) {
+ top_issuer->issuer = cert;
+ top_issuer = cert;
+ } else {
+ // In this case, a certificate was not found
+ if (logger_)
+ logger_->warn("Incomplete certificate detected {:s}", k);
+ break;
+ }
+ }
+ return crt;
+}
+
+std::shared_ptr<crypto::Certificate>
+CertificateStore::getCertificateLegacy(const std::string& dataDir, const std::string& k)
+{
+ auto oldPath = fmt::format("{}/certificates/{}", dataDir, k);
+ if (fileutils::isFile(oldPath)) {
+ auto crt = std::make_shared<crypto::Certificate>(oldPath);
+ pinCertificate(crt, true);
+ return crt;
+ }
+ return {};
+}
+
+std::shared_ptr<crypto::Certificate>
+CertificateStore::findCertificateByName(const std::string& name, crypto::NameType type) const
+{
+ std::unique_lock<std::mutex> l(lock_);
+ for (auto& i : certs_) {
+ if (i.second->getName() == name)
+ return i.second;
+ if (type != crypto::NameType::UNKNOWN) {
+ for (const auto& alt : i.second->getAltNames())
+ if (alt.first == type and alt.second == name)
+ return i.second;
+ }
+ }
+ return {};
+}
+
+std::shared_ptr<crypto::Certificate>
+CertificateStore::findCertificateByUID(const std::string& uid) const
+{
+ std::unique_lock<std::mutex> l(lock_);
+ for (auto& i : certs_) {
+ if (i.second->getUID() == uid)
+ return i.second;
+ }
+ return {};
+}
+
+std::shared_ptr<crypto::Certificate>
+CertificateStore::findIssuer(const std::shared_ptr<crypto::Certificate>& crt) const
+{
+ std::shared_ptr<crypto::Certificate> ret {};
+ auto n = crt->getIssuerUID();
+ if (not n.empty()) {
+ if (crt->issuer and crt->issuer->getUID() == n)
+ ret = crt->issuer;
+ else
+ ret = findCertificateByUID(n);
+ }
+ if (not ret) {
+ n = crt->getIssuerName();
+ if (not n.empty())
+ ret = findCertificateByName(n);
+ }
+ if (not ret)
+ return ret;
+ unsigned verify_out = 0;
+ int err = gnutls_x509_crt_verify(crt->cert, &ret->cert, 1, 0, &verify_out);
+ if (err != GNUTLS_E_SUCCESS) {
+ if (logger_)
+ logger_->warn("gnutls_x509_crt_verify failed: {:s}", gnutls_strerror(err));
+ return {};
+ }
+ if (verify_out & GNUTLS_CERT_INVALID)
+ return {};
+ return ret;
+}
+
+static std::vector<crypto::Certificate>
+readCertificates(const std::string& path, const std::string& crl_path)
+{
+ std::vector<crypto::Certificate> ret;
+ if (fileutils::isDirectory(path)) {
+ auto files = fileutils::readDirectory(path);
+ for (const auto& file : files) {
+ auto certs = readCertificates(fmt::format("{}/{}", path, file), crl_path);
+ ret.insert(std::end(ret),
+ std::make_move_iterator(std::begin(certs)),
+ std::make_move_iterator(std::end(certs)));
+ }
+ } else {
+ try {
+ auto data = fileutils::loadFile(path);
+ const gnutls_datum_t dt {data.data(), (unsigned) data.size()};
+ gnutls_x509_crt_t* certs {nullptr};
+ unsigned cert_num {0};
+ gnutls_x509_crt_list_import2(&certs, &cert_num, &dt, GNUTLS_X509_FMT_PEM, 0);
+ for (unsigned i = 0; i < cert_num; i++)
+ ret.emplace_back(certs[i]);
+ gnutls_free(certs);
+ } catch (const std::exception& e) {
+ };
+ }
+ return ret;
+}
+
+void
+CertificateStore::pinCertificatePath(const std::string& path,
+ std::function<void(const std::vector<std::string>&)> cb)
+{
+ dht::ThreadPool::computation().run([&, path, cb]() {
+ auto certs = readCertificates(path, crlPath_);
+ std::vector<std::string> ids;
+ std::vector<std::weak_ptr<crypto::Certificate>> scerts;
+ ids.reserve(certs.size());
+ scerts.reserve(certs.size());
+ {
+ std::lock_guard<std::mutex> l(lock_);
+
+ for (auto& cert : certs) {
+ auto shared = std::make_shared<crypto::Certificate>(std::move(cert));
+ scerts.emplace_back(shared);
+ auto e = certs_.emplace(shared->getId().toString(), shared);
+ ids.emplace_back(e.first->first);
+ e = certs_.emplace(shared->getLongId().toString(), shared);
+ ids.emplace_back(e.first->first);
+ }
+ paths_.emplace(path, std::move(scerts));
+ }
+ if (logger_) logger_->d("CertificateStore: loaded %zu certificates from %s.", certs.size(), path.c_str());
+ if (cb)
+ cb(ids);
+ //emitSignal<libjami::ConfigurationSignal::CertificatePathPinned>(path, ids);
+ });
+}
+
+unsigned
+CertificateStore::unpinCertificatePath(const std::string& path)
+{
+ std::lock_guard<std::mutex> l(lock_);
+
+ auto certs = paths_.find(path);
+ if (certs == std::end(paths_))
+ return 0;
+ unsigned n = 0;
+ for (const auto& wcert : certs->second) {
+ if (auto cert = wcert.lock()) {
+ certs_.erase(cert->getId().toString());
+ ++n;
+ }
+ }
+ paths_.erase(certs);
+ return n;
+}
+
+std::vector<std::string>
+CertificateStore::pinCertificate(const std::vector<uint8_t>& cert, bool local) noexcept
+{
+ try {
+ return pinCertificate(crypto::Certificate(cert), local);
+ } catch (const std::exception& e) {
+ }
+ return {};
+}
+
+std::vector<std::string>
+CertificateStore::pinCertificate(crypto::Certificate&& cert, bool local)
+{
+ return pinCertificate(std::make_shared<crypto::Certificate>(std::move(cert)), local);
+}
+
+std::vector<std::string>
+CertificateStore::pinCertificate(const std::shared_ptr<crypto::Certificate>& cert, bool local)
+{
+ bool sig {false};
+ std::vector<std::string> ids {};
+ {
+ auto c = cert;
+ std::lock_guard<std::mutex> l(lock_);
+ while (c) {
+ bool inserted;
+ auto id = c->getId().toString();
+ auto longId = c->getLongId().toString();
+ decltype(certs_)::iterator it;
+ std::tie(it, inserted) = certs_.emplace(id, c);
+ if (not inserted)
+ it->second = c;
+ std::tie(it, inserted) = certs_.emplace(longId, c);
+ if (not inserted)
+ it->second = c;
+ if (local) {
+ for (const auto& crl : c->getRevocationLists())
+ pinRevocationList(id, *crl);
+ }
+ ids.emplace_back(longId);
+ ids.emplace_back(id);
+ c = c->issuer;
+ sig |= inserted;
+ }
+ if (local) {
+ if (sig)
+ fileutils::saveFile(certPath_ + DIR_SEPARATOR_CH + ids.front(), cert->getPacked());
+ }
+ }
+ //for (const auto& id : ids)
+ // emitSignal<libjami::ConfigurationSignal::CertificatePinned>(id);
+ return ids;
+}
+
+bool
+CertificateStore::unpinCertificate(const std::string& id)
+{
+ std::lock_guard<std::mutex> l(lock_);
+
+ certs_.erase(id);
+ return remove((certPath_ + DIR_SEPARATOR_CH + id).c_str()) == 0;
+}
+
+bool
+CertificateStore::setTrustedCertificate(const std::string& id, TrustStatus status)
+{
+ if (status == TrustStatus::TRUSTED) {
+ if (auto crt = getCertificate(id)) {
+ trustedCerts_.emplace_back(crt);
+ return true;
+ }
+ } else {
+ auto tc = std::find_if(trustedCerts_.begin(),
+ trustedCerts_.end(),
+ [&](const std::shared_ptr<crypto::Certificate>& crt) {
+ return crt->getId().toString() == id;
+ });
+ if (tc != trustedCerts_.end()) {
+ trustedCerts_.erase(tc);
+ return true;
+ }
+ }
+ return false;
+}
+
+std::vector<gnutls_x509_crt_t>
+CertificateStore::getTrustedCertificates() const
+{
+ std::vector<gnutls_x509_crt_t> crts;
+ crts.reserve(trustedCerts_.size());
+ for (auto& crt : trustedCerts_)
+ crts.emplace_back(crt->getCopy());
+ return crts;
+}
+
+void
+CertificateStore::pinRevocationList(const std::string& id,
+ const std::shared_ptr<dht::crypto::RevocationList>& crl)
+{
+ try {
+ if (auto c = getCertificate(id))
+ c->addRevocationList(crl);
+ pinRevocationList(id, *crl);
+ } catch (...) {
+ if (logger_)
+ logger_->warn("Can't add revocation list");
+ }
+}
+
+void
+CertificateStore::pinRevocationList(const std::string& id, const dht::crypto::RevocationList& crl)
+{
+ fileutils::check_dir((crlPath_ + DIR_SEPARATOR_CH + id).c_str());
+ fileutils::saveFile(crlPath_ + DIR_SEPARATOR_CH + id + DIR_SEPARATOR_CH
+ + dht::toHex(crl.getNumber()),
+ crl.getPacked());
+}
+
+void
+CertificateStore::pinOcspResponse(const dht::crypto::Certificate& cert)
+{
+ if (not cert.ocspResponse)
+ return;
+ try {
+ cert.ocspResponse->getCertificateStatus();
+ } catch (dht::crypto::CryptoException& e) {
+ if (logger_) logger_->error("Failed to read certificate status of OCSP response: {:s}", e.what());
+ return;
+ }
+ auto id = cert.getId().toString();
+ auto serial = cert.getSerialNumber();
+ auto serialhex = dht::toHex(serial);
+ auto dir = ocspPath_ + DIR_SEPARATOR_CH + id;
+
+ if (auto localCert = getCertificate(id)) {
+ // Update certificate in the local store if relevant
+ if (localCert.get() != &cert && serial == localCert->getSerialNumber()) {
+ if (logger_) logger_->d("Updating OCSP for certificate %s in the local store", id.c_str());
+ localCert->ocspResponse = cert.ocspResponse;
+ }
+ }
+
+ dht::ThreadPool::io().run([l=logger_,
+ path = dir + DIR_SEPARATOR_CH + serialhex,
+ dir = std::move(dir),
+ id = std::move(id),
+ serialhex = std::move(serialhex),
+ ocspResponse = cert.ocspResponse] {
+ if (l) l->d("Saving OCSP Response of device %s with serial %s", id.c_str(), serialhex.c_str());
+ std::lock_guard<std::mutex> lock(fileutils::getFileLock(path));
+ fileutils::check_dir(dir.c_str());
+ fileutils::saveFile(path, ocspResponse->pack());
+ });
+}
+
+TrustStore::PermissionStatus
+TrustStore::statusFromStr(const char* str)
+{
+ if (!std::strcmp(str, libjami::Certificate::Status::ALLOWED))
+ return PermissionStatus::ALLOWED;
+ if (!std::strcmp(str, libjami::Certificate::Status::BANNED))
+ return PermissionStatus::BANNED;
+ return PermissionStatus::UNDEFINED;
+}
+
+const char*
+TrustStore::statusToStr(TrustStore::PermissionStatus s)
+{
+ switch (s) {
+ case PermissionStatus::ALLOWED:
+ return libjami::Certificate::Status::ALLOWED;
+ case PermissionStatus::BANNED:
+ return libjami::Certificate::Status::BANNED;
+ case PermissionStatus::UNDEFINED:
+ default:
+ return libjami::Certificate::Status::UNDEFINED;
+ }
+}
+
+TrustStatus
+trustStatusFromStr(const char* str)
+{
+ if (!std::strcmp(str, libjami::Certificate::TrustStatus::TRUSTED))
+ return TrustStatus::TRUSTED;
+ return TrustStatus::UNTRUSTED;
+}
+
+const char*
+statusToStr(TrustStatus s)
+{
+ switch (s) {
+ case TrustStatus::TRUSTED:
+ return libjami::Certificate::TrustStatus::TRUSTED;
+ case TrustStatus::UNTRUSTED:
+ default:
+ return libjami::Certificate::TrustStatus::UNTRUSTED;
+ }
+}
+
+bool
+TrustStore::addRevocationList(dht::crypto::RevocationList&& crl)
+{
+ allowed_.add(crl);
+ return true;
+}
+
+bool
+TrustStore::setCertificateStatus(const std::string& cert_id,
+ const TrustStore::PermissionStatus status)
+{
+ return setCertificateStatus(nullptr, cert_id, status, false);
+}
+
+bool
+TrustStore::setCertificateStatus(const std::shared_ptr<crypto::Certificate>& cert,
+ const TrustStore::PermissionStatus status,
+ bool local)
+{
+ return setCertificateStatus(cert, cert->getId().toString(), status, local);
+}
+
+bool
+TrustStore::setCertificateStatus(std::shared_ptr<crypto::Certificate> cert,
+ const std::string& cert_id,
+ const TrustStore::PermissionStatus status,
+ bool local)
+{
+ if (cert)
+ certStore_.pinCertificate(cert, local);
+ std::lock_guard<std::recursive_mutex> lk(mutex_);
+ updateKnownCerts();
+ bool dirty {false};
+ if (status == PermissionStatus::UNDEFINED) {
+ unknownCertStatus_.erase(cert_id);
+ dirty = certStatus_.erase(cert_id);
+ } else {
+ bool allowed = (status == PermissionStatus::ALLOWED);
+ auto s = certStatus_.find(cert_id);
+ if (s == std::end(certStatus_)) {
+ // Certificate state is currently undefined
+ if (not cert)
+ cert = certStore_.getCertificate(cert_id);
+ if (cert) {
+ unknownCertStatus_.erase(cert_id);
+ auto& crt_status = certStatus_[cert_id];
+ if (not crt_status.first)
+ crt_status.first = cert;
+ crt_status.second.allowed = allowed;
+ setStoreCertStatus(*cert, allowed);
+ } else {
+ // Can't find certificate
+ unknownCertStatus_[cert_id].allowed = allowed;
+ }
+ } else {
+ // Certificate is already allowed or banned
+ if (s->second.second.allowed != allowed) {
+ s->second.second.allowed = allowed;
+ if (allowed) // Certificate is re-added after ban, rebuld needed
+ dirty = true;
+ else
+ allowed_.remove(*s->second.first, false);
+ }
+ }
+ }
+ if (dirty)
+ rebuildTrust();
+ return true;
+}
+
+TrustStore::PermissionStatus
+TrustStore::getCertificateStatus(const std::string& cert_id) const
+{
+ std::lock_guard<std::recursive_mutex> lk(mutex_);
+ auto s = certStatus_.find(cert_id);
+ if (s == std::end(certStatus_)) {
+ auto us = unknownCertStatus_.find(cert_id);
+ if (us == std::end(unknownCertStatus_))
+ return PermissionStatus::UNDEFINED;
+ return us->second.allowed ? PermissionStatus::ALLOWED : PermissionStatus::BANNED;
+ }
+ return s->second.second.allowed ? PermissionStatus::ALLOWED : PermissionStatus::BANNED;
+}
+
+std::vector<std::string>
+TrustStore::getCertificatesByStatus(TrustStore::PermissionStatus status) const
+{
+ std::lock_guard<std::recursive_mutex> lk(mutex_);
+ std::vector<std::string> ret;
+ for (const auto& i : certStatus_)
+ if (i.second.second.allowed == (status == TrustStore::PermissionStatus::ALLOWED))
+ ret.emplace_back(i.first);
+ for (const auto& i : unknownCertStatus_)
+ if (i.second.allowed == (status == TrustStore::PermissionStatus::ALLOWED))
+ ret.emplace_back(i.first);
+ return ret;
+}
+
+bool
+TrustStore::isAllowed(const crypto::Certificate& crt, bool allowPublic)
+{
+ // Match by certificate pinning
+ std::lock_guard<std::recursive_mutex> lk(mutex_);
+ bool allowed {allowPublic};
+ for (auto c = &crt; c; c = c->issuer.get()) {
+ auto status = getCertificateStatus(c->getId().toString()); // lock mutex_
+ if (status == PermissionStatus::ALLOWED)
+ allowed = true;
+ else if (status == PermissionStatus::BANNED)
+ return false;
+ }
+
+ // Match by certificate chain
+ updateKnownCerts();
+ auto ret = allowed_.verify(crt);
+ // Unknown issuer (only that) are accepted if allowPublic is true
+ if (not ret
+ and !(allowPublic and ret.result == (GNUTLS_CERT_INVALID | GNUTLS_CERT_SIGNER_NOT_FOUND))) {
+ if (certStore_.logger())
+ certStore_.logger()->warn("%s", ret.toString().c_str());
+ return false;
+ }
+
+ return allowed;
+}
+
+void
+TrustStore::updateKnownCerts()
+{
+ auto i = std::begin(unknownCertStatus_);
+ while (i != std::end(unknownCertStatus_)) {
+ if (auto crt = certStore_.getCertificate(i->first)) {
+ certStatus_.emplace(i->first, std::make_pair(crt, i->second));
+ setStoreCertStatus(*crt, i->second.allowed);
+ i = unknownCertStatus_.erase(i);
+ } else
+ ++i;
+ }
+}
+
+void
+TrustStore::setStoreCertStatus(const crypto::Certificate& crt, bool status)
+{
+ if (status)
+ allowed_.add(crt);
+ else
+ allowed_.remove(crt, false);
+}
+
+void
+TrustStore::rebuildTrust()
+{
+ allowed_ = {};
+ for (const auto& c : certStatus_)
+ setStoreCertStatus(*c.second.first, c.second.second.allowed);
+}
+
+} // namespace tls
+} // namespace jami
diff --git a/src/security/diffie-hellman.cpp b/src/security/diffie-hellman.cpp
new file mode 100644
index 0000000..bc0a854
--- /dev/null
+++ b/src/security/diffie-hellman.cpp
@@ -0,0 +1,139 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "diffie-hellman.h"
+#include "logger.h"
+#include "fileutils.h"
+
+#include <chrono>
+#include <ciso646>
+
+namespace jami {
+namespace tls {
+
+DhParams::DhParams(const std::vector<uint8_t>& data)
+{
+ gnutls_dh_params_t new_params_;
+ int ret = gnutls_dh_params_init(&new_params_);
+ if (ret)
+ throw std::runtime_error(std::string("Error initializing DH params: ")
+ + gnutls_strerror(ret));
+ params_.reset(new_params_);
+ const gnutls_datum_t dat {(uint8_t*) data.data(), (unsigned) data.size()};
+ if (int ret_pem = gnutls_dh_params_import_pkcs3(params_.get(), &dat, GNUTLS_X509_FMT_PEM))
+ if (int ret_der = gnutls_dh_params_import_pkcs3(params_.get(), &dat, GNUTLS_X509_FMT_DER))
+ throw std::runtime_error(std::string("Error importing DH params: ")
+ + gnutls_strerror(ret_pem) + " " + gnutls_strerror(ret_der));
+}
+
+DhParams&
+DhParams::operator=(const DhParams& other)
+{
+ if (not params_) {
+ // We need a valid DH params pointer for the copy
+ gnutls_dh_params_t new_params_;
+ auto err = gnutls_dh_params_init(&new_params_);
+ if (err != GNUTLS_E_SUCCESS)
+ throw std::runtime_error(std::string("Error initializing DH params: ")
+ + gnutls_strerror(err));
+ params_.reset(new_params_);
+ }
+
+ auto err = gnutls_dh_params_cpy(params_.get(), other.get());
+ if (err != GNUTLS_E_SUCCESS)
+ throw std::runtime_error(std::string("Error copying DH params: ") + gnutls_strerror(err));
+
+ return *this;
+}
+
+std::vector<uint8_t>
+DhParams::serialize() const
+{
+ if (!params_) {
+ JAMI_WARN("serialize() called on an empty DhParams");
+ return {};
+ }
+ gnutls_datum_t out;
+ if (gnutls_dh_params_export2_pkcs3(params_.get(), GNUTLS_X509_FMT_PEM, &out))
+ return {};
+ std::vector<uint8_t> ret {out.data, out.data + out.size};
+ gnutls_free(out.data);
+ return ret;
+}
+
+DhParams
+DhParams::generate()
+{
+ using clock = std::chrono::high_resolution_clock;
+
+ auto bits = gnutls_sec_param_to_pk_bits(GNUTLS_PK_DH,
+ /* GNUTLS_SEC_PARAM_HIGH */ GNUTLS_SEC_PARAM_HIGH);
+ JAMI_DBG("Generating DH params with %u bits", bits);
+ auto start = clock::now();
+
+ gnutls_dh_params_t new_params_;
+ int ret = gnutls_dh_params_init(&new_params_);
+ if (ret != GNUTLS_E_SUCCESS) {
+ JAMI_ERR("Error initializing DH params: %s", gnutls_strerror(ret));
+ return {};
+ }
+ DhParams params {new_params_};
+
+ ret = gnutls_dh_params_generate2(params.get(), bits);
+ if (ret != GNUTLS_E_SUCCESS) {
+ JAMI_ERR("Error generating DH params: %s", gnutls_strerror(ret));
+ return {};
+ }
+
+ std::chrono::duration<double> time_span = clock::now() - start;
+ JAMI_DBG("Generated DH params with %u bits in %lfs", bits, time_span.count());
+ return params;
+}
+
+DhParams
+DhParams::loadDhParams(const std::string& path)
+{
+ std::lock_guard<std::mutex> l(fileutils::getFileLock(path));
+ try {
+ // writeTime throw exception if file doesn't exist
+ auto duration = std::chrono::system_clock::now() - fileutils::writeTime(path);
+ if (duration >= std::chrono::hours(24 * 3)) // file is valid only 3 days
+ throw std::runtime_error("file too old");
+
+ JAMI_DBG("Loading DhParams from file '%s'", path.c_str());
+ return {fileutils::loadFile(path)};
+ } catch (const std::exception& e) {
+ JAMI_DBG("Failed to load DhParams file '%s': %s", path.c_str(), e.what());
+ if (auto params = tls::DhParams::generate()) {
+ try {
+ fileutils::saveFile(path, params.serialize(), 0600);
+ JAMI_DBG("Saved DhParams to file '%s'", path.c_str());
+ } catch (const std::exception& ex) {
+ JAMI_WARN("Failed to save DhParams in file '%s': %s", path.c_str(), ex.what());
+ }
+ return params;
+ }
+ JAMI_ERR("Can't generate DH params.");
+ return {};
+ }
+}
+
+} // namespace tls
+} // namespace jami
diff --git a/src/security/security_const.h b/src/security/security_const.h
new file mode 100644
index 0000000..fb9541b
--- /dev/null
+++ b/src/security/security_const.h
@@ -0,0 +1,121 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Philippe Proulx <philippe.proulx@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+namespace libjami {
+
+namespace Certificate {
+
+namespace Status {
+constexpr static char UNDEFINED[] = "UNDEFINED";
+constexpr static char ALLOWED[] = "ALLOWED";
+constexpr static char BANNED[] = "BANNED";
+} // namespace Status
+
+namespace TrustStatus {
+constexpr static char UNTRUSTED[] = "UNTRUSTED";
+constexpr static char TRUSTED[] = "TRUSTED";
+} // namespace TrustStatus
+
+/**
+ * Those constantes are used by the ConfigurationManager.validateCertificate method
+ */
+namespace ChecksNames {
+constexpr static char HAS_PRIVATE_KEY[] = "HAS_PRIVATE_KEY";
+constexpr static char EXPIRED[] = "EXPIRED";
+constexpr static char STRONG_SIGNING[] = "STRONG_SIGNING";
+constexpr static char NOT_SELF_SIGNED[] = "NOT_SELF_SIGNED";
+constexpr static char KEY_MATCH[] = "KEY_MATCH";
+constexpr static char PRIVATE_KEY_STORAGE_PERMISSION[] = "PRIVATE_KEY_STORAGE_PERMISSION";
+constexpr static char PUBLIC_KEY_STORAGE_PERMISSION[] = "PUBLIC_KEY_STORAGE_PERMISSION";
+constexpr static char PRIVATE_KEY_DIRECTORY_PERMISSIONS[] = "PRIVATEKEY_DIRECTORY_PERMISSIONS";
+constexpr static char PUBLIC_KEY_DIRECTORY_PERMISSIONS[] = "PUBLICKEY_DIRECTORY_PERMISSIONS";
+constexpr static char PRIVATE_KEY_STORAGE_LOCATION[] = "PRIVATE_KEY_STORAGE_LOCATION";
+constexpr static char PUBLIC_KEY_STORAGE_LOCATION[] = "PUBLIC_KEY_STORAGE_LOCATION";
+constexpr static char PRIVATE_KEY_SELINUX_ATTRIBUTES[] = "PRIVATE_KEY_SELINUX_ATTRIBUTES";
+constexpr static char PUBLIC_KEY_SELINUX_ATTRIBUTES[] = "PUBLIC_KEY_SELINUX_ATTRIBUTES";
+constexpr static char EXIST[] = "EXIST";
+constexpr static char VALID[] = "VALID";
+constexpr static char VALID_AUTHORITY[] = "VALID_AUTHORITY";
+constexpr static char KNOWN_AUTHORITY[] = "KNOWN_AUTHORITY";
+constexpr static char NOT_REVOKED[] = "NOT_REVOKED";
+constexpr static char AUTHORITY_MISMATCH[] = "AUTHORITY_MISMATCH";
+constexpr static char UNEXPECTED_OWNER[] = "UNEXPECTED_OWNER";
+constexpr static char NOT_ACTIVATED[] = "NOT_ACTIVATED";
+} // namespace ChecksNames
+
+/**
+ * Those constants are used by the ConfigurationManager.getCertificateDetails method
+ */
+namespace DetailsNames {
+constexpr static char EXPIRATION_DATE[] = "EXPIRATION_DATE";
+constexpr static char ACTIVATION_DATE[] = "ACTIVATION_DATE";
+constexpr static char REQUIRE_PRIVATE_KEY_PASSWORD[] = "REQUIRE_PRIVATE_KEY_PASSWORD";
+constexpr static char PUBLIC_SIGNATURE[] = "PUBLIC_SIGNATURE";
+constexpr static char VERSION_NUMBER[] = "VERSION_NUMBER";
+constexpr static char SERIAL_NUMBER[] = "SERIAL_NUMBER";
+constexpr static char ISSUER[] = "ISSUER";
+constexpr static char SUBJECT_KEY_ALGORITHM[] = "SUBJECT_KEY_ALGORITHM";
+constexpr static char CN[] = "CN";
+constexpr static char N[] = "N";
+constexpr static char O[] = "O";
+constexpr static char SIGNATURE_ALGORITHM[] = "SIGNATURE_ALGORITHM";
+constexpr static char MD5_FINGERPRINT[] = "MD5_FINGERPRINT";
+constexpr static char SHA1_FINGERPRINT[] = "SHA1_FINGERPRINT";
+constexpr static char PUBLIC_KEY_ID[] = "PUBLIC_KEY_ID";
+constexpr static char ISSUER_DN[] = "ISSUER_DN";
+constexpr static char NEXT_EXPECTED_UPDATE_DATE[] = "NEXT_EXPECTED_UPDATE_DATE";
+constexpr static char OUTGOING_SERVER[] = "OUTGOING_SERVER";
+constexpr static char IS_CA[] = "IS_CA";
+} // namespace DetailsNames
+
+/**
+ * Those constants are used by the ConfigurationManager.getCertificateDetails and
+ * ConfigurationManager.validateCertificate methods
+ */
+namespace ChecksValuesTypesNames {
+constexpr static char BOOLEAN[] = "BOOLEAN";
+constexpr static char ISO_DATE[] = "ISO_DATE";
+constexpr static char CUSTOM[] = "CUSTOM";
+constexpr static char NUMBER[] = "NUMBER";
+} // namespace ChecksValuesTypesNames
+
+/**
+ * Those constantes are used by the ConfigurationManager.validateCertificate method
+ */
+namespace CheckValuesNames {
+constexpr static char PASSED[] = "PASSED";
+constexpr static char FAILED[] = "FAILED";
+constexpr static char UNSUPPORTED[] = "UNSUPPORTED";
+constexpr static char ISO_DATE[] = "ISO_DATE";
+constexpr static char CUSTOM[] = "CUSTOM";
+constexpr static char DATE[] = "DATE";
+} // namespace CheckValuesNames
+
+} // namespace Certificate
+
+namespace TlsTransport {
+constexpr static char TLS_PEER_CERT[] = "TLS_PEER_CERT";
+constexpr static char TLS_PEER_CA_NUM[] = "TLS_PEER_CA_NUM";
+constexpr static char TLS_PEER_CA_[] = "TLS_PEER_CA_";
+constexpr static char TLS_CIPHER[] = "TLS_CIPHER";
+} // namespace TlsTransport
+
+} // namespace libjami
diff --git a/src/security/threadloop.cpp b/src/security/threadloop.cpp
new file mode 100644
index 0000000..88db725
--- /dev/null
+++ b/src/security/threadloop.cpp
@@ -0,0 +1,135 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Guillaume Roguez <Guillaume.Roguez@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "threadloop.h"
+
+#include <ciso646> // fix windows compiler bug
+
+namespace jami {
+
+void
+ThreadLoop::mainloop(std::thread::id& tid,
+ const std::function<bool()> setup,
+ const std::function<void()> process,
+ const std::function<void()> cleanup)
+{
+ tid = std::this_thread::get_id();
+ try {
+ if (setup()) {
+ while (state_ == ThreadState::RUNNING)
+ process();
+ cleanup();
+ } else {
+ throw std::runtime_error("setup failed");
+ }
+ } catch (const ThreadLoopException& e) {
+ if (logger_) logger_->e("[threadloop:{}] ThreadLoopException: {}", fmt::ptr(this), e.what());
+ } catch (const std::exception& e) {
+ if (logger_) logger_->e("[threadloop:{}] Unwaited exception: {}", fmt::ptr(this), e.what());
+ }
+ stop();
+}
+
+ThreadLoop::ThreadLoop(std::shared_ptr<dht::log::Logger> logger,
+ const std::function<bool()>& setup,
+ const std::function<void()>& process,
+ const std::function<void()>& cleanup)
+ : setup_(setup)
+ , process_(process)
+ , cleanup_(cleanup)
+ , thread_()
+ , logger_(std::move(logger))
+{}
+
+ThreadLoop::~ThreadLoop()
+{
+ if (isRunning()) {
+ if (logger_) logger_->error("join() should be explicitly called in owner's destructor");
+ join();
+ }
+}
+
+void
+ThreadLoop::start()
+{
+ const auto s = state_.load();
+
+ if (s == ThreadState::RUNNING) {
+ if (logger_) logger_->error("already started");
+ return;
+ }
+
+ // stop pending but not processed by thread yet?
+ if (s == ThreadState::STOPPING and thread_.joinable()) {
+ if (logger_) logger_->debug("stop pending");
+ thread_.join();
+ }
+
+ state_ = ThreadState::RUNNING;
+ thread_ = std::thread(&ThreadLoop::mainloop, this, std::ref(threadId_), setup_, process_, cleanup_);
+ threadId_ = thread_.get_id();
+}
+
+void
+ThreadLoop::stop()
+{
+ if (state_ == ThreadState::RUNNING)
+ state_ = ThreadState::STOPPING;
+}
+
+void
+ThreadLoop::join()
+{
+ stop();
+ if (thread_.joinable())
+ thread_.join();
+}
+
+void
+ThreadLoop::waitForCompletion()
+{
+ if (thread_.joinable())
+ thread_.join();
+}
+
+void
+ThreadLoop::exit()
+{
+ stop();
+ throw ThreadLoopException();
+}
+
+bool
+ThreadLoop::isRunning() const noexcept
+{
+#ifdef _WIN32
+ return state_ == ThreadState::RUNNING;
+#else
+ return thread_.joinable() and state_ == ThreadState::RUNNING;
+#endif
+}
+
+void
+InterruptedThreadLoop::stop()
+{
+ ThreadLoop::stop();
+ cv_.notify_one();
+}
+} // namespace jami
diff --git a/src/security/threadloop.h b/src/security/threadloop.h
new file mode 100644
index 0000000..8a7a0c6
--- /dev/null
+++ b/src/security/threadloop.h
@@ -0,0 +1,134 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Guillaume Roguez <Guillaume.Roguez@savoirfairelinux.com>
+ * Author: Eloi Bail <Eloi.Bail@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+
+#include <atomic>
+#include <thread>
+#include <functional>
+#include <stdexcept>
+#include <condition_variable>
+#include <mutex>
+
+#include <opendht/logger.h>
+
+namespace jami {
+
+struct ThreadLoopException : public std::runtime_error
+{
+ ThreadLoopException()
+ : std::runtime_error("ThreadLoopException")
+ {}
+};
+
+class ThreadLoop
+{
+public:
+ enum class ThreadState { READY, RUNNING, STOPPING };
+
+ ThreadLoop(std::shared_ptr<dht::log::Logger> logger,
+ const std::function<bool()>& setup,
+ const std::function<void()>& process,
+ const std::function<void()>& cleanup);
+ virtual ~ThreadLoop();
+
+ void start();
+ void exit();
+ virtual void stop();
+ void join();
+ void waitForCompletion(); // thread will stop itself
+
+ bool isRunning() const noexcept;
+ bool isStopping() const noexcept { return state_ == ThreadState::STOPPING; }
+ std::thread::id get_id() const noexcept { return threadId_; }
+
+private:
+ ThreadLoop(const ThreadLoop&) = delete;
+ ThreadLoop(ThreadLoop&&) noexcept = delete;
+ ThreadLoop& operator=(const ThreadLoop&) = delete;
+ ThreadLoop& operator=(ThreadLoop&&) noexcept = delete;
+
+ // These must be provided by users of ThreadLoop
+ std::function<bool()> setup_;
+ std::function<void()> process_;
+ std::function<void()> cleanup_;
+
+ void mainloop(std::thread::id& tid,
+ const std::function<bool()> setup,
+ const std::function<void()> process,
+ const std::function<void()> cleanup);
+
+ std::atomic<ThreadState> state_ {ThreadState::READY};
+ std::thread::id threadId_;
+ std::thread thread_;
+ std::shared_ptr<dht::log::Logger> logger_;
+};
+
+class InterruptedThreadLoop : public ThreadLoop
+{
+public:
+ InterruptedThreadLoop(std::shared_ptr<dht::log::Logger> logger,
+ const std::function<bool()>& setup,
+ const std::function<void()>& process,
+ const std::function<void()>& cleanup)
+ : ThreadLoop::ThreadLoop(logger, setup, process, cleanup)
+ {}
+
+ void stop() override;
+
+ void interrupt() noexcept { cv_.notify_one(); }
+
+ template<typename Rep, typename Period>
+ void wait_for(const std::chrono::duration<Rep, Period>& rel_time)
+ {
+ if (std::this_thread::get_id() != get_id())
+ throw std::runtime_error("can not call wait_for outside thread context");
+
+ std::unique_lock<std::mutex> lk(mutex_);
+ cv_.wait_for(lk, rel_time, [this]() { return isStopping(); });
+ }
+
+ template<typename Rep, typename Period, typename Pred>
+ bool wait_for(const std::chrono::duration<Rep, Period>& rel_time, Pred&& pred)
+ {
+ if (std::this_thread::get_id() != get_id())
+ throw std::runtime_error("can not call wait_for outside thread context");
+
+ std::unique_lock<std::mutex> lk(mutex_);
+ return cv_.wait_for(lk, rel_time, [this, pred] { return isStopping() || pred(); });
+ }
+
+ template<typename Pred>
+ void wait(Pred&& pred)
+ {
+ if (std::this_thread::get_id() != get_id())
+ throw std::runtime_error("Can not call wait outside thread context");
+
+ std::unique_lock<std::mutex> lk(mutex_);
+ cv_.wait(lk, [this, p = std::forward<Pred>(pred)] { return isStopping() || p(); });
+ }
+
+private:
+ std::mutex mutex_;
+ std::condition_variable cv_;
+};
+
+} // namespace jami
diff --git a/src/security/tls_session.cpp b/src/security/tls_session.cpp
new file mode 100644
index 0000000..43f623d
--- /dev/null
+++ b/src/security/tls_session.cpp
@@ -0,0 +1,1789 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com>
+ * Author: Guillaume Roguez <guillaume.roguez@savoirfairelinux.com>
+ * Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com>
+ * Author: Vsevolod Ivanov <vsevolod.ivanov@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#include "tls_session.h"
+#include "threadloop.h"
+#include "certstore.h"
+
+#include <gnutls/gnutls.h>
+#include <gnutls/dtls.h>
+#include <gnutls/abstract.h>
+
+#include <gnutls/crypto.h>
+#include <gnutls/ocsp.h>
+#include <opendht/http.h>
+#include <opendht/logger.h>
+
+#include <list>
+#include <mutex>
+#include <condition_variable>
+#include <utility>
+#include <map>
+#include <atomic>
+#include <iterator>
+#include <stdexcept>
+#include <algorithm>
+#include <cstring> // std::memset
+
+#include <cstdlib>
+#include <unistd.h>
+
+namespace jami {
+namespace tls {
+
+static constexpr const char* DTLS_CERT_PRIORITY_STRING {
+ "SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"};
+static constexpr const char* DTLS_FULL_PRIORITY_STRING {
+ "SECURE192:-KX-ALL:+ANON-ECDH:+ANON-DH:+SECURE192:-VERS-TLS-ALL:+VERS-DTLS-ALL:-RSA:%SERVER_"
+ "PRECEDENCE:%SAFE_RENEGOTIATION"};
+// Note: -GROUP-FFDHE4096:-GROUP-FFDHE6144:-GROUP-FFDHE8192:+GROUP-X25519:
+// is added after gnutls 3.6.7, because some safety checks were introduced for FFDHE resulting in a
+// performance drop for our usage (2/3s of delay) This performance drop is visible on mobiles devices.
+
+// Benchmark result (on a computer)
+// $gnutls-cli --benchmark-tls-kx
+// (TLS1.3)-(DHE-FFDHE3072)-(RSA-PSS-RSAE-SHA256)-(AES-128-GCM) 20.48 transactions/sec
+// (avg. handshake time: 48.45 ms, sample variance: 0.68)
+// (TLS1.3)-(ECDHE-SECP256R1)-(RSA-PSS-RSAE-SHA256)-(AES-128-GCM) 208.14 transactions/sec
+// (avg. handshake time: 4.01 ms, sample variance: 0.01)
+// (TLS1.3)-(ECDHE-X25519)-(RSA-PSS-RSAE-SHA256)-(AES-128-GCM) 240.93 transactions/sec
+// (avg. handshake time: 4.00 ms, sample variance: 0.00)
+static constexpr const char* TLS_CERT_PRIORITY_STRING {
+ "SECURE192:-RSA:-GROUP-FFDHE4096:-GROUP-FFDHE6144:-GROUP-FFDHE8192:+GROUP-X25519:%SERVER_"
+ "PRECEDENCE:%SAFE_RENEGOTIATION"};
+static constexpr const char* TLS_FULL_PRIORITY_STRING {
+ "SECURE192:-KX-ALL:+ANON-ECDH:+ANON-DH:+SECURE192:-RSA:-GROUP-FFDHE4096:-GROUP-FFDHE6144:-"
+ "GROUP-FFDHE8192:+GROUP-X25519:%SERVER_PRECEDENCE:%SAFE_RENEGOTIATION"};
+static constexpr uint32_t RX_MAX_SIZE {64 * 1024}; // 64k = max size of a UDP packet
+static constexpr std::size_t INPUT_MAX_SIZE {
+ 1000}; // Maximum number of packets to store before dropping (pkt size = DTLS_MTU)
+static constexpr ssize_t FLOOD_THRESHOLD {4 * 1024};
+static constexpr auto FLOOD_PAUSE = std::chrono::milliseconds(
+ 100); // Time to wait after an invalid cookie packet (anti flood attack)
+static constexpr size_t HANDSHAKE_MAX_RETRY {64};
+static constexpr auto DTLS_RETRANSMIT_TIMEOUT = std::chrono::milliseconds(
+ 1000); // Delay between two handshake request on DTLS
+static constexpr auto COOKIE_TIMEOUT = std::chrono::seconds(
+ 10); // Time to wait for a cookie packet from client
+static constexpr int MIN_MTU {
+ 512 - 20 - 8}; // minimal payload size of a DTLS packet carried by an IPv4 packet
+static constexpr uint8_t HEARTBEAT_TRIES = 1; // Number of tries at each heartbeat ping send
+static constexpr auto HEARTBEAT_RETRANS_TIMEOUT = std::chrono::milliseconds(
+ 700); // gnutls heartbeat retransmission timeout for each ping (in milliseconds)
+static constexpr auto HEARTBEAT_TOTAL_TIMEOUT
+ = HEARTBEAT_RETRANS_TIMEOUT
+ * HEARTBEAT_TRIES; // gnutls heartbeat time limit for heartbeat procedure (in milliseconds)
+static constexpr int MISS_ORDERING_LIMIT
+ = 32; // maximal accepted distance of out-of-order packet (note: must be a signed type)
+static constexpr auto RX_OOO_TIMEOUT = std::chrono::milliseconds(1500);
+static constexpr int ASYMETRIC_TRANSPORT_MTU_OFFSET
+ = 20; // when client, if your local IP is IPV4 and server is IPV6; you must reduce your MTU to
+ // avoid packet too big error on server side. the offset is the difference in size of IP headers
+static constexpr auto OCSP_REQUEST_TIMEOUT = std::chrono::seconds(
+ 2); // Time to wait for an ocsp-request
+
+// Helper to cast any duration into an integer number of milliseconds
+template<class Rep, class Period>
+static std::chrono::milliseconds::rep
+duration2ms(std::chrono::duration<Rep, Period> d)
+{
+ return std::chrono::duration_cast<std::chrono::milliseconds>(d).count();
+}
+
+static inline uint64_t
+array2uint(const std::array<uint8_t, 8>& a)
+{
+ uint64_t res = 0;
+ for (int i = 0; i < 8; ++i)
+ res = (res << 8) + a[i];
+ return res;
+}
+
+//==============================================================================
+
+namespace {
+
+class TlsCertificateCredendials
+{
+ using T = gnutls_certificate_credentials_t;
+
+public:
+ TlsCertificateCredendials()
+ {
+ int ret = gnutls_certificate_allocate_credentials(&creds_);
+ if (ret < 0) {
+ //if (params_.logger)
+ // params_.logger->e("gnutls_certificate_allocate_credentials() failed with ret=%d", ret);
+ throw std::bad_alloc();
+ }
+ }
+
+ ~TlsCertificateCredendials() { gnutls_certificate_free_credentials(creds_); }
+
+ operator T() { return creds_; }
+
+private:
+ TlsCertificateCredendials(const TlsCertificateCredendials&) = delete;
+ TlsCertificateCredendials& operator=(const TlsCertificateCredendials&) = delete;
+ T creds_;
+};
+
+class TlsAnonymousClientCredendials
+{
+ using T = gnutls_anon_client_credentials_t;
+
+public:
+ TlsAnonymousClientCredendials()
+ {
+ int ret = gnutls_anon_allocate_client_credentials(&creds_);
+ if (ret < 0) {
+ //if (params_.logger)
+ // params_.logger->e("gnutls_anon_allocate_client_credentials() failed with ret=%d", ret);
+ throw std::bad_alloc();
+ }
+ }
+
+ ~TlsAnonymousClientCredendials() { gnutls_anon_free_client_credentials(creds_); }
+
+ operator T() { return creds_; }
+
+private:
+ TlsAnonymousClientCredendials(const TlsAnonymousClientCredendials&) = delete;
+ TlsAnonymousClientCredendials& operator=(const TlsAnonymousClientCredendials&) = delete;
+ T creds_;
+};
+
+class TlsAnonymousServerCredendials
+{
+ using T = gnutls_anon_server_credentials_t;
+
+public:
+ TlsAnonymousServerCredendials()
+ {
+ int ret = gnutls_anon_allocate_server_credentials(&creds_);
+ if (ret < 0) {
+ //if (params_.logger)
+ // params_.logger->e("gnutls_anon_allocate_server_credentials() failed with ret=%d", ret);
+ throw std::bad_alloc();
+ }
+ }
+
+ ~TlsAnonymousServerCredendials() { gnutls_anon_free_server_credentials(creds_); }
+
+ operator T() { return creds_; }
+
+private:
+ TlsAnonymousServerCredendials(const TlsAnonymousServerCredendials&) = delete;
+ TlsAnonymousServerCredendials& operator=(const TlsAnonymousServerCredendials&) = delete;
+ T creds_;
+};
+
+} // namespace
+
+//==============================================================================
+
+class TlsSession::TlsSessionImpl
+{
+public:
+ using clock = std::chrono::steady_clock;
+ using StateHandler = std::function<TlsSessionState(TlsSessionState state)>;
+ using OcspVerification = std::function<void(const int status)>;
+ using HttpResponse = std::function<void(const dht::http::Response& response)>;
+
+ // Constants (ctor init.)
+ const bool isServer_;
+ const TlsParams params_;
+ const TlsSessionCallbacks callbacks_;
+ const bool anonymous_;
+
+ TlsSessionImpl(std::unique_ptr<SocketType>&& transport,
+ const TlsParams& params,
+ const TlsSessionCallbacks& cbs,
+ bool anonymous);
+
+ ~TlsSessionImpl();
+
+ const char* typeName() const;
+
+ std::unique_ptr<SocketType> transport_;
+
+ // State protectors
+ std::mutex stateMutex_;
+ std::condition_variable stateCondition_;
+
+ // State machine
+ TlsSessionState handleStateSetup(TlsSessionState state);
+ TlsSessionState handleStateCookie(TlsSessionState state);
+ TlsSessionState handleStateHandshake(TlsSessionState state);
+ TlsSessionState handleStateMtuDiscovery(TlsSessionState state);
+ TlsSessionState handleStateEstablished(TlsSessionState state);
+ TlsSessionState handleStateShutdown(TlsSessionState state);
+ std::map<TlsSessionState, StateHandler> fsmHandlers_ {};
+ std::atomic<TlsSessionState> state_ {TlsSessionState::SETUP};
+ std::atomic<TlsSessionState> newState_ {TlsSessionState::NONE};
+ std::atomic<int> maxPayload_ {-1};
+
+ // IO GnuTLS <-> ICE
+ std::mutex rxMutex_ {};
+ std::condition_variable rxCv_ {};
+ std::list<std::vector<ValueType>> rxQueue_ {};
+
+ bool flushProcessing_ {false}; ///< protect against recursive call to flushRxQueue
+ std::vector<ValueType> rawPktBuf_; ///< gnutls incoming packet buffer
+ uint64_t baseSeq_ {0}; ///< sequence number of first application data packet received
+ uint64_t lastRxSeq_ {0}; ///< last received and valid packet sequence number
+ uint64_t gapOffset_ {0}; ///< offset of first byte not received yet
+ clock::time_point lastReadTime_;
+ std::map<uint64_t, std::vector<ValueType>> reorderBuffer_ {};
+ std::list<clock::time_point> nextFlush_ {};
+
+ std::size_t send(const ValueType*, std::size_t, std::error_code&);
+ ssize_t sendRaw(const void*, size_t);
+ ssize_t sendRawVec(const giovec_t*, int);
+ ssize_t recvRaw(void*, size_t);
+ int waitForRawData(std::chrono::milliseconds);
+
+ bool initFromRecordState(int offset = 0);
+ void handleDataPacket(std::vector<ValueType>&&, uint64_t);
+ void flushRxQueue(std::unique_lock<std::mutex>&);
+
+ // Statistics
+ std::atomic<std::size_t> stRxRawPacketCnt_ {0};
+ std::atomic<std::size_t> stRxRawBytesCnt_ {0};
+ std::atomic<std::size_t> stRxRawPacketDropCnt_ {0};
+ std::atomic<std::size_t> stTxRawPacketCnt_ {0};
+ std::atomic<std::size_t> stTxRawBytesCnt_ {0};
+ void dump_io_stats() const;
+
+ std::unique_ptr<TlsAnonymousClientCredendials> cacred_; // ctor init.
+ std::unique_ptr<TlsAnonymousServerCredendials> sacred_; // ctor init.
+ std::unique_ptr<TlsCertificateCredendials> xcred_; // ctor init.
+ std::mutex sessionReadMutex_;
+ std::mutex sessionWriteMutex_;
+ gnutls_session_t session_ {nullptr};
+ gnutls_datum_t cookie_key_ {nullptr, 0};
+ gnutls_dtls_prestate_st prestate_ {};
+ ssize_t cookie_count_ {0};
+
+ TlsSessionState setupClient();
+ TlsSessionState setupServer();
+ void initAnonymous();
+ void initCredentials();
+ bool commonSessionInit();
+
+ std::shared_ptr<dht::crypto::Certificate> peerCertificate(gnutls_session_t session) const;
+
+ /*
+ * Implicit certificate validations.
+ */
+ int verifyCertificateWrapper(gnutls_session_t session);
+ /*
+ * Verify OCSP (Online Certificate Service Protocol):
+ */
+ void verifyOcsp(const std::string& url,
+ dht::crypto::Certificate& cert,
+ gnutls_x509_crt_t issuer,
+ OcspVerification cb);
+ /*
+ * Send OCSP Request to the specified URI.
+ */
+ void sendOcspRequest(const std::string& uri,
+ std::string body,
+ std::chrono::seconds timeout,
+ HttpResponse cb = {});
+
+ // FSM thread (TLS states)
+ ThreadLoop thread_; // ctor init.
+ bool setup();
+ void process();
+ void cleanup();
+
+ // Path mtu discovery
+ std::array<int, 3> MTUS_;
+ int mtuProbe_;
+ int hbPingRecved_ {0};
+ bool pmtudOver_ {false};
+ void pathMtuHeartbeat();
+
+ std::mutex requestsMtx_;
+ std::set<std::shared_ptr<dht::http::Request>> requests_;
+ std::shared_ptr<dht::crypto::Certificate> pCert_ {};
+};
+
+TlsSession::TlsSessionImpl::TlsSessionImpl(std::unique_ptr<SocketType>&& transport,
+ const TlsParams& params,
+ const TlsSessionCallbacks& cbs,
+ bool anonymous)
+ : isServer_(not transport->isInitiator())
+ , params_(params)
+ , callbacks_(cbs)
+ , anonymous_(anonymous)
+ , transport_ {std::move(transport)}
+ , cacred_(nullptr)
+ , sacred_(nullptr)
+ , xcred_(nullptr)
+ , thread_(params.logger, [this] { return setup(); }, [this] { process(); }, [this] { cleanup(); })
+{
+ if (not transport_->isReliable()) {
+ transport_->setOnRecv([this](const ValueType* buf, size_t len) {
+ std::lock_guard<std::mutex> lk {rxMutex_};
+ if (rxQueue_.size() == INPUT_MAX_SIZE) {
+ rxQueue_.pop_front(); // drop oldest packet if input buffer is full
+ ++stRxRawPacketDropCnt_;
+ }
+ rxQueue_.emplace_back(buf, buf + len);
+ ++stRxRawPacketCnt_;
+ stRxRawBytesCnt_ += len;
+ rxCv_.notify_one();
+ return len;
+ });
+ }
+
+ // Run FSM into dedicated thread
+ thread_.start();
+}
+
+TlsSession::TlsSessionImpl::~TlsSessionImpl()
+{
+ state_ = TlsSessionState::SHUTDOWN;
+ stateCondition_.notify_all();
+ rxCv_.notify_all();
+ {
+ std::lock_guard<std::mutex> lock(requestsMtx_);
+ // requests_ store a shared_ptr, so we need to cancel requests
+ // to not be stuck in verifyCertificateWrapper
+ for (auto& request : requests_)
+ request->cancel();
+ requests_.clear();
+ }
+ thread_.join();
+ if (not transport_->isReliable())
+ transport_->setOnRecv(nullptr);
+}
+
+const char*
+TlsSession::TlsSessionImpl::typeName() const
+{
+ return isServer_ ? "server" : "client";
+}
+
+void
+TlsSession::TlsSessionImpl::dump_io_stats() const
+{
+ if (params_.logger)
+ params_.logger->debug("[TLS] RxRawPkt={:d} ({:d} bytes) - TxRawPkt={:d} ({:d} bytes)",
+ stRxRawPacketCnt_.load(),
+ stRxRawBytesCnt_.load(),
+ stTxRawPacketCnt_.load(),
+ stTxRawBytesCnt_.load());
+}
+
+TlsSessionState
+TlsSession::TlsSessionImpl::setupClient()
+{
+ int ret;
+
+ if (not transport_->isReliable()) {
+ ret = gnutls_init(&session_, GNUTLS_CLIENT | GNUTLS_DATAGRAM);
+ // uncoment to reactivate PMTUD
+ // if (params_.logger)
+ params_.logger->d("[TLS] set heartbeat reception for retrocompatibility check on server");
+ // gnutls_heartbeat_enable(session_,GNUTLS_HB_PEER_ALLOWED_TO_SEND);
+ } else {
+ ret = gnutls_init(&session_, GNUTLS_CLIENT);
+ }
+
+ if (ret != GNUTLS_E_SUCCESS) {
+ if (params_.logger)
+ params_.logger->e("[TLS] session init failed: %s", gnutls_strerror(ret));
+ return TlsSessionState::SHUTDOWN;
+ }
+
+ if (not commonSessionInit()) {
+ return TlsSessionState::SHUTDOWN;
+ }
+
+ return TlsSessionState::HANDSHAKE;
+}
+
+TlsSessionState
+TlsSession::TlsSessionImpl::setupServer()
+{
+ int ret;
+
+ if (not transport_->isReliable()) {
+ ret = gnutls_init(&session_, GNUTLS_SERVER | GNUTLS_DATAGRAM);
+
+ // uncoment to reactivate PMTUD
+ // if (params_.logger)
+ params_.logger->d("[TLS] set heartbeat reception");
+ // gnutls_heartbeat_enable(session_, GNUTLS_HB_PEER_ALLOWED_TO_SEND);
+
+ gnutls_dtls_prestate_set(session_, &prestate_);
+ } else {
+ ret = gnutls_init(&session_, GNUTLS_SERVER);
+ }
+
+ if (ret != GNUTLS_E_SUCCESS) {
+ if (params_.logger)
+ params_.logger->e("[TLS] session init failed: %s", gnutls_strerror(ret));
+ return TlsSessionState::SHUTDOWN;
+ }
+
+ gnutls_certificate_server_set_request(session_, GNUTLS_CERT_REQUIRE);
+
+ if (not commonSessionInit())
+ return TlsSessionState::SHUTDOWN;
+
+ return TlsSessionState::HANDSHAKE;
+}
+
+void
+TlsSession::TlsSessionImpl::initAnonymous()
+{
+ // credentials for handshaking and transmission
+ if (isServer_)
+ sacred_.reset(new TlsAnonymousServerCredendials());
+ else
+ cacred_.reset(new TlsAnonymousClientCredendials());
+
+ // Setup DH-params for anonymous authentification
+ if (isServer_) {
+ if (const auto& dh_params = params_.dh_params.get().get())
+ gnutls_anon_set_server_dh_params(*sacred_, dh_params);
+ else
+ if (params_.logger)
+ params_.logger->w("[TLS] DH params unavailable");
+ }
+}
+
+void
+TlsSession::TlsSessionImpl::initCredentials()
+{
+ int ret;
+
+ // credentials for handshaking and transmission
+ xcred_.reset(new TlsCertificateCredendials());
+
+ gnutls_certificate_set_verify_function(*xcred_, [](gnutls_session_t session) -> int {
+ auto this_ = reinterpret_cast<TlsSessionImpl*>(gnutls_session_get_ptr(session));
+ return this_->verifyCertificateWrapper(session);
+ });
+
+ // Load user-given CA list
+ if (not params_.ca_list.empty()) {
+ // Try PEM format first
+ ret = gnutls_certificate_set_x509_trust_file(*xcred_,
+ params_.ca_list.c_str(),
+ GNUTLS_X509_FMT_PEM);
+
+ // Then DER format
+ if (ret < 0)
+ ret = gnutls_certificate_set_x509_trust_file(*xcred_,
+ params_.ca_list.c_str(),
+ GNUTLS_X509_FMT_DER);
+ if (ret < 0)
+ throw std::runtime_error("can't load CA " + params_.ca_list + ": "
+ + std::string(gnutls_strerror(ret)));
+
+ if (params_.logger)
+ params_.logger->d("[TLS] CA list %s loadev", params_.ca_list.c_str());
+ }
+ if (params_.peer_ca) {
+ auto chain = params_.peer_ca->getChainWithRevocations();
+ auto ret = gnutls_certificate_set_x509_trust(*xcred_,
+ chain.first.data(),
+ chain.first.size());
+ if (not chain.second.empty())
+ gnutls_certificate_set_x509_crl(*xcred_, chain.second.data(), chain.second.size());
+ if (params_.logger)
+ params_.logger->debug("[TLS] Peer CA list {:d} ({:d} CRLs): {:d}",
+ chain.first.size(),
+ chain.second.size(),
+ ret);
+ }
+
+ // Load user-given identity (key and passwd)
+ if (params_.cert) {
+ std::vector<gnutls_x509_crt_t> certs;
+ certs.reserve(3);
+ auto crt = params_.cert;
+ while (crt) {
+ certs.emplace_back(crt->cert);
+ crt = crt->issuer;
+ }
+
+ ret = gnutls_certificate_set_x509_key(*xcred_,
+ certs.data(),
+ certs.size(),
+ params_.cert_key->x509_key);
+ if (ret < 0)
+ throw std::runtime_error("can't load certificate: " + std::string(gnutls_strerror(ret)));
+
+ if (params_.logger)
+ params_.logger->d("[TLS] User identity loaded");
+ }
+
+ // Setup DH-params (server only, may block on dh_params.get())
+ if (isServer_) {
+ if (const auto& dh_params = params_.dh_params.get().get())
+ gnutls_certificate_set_dh_params(*xcred_, dh_params);
+ else
+ if (params_.logger)
+ params_.logger->w("[TLS] DH params unavailable"); // YOMGUI: need to stop?
+ }
+}
+
+bool
+TlsSession::TlsSessionImpl::commonSessionInit()
+{
+ int ret;
+
+ if (anonymous_) {
+ // Force anonymous connection, see handleStateHandshake how we handle failures
+ ret = gnutls_priority_set_direct(session_,
+ transport_->isReliable() ? TLS_FULL_PRIORITY_STRING
+ : DTLS_FULL_PRIORITY_STRING,
+ nullptr);
+ if (ret != GNUTLS_E_SUCCESS) {
+ if (params_.logger)
+ params_.logger->e("[TLS] TLS priority set failed: %s", gnutls_strerror(ret));
+ return false;
+ }
+
+ // Add anonymous credentials
+ if (isServer_)
+ ret = gnutls_credentials_set(session_, GNUTLS_CRD_ANON, *sacred_);
+ else
+ ret = gnutls_credentials_set(session_, GNUTLS_CRD_ANON, *cacred_);
+
+ if (ret != GNUTLS_E_SUCCESS) {
+ if (params_.logger)
+ params_.logger->e("[TLS] anonymous credential set failed: %s", gnutls_strerror(ret));
+ return false;
+ }
+ } else {
+ // Use a classic non-encrypted CERTIFICATE exchange method (less anonymous)
+ ret = gnutls_priority_set_direct(session_,
+ transport_->isReliable() ? TLS_CERT_PRIORITY_STRING
+ : DTLS_CERT_PRIORITY_STRING,
+ nullptr);
+ if (ret != GNUTLS_E_SUCCESS) {
+ if (params_.logger)
+ params_.logger->e("[TLS] TLS priority set failed: %s", gnutls_strerror(ret));
+ return false;
+ }
+ }
+
+ // Add certificate credentials
+ ret = gnutls_credentials_set(session_, GNUTLS_CRD_CERTIFICATE, *xcred_);
+ if (ret != GNUTLS_E_SUCCESS) {
+ if (params_.logger)
+ params_.logger->e("[TLS] certificate credential set failed: %s", gnutls_strerror(ret));
+ return false;
+ }
+ gnutls_certificate_send_x509_rdn_sequence(session_, 0);
+
+ if (not transport_->isReliable()) {
+ // DTLS hanshake timeouts
+ auto re_tx_timeout = duration2ms(DTLS_RETRANSMIT_TIMEOUT);
+ gnutls_dtls_set_timeouts(session_,
+ re_tx_timeout,
+ std::max(duration2ms(params_.timeout), re_tx_timeout));
+
+ // gnutls DTLS mtu = maximum payload size given by transport
+ gnutls_dtls_set_mtu(session_, transport_->maxPayload());
+ }
+
+ // Stuff for transport callbacks
+ gnutls_session_set_ptr(session_, this);
+ gnutls_transport_set_ptr(session_, this);
+ gnutls_transport_set_vec_push_function(session_,
+ [](gnutls_transport_ptr_t t,
+ const giovec_t* iov,
+ int iovcnt) -> ssize_t {
+ auto this_ = reinterpret_cast<TlsSessionImpl*>(t);
+ return this_->sendRawVec(iov, iovcnt);
+ });
+ gnutls_transport_set_pull_function(session_,
+ [](gnutls_transport_ptr_t t, void* d, size_t s) -> ssize_t {
+ auto this_ = reinterpret_cast<TlsSessionImpl*>(t);
+ return this_->recvRaw(d, s);
+ });
+ gnutls_transport_set_pull_timeout_function(session_,
+ [](gnutls_transport_ptr_t t, unsigned ms) -> int {
+ auto this_ = reinterpret_cast<TlsSessionImpl*>(t);
+ return this_->waitForRawData(
+ std::chrono::milliseconds(ms));
+ });
+ // TODO -1 = default else set value
+ if (transport_->isReliable())
+ gnutls_handshake_set_timeout(session_, duration2ms(params_.timeout));
+ return true;
+}
+
+std::string
+getOcspUrl(gnutls_x509_crt_t cert)
+{
+ int ret;
+ gnutls_datum_t aia;
+ unsigned int seq = 0;
+ do {
+ // Extracts the Authority Information Access (AIA) extension, see RFC 5280 section 4.2.2.1
+ ret = gnutls_x509_crt_get_authority_info_access(cert, seq++, GNUTLS_IA_OCSP_URI, &aia, NULL);
+ } while (ret < 0 && ret != GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE);
+ // could also try the issuer if we include ocsp uri into there
+ if (ret < 0) {
+ return {};
+ }
+ std::string url((const char*) aia.data, (size_t) aia.size);
+ gnutls_free(aia.data);
+ return url;
+}
+
+int
+TlsSession::TlsSessionImpl::verifyCertificateWrapper(gnutls_session_t session)
+{
+ // Perform user-set verification first to avoid flooding with ocsp-requests if peer is denied
+ int verified;
+ if (callbacks_.verifyCertificate) {
+ auto this_ = reinterpret_cast<TlsSessionImpl*>(gnutls_session_get_ptr(session));
+ verified = this_->callbacks_.verifyCertificate(session);
+ if (verified != GNUTLS_E_SUCCESS)
+ return verified;
+ } else {
+ verified = GNUTLS_E_SUCCESS;
+ }
+ /*
+ * Support only x509 format
+ */
+ if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509)
+ return GNUTLS_E_CERTIFICATE_ERROR;
+
+ pCert_ = peerCertificate(session);
+ if (!pCert_)
+ return GNUTLS_E_CERTIFICATE_ERROR;
+
+ std::string ocspUrl = getOcspUrl(pCert_->cert);
+ if (ocspUrl.empty()) {
+ // Skipping OCSP verification: AIA not found
+ return verified;
+ }
+
+ // OCSP (Online Certificate Service Protocol) {
+ std::promise<int> v;
+ std::future<int> f = v.get_future();
+
+ gnutls_x509_crt_t issuer_crt = pCert_->issuer ? pCert_->issuer->cert : nullptr;
+ verifyOcsp(ocspUrl, *pCert_, issuer_crt, [&](const int status) {
+ if (status == GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE) {
+ // OCSP URI is absent, don't fail the verification by overwritting the user-set one.
+ if (params_.logger)
+ params_.logger->w("Skipping OCSP verification %s: request failed", pCert_->getUID().c_str());
+ v.set_value(verified);
+ } else {
+ if (status != GNUTLS_E_SUCCESS) {
+ if (params_.logger)
+ params_.logger->e("OCSP verification failed for %s: %s (%i)",
+ pCert_->getUID().c_str(),
+ gnutls_strerror(status),
+ status);
+ }
+ v.set_value(status);
+ }
+ });
+ f.wait();
+
+ return f.get();
+}
+
+void
+TlsSession::TlsSessionImpl::verifyOcsp(const std::string& aia_uri,
+ dht::crypto::Certificate& cert,
+ gnutls_x509_crt_t issuer,
+ OcspVerification cb)
+{
+ if (params_.logger)
+ params_.logger->d("Certificate's AIA URI: %s", aia_uri.c_str());
+
+ // Generate OCSP request
+ std::pair<std::string, dht::Blob> ocsp_req;
+ try {
+ ocsp_req = cert.generateOcspRequest(issuer);
+ } catch (dht::crypto::CryptoException& e) {
+ if (params_.logger)
+ params_.logger->e("Failed to generate OCSP request: %s", e.what());
+ if (cb)
+ cb(GNUTLS_E_INVALID_REQUEST);
+ return;
+ }
+
+ sendOcspRequest(aia_uri,
+ std::move(ocsp_req.first),
+ OCSP_REQUEST_TIMEOUT,
+ [cb = std::move(cb), &cert, nonce = std::move(ocsp_req.second), this](
+ const dht::http::Response& r) {
+ // Prepare response data
+ // Verify response validity
+ if (r.status_code != 200) {
+ if (params_.logger)
+ params_.logger->w("HTTP OCSP Request Failed with code %i", r.status_code);
+ if (cb)
+ cb(GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE);
+ return;
+ }
+ if (params_.logger)
+ params_.logger->d("HTTP OCSP Request done!");
+ gnutls_ocsp_cert_status_t verify = GNUTLS_OCSP_CERT_UNKNOWN;
+ try {
+ cert.ocspResponse = std::make_shared<dht::crypto::OcspResponse>(
+ (const uint8_t*) r.body.data(), r.body.size());
+ if (params_.logger)
+ params_.logger->d("%s", cert.ocspResponse->toString().c_str());
+ verify = cert.ocspResponse->verifyDirect(cert, nonce);
+ } catch (dht::crypto::CryptoException& e) {
+ if (params_.logger)
+ params_.logger->e("Failed to verify OCSP response: %s", e.what());
+ }
+ if (verify == GNUTLS_OCSP_CERT_UNKNOWN) {
+ // Soft-fail
+ if (cb)
+ cb(GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE);
+ return;
+ }
+ int status = GNUTLS_E_SUCCESS;
+ if (verify == GNUTLS_OCSP_CERT_GOOD) {
+ if (params_.logger)
+ params_.logger->d("OCSP verification success!");
+ } else {
+ status = GNUTLS_E_CERTIFICATE_ERROR;
+ if (params_.logger)
+ params_.logger->e("OCSP verification: certificate is revoked!");
+ }
+ // Save response into the certificate store
+ try {
+ params_.certStore.pinOcspResponse(cert);
+ } catch (std::exception& e) {
+ if (params_.logger)
+ params_.logger->error("{}", e.what());
+ }
+ if (cb)
+ cb(status);
+ });
+}
+
+void
+TlsSession::TlsSessionImpl::sendOcspRequest(const std::string& uri,
+ std::string body,
+ std::chrono::seconds timeout,
+ HttpResponse cb)
+{
+ using namespace dht;
+ auto request = std::make_shared<http::Request>(*params_.io_context,
+ uri); //, logger);
+ request->set_method(restinio::http_method_post());
+ request->set_header_field(restinio::http_field_t::user_agent, "Jami");
+ request->set_header_field(restinio::http_field_t::accept, "*/*");
+ request->set_header_field(restinio::http_field_t::content_type, "application/ocsp-request");
+ request->set_body(std::move(body));
+ request->set_connection_type(restinio::http_connection_header_t::close);
+ request->timeout(timeout, [request,l=params_.logger](const asio::error_code& ec) {
+ if (ec and ec != asio::error::operation_aborted)
+ if (l) l->error("HTTP OCSP Request timeout with error: {:s}", ec.message());
+ request->cancel();
+ });
+ request->add_on_state_change_callback([this, cb = std::move(cb)](const http::Request::State state,
+ const http::Response response) {
+ if (params_.logger)
+ params_.logger->d("HTTP OCSP Request state=%i status_code=%i",
+ (unsigned int) state,
+ response.status_code);
+ if (state != http::Request::State::DONE)
+ return;
+ if (cb)
+ cb(response);
+ if (auto request = response.request.lock()) {
+ std::lock_guard<std::mutex> lock(requestsMtx_);
+ requests_.erase(request);
+ }
+ });
+ {
+ std::lock_guard<std::mutex> lock(requestsMtx_);
+ requests_.emplace(request);
+ }
+ request->send();
+}
+
+std::shared_ptr<dht::crypto::Certificate>
+TlsSession::TlsSessionImpl::peerCertificate(gnutls_session_t session) const
+{
+ if (!session)
+ return {};
+ /*
+ * Get the peer's raw certificate (chain) as sent by the peer.
+ * The first certificate in the list is the peer's certificate, following the issuer's cert. etc.
+ */
+ unsigned int cert_list_size = 0;
+ auto cert_list = gnutls_certificate_get_peers(session, &cert_list_size);
+
+ if (cert_list == nullptr)
+ return {};
+ std::vector<std::pair<uint8_t*, uint8_t*>> crt_data;
+ crt_data.reserve(cert_list_size);
+ for (unsigned i = 0; i < cert_list_size; i++)
+ crt_data.emplace_back(cert_list[i].data, cert_list[i].data + cert_list[i].size);
+ return std::make_shared<dht::crypto::Certificate>(crt_data);
+}
+
+std::size_t
+TlsSession::TlsSessionImpl::send(const ValueType* tx_data, std::size_t tx_size, std::error_code& ec)
+{
+ std::lock_guard<std::mutex> lk(sessionWriteMutex_);
+ if (state_ != TlsSessionState::ESTABLISHED) {
+ ec = std::error_code(GNUTLS_E_INVALID_SESSION, std::system_category());
+ return 0;
+ }
+
+ std::size_t total_written = 0;
+ std::size_t max_tx_sz;
+
+ if (transport_->isReliable())
+ max_tx_sz = tx_size;
+ else
+ max_tx_sz = gnutls_dtls_get_data_mtu(session_);
+
+ // Split incoming data into chunck suitable for the underlying transport
+ while (total_written < tx_size) {
+ auto chunck_sz = std::min(max_tx_sz, tx_size - total_written);
+ auto data_seq = tx_data + total_written;
+ ssize_t nwritten;
+ do {
+ nwritten = gnutls_record_send(session_, data_seq, chunck_sz);
+ } while ((nwritten == GNUTLS_E_INTERRUPTED and state_ != TlsSessionState::SHUTDOWN)
+ or nwritten == GNUTLS_E_AGAIN);
+ if (nwritten < 0) {
+ /* Normally we would have to retry record_send but our internal
+ * state has not changed, so we have to ask for more data first.
+ * We will just try again later, although this should never happen.
+ */
+ if (params_.logger)
+ params_.logger->error("[TLS] send failed (only {} bytes sent): {}", total_written, gnutls_strerror(nwritten));
+ ec = std::error_code(nwritten, std::system_category());
+ return 0;
+ }
+
+ total_written += nwritten;
+ }
+
+ ec.clear();
+ return total_written;
+}
+
+// Called by GNUTLS to send encrypted packet to low-level transport.
+// Should return a positive number indicating the bytes sent, and -1 on error.
+ssize_t
+TlsSession::TlsSessionImpl::sendRaw(const void* buf, size_t size)
+{
+ std::error_code ec;
+ unsigned retry_count = 0;
+ do {
+ auto n = transport_->write(reinterpret_cast<const ValueType*>(buf), size, ec);
+ if (!ec) {
+ // log only on success
+ ++stTxRawPacketCnt_;
+ stTxRawBytesCnt_ += n;
+ return n;
+ }
+
+ if (ec.value() == EAGAIN) {
+ if (params_.logger)
+ params_.logger->w("[TLS] EAGAIN from transport, retry#", ++retry_count);
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
+ if (retry_count == 100) {
+ if (params_.logger)
+ params_.logger->e("[TLS] excessive retry detected, aborting");
+ ec.assign(EIO, std::system_category());
+ }
+ }
+ } while (ec.value() == EAGAIN);
+
+ // Must be called to pass errno value to GnuTLS on Windows (cf. GnuTLS doc)
+ gnutls_transport_set_errno(session_, ec.value());
+ if (params_.logger)
+ params_.logger->e("[TLS] transport failure on tx: errno = {}", ec.value());
+ return -1;
+}
+
+// Called by GNUTLS to send encrypted packet to low-level transport.
+// Should return a positive number indicating the bytes sent, and -1 on error.
+ssize_t
+TlsSession::TlsSessionImpl::sendRawVec(const giovec_t* iov, int iovcnt)
+{
+ ssize_t sent = 0;
+ for (int i = 0; i < iovcnt; ++i) {
+ const giovec_t& dat = iov[i];
+ ssize_t ret = sendRaw(dat.iov_base, dat.iov_len);
+ if (ret < 0)
+ return -1;
+ sent += ret;
+ }
+ return sent;
+}
+
+// Called by GNUTLS to receive encrypted packet from low-level transport.
+// Should return 0 on connection termination,
+// a positive number indicating the number of bytes received,
+// and -1 on error.
+ssize_t
+TlsSession::TlsSessionImpl::recvRaw(void* buf, size_t size)
+{
+ if (transport_->isReliable()) {
+ std::error_code ec;
+ auto count = transport_->read(reinterpret_cast<ValueType*>(buf), size, ec);
+ if (!ec)
+ return count;
+ gnutls_transport_set_errno(session_, ec.value());
+ return -1;
+ }
+
+ std::lock_guard<std::mutex> lk {rxMutex_};
+ if (rxQueue_.empty()) {
+ gnutls_transport_set_errno(session_, EAGAIN);
+ return -1;
+ }
+
+ const auto& pkt = rxQueue_.front();
+ const std::size_t count = std::min(pkt.size(), size);
+ std::copy_n(pkt.begin(), count, reinterpret_cast<ValueType*>(buf));
+ rxQueue_.pop_front();
+ return count;
+}
+
+// Called by GNUTLS to wait for encrypted packet from low-level transport.
+// 'timeout' is in milliseconds.
+// Should return 0 on timeout, a positive number if data are available for read, or -1 on error.
+int
+TlsSession::TlsSessionImpl::waitForRawData(std::chrono::milliseconds timeout)
+{
+ if (transport_->isReliable()) {
+ std::error_code ec;
+ auto err = transport_->waitForData(timeout, ec);
+ if (err <= 0) {
+ // shutdown?
+ if (state_ == TlsSessionState::SHUTDOWN) {
+ gnutls_transport_set_errno(session_, EINTR);
+ return -1;
+ }
+ if (ec) {
+ gnutls_transport_set_errno(session_, ec.value());
+ return -1;
+ }
+ return 0;
+ }
+ return 1;
+ }
+
+ // non-reliable uses callback installed with setOnRecv()
+ std::unique_lock<std::mutex> lk {rxMutex_};
+ rxCv_.wait_for(lk, timeout, [this] {
+ return !rxQueue_.empty() or state_ == TlsSessionState::SHUTDOWN;
+ });
+ if (state_ == TlsSessionState::SHUTDOWN) {
+ gnutls_transport_set_errno(session_, EINTR);
+ return -1;
+ }
+ if (rxQueue_.empty()) {
+ if (params_.logger)
+ params_.logger->error("[TLS] waitForRawData: timeout after {}", timeout);
+ return 0;
+ }
+ return 1;
+}
+
+bool
+TlsSession::TlsSessionImpl::initFromRecordState(int offset)
+{
+ std::array<uint8_t, 8> seq;
+ if (gnutls_record_get_state(session_, 1, nullptr, nullptr, nullptr, &seq[0])
+ != GNUTLS_E_SUCCESS) {
+ if (params_.logger)
+ params_.logger->e("[TLS] Fatal-error Unable to read initial state");
+ return false;
+ }
+
+ baseSeq_ = array2uint(seq) + offset;
+ gapOffset_ = baseSeq_;
+ lastRxSeq_ = baseSeq_ - 1;
+ if (params_.logger)
+ params_.logger->debug("[TLS] Initial sequence number: {:d}", baseSeq_);
+ return true;
+}
+
+bool
+TlsSession::TlsSessionImpl::setup()
+{
+ // Setup FSM
+ fsmHandlers_[TlsSessionState::SETUP] = [this](TlsSessionState s) {
+ return handleStateSetup(s);
+ };
+ fsmHandlers_[TlsSessionState::COOKIE] = [this](TlsSessionState s) {
+ return handleStateCookie(s);
+ };
+ fsmHandlers_[TlsSessionState::HANDSHAKE] = [this](TlsSessionState s) {
+ return handleStateHandshake(s);
+ };
+ fsmHandlers_[TlsSessionState::MTU_DISCOVERY] = [this](TlsSessionState s) {
+ return handleStateMtuDiscovery(s);
+ };
+ fsmHandlers_[TlsSessionState::ESTABLISHED] = [this](TlsSessionState s) {
+ return handleStateEstablished(s);
+ };
+ fsmHandlers_[TlsSessionState::SHUTDOWN] = [this](TlsSessionState s) {
+ return handleStateShutdown(s);
+ };
+
+ return true;
+}
+
+void
+TlsSession::TlsSessionImpl::cleanup()
+{
+ state_ = TlsSessionState::SHUTDOWN; // be sure to block any user operations
+ stateCondition_.notify_all();
+
+ {
+ std::lock_guard<std::mutex> lk1(sessionReadMutex_);
+ std::lock_guard<std::mutex> lk2(sessionWriteMutex_);
+ if (session_) {
+ if (transport_->isReliable())
+ gnutls_bye(session_, GNUTLS_SHUT_RDWR);
+ else
+ gnutls_bye(session_, GNUTLS_SHUT_WR); // not wait for a peer answer
+ gnutls_deinit(session_);
+ session_ = nullptr;
+ }
+ }
+
+ if (cookie_key_.data)
+ gnutls_free(cookie_key_.data);
+
+ transport_->shutdown();
+}
+
+TlsSessionState
+TlsSession::TlsSessionImpl::handleStateSetup([[maybe_unused]] TlsSessionState state)
+{
+ if (params_.logger)
+ params_.logger->d("[TLS] Start %s session", typeName());
+
+ try {
+ if (anonymous_)
+ initAnonymous();
+ initCredentials();
+ } catch (const std::exception& e) {
+ if (params_.logger)
+ params_.logger->e("[TLS] authentifications init failed: %s", e.what());
+ return TlsSessionState::SHUTDOWN;
+ }
+
+ if (not isServer_)
+ return setupClient();
+
+ // Extra step for DTLS-like transports
+ if (transport_ and not transport_->isReliable()) {
+ gnutls_key_generate(&cookie_key_, GNUTLS_COOKIE_KEY_SIZE);
+ return TlsSessionState::COOKIE;
+ }
+ return setupServer();
+}
+
+TlsSessionState
+TlsSession::TlsSessionImpl::handleStateCookie(TlsSessionState state)
+{
+ if (params_.logger)
+ params_.logger->d("[TLS] SYN cookie");
+
+ std::size_t count;
+ {
+ // block until rx packet or shutdown
+ std::unique_lock<std::mutex> lk {rxMutex_};
+ if (!rxCv_.wait_for(lk, COOKIE_TIMEOUT, [this] {
+ return !rxQueue_.empty() or state_ == TlsSessionState::SHUTDOWN;
+ })) {
+ if (params_.logger)
+ params_.logger->e("[TLS] SYN cookie failed: timeout");
+ return TlsSessionState::SHUTDOWN;
+ }
+ // Shutdown state?
+ if (rxQueue_.empty())
+ return TlsSessionState::SHUTDOWN;
+ count = rxQueue_.front().size();
+ }
+
+ // Total bytes rx during cookie checking (see flood protection below)
+ cookie_count_ += count;
+
+ int ret;
+
+ // Peek and verify front packet
+ {
+ std::lock_guard<std::mutex> lk {rxMutex_};
+ auto& pkt = rxQueue_.front();
+ std::memset(&prestate_, 0, sizeof(prestate_));
+ ret = gnutls_dtls_cookie_verify(&cookie_key_, nullptr, 0, pkt.data(), pkt.size(), &prestate_);
+ }
+
+ if (ret < 0) {
+ gnutls_dtls_cookie_send(&cookie_key_,
+ nullptr,
+ 0,
+ &prestate_,
+ this,
+ [](gnutls_transport_ptr_t t, const void* d, size_t s) -> ssize_t {
+ auto this_ = reinterpret_cast<TlsSessionImpl*>(t);
+ return this_->sendRaw(d, s);
+ });
+
+ // Drop front packet
+ {
+ std::lock_guard<std::mutex> lk {rxMutex_};
+ rxQueue_.pop_front();
+ }
+
+ // Cookie may be sent on multiple network packets
+ // So we retry until we get a valid cookie.
+ // To protect against a flood attack we delay each retry after FLOOD_THRESHOLD rx bytes.
+ if (cookie_count_ >= FLOOD_THRESHOLD) {
+ if (params_.logger)
+ params_.logger->warn("[TLS] flood threshold reach (retry in {})", FLOOD_PAUSE);
+ dump_io_stats();
+ std::this_thread::sleep_for(FLOOD_PAUSE); // flood attack protection
+ }
+ return state;
+ }
+
+ if (params_.logger)
+ params_.logger->d("[TLS] cookie ok");
+
+ return setupServer();
+}
+
+TlsSessionState
+TlsSession::TlsSessionImpl::handleStateHandshake(TlsSessionState state)
+{
+ int ret;
+ size_t retry_count = 0;
+ if (params_.logger)
+ params_.logger->debug("[TLS] handshake");
+ do {
+ ret = gnutls_handshake(session_);
+ } while ((ret == GNUTLS_E_INTERRUPTED or ret == GNUTLS_E_AGAIN)
+ and ++retry_count < HANDSHAKE_MAX_RETRY
+ and state_.load() != TlsSessionState::SHUTDOWN);
+ if (retry_count > 0) {
+ if (params_.logger)
+ params_.logger->error("[TLS] handshake retried count: {}", retry_count);
+ }
+
+ // Stop on fatal error
+ if (gnutls_error_is_fatal(ret) || state_.load() == TlsSessionState::SHUTDOWN) {
+ if (params_.logger)
+ params_.logger->error("[TLS] handshake failed: {:s}", gnutls_strerror(ret));
+ return TlsSessionState::SHUTDOWN;
+ }
+
+ // Continue handshaking on non-fatal error
+ if (ret != GNUTLS_E_SUCCESS) {
+ // TODO: handle GNUTLS_E_LARGE_PACKET (MTU must be lowered)
+ if (ret != GNUTLS_E_AGAIN)
+ if (params_.logger)
+ params_.logger->debug("[TLS] non-fatal handshake error: {:s}", gnutls_strerror(ret));
+ return state;
+ }
+
+ // Safe-Renegotiation status shall always be true to prevent MiM attack
+ // Following https://www.gnutls.org/manual/html_node/Safe-renegotiation.html
+ // "Unlike TLS 1.2, the server is not allowed to change identities"
+ // So, we don't have to check the status if we are the client
+ bool isTLS1_3 = gnutls_protocol_get_version(session_) == GNUTLS_TLS1_3;
+ if (!isTLS1_3 || (isTLS1_3 && isServer_)) {
+ if (!gnutls_safe_renegotiation_status(session_)) {
+ if (params_.logger)
+ params_.logger->error("[TLS] server identity changed! MiM attack?");
+ return TlsSessionState::SHUTDOWN;
+ }
+ }
+
+ auto desc = gnutls_session_get_desc(session_);
+ if (params_.logger)
+ params_.logger->debug("[TLS] session established: {:s}", desc);
+ gnutls_free(desc);
+
+ // Anonymous connection? rehandshake immediately with certificate authentification forced
+ auto cred = gnutls_auth_get_type(session_);
+ if (cred == GNUTLS_CRD_ANON) {
+ if (params_.logger)
+ params_.logger->debug("[TLS] renogotiate with certificate authentification");
+
+ // Re-setup TLS algorithms priority list with only certificate based cipher suites
+ ret = gnutls_priority_set_direct(session_,
+ transport_ and transport_->isReliable()
+ ? TLS_CERT_PRIORITY_STRING
+ : DTLS_CERT_PRIORITY_STRING,
+ nullptr);
+ if (ret != GNUTLS_E_SUCCESS) {
+ if (params_.logger)
+ params_.logger->error("[TLS] session TLS cert-only priority set failed: {:s}", gnutls_strerror(ret));
+ return TlsSessionState::SHUTDOWN;
+ }
+
+ // remove anon credentials and re-enable certificate ones
+ gnutls_credentials_clear(session_);
+ ret = gnutls_credentials_set(session_, GNUTLS_CRD_CERTIFICATE, *xcred_);
+ if (ret != GNUTLS_E_SUCCESS) {
+ if (params_.logger)
+ params_.logger->error("[TLS] session credential set failed: {:s}", gnutls_strerror(ret));
+ return TlsSessionState::SHUTDOWN;
+ }
+
+ return state; // handshake
+
+ } else if (cred != GNUTLS_CRD_CERTIFICATE) {
+ if (params_.logger)
+ params_.logger->error("[TLS] spurious session credential ({})", cred);
+ return TlsSessionState::SHUTDOWN;
+ }
+
+ // Aware about certificates updates
+ if (callbacks_.onCertificatesUpdate) {
+ unsigned int remote_count;
+ auto local = gnutls_certificate_get_ours(session_);
+ auto remote = gnutls_certificate_get_peers(session_, &remote_count);
+ callbacks_.onCertificatesUpdate(local, remote, remote_count);
+ }
+
+ return transport_ and transport_->isReliable() ? TlsSessionState::ESTABLISHED
+ : TlsSessionState::MTU_DISCOVERY;
+}
+
+TlsSessionState
+TlsSession::TlsSessionImpl::handleStateMtuDiscovery([[maybe_unused]] TlsSessionState state)
+{
+ if (!transport_) {
+ if (params_.logger)
+ params_.logger->w("No transport available when discovering the MTU");
+ return TlsSessionState::SHUTDOWN;
+ }
+ mtuProbe_ = transport_->maxPayload();
+ assert(mtuProbe_ >= MIN_MTU);
+ MTUS_ = {MIN_MTU, std::max((mtuProbe_ + MIN_MTU) / 2, MIN_MTU), mtuProbe_};
+
+ // retrocompatibility check
+ if (gnutls_heartbeat_allowed(session_, GNUTLS_HB_LOCAL_ALLOWED_TO_SEND) == 1) {
+ if (!isServer_) {
+ pathMtuHeartbeat();
+ if (state_ == TlsSessionState::SHUTDOWN) {
+ if (params_.logger)
+ params_.logger->e("[TLS] session destroyed while performing PMTUD, shuting down");
+ return TlsSessionState::SHUTDOWN;
+ }
+ pmtudOver_ = true;
+ }
+ } else {
+ if (params_.logger)
+ params_.logger->e("[TLS] PEER HEARTBEAT DISABLED: using transport MTU value ", mtuProbe_);
+ pmtudOver_ = true;
+ }
+
+ gnutls_dtls_set_mtu(session_, mtuProbe_);
+ maxPayload_ = gnutls_dtls_get_data_mtu(session_);
+
+ if (pmtudOver_) {
+ if (params_.logger)
+ params_.logger->d("[TLS] maxPayload: ", maxPayload_.load());
+ if (!initFromRecordState())
+ return TlsSessionState::SHUTDOWN;
+ }
+
+ return TlsSessionState::ESTABLISHED;
+}
+
+/*
+ * Path MTU discovery heuristic
+ * heuristic description:
+ * The two members of the current tls connection will exchange dtls heartbeat messages
+ * of increasing size until the heartbeat times out which will be considered as a packet
+ * drop from the network due to the size of the packet. (one retry to test for a buffer issue)
+ * when timeout happens or all the values have been tested, the mtu will be returned.
+ * In case of unexpected error the first (and minimal) value of the mtu array
+ */
+void
+TlsSession::TlsSessionImpl::pathMtuHeartbeat()
+{
+ if (params_.logger)
+ params_.logger->debug("[TLS] PMTUD: starting probing with {} of retransmission timeout", HEARTBEAT_RETRANS_TIMEOUT);
+
+ gnutls_heartbeat_set_timeouts(session_,
+ HEARTBEAT_RETRANS_TIMEOUT.count(),
+ HEARTBEAT_TOTAL_TIMEOUT.count());
+
+ int errno_send = GNUTLS_E_SUCCESS;
+ int mtuOffset = 0;
+
+ // when the remote (server) has a IPV6 interface selected by ICE, and local (client) has a IPV4
+ // selected, the path MTU discovery triggers errors for packets too big on server side because
+ // of different IP headers overhead. Hence we have to signal to the TLS session to reduce the
+ // MTU on client size accordingly.
+ if (transport_ and transport_->localAddr().isIpv4() and transport_->remoteAddr().isIpv6()) {
+ mtuOffset = ASYMETRIC_TRANSPORT_MTU_OFFSET;
+ if (params_.logger)
+ params_.logger->w("[TLS] local/remote IP protocol version not alike, use an MTU offset of {} bytes to compensate", ASYMETRIC_TRANSPORT_MTU_OFFSET);
+ }
+
+ mtuProbe_ = MTUS_[0];
+
+ for (auto mtu : MTUS_) {
+ gnutls_dtls_set_mtu(session_, mtu);
+ auto data_mtu = gnutls_dtls_get_data_mtu(session_);
+ if (params_.logger)
+ params_.logger->debug("[TLS] PMTUD: mtu {}, payload {}", mtu, data_mtu);
+ auto bytesToSend = data_mtu - mtuOffset - 3; // want to know why -3? ask gnutls!
+
+ do {
+ errno_send = gnutls_heartbeat_ping(session_,
+ bytesToSend,
+ HEARTBEAT_TRIES,
+ GNUTLS_HEARTBEAT_WAIT);
+ } while (errno_send == GNUTLS_E_AGAIN
+ || (errno_send == GNUTLS_E_INTERRUPTED && state_ != TlsSessionState::SHUTDOWN));
+
+ if (errno_send != GNUTLS_E_SUCCESS) {
+ if (params_.logger)
+ params_.logger->debug("[TLS] PMTUD: mtu {} [FAILED]", mtu);
+ break;
+ }
+
+ mtuProbe_ = mtu;
+ if (params_.logger)
+ params_.logger->debug("[TLS] PMTUD: mtu {} [OK]", mtu);
+ }
+
+ if (errno_send == GNUTLS_E_TIMEDOUT) { // timeout is considered as a packet loss, then the good
+ // mtu is the precedent
+ if (mtuProbe_ == MTUS_[0]) {
+ if (params_.logger)
+ params_.logger->warn("[TLS] PMTUD: no response on first ping, using minimal MTU value {}", mtuProbe_);
+ } else {
+ if (params_.logger)
+ params_.logger->warn("[TLS] PMTUD: timed out, using last working mtu {}", mtuProbe_);
+ }
+ } else if (errno_send != GNUTLS_E_SUCCESS) {
+ if (params_.logger)
+ params_.logger->error("[TLS] PMTUD: failed with gnutls error '{}'", gnutls_strerror(errno_send));
+ } else {
+ if (params_.logger)
+ params_.logger->debug("[TLS] PMTUD: reached maximal value");
+ }
+}
+
+void
+TlsSession::TlsSessionImpl::handleDataPacket(std::vector<ValueType>&& buf, uint64_t pkt_seq)
+{
+ // Check for a valid seq. num. delta
+ int64_t seq_delta = pkt_seq - lastRxSeq_;
+ if (seq_delta > 0) {
+ lastRxSeq_ = pkt_seq;
+ } else {
+ // too old?
+ if (seq_delta <= -MISS_ORDERING_LIMIT) {
+ if (params_.logger)
+ params_.logger->warn("[TLS] drop old pkt: 0x{:x}", pkt_seq);
+ return;
+ }
+
+ // No duplicate check as DTLS prevents that for us (replay protection)
+
+ // accept Out-Of-Order pkt - will be reordered by queue flush operation
+ if (params_.logger)
+ params_.logger->warn("[TLS] OOO pkt: 0x{:x}", pkt_seq);
+ }
+
+ std::unique_lock<std::mutex> lk {rxMutex_};
+ auto now = clock::now();
+ if (reorderBuffer_.empty())
+ lastReadTime_ = now;
+ reorderBuffer_.emplace(pkt_seq, std::move(buf));
+ nextFlush_.emplace_back(now + RX_OOO_TIMEOUT);
+ rxCv_.notify_one();
+ // Try to flush right now as a new packet is available
+ flushRxQueue(lk);
+}
+
+///
+/// Reorder and push received packet to upper layer
+///
+/// \note This method must be called continuously, faster than RX_OOO_TIMEOUT
+///
+void
+TlsSession::TlsSessionImpl::flushRxQueue(std::unique_lock<std::mutex>& lk)
+{
+ // RAII bool swap
+ class GuardedBoolSwap
+ {
+ public:
+ explicit GuardedBoolSwap(bool& var)
+ : var_ {var}
+ {
+ var_ = !var_;
+ }
+ ~GuardedBoolSwap() { var_ = !var_; }
+
+ private:
+ bool& var_;
+ };
+
+ if (reorderBuffer_.empty())
+ return;
+
+ // Prevent re-entrant access as the callbacks_.onRxData() is called in unprotected region
+ if (flushProcessing_)
+ return;
+
+ GuardedBoolSwap swap_flush_processing {flushProcessing_};
+
+ auto now = clock::now();
+
+ auto item = std::begin(reorderBuffer_);
+ auto next_offset = item->first;
+
+ // Wait for next continuous packet until timeout
+ if ((now - lastReadTime_) >= RX_OOO_TIMEOUT) {
+ // OOO packet timeout - consider waited packets as lost
+ if (auto lost = next_offset - gapOffset_) {
+ if (params_.logger)
+ params_.logger->warn("[TLS] {:d} lost since 0x{:x}", lost, gapOffset_);
+ } else if (params_.logger)
+ params_.logger->warn("[TLS] slow flush");
+ } else if (next_offset != gapOffset_)
+ return;
+
+ // Loop on offset-ordered received packet until a discontinuity in sequence number
+ while (item != std::end(reorderBuffer_) and item->first <= next_offset) {
+ auto pkt_offset = item->first;
+ auto pkt = std::move(item->second);
+
+ // Remove item before unlocking to not trash the item' relationship
+ next_offset = pkt_offset + 1;
+ item = reorderBuffer_.erase(item);
+
+ if (callbacks_.onRxData) {
+ lk.unlock();
+ callbacks_.onRxData(std::move(pkt));
+ lk.lock();
+ }
+ }
+
+ gapOffset_ = std::max(gapOffset_, next_offset);
+ lastReadTime_ = now;
+}
+
+TlsSessionState
+TlsSession::TlsSessionImpl::handleStateEstablished(TlsSessionState state)
+{
+ // Nothing to do in reliable mode, so just wait for state change
+ if (transport_ and transport_->isReliable()) {
+ auto disconnected = [this]() -> bool {
+ return state_.load() != TlsSessionState::ESTABLISHED
+ or newState_.load() != TlsSessionState::NONE;
+ };
+ std::unique_lock<std::mutex> lk(stateMutex_);
+ stateCondition_.wait(lk, disconnected);
+ auto oldState = state_.load();
+ if (oldState == TlsSessionState::ESTABLISHED) {
+ auto newState = newState_.load();
+ if (newState != TlsSessionState::NONE) {
+ newState_ = TlsSessionState::NONE;
+ return newState;
+ }
+ }
+ return oldState;
+ }
+
+ // block until rx packet or state change
+ {
+ std::unique_lock<std::mutex> lk {rxMutex_};
+ if (nextFlush_.empty())
+ rxCv_.wait(lk, [this] {
+ return state_ != TlsSessionState::ESTABLISHED or not rxQueue_.empty()
+ or not nextFlush_.empty();
+ });
+ else
+ rxCv_.wait_until(lk, nextFlush_.front(), [this] {
+ return state_ != TlsSessionState::ESTABLISHED or !rxQueue_.empty();
+ });
+ state = state_.load();
+ if (state != TlsSessionState::ESTABLISHED)
+ return state;
+
+ if (not nextFlush_.empty()) {
+ auto now = clock::now();
+ if (nextFlush_.front() <= now) {
+ while (nextFlush_.front() <= now)
+ nextFlush_.pop_front();
+ flushRxQueue(lk);
+ return state;
+ }
+ }
+ }
+
+ std::array<uint8_t, 8> seq;
+ rawPktBuf_.resize(RX_MAX_SIZE);
+ auto ret = gnutls_record_recv_seq(session_, rawPktBuf_.data(), rawPktBuf_.size(), &seq[0]);
+
+ if (ret > 0) {
+ // Are we in PMTUD phase?
+ if (!pmtudOver_) {
+ mtuProbe_ = MTUS_[std::max(0, hbPingRecved_ - 1)];
+ gnutls_dtls_set_mtu(session_, mtuProbe_);
+ maxPayload_ = gnutls_dtls_get_data_mtu(session_);
+ pmtudOver_ = true;
+ if (params_.logger)
+ params_.logger->debug("[TLS] maxPayload: {}", maxPayload_.load());
+
+ if (!initFromRecordState(-1))
+ return TlsSessionState::SHUTDOWN;
+ }
+
+ rawPktBuf_.resize(ret);
+ handleDataPacket(std::move(rawPktBuf_), array2uint(seq));
+ // no state change
+ } else if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) {
+ if (params_.logger)
+ params_.logger->d("[TLS] PMTUD: ping received sending pong");
+ auto errno_send = gnutls_heartbeat_pong(session_, 0);
+
+ if (errno_send != GNUTLS_E_SUCCESS) {
+ if (params_.logger)
+ params_.logger->e("[TLS] PMTUD: failed on pong with error %d: %s",
+ errno_send,
+ gnutls_strerror(errno_send));
+ } else {
+ ++hbPingRecved_;
+ }
+ // no state change
+ } else if (ret == 0) {
+ if (params_.logger)
+ params_.logger->d("[TLS] eof");
+ state = TlsSessionState::SHUTDOWN;
+ } else if (ret == GNUTLS_E_REHANDSHAKE) {
+ if (params_.logger)
+ params_.logger->d("[TLS] re-handshake");
+ state = TlsSessionState::HANDSHAKE;
+ } else if (gnutls_error_is_fatal(ret)) {
+ if (params_.logger)
+ params_.logger->e("[TLS] fatal error in recv: %s", gnutls_strerror(ret));
+ state = TlsSessionState::SHUTDOWN;
+ } // else non-fatal error... let's continue
+
+ return state;
+}
+
+TlsSessionState
+TlsSession::TlsSessionImpl::handleStateShutdown(TlsSessionState state)
+{
+ if (params_.logger)
+ params_.logger->d("[TLS] shutdown");
+
+ // Stop ourself
+ thread_.stop();
+ return state;
+}
+
+void
+TlsSession::TlsSessionImpl::process()
+{
+ auto old_state = state_.load();
+ auto new_state = fsmHandlers_[old_state](old_state);
+
+ // update state_ with taking care for external state change
+ if (not std::atomic_compare_exchange_strong(&state_, &old_state, new_state))
+ new_state = old_state;
+
+ if (old_state != new_state)
+ stateCondition_.notify_all();
+
+ if (old_state != new_state and callbacks_.onStateChange)
+ callbacks_.onStateChange(new_state);
+}
+
+//==============================================================================
+
+TlsSession::TlsSession(std::unique_ptr<SocketType>&& transport,
+ const TlsParams& params,
+ const TlsSessionCallbacks& cbs,
+ bool anonymous)
+
+ : pimpl_ {std::make_unique<TlsSessionImpl>(std::move(transport), params, cbs, anonymous)}
+{}
+
+TlsSession::~TlsSession() {}
+
+bool
+TlsSession::isInitiator() const
+{
+ return !pimpl_->isServer_;
+}
+
+bool
+TlsSession::isReliable() const
+{
+ if (!pimpl_->transport_)
+ return false;
+ return pimpl_->transport_->isReliable();
+}
+
+int
+TlsSession::maxPayload() const
+{
+ if (pimpl_->state_ == TlsSessionState::SHUTDOWN)
+ throw std::runtime_error("Getting maxPayload from non-valid TLS session");
+ if (!pimpl_->transport_)
+ return 0;
+ return pimpl_->transport_->maxPayload();
+}
+
+// Called by anyone to stop the connection and the FSM thread
+void
+TlsSession::shutdown()
+{
+ pimpl_->newState_ = TlsSessionState::SHUTDOWN;
+ pimpl_->stateCondition_.notify_all();
+ pimpl_->rxCv_.notify_one(); // unblock waiting FSM
+}
+
+std::size_t
+TlsSession::write(const ValueType* data, std::size_t size, std::error_code& ec)
+{
+ return pimpl_->send(data, size, ec);
+}
+
+std::size_t
+TlsSession::read(ValueType* data, std::size_t size, std::error_code& ec)
+{
+ std::errc error;
+
+ if (pimpl_->state_ != TlsSessionState::ESTABLISHED) {
+ ec = std::make_error_code(std::errc::broken_pipe);
+ return 0;
+ }
+
+ while (true) {
+ ssize_t ret;
+ {
+ std::lock_guard<std::mutex> lk(pimpl_->sessionReadMutex_);
+ if (!pimpl_->session_)
+ return 0;
+ ret = gnutls_record_recv(pimpl_->session_, data, size);
+ }
+ if (ret > 0) {
+ ec.clear();
+ return ret;
+ }
+
+ std::lock_guard<std::mutex> lk(pimpl_->stateMutex_);
+ if (ret == 0) {
+ if (pimpl_) {
+ if (pimpl_->params_.logger)
+ pimpl_->params_.logger->d("[TLS] eof");
+ pimpl_->newState_ = TlsSessionState::SHUTDOWN;
+ pimpl_->stateCondition_.notify_all();
+ pimpl_->rxCv_.notify_one(); // unblock waiting FSM
+ }
+ error = std::errc::broken_pipe;
+ break;
+ } else if (ret == GNUTLS_E_REHANDSHAKE) {
+ if (pimpl_->params_.logger)
+ pimpl_->params_.logger->d("[TLS] re-handshake");
+ pimpl_->newState_ = TlsSessionState::HANDSHAKE;
+ pimpl_->rxCv_.notify_one(); // unblock waiting FSM
+ pimpl_->stateCondition_.notify_all();
+ } else if (gnutls_error_is_fatal(ret)) {
+ if (pimpl_ && pimpl_->state_ != TlsSessionState::SHUTDOWN) {
+ if (pimpl_->params_.logger)
+ pimpl_->params_.logger->e("[TLS] fatal error in recv: %s", gnutls_strerror(ret));
+ pimpl_->newState_ = TlsSessionState::SHUTDOWN;
+ pimpl_->stateCondition_.notify_all();
+ pimpl_->rxCv_.notify_one(); // unblock waiting FSM
+ }
+ error = std::errc::io_error;
+ break;
+ }
+ }
+
+ ec = std::make_error_code(error);
+ return 0;
+}
+
+void
+TlsSession::waitForReady(const duration& timeout)
+{
+ auto ready = [this]() -> bool {
+ auto state = pimpl_->state_.load();
+ return state == TlsSessionState::ESTABLISHED or state == TlsSessionState::SHUTDOWN;
+ };
+ std::unique_lock<std::mutex> lk(pimpl_->stateMutex_);
+ if (timeout == duration::zero())
+ pimpl_->stateCondition_.wait(lk, ready);
+ else
+ pimpl_->stateCondition_.wait_for(lk, timeout, ready);
+
+ if (!ready())
+ throw std::logic_error("Invalid state in TlsSession::waitForReady: "
+ + std::to_string((int) pimpl_->state_.load()));
+}
+
+int
+TlsSession::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
+{
+ if (!pimpl_->transport_) {
+ ec = std::make_error_code(std::errc::broken_pipe);
+ return -1;
+ }
+ if (!pimpl_->transport_->waitForData(timeout, ec))
+ return 0;
+ return 1;
+}
+
+std::shared_ptr<dht::crypto::Certificate>
+TlsSession::peerCertificate() const
+{
+ return pimpl_->pCert_;
+}
+
+const std::shared_ptr<dht::log::Logger>&
+TlsSession::logger() const
+{
+ return pimpl_->params_.logger;
+}
+
+} // namespace tls
+} // namespace jami
diff --git a/src/sip_utils.h b/src/sip_utils.h
new file mode 100644
index 0000000..6460b70
--- /dev/null
+++ b/src/sip_utils.h
@@ -0,0 +1,173 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Tristan Matthews <tristan.matthews@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+
+#include "ip_utils.h"
+
+#include <utility>
+#include <string>
+#include <vector>
+#include <cstring> // strcmp
+
+#include <pjsip/sip_msg.h>
+#include <pjlib.h>
+#include <pj/pool.h>
+#include <pjsip/sip_endpoint.h>
+#include <pjsip/sip_dialog.h>
+
+namespace jami {
+namespace sip_utils {
+
+using namespace std::literals;
+
+// SIP methods. Only list methods that need to be explicitly
+// handled
+
+namespace SIP_METHODS {
+constexpr std::string_view MESSAGE = "MESSAGE"sv;
+constexpr std::string_view INFO = "INFO"sv;
+constexpr std::string_view OPTIONS = "OPTIONS"sv;
+constexpr std::string_view PUBLISH = "PUBLISH"sv;
+constexpr std::string_view REFER = "REFER"sv;
+constexpr std::string_view NOTIFY = "NOTIFY"sv;
+} // namespace SIP_METHODS
+
+static constexpr int DEFAULT_SIP_PORT {5060};
+static constexpr int DEFAULT_SIP_TLS_PORT {5061};
+static constexpr int DEFAULT_AUTO_SELECT_PORT {0};
+
+/// PjsipErrorCategory - a PJSIP error category for std::error_code
+class PjsipErrorCategory final : public std::error_category
+{
+public:
+ const char* name() const noexcept override { return "pjsip"; }
+ std::string message(int condition) const override;
+};
+
+/// PJSIP related exception
+/// Based on std::system_error with code() returning std::error_code with PjsipErrorCategory category
+class PjsipFailure : public std::system_error
+{
+private:
+ static constexpr const char* what_ = "PJSIP call failed";
+
+public:
+ PjsipFailure()
+ : std::system_error(std::error_code(PJ_EUNKNOWN, PjsipErrorCategory()), what_)
+ {}
+
+ explicit PjsipFailure(pj_status_t status)
+ : std::system_error(std::error_code(status, PjsipErrorCategory()), what_)
+ {}
+};
+
+
+/**
+ * Helper function to parser header from incoming sip messages
+ * @return Header from SIP message
+ */
+/*std::string fetchHeaderValue(pjsip_msg* msg, const std::string& field);
+
+pjsip_route_hdr* createRouteSet(const std::string& route, pj_pool_t* hdr_pool);
+
+std::string_view stripSipUriPrefix(std::string_view sipUri);
+
+std::string parseDisplayName(const pjsip_name_addr* sip_name_addr);
+std::string parseDisplayName(const pjsip_from_hdr* header);
+std::string parseDisplayName(const pjsip_contact_hdr* header);
+
+std::string_view getHostFromUri(std::string_view sipUri);
+
+void addContactHeader(const std::string& contact, pjsip_tx_data* tdata);
+void addUserAgentHeader(const std::string& userAgent, pjsip_tx_data* tdata);
+std::string_view getPeerUserAgent(const pjsip_rx_data* rdata);
+std::vector<std::string> getPeerAllowMethods(const pjsip_rx_data* rdata);
+void logMessageHeaders(const pjsip_hdr* hdr_list);*/
+
+std::string_view sip_strerror(pj_status_t code);
+
+// Helper function that return a constant pj_str_t from an array of any types
+// that may be statically casted into char pointer.
+// Per convention, the input array is supposed to be null terminated.
+template<typename T, std::size_t N>
+constexpr const pj_str_t
+CONST_PJ_STR(T (&a)[N]) noexcept
+{
+ return {const_cast<char*>(a), N - 1};
+}
+
+inline const pj_str_t
+CONST_PJ_STR(const std::string& str) noexcept
+{
+ return {const_cast<char*>(str.c_str()), (pj_ssize_t) str.size()};
+}
+
+inline constexpr pj_str_t
+CONST_PJ_STR(const std::string_view& str) noexcept
+{
+ return {const_cast<char*>(str.data()), (pj_ssize_t) str.size()};
+}
+
+inline constexpr std::string_view
+as_view(const pj_str_t& str) noexcept
+{
+ return {str.ptr, (size_t) str.slen};
+}
+
+// PJSIP dialog locking in RAII way
+// Usage: declare local variable like this: sip_utils::PJDialogLock lock {dialog};
+// The lock is kept until the local variable is deleted
+class PJDialogLock
+{
+public:
+ explicit PJDialogLock(pjsip_dialog* dialog)
+ : dialog_(dialog)
+ {
+ pjsip_dlg_inc_lock(dialog_);
+ }
+
+ ~PJDialogLock() { pjsip_dlg_dec_lock(dialog_); }
+
+private:
+ PJDialogLock(const PJDialogLock&) = delete;
+ PJDialogLock& operator=(const PJDialogLock&) = delete;
+ pjsip_dialog* dialog_ {nullptr};
+};
+
+// Helper on PJSIP memory pool allocation from endpoint
+// This encapsulate the allocated memory pool inside a unique_ptr
+static inline std::unique_ptr<pj_pool_t, decltype(pj_pool_release)&>
+smart_alloc_pool(pjsip_endpoint* endpt, const char* const name, pj_size_t initial, pj_size_t inc)
+{
+ auto pool = pjsip_endpt_create_pool(endpt, name, initial, inc);
+ if (not pool)
+ throw std::bad_alloc();
+ return std::unique_ptr<pj_pool_t, decltype(pj_pool_release)&>(pool, pj_pool_release);
+}
+
+void sockaddr_to_host_port(pj_pool_t* pool, pjsip_host_port* host_port, const pj_sockaddr* addr);
+
+static constexpr int POOL_TP_INIT {512};
+static constexpr int POOL_TP_INC {512};
+static constexpr int TRANSPORT_INFO_LENGTH {64};
+
+} // namespace sip_utils
+} // namespace jami
diff --git a/src/string_utils.cpp b/src/string_utils.cpp
new file mode 100644
index 0000000..934ff23
--- /dev/null
+++ b/src/string_utils.cpp
@@ -0,0 +1,167 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Tristan Matthews <tristan.matthews@savoirfairelinux.com>
+ * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "string_utils.h"
+
+#include <fmt/core.h>
+#include <fmt/ranges.h>
+
+#include <sstream>
+#include <cctype>
+#include <algorithm>
+#include <ostream>
+#include <iomanip>
+#include <stdexcept>
+#include <ios>
+#include <charconv>
+#include <string_view>
+#ifdef _WIN32
+#include <windows.h>
+#include <oleauto.h>
+#endif
+
+#include <ciso646> // fix windows compiler bug
+
+namespace jami {
+
+#ifdef _WIN32
+std::wstring
+to_wstring(const std::string& str, int codePage)
+{
+ int srcLength = (int) str.length();
+ int requiredSize = MultiByteToWideChar(codePage, 0, str.c_str(), srcLength, nullptr, 0);
+ if (!requiredSize) {
+ throw std::runtime_error("Can't convert string to wstring");
+ }
+ std::wstring result((size_t) requiredSize, 0);
+ if (!MultiByteToWideChar(codePage, 0, str.c_str(), srcLength, &(*result.begin()), requiredSize)) {
+ throw std::runtime_error("Can't convert string to wstring");
+ }
+ return result;
+}
+
+std::string
+to_string(const std::wstring& wstr, int codePage)
+{
+ int srcLength = (int) wstr.length();
+ int requiredSize = WideCharToMultiByte(codePage, 0, wstr.c_str(), srcLength, nullptr, 0, 0, 0);
+ if (!requiredSize) {
+ throw std::runtime_error("Can't convert wstring to string");
+ }
+ std::string result((size_t) requiredSize, 0);
+ if (!WideCharToMultiByte(
+ codePage, 0, wstr.c_str(), srcLength, &(*result.begin()), requiredSize, 0, 0)) {
+ throw std::runtime_error("Can't convert wstring to string");
+ }
+ return result;
+}
+#endif
+
+std::string
+to_string(double value)
+{
+ char buf[64];
+ int len = snprintf(buf, sizeof(buf), "%-.*G", 16, value);
+ if (len <= 0)
+ throw std::invalid_argument {"can't parse double"};
+ return {buf, (size_t) len};
+}
+
+std::string
+to_hex_string(uint64_t id)
+{
+ return fmt::format("{:016x}", id);
+}
+
+uint64_t
+from_hex_string(const std::string& str)
+{
+ uint64_t id;
+ if (auto [p, ec] = std::from_chars(str.data(), str.data()+str.size(), id, 16); ec != std::errc()) {
+ throw std::invalid_argument("Can't parse id: " + str);
+ }
+ return id;
+}
+
+std::string_view
+trim(std::string_view s)
+{
+ auto wsfront = std::find_if_not(s.cbegin(), s.cend(), [](int c) { return std::isspace(c); });
+ return std::string_view(&*wsfront, std::find_if_not(s.rbegin(),
+ std::string_view::const_reverse_iterator(wsfront),
+ [](int c) { return std::isspace(c); })
+ .base() - wsfront);
+}
+
+std::vector<unsigned>
+split_string_to_unsigned(std::string_view str, char delim)
+{
+ std::vector<unsigned> output;
+ for (auto first = str.data(), second = str.data(), last = first + str.size(); second != last && first != last; first = second + 1) {
+ second = std::find(first, last, delim);
+ if (first != second) {
+ unsigned result;
+ auto [p, ec] = std::from_chars(first, second, result);
+ if (ec == std::errc())
+ output.emplace_back(result);
+ }
+ }
+ return output;
+}
+
+void
+string_replace(std::string& str, const std::string& from, const std::string& to)
+{
+ size_t start_pos = 0;
+ while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
+ str.replace(start_pos, from.length(), to);
+ start_pos += to.length(); // Handles case where 'to' is a substring of 'from'
+ }
+}
+
+std::string_view
+string_remove_suffix(std::string_view str, char separator)
+{
+ auto it = str.find(separator);
+ if (it != std::string_view::npos)
+ str = str.substr(0, it);
+ return str;
+}
+
+std::string
+string_join(const std::set<std::string>& set, std::string_view separator)
+{
+ return fmt::format("{}", fmt::join(set, separator));
+}
+
+std::set<std::string>
+string_split_set(std::string& str, std::string_view separator)
+{
+ std::set<std::string> output;
+ for (auto first = str.data(), second = str.data(), last = first + str.size(); second != last && first != last; first = second + 1) {
+ second = std::find_first_of(first, last, std::cbegin(separator), std::cend(separator));
+ if (first != second)
+ output.emplace(first, second - first);
+ }
+ return output;
+}
+
+} // namespace jami
diff --git a/src/tracepoint/trace-tools.h b/src/tracepoint/trace-tools.h
new file mode 100644
index 0000000..ccd65cd
--- /dev/null
+++ b/src/tracepoint/trace-tools.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2022-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Olivier Dion <olivier.dion@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+
+#ifdef ENABLE_TRACEPOINTS
+/*
+ * GCC Only. We use these instead of classic __FILE__ and __LINE__ because
+ * these are evaluated where invoked and not at expansion time. See GCC manual.
+ */
+# define CURRENT_FILENAME() __builtin_FILE()
+# define CURRENT_LINE() __builtin_LINE()
+#else
+# define CURRENT_FILENAME() ""
+# define CURRENT_LINE() 0
+#endif
+
+#ifdef HAVE_CXXABI_H
+#include <cxxabi.h>
+#include <string>
+
+template<typename T>
+std::string demangle()
+{
+ int err;
+ char *raw;
+ std::string ret;
+
+ raw = abi::__cxa_demangle(typeid(T).name(), 0, 0, &err);
+
+ if (0 == err) {
+ ret = raw;
+ } else {
+ ret = typeid(T).name();
+ }
+
+ std::free(raw);
+
+ return ret;
+}
+
+#else
+template<typename T>
+std::string demangle()
+{
+ return typeid(T).name();
+}
+#endif
diff --git a/src/tracepoint/tracepoint-def.h b/src/tracepoint/tracepoint-def.h
new file mode 100644
index 0000000..ed584c2
--- /dev/null
+++ b/src/tracepoint/tracepoint-def.h
@@ -0,0 +1,237 @@
+#ifdef ENABLE_TRACEPOINTS
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#undef LTTNG_UST_TRACEPOINT_PROVIDER
+#define LTTNG_UST_TRACEPOINT_PROVIDER jami
+
+#undef LTTNG_UST_TRACEPOINT_INCLUDE
+#define LTTNG_UST_TRACEPOINT_INCLUDE "src/jami/tracepoint-def.h"
+
+#if !defined(TRACEPOINT_DEF_H) || defined(LTTNG_UST_TRACEPOINT_HEADER_MULTI_READ)
+#define TRACEPOINT_DEF_H
+
+#include <lttng/tracepoint.h>
+
+
+/*
+ * Use LTTNG_UST_TRACEPOINT_EVENT(), LTTNG_UST_TRACEPOINT_EVENT_CLASS(),
+ * LTTNG_UST_TRACEPOINT_EVENT_INSTANCE(), and LTTNG_UST_TRACEPOINT_LOGLEVEL()
+ * here.
+ */
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ scheduled_executor_task_begin,
+ LTTNG_UST_TP_ARGS(
+ const char *, executor_name,
+ const char *, filename,
+ uint32_t, linum,
+ uint64_t, cookie
+ ),
+ LTTNG_UST_TP_FIELDS(
+ lttng_ust_field_string(executor, executor_name)
+ lttng_ust_field_string(source_filename, filename)
+ lttng_ust_field_integer(uint32_t, source_line, linum)
+ lttng_ust_field_integer(uint64_t, cookie, cookie)
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ scheduled_executor_task_end,
+ LTTNG_UST_TP_ARGS(uint64_t, cookie),
+ LTTNG_UST_TP_FIELDS(lttng_ust_field_integer(uint64_t, cookie, cookie))
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ ice_transport_context,
+ LTTNG_UST_TP_ARGS(
+ uint64_t, context
+ ),
+ LTTNG_UST_TP_FIELDS(
+ lttng_ust_field_integer(uint64_t, ice_context, context)
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ ice_transport_send,
+ LTTNG_UST_TP_ARGS(
+ uint64_t, context,
+ unsigned, component,
+ size_t, len,
+ const char*, remote_addr
+ ),
+ LTTNG_UST_TP_FIELDS(
+ lttng_ust_field_integer(uint64_t, ice_context, context)
+ lttng_ust_field_integer(unsigned, component, component)
+ lttng_ust_field_integer(size_t, packet_length, len)
+ lttng_ust_field_string(remote_addr, remote_addr)
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ ice_transport_send_status,
+ LTTNG_UST_TP_ARGS(
+ int, status
+ ),
+ LTTNG_UST_TP_FIELDS(
+ lttng_ust_field_integer(int, pj_status, status)
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ ice_transport_recv,
+ LTTNG_UST_TP_ARGS(
+ uint64_t, context,
+ unsigned, component,
+ size_t, len,
+ const char*, remote_addr
+ ),
+ LTTNG_UST_TP_FIELDS(
+ lttng_ust_field_integer(uint64_t, ice_context, context)
+ lttng_ust_field_integer(unsigned, component, component)
+ lttng_ust_field_integer(size_t, packet_length, len)
+ lttng_ust_field_string(remote_addr, remote_addr)
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ emit_signal,
+ LTTNG_UST_TP_ARGS(
+ const char*, signal_type
+ ),
+ LTTNG_UST_TP_FIELDS(
+ lttng_ust_field_string(signal_type, signal_type)
+
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ emit_signal_end,
+ LTTNG_UST_TP_ARGS(
+ ),
+ LTTNG_UST_TP_FIELDS(
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ emit_signal_begin_callback,
+ LTTNG_UST_TP_ARGS(
+ const char*, filename,
+ uint32_t, linum
+ ),
+ LTTNG_UST_TP_FIELDS(
+ lttng_ust_field_string(source_filename, filename)
+ lttng_ust_field_integer(uint32_t, source_line, linum)
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ emit_signal_end_callback,
+ LTTNG_UST_TP_ARGS(
+ ),
+ LTTNG_UST_TP_FIELDS(
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ audio_input_read_from_device_end,
+ LTTNG_UST_TP_ARGS(
+ const char*, id
+ ),
+ LTTNG_UST_TP_FIELDS(
+ lttng_ust_field_integer(uint64_t, id, strtoull(id, NULL, 16))
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ audio_layer_put_recorded_end,
+ LTTNG_UST_TP_ARGS(
+ ),
+ LTTNG_UST_TP_FIELDS(
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ audio_layer_get_to_play_end,
+ LTTNG_UST_TP_ARGS(
+ ),
+ LTTNG_UST_TP_FIELDS(
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ call_start,
+ LTTNG_UST_TP_ARGS(
+ const char*, id
+ ),
+ LTTNG_UST_TP_FIELDS(
+ lttng_ust_field_integer(uint64_t, id, strtoull(id, NULL, 16))
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ call_end,
+ LTTNG_UST_TP_ARGS(
+ const char*, id
+ ),
+ LTTNG_UST_TP_FIELDS(
+ lttng_ust_field_integer(uint64_t, id, strtoull(id, NULL, 16))
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ conference_begin,
+ LTTNG_UST_TP_ARGS(
+ const char*, id
+ ),
+ LTTNG_UST_TP_FIELDS(
+ lttng_ust_field_integer(uint64_t, id, strtoull(id, NULL, 16))
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ conference_end,
+ LTTNG_UST_TP_ARGS(
+ const char*, id
+ ),
+ LTTNG_UST_TP_FIELDS(
+ lttng_ust_field_integer(uint64_t, id, strtoull(id, NULL, 16))
+ )
+)
+
+LTTNG_UST_TRACEPOINT_EVENT(
+ jami,
+ conference_add_participant,
+ LTTNG_UST_TP_ARGS(
+ const char*, conference_id,
+ const char*, participant_id
+ ),
+ LTTNG_UST_TP_FIELDS(
+ lttng_ust_field_integer(uint64_t, id, strtoull(conference_id, NULL, 16))
+ lttng_ust_field_integer(uint64_t, participant_id, strtoull(participant_id, NULL, 16))
+ )
+)
+
+#endif /* TRACEPOINT_DEF_H */
+
+#include <lttng/tracepoint-event.h>
+
+#endif
diff --git a/src/tracepoint/tracepoint.c b/src/tracepoint/tracepoint.c
new file mode 100644
index 0000000..392fb0e
--- /dev/null
+++ b/src/tracepoint/tracepoint.c
@@ -0,0 +1,3 @@
+#define LTTNG_UST_TRACEPOINT_CREATE_PROBES
+#define LTTNG_UST_TRACEPOINT_DEFINE
+#include "./tracepoint.h"
diff --git a/src/tracepoint/tracepoint.h b/src/tracepoint/tracepoint.h
new file mode 100644
index 0000000..1e7f9a3
--- /dev/null
+++ b/src/tracepoint/tracepoint.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright (C) 2022-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Olivier Dion <olivier.dion@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
+#include "tracepoint-def.h"
+#pragma GCC diagnostic pop
+
+#ifdef ENABLE_TRACEPOINTS
+
+# ifndef lttng_ust_tracepoint
+# define lttng_ust_tracepoint(...) tracepoint(__VA_ARGS__)
+# endif
+
+# ifndef lttng_ust_do_tracepoint
+# define lttng_ust_do_tracepoint(...) do_tracepoint(__VA_ARGS__)
+# endif
+
+# ifndef lttng_ust_tracepoint_enabled
+# define lttng_ust_tracepoint_enabled(...) tracepoint_enabled(__VA_ARGS__)
+# endif
+
+# define jami_tracepoint(tp_name, ...) \
+ lttng_ust_tracepoint(jami, tp_name __VA_OPT__(,) __VA_ARGS__)
+
+# define jami_tracepoint_if_enabled(tp_name, ...) \
+ do { \
+ if (lttng_ust_tracepoint_enabled(jami, tp_name)) { \
+ lttng_ust_do_tracepoint(jami, \
+ tp_name \
+ __VA_OPT__(,) \
+ __VA_ARGS__); \
+ } \
+ } \
+ while (0)
+
+#else
+
+# define jami_tracepoint(...) static_assert(true)
+# define jami_tracepoint_if_enabled(...) static_assert(true)
+
+#endif
diff --git a/src/transport/peer_channel.h b/src/transport/peer_channel.h
new file mode 100644
index 0000000..5f25123
--- /dev/null
+++ b/src/transport/peer_channel.h
@@ -0,0 +1,109 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ * Authors: Adrien Béraud <adrien.beraud@savoirfairelinux.com>
+ * Guillaume Roguez <guillaume.roguez@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+#pragma once
+
+#include <mutex>
+#include <condition_variable>
+#include <deque>
+#include <algorithm>
+
+namespace jami {
+
+class PeerChannel
+{
+public:
+ PeerChannel() {}
+ ~PeerChannel() { stop(); }
+ PeerChannel(PeerChannel&& o)
+ {
+ std::lock_guard<std::mutex> lk(o.mutex_);
+ stream_ = std::move(o.stream_);
+ stop_ = o.stop_;
+ o.cv_.notify_all();
+ }
+
+ template<typename Duration>
+ ssize_t wait(Duration timeout, std::error_code& ec)
+ {
+ std::unique_lock<std::mutex> lk {mutex_};
+ cv_.wait_for(lk, timeout, [this] { return stop_ or not stream_.empty(); });
+ if (stop_) {
+ ec = std::make_error_code(std::errc::broken_pipe);
+ return -1;
+ }
+ ec.clear();
+ return stream_.size();
+ }
+
+ ssize_t read(char* output, std::size_t size, std::error_code& ec)
+ {
+ std::unique_lock<std::mutex> lk {mutex_};
+ cv_.wait(lk, [this] { return stop_ or not stream_.empty(); });
+ if (stream_.size()) {
+ auto toRead = std::min(size, stream_.size());
+ if (toRead) {
+ auto endIt = stream_.begin() + toRead;
+ std::copy(stream_.begin(), endIt, output);
+ stream_.erase(stream_.begin(), endIt);
+ }
+ ec.clear();
+ return toRead;
+ }
+ if (stop_) {
+ ec.clear();
+ return 0;
+ }
+ ec = std::make_error_code(std::errc::resource_unavailable_try_again);
+ return -1;
+ }
+
+ ssize_t write(const char* data, std::size_t size, std::error_code& ec)
+ {
+ std::lock_guard<std::mutex> lk {mutex_};
+ if (stop_) {
+ ec = std::make_error_code(std::errc::broken_pipe);
+ return -1;
+ }
+ stream_.insert(stream_.end(), data, data + size);
+ cv_.notify_all();
+ ec.clear();
+ return size;
+ }
+
+ void stop() noexcept
+ {
+ std::lock_guard<std::mutex> lk {mutex_};
+ if (stop_)
+ return;
+ stop_ = true;
+ cv_.notify_all();
+ }
+
+private:
+ PeerChannel(const PeerChannel& o) = delete;
+ PeerChannel& operator=(const PeerChannel& o) = delete;
+ PeerChannel& operator=(PeerChannel&& o) = delete;
+
+ std::mutex mutex_ {};
+ std::condition_variable cv_ {};
+ std::deque<char> stream_;
+ bool stop_ {false};
+};
+
+} // namespace jami
diff --git a/src/upnp/protocol/igd.cpp b/src/upnp/protocol/igd.cpp
new file mode 100644
index 0000000..0e2ac90
--- /dev/null
+++ b/src/upnp/protocol/igd.cpp
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com>
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "igd.h"
+#include "logger.h"
+
+namespace jami {
+namespace upnp {
+
+IGD::IGD(NatProtocolType proto)
+ : protocol_(proto)
+{}
+
+bool
+IGD::operator==(IGD& other) const
+{
+ return localIp_ == other.localIp_ and publicIp_ == other.publicIp_ and uid_ == other.uid_;
+}
+
+void
+IGD::setValid(bool valid)
+{
+ valid_ = valid;
+
+ if (valid) {
+ // Reset errors counter.
+ errorsCounter_ = 0;
+ } else {
+ JAMI_WARN("IGD %s [%s] was disabled", toString().c_str(), getProtocolName());
+ }
+}
+
+bool
+IGD::incrementErrorsCounter()
+{
+ if (not valid_)
+ return false;
+
+ if (++errorsCounter_ >= MAX_ERRORS_COUNT) {
+ JAMI_WARN("IGD %s [%s] has too many errors, it will be disabled",
+ toString().c_str(),
+ getProtocolName());
+ setValid(false);
+ return false;
+ }
+
+ return true;
+}
+
+int
+IGD::getErrorsCount() const
+{
+ return errorsCounter_.load();
+}
+
+} // namespace upnp
+} // namespace jami
\ No newline at end of file
diff --git a/src/upnp/protocol/igd.h b/src/upnp/protocol/igd.h
new file mode 100644
index 0000000..33810f8
--- /dev/null
+++ b/src/upnp/protocol/igd.h
@@ -0,0 +1,110 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include <mutex>
+
+#include "ip_utils.h"
+#include "mapping.h"
+
+#ifdef _MSC_VER
+typedef uint16_t in_port_t;
+#endif
+
+namespace jami {
+namespace upnp {
+
+enum class NatProtocolType { UNKNOWN, PUPNP, NAT_PMP };
+
+class IGD
+{
+public:
+ // Max error before moving the IGD to invalid state.
+ constexpr static int MAX_ERRORS_COUNT = 10;
+
+ IGD(NatProtocolType prot);
+ virtual ~IGD() = default;
+ bool operator==(IGD& other) const;
+
+ NatProtocolType getProtocol() const { return protocol_; }
+
+ char const* getProtocolName() const
+ {
+ return protocol_ == NatProtocolType::NAT_PMP ? "NAT-PMP" : "UPNP";
+ };
+
+ IpAddr getLocalIp() const
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return localIp_;
+ }
+ IpAddr getPublicIp() const
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return publicIp_;
+ }
+ void setLocalIp(const IpAddr& addr)
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ localIp_ = addr;
+ }
+ void setPublicIp(const IpAddr& addr)
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ publicIp_ = addr;
+ }
+ void setUID(const std::string& uid)
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ uid_ = uid;
+ }
+ std::string getUID() const
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return uid_;
+ }
+
+ void setValid(bool valid);
+ bool isValid() const { return valid_; }
+ bool incrementErrorsCounter();
+ int getErrorsCount() const;
+
+ virtual const std::string toString() const = 0;
+
+protected:
+ const NatProtocolType protocol_ {NatProtocolType::UNKNOWN};
+ std::atomic_bool valid_ {false};
+ std::atomic<int> errorsCounter_ {0};
+
+ mutable std::mutex mutex_;
+ IpAddr localIp_ {}; // Local IP of the IGD (typically the same as the gateway).
+ IpAddr publicIp_ {}; // External/public IP of IGD.
+ std::string uid_ {};
+
+private:
+ IGD(IGD&& other) = delete;
+ IGD(IGD& other) = delete;
+ IGD& operator=(IGD&& other) = delete;
+ IGD& operator=(IGD& other) = delete;
+};
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/protocol/mapping.cpp b/src/upnp/protocol/mapping.cpp
new file mode 100644
index 0000000..9b38831
--- /dev/null
+++ b/src/upnp/protocol/mapping.cpp
@@ -0,0 +1,347 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com>
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "mapping.h"
+#include "logger.h"
+
+namespace jami {
+namespace upnp {
+
+Mapping::Mapping(PortType type, uint16_t portExternal, uint16_t portInternal, bool available)
+ : type_(type)
+ , externalPort_(portExternal)
+ , internalPort_(portInternal)
+ , internalAddr_()
+ , igd_()
+ , available_(available)
+ , state_(MappingState::PENDING)
+ , notifyCb_(nullptr)
+ , autoUpdate_(false)
+#if HAVE_LIBNATPMP
+ , renewalTime_(sys_clock::now())
+#endif
+{}
+
+Mapping::Mapping(const Mapping& other)
+{
+ std::lock_guard<std::mutex> lock(other.mutex_);
+
+ internalAddr_ = other.internalAddr_;
+ internalPort_ = other.internalPort_;
+ externalPort_ = other.externalPort_;
+ type_ = other.type_;
+ igd_ = other.igd_;
+ available_ = other.available_;
+ state_ = other.state_;
+ notifyCb_ = other.notifyCb_;
+ autoUpdate_ = other.autoUpdate_;
+#if HAVE_LIBNATPMP
+ renewalTime_ = other.renewalTime_;
+#endif
+}
+
+void
+Mapping::updateFrom(const Mapping::sharedPtr_t& other)
+{
+ updateFrom(*other);
+}
+
+void
+Mapping::updateFrom(const Mapping& other)
+{
+ if (type_ != other.type_) {
+ JAMI_ERR("The source and destination types must match");
+ return;
+ }
+
+ internalAddr_ = std::move(other.internalAddr_);
+ internalPort_ = other.internalPort_;
+ externalPort_ = other.externalPort_;
+ igd_ = other.igd_;
+ state_ = other.state_;
+}
+
+void
+Mapping::setAvailable(bool val)
+{
+ JAMI_DBG("Changing mapping %s state from %s to %s",
+ toString().c_str(),
+ available_ ? "AVAILABLE" : "UNAVAILABLE",
+ val ? "AVAILABLE" : "UNAVAILABLE");
+
+ std::lock_guard<std::mutex> lock(mutex_);
+ available_ = val;
+}
+
+void
+Mapping::setState(const MappingState& state)
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ state_ = state;
+}
+
+const char*
+Mapping::getStateStr() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return getStateStr(state_);
+}
+
+std::string
+Mapping::toString(bool extraInfo) const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ std::ostringstream descr;
+ descr << UPNP_MAPPING_DESCRIPTION_PREFIX << "-" << getTypeStr(type_);
+ descr << ":" << std::to_string(internalPort_);
+
+ if (extraInfo) {
+ descr << " (state=" << getStateStr(state_)
+ << ", auto-update=" << (autoUpdate_ ? "YES" : "NO") << ")";
+ }
+
+ return descr.str();
+}
+
+bool
+Mapping::isValid() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (state_ == MappingState::FAILED)
+ return false;
+ if (internalPort_ == 0)
+ return false;
+ if (externalPort_ == 0)
+ return false;
+ if (not igd_ or not igd_->isValid())
+ return false;
+ IpAddr intAddr(internalAddr_);
+ return intAddr and not intAddr.isLoopback();
+}
+
+bool
+Mapping::hasValidHostAddress() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ IpAddr intAddr(internalAddr_);
+ return intAddr and not intAddr.isLoopback();
+}
+
+bool
+Mapping::hasPublicAddress() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ return igd_ and igd_->getPublicIp() and not igd_->getPublicIp().isPrivate();
+}
+
+Mapping::key_t
+Mapping::getMapKey() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+
+ key_t mapKey = internalPort_;
+ if (type_ == PortType::UDP)
+ mapKey |= 1 << (sizeof(uint16_t) * 8);
+ return mapKey;
+}
+
+PortType
+Mapping::getTypeFromMapKey(key_t key)
+{
+ return (key >> (sizeof(uint16_t) * 8)) ? PortType::UDP : PortType::TCP;
+}
+
+std::string
+Mapping::getExternalAddress() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (igd_)
+ return igd_->getPublicIp().toString();
+ return {};
+}
+
+void
+Mapping::setExternalPort(uint16_t port)
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ externalPort_ = port;
+}
+
+uint16_t
+Mapping::getExternalPort() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return externalPort_;
+}
+
+std::string
+Mapping::getExternalPortStr() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return std::to_string(externalPort_);
+}
+
+void
+Mapping::setInternalAddress(const std::string& addr)
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ internalAddr_ = addr;
+}
+
+std::string
+Mapping::getInternalAddress() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return internalAddr_;
+}
+
+void
+Mapping::setInternalPort(uint16_t port)
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ internalPort_ = port;
+}
+
+uint16_t
+Mapping::getInternalPort() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return internalPort_;
+}
+
+std::string
+Mapping::getInternalPortStr() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return std::to_string(internalPort_);
+}
+
+PortType
+Mapping::getType() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return type_;
+}
+
+const char*
+Mapping::getTypeStr() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return getTypeStr(type_);
+}
+
+bool
+Mapping::isAvailable() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return available_;
+}
+
+std::shared_ptr<IGD>
+Mapping::getIgd() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return igd_;
+}
+
+NatProtocolType
+Mapping::getProtocol() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (igd_)
+ return igd_->getProtocol();
+ return NatProtocolType::UNKNOWN;
+}
+const char*
+Mapping::getProtocolName() const
+{
+ if (igd_) {
+ if (igd_->getProtocol() == NatProtocolType::NAT_PMP)
+ return "NAT-PMP";
+ if (igd_->getProtocol() == NatProtocolType::PUPNP)
+ return "PUPNP";
+ }
+ return "UNKNOWN";
+}
+
+void
+Mapping::setIgd(const std::shared_ptr<IGD>& igd)
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ igd_ = igd;
+}
+
+MappingState
+Mapping::getState() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return state_;
+}
+
+Mapping::NotifyCallback
+Mapping::getNotifyCallback() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return notifyCb_;
+}
+
+void
+Mapping::setNotifyCallback(NotifyCallback cb)
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ notifyCb_ = std::move(cb);
+}
+
+void
+Mapping::enableAutoUpdate(bool enable)
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ autoUpdate_ = enable;
+}
+
+bool
+Mapping::getAutoUpdate() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return autoUpdate_;
+}
+
+#if HAVE_LIBNATPMP
+sys_clock::time_point
+Mapping::getRenewalTime() const
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ return renewalTime_;
+}
+
+void
+Mapping::setRenewalTime(sys_clock::time_point time)
+{
+ std::lock_guard<std::mutex> lock(mutex_);
+ renewalTime_ = time;
+}
+#endif
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/protocol/mapping.h b/src/upnp/protocol/mapping.h
new file mode 100644
index 0000000..89e46b0
--- /dev/null
+++ b/src/upnp/protocol/mapping.h
@@ -0,0 +1,146 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+
+#include "ip_utils.h"
+#include "igd.h"
+
+#include <map>
+#include <string>
+#include <chrono>
+#include <functional>
+#include <mutex>
+
+namespace jami {
+namespace upnp {
+
+using sys_clock = std::chrono::system_clock;
+
+enum class PortType { TCP, UDP };
+enum class MappingState { PENDING, IN_PROGRESS, FAILED, OPEN };
+
+enum class NatProtocolType;
+class IGD;
+
+class Mapping
+{
+ friend class UPnPContext;
+ friend class NatPmp;
+ friend class PUPnP;
+
+public:
+ using key_t = uint64_t;
+ using sharedPtr_t = std::shared_ptr<Mapping>;
+ using NotifyCallback = std::function<void(sharedPtr_t)>;
+
+ static constexpr char const* MAPPING_STATE_STR[4] {"PENDING", "IN_PROGRESS", "FAILED", "OPEN"};
+ static constexpr char const* UPNP_MAPPING_DESCRIPTION_PREFIX {"JAMI"};
+
+ Mapping(PortType type,
+ uint16_t portExternal = 0,
+ uint16_t portInternal = 0,
+ bool available = true);
+ Mapping(const Mapping& other);
+ Mapping(Mapping&& other) = delete;
+ ~Mapping() = default;
+
+ // Delete operators with confusing semantic.
+ Mapping& operator=(Mapping&& other) = delete;
+ bool operator==(const Mapping& other) = delete;
+ bool operator!=(const Mapping& other) = delete;
+ bool operator<(const Mapping& other) = delete;
+ bool operator>(const Mapping& other) = delete;
+ bool operator<=(const Mapping& other) = delete;
+ bool operator>=(const Mapping& other) = delete;
+
+ inline explicit operator bool() const { return isValid(); }
+
+ void updateFrom(const Mapping& other);
+ void updateFrom(const Mapping::sharedPtr_t& other);
+ std::string getExternalAddress() const;
+ uint16_t getExternalPort() const;
+ std::string getExternalPortStr() const;
+ std::string getInternalAddress() const;
+ uint16_t getInternalPort() const;
+ std::string getInternalPortStr() const;
+ PortType getType() const;
+ const char* getTypeStr() const;
+ static const char* getTypeStr(PortType type) { return type == PortType::UDP ? "UDP" : "TCP"; }
+ std::shared_ptr<IGD> getIgd() const;
+ NatProtocolType getProtocol() const;
+ const char* getProtocolName() const;
+ bool isAvailable() const;
+ MappingState getState() const;
+ const char* getStateStr() const;
+ static const char* getStateStr(MappingState state)
+ {
+ return MAPPING_STATE_STR[static_cast<int>(state)];
+ }
+ std::string toString(bool extraInfo = false) const;
+ bool isValid() const;
+ bool hasValidHostAddress() const;
+ bool hasPublicAddress() const;
+ void setNotifyCallback(NotifyCallback cb);
+ void enableAutoUpdate(bool enable);
+ bool getAutoUpdate() const;
+ key_t getMapKey() const;
+ static PortType getTypeFromMapKey(key_t key);
+#if HAVE_LIBNATPMP
+ sys_clock::time_point getRenewalTime() const;
+#endif
+
+private:
+ NotifyCallback getNotifyCallback() const;
+ void setInternalAddress(const std::string& addr);
+ void setExternalPort(uint16_t port);
+ void setInternalPort(uint16_t port);
+
+ void setIgd(const std::shared_ptr<IGD>& igd);
+ void setAvailable(bool val);
+ void setState(const MappingState& state);
+ void updateDescription();
+#if HAVE_LIBNATPMP
+ void setRenewalTime(sys_clock::time_point time);
+#endif
+
+ mutable std::mutex mutex_;
+ PortType type_ {PortType::UDP};
+ uint16_t externalPort_ {0};
+ uint16_t internalPort_ {0};
+ std::string internalAddr_;
+ // Protocol and
+ std::shared_ptr<IGD> igd_;
+ // Track if the mapping is available to use.
+ bool available_;
+ // Track the state of the mapping
+ MappingState state_;
+ NotifyCallback notifyCb_;
+ // If true, a new mapping will be requested on behave of the mapping
+ // owner when the mapping state changes from "OPEN" to "FAILED".
+ bool autoUpdate_;
+#if HAVE_LIBNATPMP
+ sys_clock::time_point renewalTime_;
+#endif
+};
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/protocol/natpmp/nat_pmp.cpp b/src/upnp/protocol/natpmp/nat_pmp.cpp
new file mode 100644
index 0000000..21f11ee
--- /dev/null
+++ b/src/upnp/protocol/natpmp/nat_pmp.cpp
@@ -0,0 +1,775 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "nat_pmp.h"
+
+#if HAVE_LIBNATPMP
+
+namespace jami {
+namespace upnp {
+
+NatPmp::NatPmp()
+{
+ JAMI_DBG("NAT-PMP: Instance [%p] created", this);
+ runOnNatPmpQueue([this] {
+ threadId_ = getCurrentThread();
+ igd_ = std::make_shared<PMPIGD>();
+ });
+}
+
+NatPmp::~NatPmp()
+{
+ JAMI_DBG("NAT-PMP: Instance [%p] destroyed", this);
+}
+
+void
+NatPmp::initNatPmp()
+{
+ if (not isValidThread()) {
+ runOnNatPmpQueue([w = weak()] {
+ if (auto pmpThis = w.lock()) {
+ pmpThis->initNatPmp();
+ }
+ });
+ return;
+ }
+
+ initialized_ = false;
+
+ {
+ std::lock_guard<std::mutex> lock(natpmpMutex_);
+ hostAddress_ = ip_utils::getLocalAddr(AF_INET);
+ }
+
+ // Local address must be valid.
+ if (not getHostAddress() or getHostAddress().isLoopback()) {
+ JAMI_WARN("NAT-PMP: Does not have a valid local address!");
+ return;
+ }
+
+ assert(igd_);
+ if (igd_->isValid()) {
+ igd_->setValid(false);
+ processIgdUpdate(UpnpIgdEvent::REMOVED);
+ }
+
+ igd_->setLocalIp(IpAddr());
+ igd_->setPublicIp(IpAddr());
+ igd_->setUID("");
+
+ JAMI_DBG("NAT-PMP: Trying to initialize IGD");
+
+ int err = initnatpmp(&natpmpHdl_, 0, 0);
+
+ if (err < 0) {
+ JAMI_WARN("NAT-PMP: Initializing IGD using default gateway failed!");
+ const auto& localGw = ip_utils::getLocalGateway();
+ if (not localGw) {
+ JAMI_WARN("NAT-PMP: Couldn't find valid gateway on local host");
+ err = NATPMP_ERR_CANNOTGETGATEWAY;
+ } else {
+ JAMI_WARN("NAT-PMP: Trying to initialize using detected gateway %s",
+ localGw.toString().c_str());
+
+ struct in_addr inaddr;
+ inet_pton(AF_INET, localGw.toString().c_str(), &inaddr);
+ err = initnatpmp(&natpmpHdl_, 1, inaddr.s_addr);
+ }
+ }
+
+ if (err < 0) {
+ JAMI_ERR("NAT-PMP: Can't initialize libnatpmp -> %s", getNatPmpErrorStr(err));
+ return;
+ }
+
+ char addrbuf[INET_ADDRSTRLEN];
+ inet_ntop(AF_INET, &natpmpHdl_.gateway, addrbuf, sizeof(addrbuf));
+ IpAddr igdAddr(addrbuf);
+ JAMI_DBG("NAT-PMP: Initialized on gateway %s", igdAddr.toString().c_str());
+
+ // Set the local (gateway) address.
+ igd_->setLocalIp(igdAddr);
+ // NAT-PMP protocol does not have UID, but we will set generic
+ // one debugging purposes.
+ igd_->setUID("NAT-PMP Gateway");
+
+ // Search and set the public address.
+ getIgdPublicAddress();
+
+ // Update and notify.
+ if (igd_->isValid()) {
+ initialized_ = true;
+ processIgdUpdate(UpnpIgdEvent::ADDED);
+ };
+}
+
+void
+NatPmp::setObserver(UpnpMappingObserver* obs)
+{
+ if (not isValidThread()) {
+ runOnNatPmpQueue([w = weak(), obs] {
+ if (auto pmpThis = w.lock()) {
+ pmpThis->setObserver(obs);
+ }
+ });
+ return;
+ }
+
+ JAMI_DBG("NAT-PMP: Setting observer to %p", obs);
+
+ observer_ = obs;
+}
+
+void
+NatPmp::terminate(std::condition_variable& cv)
+{
+ initialized_ = false;
+ observer_ = nullptr;
+
+ {
+ std::lock_guard<std::mutex> lock(natpmpMutex_);
+ shutdownComplete_ = true;
+ cv.notify_one();
+ }
+}
+
+void
+NatPmp::terminate()
+{
+ std::unique_lock<std::mutex> lk(natpmpMutex_);
+ std::condition_variable cv {};
+
+ runOnNatPmpQueue([w = weak(), &cv = cv] {
+ if (auto pmpThis = w.lock()) {
+ pmpThis->terminate(cv);
+ }
+ });
+
+ if (cv.wait_for(lk, std::chrono::seconds(10), [this] { return shutdownComplete_; })) {
+ JAMI_DBG("NAT-PMP: Shutdown completed");
+ } else {
+ JAMI_ERR("NAT-PMP: Shutdown timed-out");
+ }
+}
+
+const IpAddr
+NatPmp::getHostAddress() const
+{
+ std::lock_guard<std::mutex> lock(natpmpMutex_);
+ return hostAddress_;
+}
+
+void
+NatPmp::clearIgds()
+{
+ if (not isValidThread()) {
+ runOnNatPmpQueue([w = weak()] {
+ if (auto pmpThis = w.lock()) {
+ pmpThis->clearIgds();
+ }
+ });
+ return;
+ }
+
+ bool do_close = false;
+
+ if (igd_) {
+ if (igd_->isValid()) {
+ do_close = true;
+ }
+ igd_->setValid(false);
+ }
+
+ initialized_ = false;
+ if (searchForIgdTimer_)
+ searchForIgdTimer_->cancel();
+
+ igdSearchCounter_ = 0;
+
+ if (do_close) {
+ closenatpmp(&natpmpHdl_);
+ memset(&natpmpHdl_, 0, sizeof(natpmpHdl_));
+ }
+}
+
+void
+NatPmp::searchForIgd()
+{
+ if (not isValidThread()) {
+ runOnNatPmpQueue([w = weak()] {
+ if (auto pmpThis = w.lock()) {
+ pmpThis->searchForIgd();
+ }
+ });
+ return;
+ }
+
+ if (not initialized_) {
+ initNatPmp();
+ }
+
+ // Schedule a retry in case init failed.
+ if (not initialized_) {
+ if (igdSearchCounter_++ < MAX_RESTART_SEARCH_RETRIES) {
+ JAMI_DBG("NAT-PMP: Start search for IGDs. Attempt %i", igdSearchCounter_);
+
+ // Cancel the current timer (if any) and re-schedule.
+ if (searchForIgdTimer_)
+ searchForIgdTimer_->cancel();
+
+ searchForIgdTimer_ = getNatpmpScheduler()->scheduleIn([this] { searchForIgd(); },
+ NATPMP_SEARCH_RETRY_UNIT
+ * igdSearchCounter_);
+ } else {
+ JAMI_WARN("NAT-PMP: Setup failed after %u trials. NAT-PMP will be disabled!",
+ MAX_RESTART_SEARCH_RETRIES);
+ }
+ }
+}
+
+std::list<std::shared_ptr<IGD>>
+NatPmp::getIgdList() const
+{
+ std::lock_guard<std::mutex> lock(natpmpMutex_);
+ std::list<std::shared_ptr<IGD>> igdList;
+ if (igd_->isValid())
+ igdList.emplace_back(igd_);
+ return igdList;
+}
+
+bool
+NatPmp::isReady() const
+{
+ if (observer_ == nullptr) {
+ JAMI_ERR("NAT-PMP: the observer is not set!");
+ return false;
+ }
+
+ // Must at least have a valid local address.
+ if (not getHostAddress() or getHostAddress().isLoopback())
+ return false;
+
+ return igd_ and igd_->isValid();
+}
+
+void
+NatPmp::incrementErrorsCounter(const std::shared_ptr<IGD>& igdIn)
+{
+ if (not validIgdInstance(igdIn)) {
+ return;
+ }
+
+ if (not igd_->isValid()) {
+ // Already invalid. Nothing to do.
+ return;
+ }
+
+ if (not igd_->incrementErrorsCounter()) {
+ // Disable this IGD.
+ igd_->setValid(false);
+ // Notify the listener.
+ JAMI_WARN("NAT-PMP: No more valid IGD!");
+
+ processIgdUpdate(UpnpIgdEvent::INVALID_STATE);
+ }
+}
+
+void
+NatPmp::requestMappingAdd(const Mapping& mapping)
+{
+ // Process on nat-pmp thread.
+ if (not isValidThread()) {
+ runOnNatPmpQueue([w = weak(), mapping] {
+ if (auto pmpThis = w.lock()) {
+ pmpThis->requestMappingAdd(mapping);
+ }
+ });
+ return;
+ }
+
+ Mapping map(mapping);
+ assert(map.getIgd());
+ auto err = addPortMapping(map);
+ if (err < 0) {
+ JAMI_WARN("NAT-PMP: Request for mapping %s on %s failed with error %i: %s",
+ map.toString().c_str(),
+ igd_->toString().c_str(),
+ err,
+ getNatPmpErrorStr(err));
+
+ if (isErrorFatal(err)) {
+ // Fatal error, increment the counter.
+ incrementErrorsCounter(igd_);
+ }
+ // Notify the listener.
+ processMappingRequestFailed(std::move(map));
+ } else {
+ JAMI_DBG("NAT-PMP: Request for mapping %s on %s succeeded",
+ map.toString().c_str(),
+ igd_->toString().c_str());
+ // Notify the listener.
+ processMappingAdded(std::move(map));
+ }
+}
+
+void
+NatPmp::requestMappingRenew(const Mapping& mapping)
+{
+ // Process on nat-pmp thread.
+ if (not isValidThread()) {
+ runOnNatPmpQueue([w = weak(), mapping] {
+ if (auto pmpThis = w.lock()) {
+ pmpThis->requestMappingRenew(mapping);
+ }
+ });
+ return;
+ }
+
+ Mapping map(mapping);
+ auto err = addPortMapping(map);
+ if (err < 0) {
+ JAMI_WARN("NAT-PMP: Renewal request for mapping %s on %s failed with error %i: %s",
+ map.toString().c_str(),
+ igd_->toString().c_str(),
+ err,
+ getNatPmpErrorStr(err));
+ // Notify the listener.
+ processMappingRequestFailed(std::move(map));
+
+ if (isErrorFatal(err)) {
+ // Fatal error, increment the counter.
+ incrementErrorsCounter(igd_);
+ }
+ } else {
+ JAMI_DBG("NAT-PMP: Renewal request for mapping %s on %s succeeded",
+ map.toString().c_str(),
+ igd_->toString().c_str());
+ // Notify the listener.
+ processMappingRenewed(map);
+ }
+}
+
+int
+NatPmp::readResponse(natpmp_t& handle, natpmpresp_t& response)
+{
+ int err = 0;
+ unsigned readRetriesCounter = 0;
+
+ while (true) {
+ if (readRetriesCounter++ > MAX_READ_RETRIES) {
+ err = NATPMP_ERR_SOCKETERROR;
+ break;
+ }
+
+ fd_set fds;
+ struct timeval timeout;
+ FD_ZERO(&fds);
+ FD_SET(handle.s, &fds);
+ getnatpmprequesttimeout(&handle, &timeout);
+ // Wait for data.
+ if (select(FD_SETSIZE, &fds, NULL, NULL, &timeout) == -1) {
+ err = NATPMP_ERR_SOCKETERROR;
+ break;
+ }
+
+ // Read the data.
+ err = readnatpmpresponseorretry(&handle, &response);
+
+ if (err == NATPMP_TRYAGAIN) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(TIMEOUT_BEFORE_READ_RETRY));
+ } else {
+ break;
+ }
+ }
+
+ return err;
+}
+
+int
+NatPmp::sendMappingRequest(const Mapping& mapping, uint32_t& lifetime)
+{
+ CHECK_VALID_THREAD();
+
+ int err = sendnewportmappingrequest(&natpmpHdl_,
+ mapping.getType() == PortType::UDP ? NATPMP_PROTOCOL_UDP
+ : NATPMP_PROTOCOL_TCP,
+ mapping.getInternalPort(),
+ mapping.getExternalPort(),
+ lifetime);
+
+ if (err < 0) {
+ JAMI_ERR("NAT-PMP: Send mapping request failed with error %s %i",
+ getNatPmpErrorStr(err),
+ errno);
+ return err;
+ }
+
+ unsigned readRetriesCounter = 0;
+
+ while (readRetriesCounter++ < MAX_READ_RETRIES) {
+ // Read the response
+ natpmpresp_t response;
+ err = readResponse(natpmpHdl_, response);
+
+ if (err < 0) {
+ JAMI_WARN("NAT-PMP: Read response on IGD %s failed with error %s",
+ igd_->toString().c_str(),
+ getNatPmpErrorStr(err));
+ } else if (response.type != NATPMP_RESPTYPE_TCPPORTMAPPING
+ and response.type != NATPMP_RESPTYPE_UDPPORTMAPPING) {
+ JAMI_ERR("NAT-PMP: Unexpected response type (%i) for mapping %s from IGD %s.",
+ response.type,
+ mapping.toString().c_str(),
+ igd_->toString().c_str());
+ // Try to read again.
+ continue;
+ }
+
+ lifetime = response.pnu.newportmapping.lifetime;
+ // Done.
+ break;
+ }
+
+ return err;
+}
+
+int
+NatPmp::addPortMapping(Mapping& mapping)
+{
+ auto const& igdIn = mapping.getIgd();
+ assert(igdIn);
+ assert(igdIn->getProtocol() == NatProtocolType::NAT_PMP);
+
+ if (not igdIn->isValid() or not validIgdInstance(igdIn)) {
+ mapping.setState(MappingState::FAILED);
+ return NATPMP_ERR_INVALIDARGS;
+ }
+
+ mapping.setInternalAddress(getHostAddress().toString());
+
+ uint32_t lifetime = MAPPING_ALLOCATION_LIFETIME;
+ int err = sendMappingRequest(mapping, lifetime);
+
+ if (err < 0) {
+ mapping.setState(MappingState::FAILED);
+ return err;
+ }
+
+ // Set the renewal time and update.
+ mapping.setRenewalTime(sys_clock::now() + std::chrono::seconds(lifetime * 4 / 5));
+ mapping.setState(MappingState::OPEN);
+
+ return 0;
+}
+
+void
+NatPmp::requestMappingRemove(const Mapping& mapping)
+{
+ // Process on nat-pmp thread.
+ if (not isValidThread()) {
+ runOnNatPmpQueue([w = weak(), mapping] {
+ if (auto pmpThis = w.lock()) {
+ Mapping map {mapping};
+ pmpThis->removePortMapping(map);
+ }
+ });
+ return;
+ }
+}
+
+void
+NatPmp::removePortMapping(Mapping& mapping)
+{
+ auto igdIn = mapping.getIgd();
+ assert(igdIn);
+ if (not igdIn->isValid()) {
+ return;
+ }
+
+ if (not validIgdInstance(igdIn)) {
+ return;
+ }
+
+ Mapping mapToRemove(mapping);
+
+ uint32_t lifetime = 0;
+ int err = sendMappingRequest(mapping, lifetime);
+
+ if (err < 0) {
+ // Nothing to do if the request fails, just log the error.
+ JAMI_WARN("NAT-PMP: Send remove request failed with error %s. Ignoring",
+ getNatPmpErrorStr(err));
+ }
+
+ // Update and notify the listener.
+ mapToRemove.setState(MappingState::FAILED);
+ processMappingRemoved(std::move(mapToRemove));
+}
+
+void
+NatPmp::getIgdPublicAddress()
+{
+ CHECK_VALID_THREAD();
+
+ // Set the public address for this IGD if it does not
+ // have one already.
+ if (igd_->getPublicIp()) {
+ JAMI_WARN("NAT-PMP: IGD %s already have a public address (%s)",
+ igd_->toString().c_str(),
+ igd_->getPublicIp().toString().c_str());
+ return;
+ }
+ assert(igd_->getProtocol() == NatProtocolType::NAT_PMP);
+
+ int err = sendpublicaddressrequest(&natpmpHdl_);
+
+ if (err < 0) {
+ JAMI_ERR("NAT-PMP: send public address request on IGD %s failed with error: %s",
+ igd_->toString().c_str(),
+ getNatPmpErrorStr(err));
+
+ if (isErrorFatal(err)) {
+ // Fatal error, increment the counter.
+ incrementErrorsCounter(igd_);
+ }
+ return;
+ }
+
+ natpmpresp_t response;
+ err = readResponse(natpmpHdl_, response);
+
+ if (err < 0) {
+ JAMI_WARN("NAT-PMP: Read response on IGD %s failed - %s",
+ igd_->toString().c_str(),
+ getNatPmpErrorStr(err));
+ return;
+ }
+
+ if (response.type != NATPMP_RESPTYPE_PUBLICADDRESS) {
+ JAMI_ERR("NAT-PMP: Unexpected response type (%i) for public address request from IGD %s.",
+ response.type,
+ igd_->toString().c_str());
+ return;
+ }
+
+ IpAddr publicAddr(response.pnu.publicaddress.addr);
+
+ if (not publicAddr) {
+ JAMI_ERR("NAT-PMP: IGD %s returned an invalid public address %s",
+ igd_->toString().c_str(),
+ publicAddr.toString().c_str());
+ }
+
+ // Update.
+ igd_->setPublicIp(publicAddr);
+ igd_->setValid(true);
+
+ JAMI_DBG("NAT-PMP: Setting IGD %s public address to %s",
+ igd_->toString().c_str(),
+ igd_->getPublicIp().toString().c_str());
+}
+
+void
+NatPmp::removeAllMappings()
+{
+ CHECK_VALID_THREAD();
+
+ JAMI_WARN("NAT-PMP: Send request to close all existing mappings to IGD %s",
+ igd_->toString().c_str());
+
+ int err = sendnewportmappingrequest(&natpmpHdl_, NATPMP_PROTOCOL_TCP, 0, 0, 0);
+ if (err < 0) {
+ JAMI_WARN("NAT-PMP: Send close all TCP mappings request failed with error %s",
+ getNatPmpErrorStr(err));
+ }
+ err = sendnewportmappingrequest(&natpmpHdl_, NATPMP_PROTOCOL_UDP, 0, 0, 0);
+ if (err < 0) {
+ JAMI_WARN("NAT-PMP: Send close all UDP mappings request failed with error %s",
+ getNatPmpErrorStr(err));
+ }
+}
+
+const char*
+NatPmp::getNatPmpErrorStr(int errorCode) const
+{
+#ifdef ENABLE_STRNATPMPERR
+ return strnatpmperr(errorCode);
+#else
+ switch (errorCode) {
+ case NATPMP_ERR_INVALIDARGS:
+ return "INVALIDARGS";
+ break;
+ case NATPMP_ERR_SOCKETERROR:
+ return "SOCKETERROR";
+ break;
+ case NATPMP_ERR_CANNOTGETGATEWAY:
+ return "CANNOTGETGATEWAY";
+ break;
+ case NATPMP_ERR_CLOSEERR:
+ return "CLOSEERR";
+ break;
+ case NATPMP_ERR_RECVFROM:
+ return "RECVFROM";
+ break;
+ case NATPMP_ERR_NOPENDINGREQ:
+ return "NOPENDINGREQ";
+ break;
+ case NATPMP_ERR_NOGATEWAYSUPPORT:
+ return "NOGATEWAYSUPPORT";
+ break;
+ case NATPMP_ERR_CONNECTERR:
+ return "CONNECTERR";
+ break;
+ case NATPMP_ERR_WRONGPACKETSOURCE:
+ return "WRONGPACKETSOURCE";
+ break;
+ case NATPMP_ERR_SENDERR:
+ return "SENDERR";
+ break;
+ case NATPMP_ERR_FCNTLERROR:
+ return "FCNTLERROR";
+ break;
+ case NATPMP_ERR_GETTIMEOFDAYERR:
+ return "GETTIMEOFDAYERR";
+ break;
+ case NATPMP_ERR_UNSUPPORTEDVERSION:
+ return "UNSUPPORTEDVERSION";
+ break;
+ case NATPMP_ERR_UNSUPPORTEDOPCODE:
+ return "UNSUPPORTEDOPCODE";
+ break;
+ case NATPMP_ERR_UNDEFINEDERROR:
+ return "UNDEFINEDERROR";
+ break;
+ case NATPMP_ERR_NOTAUTHORIZED:
+ return "NOTAUTHORIZED";
+ break;
+ case NATPMP_ERR_NETWORKFAILURE:
+ return "NETWORKFAILURE";
+ break;
+ case NATPMP_ERR_OUTOFRESOURCES:
+ return "OUTOFRESOURCES";
+ break;
+ case NATPMP_TRYAGAIN:
+ return "TRYAGAIN";
+ break;
+ default:
+ return "UNKNOWNERR";
+ break;
+ }
+#endif
+}
+
+bool
+NatPmp::isErrorFatal(int error)
+{
+ switch (error) {
+ case NATPMP_ERR_INVALIDARGS:
+ case NATPMP_ERR_SOCKETERROR:
+ case NATPMP_ERR_CANNOTGETGATEWAY:
+ case NATPMP_ERR_CLOSEERR:
+ case NATPMP_ERR_RECVFROM:
+ case NATPMP_ERR_NOGATEWAYSUPPORT:
+ case NATPMP_ERR_CONNECTERR:
+ case NATPMP_ERR_SENDERR:
+ case NATPMP_ERR_UNDEFINEDERROR:
+ case NATPMP_ERR_UNSUPPORTEDVERSION:
+ case NATPMP_ERR_UNSUPPORTEDOPCODE:
+ case NATPMP_ERR_NOTAUTHORIZED:
+ case NATPMP_ERR_NETWORKFAILURE:
+ case NATPMP_ERR_OUTOFRESOURCES:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool
+NatPmp::validIgdInstance(const std::shared_ptr<IGD>& igdIn)
+{
+ if (igd_.get() != igdIn.get()) {
+ JAMI_ERR("NAT-PMP: IGD (%s) does not match local instance (%s)",
+ igdIn->toString().c_str(),
+ igd_->toString().c_str());
+ return false;
+ }
+
+ return true;
+}
+
+void
+NatPmp::processIgdUpdate(UpnpIgdEvent event)
+{
+ if (igd_->isValid()) {
+ // Remove all current mappings if any.
+ removeAllMappings();
+ }
+
+ if (observer_ == nullptr)
+ return;
+ // Process the response on the context thread.
+ runOnUpnpContextQueue([obs = observer_, igd = igd_, event] { obs->onIgdUpdated(igd, event); });
+}
+
+void
+NatPmp::processMappingAdded(const Mapping& map)
+{
+ if (observer_ == nullptr)
+ return;
+
+ // Process the response on the context thread.
+ runOnUpnpContextQueue([obs = observer_, igd = igd_, map] { obs->onMappingAdded(igd, map); });
+}
+
+void
+NatPmp::processMappingRequestFailed(const Mapping& map)
+{
+ if (observer_ == nullptr)
+ return;
+
+ // Process the response on the context thread.
+ runOnUpnpContextQueue([obs = observer_, igd = igd_, map] { obs->onMappingRequestFailed(map); });
+}
+
+void
+NatPmp::processMappingRenewed(const Mapping& map)
+{
+ if (observer_ == nullptr)
+ return;
+
+ // Process the response on the context thread.
+ runOnUpnpContextQueue([obs = observer_, igd = igd_, map] { obs->onMappingRenewed(igd, map); });
+}
+
+void
+NatPmp::processMappingRemoved(const Mapping& map)
+{
+ if (observer_ == nullptr)
+ return;
+
+ // Process the response on the context thread.
+ runOnUpnpContextQueue([obs = observer_, igd = igd_, map] { obs->onMappingRemoved(igd, map); });
+}
+
+} // namespace upnp
+} // namespace jami
+
+#endif //-- #if HAVE_LIBNATPMP
diff --git a/src/upnp/protocol/natpmp/nat_pmp.h b/src/upnp/protocol/natpmp/nat_pmp.h
new file mode 100644
index 0000000..68fd28b
--- /dev/null
+++ b/src/upnp/protocol/natpmp/nat_pmp.h
@@ -0,0 +1,174 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+
+#include "connectivity/upnp/protocol/upnp_protocol.h"
+#include "connectivity/upnp/protocol/igd.h"
+#include "pmp_igd.h"
+
+#include "logger.h"
+#include "connectivity/ip_utils.h"
+#include "noncopyable.h"
+#include "compiler_intrinsics.h"
+
+// uncomment to enable native natpmp error messages
+//#define ENABLE_STRNATPMPERR 1
+#include <natpmp.h>
+
+#include <atomic>
+#include <thread>
+
+namespace jami {
+class IpAddr;
+}
+
+namespace jami {
+namespace upnp {
+
+// Requested lifetime in seconds. The actual lifetime might be different.
+constexpr static unsigned int MAPPING_ALLOCATION_LIFETIME {60 * 60};
+// Max number of IGD search attempts before failure.
+constexpr static unsigned int MAX_RESTART_SEARCH_RETRIES {3};
+// Time-out between two successive read response.
+constexpr static auto TIMEOUT_BEFORE_READ_RETRY {std::chrono::milliseconds(300)};
+// Max number of read attempts before failure.
+constexpr static unsigned int MAX_READ_RETRIES {3};
+// Base unit for the timeout between two successive IGD search.
+constexpr static auto NATPMP_SEARCH_RETRY_UNIT {std::chrono::seconds(10)};
+
+class NatPmp : public UPnPProtocol
+{
+public:
+ NatPmp();
+ ~NatPmp();
+
+ // Set the observer.
+ void setObserver(UpnpMappingObserver* obs) override;
+
+ // Returns the protocol type.
+ NatProtocolType getProtocol() const override { return NatProtocolType::NAT_PMP; }
+
+ // Get protocol type as string.
+ char const* getProtocolName() const override { return "NAT-PMP"; }
+
+ // Notifies a change in network.
+ void clearIgds() override;
+
+ // Renew pmp_igd.
+ void searchForIgd() override;
+
+ // Get the IGD list.
+ std::list<std::shared_ptr<IGD>> getIgdList() const override;
+
+ // Return true if it has at least one valid IGD.
+ bool isReady() const override;
+
+ // Request a new mapping.
+ void requestMappingAdd(const Mapping& mapping) override;
+
+ // Renew an allocated mapping.
+ void requestMappingRenew(const Mapping& mapping) override;
+
+ // Removes a mapping.
+ void requestMappingRemove(const Mapping& mapping) override;
+
+ // Get the host (local) address.
+ const IpAddr getHostAddress() const override;
+
+ // Terminate. Nothing to do here, the clean-up is done when
+ // the IGD is cleared.
+ void terminate() override;
+
+private:
+ NON_COPYABLE(NatPmp);
+
+ std::weak_ptr<NatPmp> weak() { return std::static_pointer_cast<NatPmp>(shared_from_this()); }
+
+ // Helpers to run tasks on NAT-PMP internal execution queue.
+ ScheduledExecutor* getNatpmpScheduler() { return &natpmpScheduler_; }
+ template<typename Callback>
+ void runOnNatPmpQueue(Callback&& cb)
+ {
+ natpmpScheduler_.run([cb = std::forward<Callback>(cb)]() mutable { cb(); });
+ }
+
+ // Helpers to run tasks on UPNP context execution queue.
+ ScheduledExecutor* getUpnContextScheduler() { return UpnpThreadUtil::getScheduler(); }
+
+ void terminate(std::condition_variable& cv);
+
+ void initNatPmp();
+ void getIgdPublicAddress();
+ void removeAllMappings();
+ int readResponse(natpmp_t& handle, natpmpresp_t& response);
+ int sendMappingRequest(const Mapping& mapping, uint32_t& lifetime);
+
+ // Adds a port mapping.
+ int addPortMapping(Mapping& mapping);
+ // Removes a port mapping.
+ void removePortMapping(Mapping& mapping);
+
+ // True if the error is fatal.
+ bool isErrorFatal(int error);
+ // Gets NAT-PMP error code string.
+ const char* getNatPmpErrorStr(int errorCode) const;
+ // Get local getaway.
+ std::unique_ptr<IpAddr> getLocalGateway() const;
+
+ // Helpers to process user's callbacks
+ void processIgdUpdate(UpnpIgdEvent event);
+ void processMappingAdded(const Mapping& map);
+ void processMappingRequestFailed(const Mapping& map);
+ void processMappingRenewed(const Mapping& map);
+ void processMappingRemoved(const Mapping& map);
+
+ // Check if the IGD has a local match
+ bool validIgdInstance(const std::shared_ptr<IGD>& igdIn);
+
+ // Increment errors counter.
+ void incrementErrorsCounter(const std::shared_ptr<IGD>& igd);
+
+ std::atomic_bool initialized_ {false};
+
+ // Data members
+ std::shared_ptr<PMPIGD> igd_;
+ natpmp_t natpmpHdl_;
+ ScheduledExecutor natpmpScheduler_ {"natpmp"};
+ std::shared_ptr<Task> searchForIgdTimer_ {};
+ unsigned int igdSearchCounter_ {0};
+ UpnpMappingObserver* observer_ {nullptr};
+ IpAddr hostAddress_ {};
+
+ // Calls from other threads that does not need synchronous access are
+ // rescheduled on the NatPmp private queue. This will avoid the need to
+ // protect most of the data members of this class.
+ // For some internal members (such as the igd instance and the host
+ // address) that need to be synchronously accessed, are protected by
+ // this mutex.
+ mutable std::mutex natpmpMutex_;
+
+ // Shutdown synchronization
+ bool shutdownComplete_ {false};
+};
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/protocol/natpmp/pmp_igd.cpp b/src/upnp/protocol/natpmp/pmp_igd.cpp
new file mode 100644
index 0000000..ac8b698
--- /dev/null
+++ b/src/upnp/protocol/natpmp/pmp_igd.cpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com>
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "pmp_igd.h"
+
+#include <algorithm>
+
+namespace jami {
+namespace upnp {
+
+PMPIGD::PMPIGD()
+ : IGD(NatProtocolType::NAT_PMP)
+{}
+
+PMPIGD::PMPIGD(const PMPIGD& other)
+ : PMPIGD()
+{
+ assert(protocol_ == NatProtocolType::NAT_PMP);
+ // protocol_ = other.protocol_;
+ localIp_ = other.localIp_;
+ publicIp_ = other.publicIp_;
+ uid_ = other.uid_;
+}
+
+bool
+PMPIGD::operator==(IGD& other) const
+{
+ return getPublicIp() == other.getPublicIp() and getLocalIp() == other.getLocalIp();
+}
+
+bool
+PMPIGD::operator==(PMPIGD& other) const
+{
+ return getPublicIp() == other.getPublicIp() and getLocalIp() == other.getLocalIp();
+}
+
+const std::string
+PMPIGD::toString() const
+{
+ return getLocalIp().toString();
+}
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/protocol/natpmp/pmp_igd.h b/src/upnp/protocol/natpmp/pmp_igd.h
new file mode 100644
index 0000000..a70e7ee
--- /dev/null
+++ b/src/upnp/protocol/natpmp/pmp_igd.h
@@ -0,0 +1,54 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com>
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include "../igd.h"
+#include "noncopyable.h"
+#include "connectivity/ip_utils.h"
+
+#include <map>
+#include <atomic>
+#include <string>
+#include <chrono>
+#include <functional>
+
+namespace jami {
+namespace upnp {
+
+class PMPIGD : public IGD
+{
+public:
+ PMPIGD();
+ PMPIGD(const PMPIGD&);
+ ~PMPIGD() = default;
+
+ PMPIGD& operator=(PMPIGD&& other) = delete;
+ PMPIGD& operator=(PMPIGD& other) = delete;
+
+ bool operator==(IGD& other) const;
+ bool operator==(PMPIGD& other) const;
+
+ const std::string toString() const override;
+};
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/protocol/pupnp/pupnp.cpp b/src/upnp/protocol/pupnp/pupnp.cpp
new file mode 100644
index 0000000..cc63347
--- /dev/null
+++ b/src/upnp/protocol/pupnp/pupnp.cpp
@@ -0,0 +1,1599 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com>
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Adrien Béraud <adrien.beraud@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "pupnp.h"
+
+#include <opendht/thread_pool.h>
+#include <opendht/http.h>
+
+namespace jami {
+namespace upnp {
+
+// Action identifiers.
+constexpr static const char* ACTION_ADD_PORT_MAPPING {"AddPortMapping"};
+constexpr static const char* ACTION_DELETE_PORT_MAPPING {"DeletePortMapping"};
+constexpr static const char* ACTION_GET_GENERIC_PORT_MAPPING_ENTRY {"GetGenericPortMappingEntry"};
+constexpr static const char* ACTION_GET_STATUS_INFO {"GetStatusInfo"};
+constexpr static const char* ACTION_GET_EXTERNAL_IP_ADDRESS {"GetExternalIPAddress"};
+
+// Error codes returned by router when trying to remove ports.
+constexpr static int ARRAY_IDX_INVALID = 713;
+constexpr static int CONFLICT_IN_MAPPING = 718;
+
+// Max number of IGD search attempts before failure.
+constexpr static unsigned int PUPNP_MAX_RESTART_SEARCH_RETRIES {3};
+// IGD search timeout (in seconds).
+constexpr static unsigned int SEARCH_TIMEOUT {60};
+// Base unit for the timeout between two successive IGD search.
+constexpr static auto PUPNP_SEARCH_RETRY_UNIT {std::chrono::seconds(10)};
+
+// Helper functions for xml parsing.
+static std::string_view
+getElementText(IXML_Node* node)
+{
+ if (node) {
+ if (IXML_Node* textNode = ixmlNode_getFirstChild(node))
+ if (const char* value = ixmlNode_getNodeValue(textNode))
+ return std::string_view(value);
+ }
+ return {};
+}
+
+static std::string_view
+getFirstDocItem(IXML_Document* doc, const char* item)
+{
+ std::unique_ptr<IXML_NodeList, decltype(ixmlNodeList_free)&>
+ nodeList(ixmlDocument_getElementsByTagName(doc, item), ixmlNodeList_free);
+ if (nodeList) {
+ // If there are several nodes which match the tag, we only want the first one.
+ return getElementText(ixmlNodeList_item(nodeList.get(), 0));
+ }
+ return {};
+}
+
+static std::string_view
+getFirstElementItem(IXML_Element* element, const char* item)
+{
+ std::unique_ptr<IXML_NodeList, decltype(ixmlNodeList_free)&>
+ nodeList(ixmlElement_getElementsByTagName(element, item), ixmlNodeList_free);
+ if (nodeList) {
+ // If there are several nodes which match the tag, we only want the first one.
+ return getElementText(ixmlNodeList_item(nodeList.get(), 0));
+ }
+ return {};
+}
+
+static bool
+errorOnResponse(IXML_Document* doc)
+{
+ if (not doc)
+ return true;
+
+ auto errorCode = getFirstDocItem(doc, "errorCode");
+ if (not errorCode.empty()) {
+ auto errorDescription = getFirstDocItem(doc, "errorDescription");
+ JAMI_WARNING("PUPnP: Response contains error: {:s}: {:s}",
+ errorCode,
+ errorDescription);
+ return true;
+ }
+ return false;
+}
+
+// UPNP class implementation
+
+PUPnP::PUPnP()
+{
+ JAMI_DBG("PUPnP: Creating instance [%p] ...", this);
+ runOnPUPnPQueue([this] {
+ threadId_ = getCurrentThread();
+ JAMI_DBG("PUPnP: Instance [%p] created", this);
+ });
+}
+
+PUPnP::~PUPnP()
+{
+ JAMI_DBG("PUPnP: Instance [%p] destroyed", this);
+}
+
+void
+PUPnP::initUpnpLib()
+{
+ assert(not initialized_);
+
+ int upnp_err = UpnpInit2(nullptr, 0);
+
+ if (upnp_err != UPNP_E_SUCCESS) {
+ JAMI_ERR("PUPnP: Can't initialize libupnp: %s", UpnpGetErrorMessage(upnp_err));
+ UpnpFinish();
+ initialized_ = false;
+ return;
+ }
+
+ // Disable embedded WebServer if any.
+ if (UpnpIsWebserverEnabled() == 1) {
+ JAMI_WARN("PUPnP: Web-server is enabled. Disabling");
+ UpnpEnableWebserver(0);
+ if (UpnpIsWebserverEnabled() == 1) {
+ JAMI_ERR("PUPnP: Could not disable Web-server!");
+ } else {
+ JAMI_DBG("PUPnP: Web-server successfully disabled");
+ }
+ }
+
+ char* ip_address = UpnpGetServerIpAddress();
+ char* ip_address6 = nullptr;
+ unsigned short port = UpnpGetServerPort();
+ unsigned short port6 = 0;
+#if UPNP_ENABLE_IPV6
+ ip_address6 = UpnpGetServerIp6Address();
+ port6 = UpnpGetServerPort6();
+#endif
+ if (ip_address6 and port6)
+ JAMI_DBG("PUPnP: Initialized on %s:%u | %s:%u", ip_address, port, ip_address6, port6);
+ else
+ JAMI_DBG("PUPnP: Initialized on %s:%u", ip_address, port);
+
+ // Relax the parser to allow malformed XML text.
+ ixmlRelaxParser(1);
+
+ initialized_ = true;
+}
+
+bool
+PUPnP::isRunning() const
+{
+ std::unique_lock<std::mutex> lk(pupnpMutex_);
+ return not shutdownComplete_;
+}
+
+void
+PUPnP::registerClient()
+{
+ assert(not clientRegistered_);
+
+ CHECK_VALID_THREAD();
+
+ // Register Upnp control point.
+ int upnp_err = UpnpRegisterClient(ctrlPtCallback, this, &ctrlptHandle_);
+ if (upnp_err != UPNP_E_SUCCESS) {
+ JAMI_ERR("PUPnP: Can't register client: %s", UpnpGetErrorMessage(upnp_err));
+ } else {
+ JAMI_DBG("PUPnP: Successfully registered client");
+ clientRegistered_ = true;
+ }
+}
+
+void
+PUPnP::setObserver(UpnpMappingObserver* obs)
+{
+ if (not isValidThread()) {
+ runOnPUPnPQueue([w = weak(), obs] {
+ if (auto upnpThis = w.lock()) {
+ upnpThis->setObserver(obs);
+ }
+ });
+ return;
+ }
+
+ JAMI_DBG("PUPnP: Setting observer to %p", obs);
+
+ observer_ = obs;
+}
+
+const IpAddr
+PUPnP::getHostAddress() const
+{
+ std::lock_guard<std::mutex> lock(pupnpMutex_);
+ return hostAddress_;
+}
+
+void
+PUPnP::terminate(std::condition_variable& cv)
+{
+ JAMI_DBG("PUPnP: Terminate instance %p", this);
+
+ clientRegistered_ = false;
+ observer_ = nullptr;
+
+ UpnpUnRegisterClient(ctrlptHandle_);
+
+ if (initialized_) {
+ if (UpnpFinish() != UPNP_E_SUCCESS) {
+ JAMI_ERR("PUPnP: Failed to properly close lib-upnp");
+ }
+
+ initialized_ = false;
+ }
+
+ // Clear all the lists.
+ discoveredIgdList_.clear();
+
+ {
+ std::lock_guard<std::mutex> lock(pupnpMutex_);
+ validIgdList_.clear();
+ shutdownComplete_ = true;
+ cv.notify_one();
+ }
+}
+
+void
+PUPnP::terminate()
+{
+ std::unique_lock<std::mutex> lk(pupnpMutex_);
+ std::condition_variable cv {};
+
+ runOnPUPnPQueue([w = weak(), &cv = cv] {
+ if (auto upnpThis = w.lock()) {
+ upnpThis->terminate(cv);
+ }
+ });
+
+ if (cv.wait_for(lk, std::chrono::seconds(10), [this] { return shutdownComplete_; })) {
+ JAMI_DBG("PUPnP: Shutdown completed");
+ } else {
+ JAMI_ERR("PUPnP: Shutdown timed-out");
+ // Force stop if the shutdown take too much time.
+ shutdownComplete_ = true;
+ }
+}
+
+void
+PUPnP::searchForDevices()
+{
+ CHECK_VALID_THREAD();
+
+ JAMI_DBG("PUPnP: Send IGD search request");
+
+ // Send out search for multiple types of devices, as some routers may possibly
+ // only reply to one.
+
+ auto err = UpnpSearchAsync(ctrlptHandle_, SEARCH_TIMEOUT, UPNP_ROOT_DEVICE, this);
+ if (err != UPNP_E_SUCCESS) {
+ JAMI_WARN("PUPnP: Send search for UPNP_ROOT_DEVICE failed. Error %d: %s",
+ err,
+ UpnpGetErrorMessage(err));
+ }
+
+ err = UpnpSearchAsync(ctrlptHandle_, SEARCH_TIMEOUT, UPNP_IGD_DEVICE, this);
+ if (err != UPNP_E_SUCCESS) {
+ JAMI_WARN("PUPnP: Send search for UPNP_IGD_DEVICE failed. Error %d: %s",
+ err,
+ UpnpGetErrorMessage(err));
+ }
+
+ err = UpnpSearchAsync(ctrlptHandle_, SEARCH_TIMEOUT, UPNP_WANIP_SERVICE, this);
+ if (err != UPNP_E_SUCCESS) {
+ JAMI_WARN("PUPnP: Send search for UPNP_WANIP_SERVICE failed. Error %d: %s",
+ err,
+ UpnpGetErrorMessage(err));
+ }
+
+ err = UpnpSearchAsync(ctrlptHandle_, SEARCH_TIMEOUT, UPNP_WANPPP_SERVICE, this);
+ if (err != UPNP_E_SUCCESS) {
+ JAMI_WARN("PUPnP: Send search for UPNP_WANPPP_SERVICE failed. Error %d: %s",
+ err,
+ UpnpGetErrorMessage(err));
+ }
+}
+
+void
+PUPnP::clearIgds()
+{
+ if (not isValidThread()) {
+ runOnPUPnPQueue([w = weak()] {
+ if (auto upnpThis = w.lock()) {
+ upnpThis->clearIgds();
+ }
+ });
+ return;
+ }
+
+ JAMI_DBG("PUPnP: clearing IGDs and devices lists");
+
+ if (searchForIgdTimer_)
+ searchForIgdTimer_->cancel();
+
+ igdSearchCounter_ = 0;
+
+ {
+ std::lock_guard<std::mutex> lock(pupnpMutex_);
+ for (auto const& igd : validIgdList_) {
+ igd->setValid(false);
+ }
+ validIgdList_.clear();
+ hostAddress_ = {};
+ }
+
+ discoveredIgdList_.clear();
+}
+
+void
+PUPnP::searchForIgd()
+{
+ if (not isValidThread()) {
+ runOnPUPnPQueue([w = weak()] {
+ if (auto upnpThis = w.lock()) {
+ upnpThis->searchForIgd();
+ }
+ });
+ return;
+ }
+
+ // Update local address before searching.
+ updateHostAddress();
+
+ if (isReady()) {
+ JAMI_DBG("PUPnP: Already have a valid IGD. Skip the search request");
+ return;
+ }
+
+ if (igdSearchCounter_++ >= PUPNP_MAX_RESTART_SEARCH_RETRIES) {
+ JAMI_WARN("PUPnP: Setup failed after %u trials. PUPnP will be disabled!",
+ PUPNP_MAX_RESTART_SEARCH_RETRIES);
+ return;
+ }
+
+ JAMI_DBG("PUPnP: Start search for IGD: attempt %u", igdSearchCounter_);
+
+ // Do not init if the host is not valid. Otherwise, the init will fail
+ // anyway and may put libupnp in an unstable state (mainly deadlocks)
+ // even if the UpnpFinish() method is called.
+ if (not hasValidHostAddress()) {
+ JAMI_WARN("PUPnP: Host address is invalid. Skipping the IGD search");
+ } else {
+ // Init and register if needed
+ if (not initialized_) {
+ initUpnpLib();
+ }
+ if (initialized_ and not clientRegistered_) {
+ registerClient();
+ }
+ // Start searching
+ if (clientRegistered_) {
+ assert(initialized_);
+ searchForDevices();
+ } else {
+ JAMI_WARN("PUPnP: PUPNP not fully setup. Skipping the IGD search");
+ }
+ }
+
+ // Cancel the current timer (if any) and re-schedule.
+ // The connectivity change may be received while the the local
+ // interface is not fully setup. The rescheduling typically
+ // usefull to mitigate this race.
+ if (searchForIgdTimer_)
+ searchForIgdTimer_->cancel();
+
+ searchForIgdTimer_ = getUpnContextScheduler()->scheduleIn(
+ [w = weak()] {
+ if (auto upnpThis = w.lock())
+ upnpThis->searchForIgd();
+ },
+ PUPNP_SEARCH_RETRY_UNIT * igdSearchCounter_);
+}
+
+std::list<std::shared_ptr<IGD>>
+PUPnP::getIgdList() const
+{
+ std::lock_guard<std::mutex> lock(pupnpMutex_);
+ std::list<std::shared_ptr<IGD>> igdList;
+ for (auto& it : validIgdList_) {
+ // Return only active IGDs.
+ if (it->isValid()) {
+ igdList.emplace_back(it);
+ }
+ }
+ return igdList;
+}
+
+bool
+PUPnP::isReady() const
+{
+ // Must at least have a valid local address.
+ if (not getHostAddress() or getHostAddress().isLoopback())
+ return false;
+
+ return hasValidIgd();
+}
+
+bool
+PUPnP::hasValidIgd() const
+{
+ std::lock_guard<std::mutex> lock(pupnpMutex_);
+ for (auto& it : validIgdList_) {
+ if (it->isValid()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void
+PUPnP::updateHostAddress()
+{
+ std::lock_guard<std::mutex> lock(pupnpMutex_);
+ hostAddress_ = ip_utils::getLocalAddr(AF_INET);
+}
+
+bool
+PUPnP::hasValidHostAddress()
+{
+ std::lock_guard<std::mutex> lock(pupnpMutex_);
+ return hostAddress_ and not hostAddress_.isLoopback();
+}
+
+void
+PUPnP::incrementErrorsCounter(const std::shared_ptr<IGD>& igd)
+{
+ if (not igd or not igd->isValid())
+ return;
+ if (not igd->incrementErrorsCounter()) {
+ // Disable this IGD.
+ igd->setValid(false);
+ // Notify the listener.
+ if (observer_)
+ observer_->onIgdUpdated(igd, UpnpIgdEvent::INVALID_STATE);
+ }
+}
+
+bool
+PUPnP::validateIgd(const std::string& location, IXML_Document* doc_container_ptr)
+{
+ CHECK_VALID_THREAD();
+
+ assert(doc_container_ptr != nullptr);
+
+ XMLDocument document(doc_container_ptr, ixmlDocument_free);
+ auto descDoc = document.get();
+ // Check device type.
+ auto deviceType = getFirstDocItem(descDoc, "deviceType");
+ if (deviceType != UPNP_IGD_DEVICE) {
+ // Device type not IGD.
+ return false;
+ }
+
+ std::shared_ptr<UPnPIGD> igd_candidate = parseIgd(descDoc, location);
+ if (not igd_candidate) {
+ // No valid IGD candidate.
+ return false;
+ }
+
+ JAMI_DBG("PUPnP: Validating the IGD candidate [UDN: %s]\n"
+ " Name : %s\n"
+ " Service Type : %s\n"
+ " Service ID : %s\n"
+ " Base URL : %s\n"
+ " Location URL : %s\n"
+ " control URL : %s\n"
+ " Event URL : %s",
+ igd_candidate->getUID().c_str(),
+ igd_candidate->getFriendlyName().c_str(),
+ igd_candidate->getServiceType().c_str(),
+ igd_candidate->getServiceId().c_str(),
+ igd_candidate->getBaseURL().c_str(),
+ igd_candidate->getLocationURL().c_str(),
+ igd_candidate->getControlURL().c_str(),
+ igd_candidate->getEventSubURL().c_str());
+
+ // Check if IGD is connected.
+ if (not actionIsIgdConnected(*igd_candidate)) {
+ JAMI_WARN("PUPnP: IGD candidate %s is not connected", igd_candidate->getUID().c_str());
+ return false;
+ }
+
+ // Validate external Ip.
+ igd_candidate->setPublicIp(actionGetExternalIP(*igd_candidate));
+ if (igd_candidate->getPublicIp().toString().empty()) {
+ JAMI_WARN("PUPnP: IGD candidate %s has no valid external Ip",
+ igd_candidate->getUID().c_str());
+ return false;
+ }
+
+ // Validate internal Ip.
+ if (igd_candidate->getBaseURL().empty()) {
+ JAMI_WARN("PUPnP: IGD candidate %s has no valid internal Ip",
+ igd_candidate->getUID().c_str());
+ return false;
+ }
+
+ // Typically the IGD local address should be extracted from the XML
+ // document (e.g. parsing the base URL). For simplicity, we assume
+ // that it matches the gateway as seen by the local interface.
+ if (const auto& localGw = ip_utils::getLocalGateway()) {
+ igd_candidate->setLocalIp(localGw);
+ } else {
+ JAMI_WARN("PUPnP: Could not set internal address for IGD candidate %s",
+ igd_candidate->getUID().c_str());
+ return false;
+ }
+
+ // Store info for subscription.
+ std::string eventSub = igd_candidate->getEventSubURL();
+
+ {
+ // Add the IGD if not already present in the list.
+ std::lock_guard<std::mutex> lock(pupnpMutex_);
+ for (auto& igd : validIgdList_) {
+ // Must not be a null pointer
+ assert(igd.get() != nullptr);
+ if (*igd == *igd_candidate) {
+ JAMI_DBG("PUPnP: Device [%s] with int/ext addresses [%s:%s] is already in the list "
+ "of valid IGDs",
+ igd_candidate->getUID().c_str(),
+ igd_candidate->toString().c_str(),
+ igd_candidate->getPublicIp().toString().c_str());
+ return true;
+ }
+ }
+ }
+
+ // We have a valid IGD
+ igd_candidate->setValid(true);
+
+ JAMI_DBG("PUPnP: Added a new IGD [%s] to the list of valid IGDs",
+ igd_candidate->getUID().c_str());
+
+ JAMI_DBG("PUPnP: New IGD addresses [int: %s - ext: %s]",
+ igd_candidate->toString().c_str(),
+ igd_candidate->getPublicIp().toString().c_str());
+
+ // Subscribe to IGD events.
+ int upnp_err = UpnpSubscribeAsync(ctrlptHandle_,
+ eventSub.c_str(),
+ UPNP_INFINITE,
+ subEventCallback,
+ this);
+ if (upnp_err != UPNP_E_SUCCESS) {
+ JAMI_WARN("PUPnP: Failed to send subscribe request to %s: error %i - %s",
+ igd_candidate->getUID().c_str(),
+ upnp_err,
+ UpnpGetErrorMessage(upnp_err));
+ // return false;
+ } else {
+ JAMI_DBG("PUPnP: Successfully subscribed to IGD %s", igd_candidate->getUID().c_str());
+ }
+
+ {
+ // This is a new (and hopefully valid) IGD.
+ std::lock_guard<std::mutex> lock(pupnpMutex_);
+ validIgdList_.emplace_back(igd_candidate);
+ }
+
+ // Report to the listener.
+ runOnUpnpContextQueue([w = weak(), igd_candidate] {
+ if (auto upnpThis = w.lock()) {
+ if (upnpThis->observer_)
+ upnpThis->observer_->onIgdUpdated(igd_candidate, UpnpIgdEvent::ADDED);
+ }
+ });
+
+ return true;
+}
+
+void
+PUPnP::requestMappingAdd(const Mapping& mapping)
+{
+ runOnPUPnPQueue([w = weak(), mapping] {
+ if (auto upnpThis = w.lock()) {
+ if (not upnpThis->isRunning())
+ return;
+ Mapping mapRes(mapping);
+ if (upnpThis->actionAddPortMapping(mapRes)) {
+ mapRes.setState(MappingState::OPEN);
+ mapRes.setInternalAddress(upnpThis->getHostAddress().toString());
+ upnpThis->processAddMapAction(mapRes);
+ } else {
+ upnpThis->incrementErrorsCounter(mapRes.getIgd());
+ mapRes.setState(MappingState::FAILED);
+ upnpThis->processRequestMappingFailure(mapRes);
+ }
+ }
+ });
+}
+
+void
+PUPnP::requestMappingRemove(const Mapping& mapping)
+{
+ // Send remove request using the matching IGD
+ runOnPUPnPQueue([w = weak(), mapping] {
+ if (auto upnpThis = w.lock()) {
+ // Abort if we are shutting down.
+ if (not upnpThis->isRunning())
+ return;
+ if (upnpThis->actionDeletePortMapping(mapping)) {
+ upnpThis->processRemoveMapAction(mapping);
+ } else {
+ assert(mapping.getIgd());
+ // Dont need to report in case of failure.
+ upnpThis->incrementErrorsCounter(mapping.getIgd());
+ }
+ }
+ });
+}
+
+std::shared_ptr<UPnPIGD>
+PUPnP::findMatchingIgd(const std::string& ctrlURL) const
+{
+ std::lock_guard<std::mutex> lock(pupnpMutex_);
+
+ auto iter = std::find_if(validIgdList_.begin(),
+ validIgdList_.end(),
+ [&ctrlURL](const std::shared_ptr<IGD>& igd) {
+ if (auto upnpIgd = std::dynamic_pointer_cast<UPnPIGD>(igd)) {
+ return upnpIgd->getControlURL() == ctrlURL;
+ }
+ return false;
+ });
+
+ if (iter == validIgdList_.end()) {
+ JAMI_WARN("PUPnP: Did not find the IGD matching ctrl URL [%s]", ctrlURL.c_str());
+ return {};
+ }
+
+ return std::dynamic_pointer_cast<UPnPIGD>(*iter);
+}
+
+void
+PUPnP::processAddMapAction(const Mapping& map)
+{
+ CHECK_VALID_THREAD();
+
+ if (observer_ == nullptr)
+ return;
+
+ runOnUpnpContextQueue([w = weak(), map] {
+ if (auto upnpThis = w.lock()) {
+ if (upnpThis->observer_)
+ upnpThis->observer_->onMappingAdded(map.getIgd(), std::move(map));
+ }
+ });
+}
+
+void
+PUPnP::processRequestMappingFailure(const Mapping& map)
+{
+ CHECK_VALID_THREAD();
+
+ if (observer_ == nullptr)
+ return;
+
+ runOnUpnpContextQueue([w = weak(), map] {
+ if (auto upnpThis = w.lock()) {
+ JAMI_DBG("PUPnP: Failed to request mapping %s", map.toString().c_str());
+ if (upnpThis->observer_)
+ upnpThis->observer_->onMappingRequestFailed(map);
+ }
+ });
+}
+
+void
+PUPnP::processRemoveMapAction(const Mapping& map)
+{
+ CHECK_VALID_THREAD();
+
+ if (observer_ == nullptr)
+ return;
+
+ runOnUpnpContextQueue([map, obs = observer_] {
+ JAMI_DBG("PUPnP: Closed mapping %s", map.toString().c_str());
+ obs->onMappingRemoved(map.getIgd(), std::move(map));
+ });
+}
+
+const char*
+PUPnP::eventTypeToString(Upnp_EventType eventType)
+{
+ switch (eventType) {
+ case UPNP_CONTROL_ACTION_REQUEST:
+ return "UPNP_CONTROL_ACTION_REQUEST";
+ case UPNP_CONTROL_ACTION_COMPLETE:
+ return "UPNP_CONTROL_ACTION_COMPLETE";
+ case UPNP_CONTROL_GET_VAR_REQUEST:
+ return "UPNP_CONTROL_GET_VAR_REQUEST";
+ case UPNP_CONTROL_GET_VAR_COMPLETE:
+ return "UPNP_CONTROL_GET_VAR_COMPLETE";
+ case UPNP_DISCOVERY_ADVERTISEMENT_ALIVE:
+ return "UPNP_DISCOVERY_ADVERTISEMENT_ALIVE";
+ case UPNP_DISCOVERY_ADVERTISEMENT_BYEBYE:
+ return "UPNP_DISCOVERY_ADVERTISEMENT_BYEBYE";
+ case UPNP_DISCOVERY_SEARCH_RESULT:
+ return "UPNP_DISCOVERY_SEARCH_RESULT";
+ case UPNP_DISCOVERY_SEARCH_TIMEOUT:
+ return "UPNP_DISCOVERY_SEARCH_TIMEOUT";
+ case UPNP_EVENT_SUBSCRIPTION_REQUEST:
+ return "UPNP_EVENT_SUBSCRIPTION_REQUEST";
+ case UPNP_EVENT_RECEIVED:
+ return "UPNP_EVENT_RECEIVED";
+ case UPNP_EVENT_RENEWAL_COMPLETE:
+ return "UPNP_EVENT_RENEWAL_COMPLETE";
+ case UPNP_EVENT_SUBSCRIBE_COMPLETE:
+ return "UPNP_EVENT_SUBSCRIBE_COMPLETE";
+ case UPNP_EVENT_UNSUBSCRIBE_COMPLETE:
+ return "UPNP_EVENT_UNSUBSCRIBE_COMPLETE";
+ case UPNP_EVENT_AUTORENEWAL_FAILED:
+ return "UPNP_EVENT_AUTORENEWAL_FAILED";
+ case UPNP_EVENT_SUBSCRIPTION_EXPIRED:
+ return "UPNP_EVENT_SUBSCRIPTION_EXPIRED";
+ default:
+ return "Unknown UPNP Event";
+ }
+}
+
+int
+PUPnP::ctrlPtCallback(Upnp_EventType event_type, const void* event, void* user_data)
+{
+ auto pupnp = static_cast<PUPnP*>(user_data);
+
+ if (pupnp == nullptr) {
+ JAMI_WARN("PUPnP: Control point callback without PUPnP");
+ return UPNP_E_SUCCESS;
+ }
+
+ auto upnpThis = pupnp->weak().lock();
+
+ if (not upnpThis)
+ return UPNP_E_SUCCESS;
+
+ // Ignore if already unregistered.
+ if (not upnpThis->clientRegistered_)
+ return UPNP_E_SUCCESS;
+
+ // Process the callback.
+ return upnpThis->handleCtrlPtUPnPEvents(event_type, event);
+}
+
+PUPnP::CtrlAction
+PUPnP::getAction(const char* xmlNode)
+{
+ if (strstr(xmlNode, ACTION_ADD_PORT_MAPPING)) {
+ return CtrlAction::ADD_PORT_MAPPING;
+ } else if (strstr(xmlNode, ACTION_DELETE_PORT_MAPPING)) {
+ return CtrlAction::DELETE_PORT_MAPPING;
+ } else if (strstr(xmlNode, ACTION_GET_GENERIC_PORT_MAPPING_ENTRY)) {
+ return CtrlAction::GET_GENERIC_PORT_MAPPING_ENTRY;
+ } else if (strstr(xmlNode, ACTION_GET_STATUS_INFO)) {
+ return CtrlAction::GET_STATUS_INFO;
+ } else if (strstr(xmlNode, ACTION_GET_EXTERNAL_IP_ADDRESS)) {
+ return CtrlAction::GET_EXTERNAL_IP_ADDRESS;
+ } else {
+ return CtrlAction::UNKNOWN;
+ }
+}
+
+void
+PUPnP::processDiscoverySearchResult(const std::string& cpDeviceId,
+ const std::string& igdLocationUrl,
+ const IpAddr& dstAddr)
+{
+ CHECK_VALID_THREAD();
+
+ // Update host address if needed.
+ if (not hasValidHostAddress())
+ updateHostAddress();
+
+ // The host address must be valid to proceed.
+ if (not hasValidHostAddress()) {
+ JAMI_WARN("PUPnP: Local address is invalid. Ignore search result for now!");
+ return;
+ }
+
+ // Use the device ID and the URL as ID. This is necessary as some
+ // IGDs may have the same device ID but different URLs.
+
+ auto igdId = cpDeviceId + " url: " + igdLocationUrl;
+
+ if (not discoveredIgdList_.emplace(igdId).second) {
+ // JAMI_WARN("PUPnP: IGD [%s] already in the list", igdId.c_str());
+ return;
+ }
+
+ JAMI_DBG("PUPnP: Discovered a new IGD [%s]", igdId.c_str());
+
+ // NOTE: here, we check if the location given is related to the source address.
+ // If it's not the case, it's certainly a router plugged in the network, but not
+ // related to this network. So the given location will be unreachable and this
+ // will cause some timeout.
+
+ // Only check the IP address (ignore the port number).
+ dht::http::Url url(igdLocationUrl);
+ if (IpAddr(url.host).toString(false) != dstAddr.toString(false)) {
+ JAMI_DBG("PUPnP: Returned location %s does not match the source address %s",
+ IpAddr(url.host).toString(true, true).c_str(),
+ dstAddr.toString(true, true).c_str());
+ return;
+ }
+
+ // Run a separate thread to prevent blocking this thread
+ // if the IGD HTTP server is not responsive.
+ dht::ThreadPool::io().run([w = weak(), igdLocationUrl] {
+ if (auto upnpThis = w.lock()) {
+ upnpThis->downLoadIgdDescription(igdLocationUrl);
+ }
+ });
+}
+
+void
+PUPnP::downLoadIgdDescription(const std::string& locationUrl)
+{
+ IXML_Document* doc_container_ptr = nullptr;
+ int upnp_err = UpnpDownloadXmlDoc(locationUrl.c_str(), &doc_container_ptr);
+
+ if (upnp_err != UPNP_E_SUCCESS or not doc_container_ptr) {
+ JAMI_WARN("PUPnP: Error downloading device XML document from %s -> %s",
+ locationUrl.c_str(),
+ UpnpGetErrorMessage(upnp_err));
+ } else {
+ JAMI_DBG("PUPnP: Succeeded to download device XML document from %s", locationUrl.c_str());
+ runOnPUPnPQueue([w = weak(), url = locationUrl, doc_container_ptr] {
+ if (auto upnpThis = w.lock()) {
+ upnpThis->validateIgd(url, doc_container_ptr);
+ }
+ });
+ }
+}
+
+void
+PUPnP::processDiscoveryAdvertisementByebye(const std::string& cpDeviceId)
+{
+ CHECK_VALID_THREAD();
+
+ discoveredIgdList_.erase(cpDeviceId);
+
+ std::shared_ptr<IGD> igd;
+ {
+ std::lock_guard<std::mutex> lk(pupnpMutex_);
+ for (auto it = validIgdList_.begin(); it != validIgdList_.end();) {
+ if ((*it)->getUID() == cpDeviceId) {
+ igd = *it;
+ JAMI_DBG("PUPnP: Received [%s] for IGD [%s] %s. Will be removed.",
+ PUPnP::eventTypeToString(UPNP_DISCOVERY_ADVERTISEMENT_BYEBYE),
+ igd->getUID().c_str(),
+ igd->toString().c_str());
+ igd->setValid(false);
+ // Remove the IGD.
+ it = validIgdList_.erase(it);
+ break;
+ } else {
+ it++;
+ }
+ }
+ }
+
+ // Notify the listener.
+ if (observer_ and igd) {
+ observer_->onIgdUpdated(igd, UpnpIgdEvent::REMOVED);
+ }
+}
+
+void
+PUPnP::processDiscoverySubscriptionExpired(Upnp_EventType event_type, const std::string& eventSubUrl)
+{
+ CHECK_VALID_THREAD();
+
+ std::lock_guard<std::mutex> lk(pupnpMutex_);
+ for (auto& it : validIgdList_) {
+ if (auto igd = std::dynamic_pointer_cast<UPnPIGD>(it)) {
+ if (igd->getEventSubURL() == eventSubUrl) {
+ JAMI_DBG("PUPnP: Received [%s] event for IGD [%s] %s. Request a new subscribe.",
+ PUPnP::eventTypeToString(event_type),
+ igd->getUID().c_str(),
+ igd->toString().c_str());
+ UpnpSubscribeAsync(ctrlptHandle_,
+ eventSubUrl.c_str(),
+ UPNP_INFINITE,
+ subEventCallback,
+ this);
+ break;
+ }
+ }
+ }
+}
+
+int
+PUPnP::handleCtrlPtUPnPEvents(Upnp_EventType event_type, const void* event)
+{
+ switch (event_type) {
+ // "ALIVE" events are processed as "SEARCH RESULT". It might be usefull
+ // if "SEARCH RESULT" was missed.
+ case UPNP_DISCOVERY_ADVERTISEMENT_ALIVE:
+ case UPNP_DISCOVERY_SEARCH_RESULT: {
+ const UpnpDiscovery* d_event = (const UpnpDiscovery*) event;
+
+ // First check the error code.
+ auto upnp_status = UpnpDiscovery_get_ErrCode(d_event);
+ if (upnp_status != UPNP_E_SUCCESS) {
+ JAMI_ERR("PUPnP: UPNP discovery is in erroneous state: %s",
+ UpnpGetErrorMessage(upnp_status));
+ break;
+ }
+
+ // Parse the event's data.
+ std::string deviceId {UpnpDiscovery_get_DeviceID_cstr(d_event)};
+ std::string location {UpnpDiscovery_get_Location_cstr(d_event)};
+ IpAddr dstAddr(*(const pj_sockaddr*) (UpnpDiscovery_get_DestAddr(d_event)));
+ runOnPUPnPQueue([w = weak(),
+ deviceId = std::move(deviceId),
+ location = std::move(location),
+ dstAddr = std::move(dstAddr)] {
+ if (auto upnpThis = w.lock()) {
+ upnpThis->processDiscoverySearchResult(deviceId, location, dstAddr);
+ }
+ });
+ break;
+ }
+ case UPNP_DISCOVERY_ADVERTISEMENT_BYEBYE: {
+ const UpnpDiscovery* d_event = (const UpnpDiscovery*) event;
+
+ std::string deviceId(UpnpDiscovery_get_DeviceID_cstr(d_event));
+
+ // Process the response on the main thread.
+ runOnPUPnPQueue([w = weak(), deviceId = std::move(deviceId)] {
+ if (auto upnpThis = w.lock()) {
+ upnpThis->processDiscoveryAdvertisementByebye(deviceId);
+ }
+ });
+ break;
+ }
+ case UPNP_DISCOVERY_SEARCH_TIMEOUT: {
+ // Even if the discovery search is successful, it's normal to receive
+ // time-out events. This because we send search requests using various
+ // device types, which some of them may not return a response.
+ break;
+ }
+ case UPNP_EVENT_RECEIVED: {
+ // Nothing to do.
+ break;
+ }
+ // Treat failed autorenewal like an expired subscription.
+ case UPNP_EVENT_AUTORENEWAL_FAILED:
+ case UPNP_EVENT_SUBSCRIPTION_EXPIRED: // This event will occur only if autorenewal is disabled.
+ {
+ JAMI_WARN("PUPnP: Received Subscription Event %s", eventTypeToString(event_type));
+ const UpnpEventSubscribe* es_event = (const UpnpEventSubscribe*) event;
+ if (es_event == nullptr) {
+ JAMI_WARN("PUPnP: Received Subscription Event with null pointer");
+ break;
+ }
+ std::string publisherUrl(UpnpEventSubscribe_get_PublisherUrl_cstr(es_event));
+
+ // Process the response on the main thread.
+ runOnPUPnPQueue([w = weak(), event_type, publisherUrl = std::move(publisherUrl)] {
+ if (auto upnpThis = w.lock()) {
+ upnpThis->processDiscoverySubscriptionExpired(event_type, publisherUrl);
+ }
+ });
+ break;
+ }
+ case UPNP_EVENT_SUBSCRIBE_COMPLETE:
+ case UPNP_EVENT_UNSUBSCRIBE_COMPLETE: {
+ UpnpEventSubscribe* es_event = (UpnpEventSubscribe*) event;
+ if (es_event == nullptr) {
+ JAMI_WARN("PUPnP: Received Subscription Event with null pointer");
+ } else {
+ UpnpEventSubscribe_delete(es_event);
+ }
+ break;
+ }
+ case UPNP_CONTROL_ACTION_COMPLETE: {
+ const UpnpActionComplete* a_event = (const UpnpActionComplete*) event;
+ if (a_event == nullptr) {
+ JAMI_WARN("PUPnP: Received Action Complete Event with null pointer");
+ break;
+ }
+ auto res = UpnpActionComplete_get_ErrCode(a_event);
+ if (res != UPNP_E_SUCCESS and res != UPNP_E_TIMEDOUT) {
+ auto err = UpnpActionComplete_get_ErrCode(a_event);
+ JAMI_WARN("PUPnP: Received Action Complete error %i %s", err, UpnpGetErrorMessage(err));
+ } else {
+ auto actionRequest = UpnpActionComplete_get_ActionRequest(a_event);
+ // Abort if there is no action to process.
+ if (actionRequest == nullptr) {
+ JAMI_WARN("PUPnP: Can't get the Action Request data from the event");
+ break;
+ }
+
+ auto actionResult = UpnpActionComplete_get_ActionResult(a_event);
+ if (actionResult != nullptr) {
+ ixmlDocument_free(actionResult);
+ } else {
+ JAMI_WARN("PUPnP: Action Result document not found");
+ }
+ }
+ break;
+ }
+ default: {
+ JAMI_WARN("PUPnP: Unhandled Control Point event");
+ break;
+ }
+ }
+
+ return UPNP_E_SUCCESS;
+}
+
+int
+PUPnP::subEventCallback(Upnp_EventType event_type, const void* event, void* user_data)
+{
+ if (auto pupnp = static_cast<PUPnP*>(user_data))
+ return pupnp->handleSubscriptionUPnPEvent(event_type, event);
+ JAMI_WARN("PUPnP: Subscription callback without service Id string");
+ return 0;
+}
+
+int
+PUPnP::handleSubscriptionUPnPEvent(Upnp_EventType, const void* event)
+{
+ UpnpEventSubscribe* es_event = static_cast<UpnpEventSubscribe*>(const_cast<void*>(event));
+
+ if (es_event == nullptr) {
+ JAMI_ERR("PUPnP: Unexpected null pointer!");
+ return UPNP_E_INVALID_ARGUMENT;
+ }
+ std::string publisherUrl(UpnpEventSubscribe_get_PublisherUrl_cstr(es_event));
+ int upnp_err = UpnpEventSubscribe_get_ErrCode(es_event);
+ if (upnp_err != UPNP_E_SUCCESS) {
+ JAMI_WARN("PUPnP: Subscription error %s from %s",
+ UpnpGetErrorMessage(upnp_err),
+ publisherUrl.c_str());
+ return upnp_err;
+ }
+
+ return UPNP_E_SUCCESS;
+}
+
+std::unique_ptr<UPnPIGD>
+PUPnP::parseIgd(IXML_Document* doc, std::string locationUrl)
+{
+ if (not(doc and locationUrl.c_str()))
+ return nullptr;
+
+ // Check the UDN to see if its already in our device list.
+ std::string UDN(getFirstDocItem(doc, "UDN"));
+ if (UDN.empty()) {
+ JAMI_WARN("PUPnP: could not find UDN in description document of device");
+ return nullptr;
+ } else {
+ std::lock_guard<std::mutex> lk(pupnpMutex_);
+ for (auto& it : validIgdList_) {
+ if (it->getUID() == UDN) {
+ // We already have this device in our list.
+ return nullptr;
+ }
+ }
+ }
+
+ JAMI_DBG("PUPnP: Found new device [%s]", UDN.c_str());
+
+ std::unique_ptr<UPnPIGD> new_igd;
+ int upnp_err;
+
+ // Get friendly name.
+ std::string friendlyName(getFirstDocItem(doc, "friendlyName"));
+
+ // Get base URL.
+ std::string baseURL(getFirstDocItem(doc, "URLBase"));
+ if (baseURL.empty())
+ baseURL = locationUrl;
+
+ // Get list of services defined by serviceType.
+ std::unique_ptr<IXML_NodeList, decltype(ixmlNodeList_free)&> serviceList(nullptr,
+ ixmlNodeList_free);
+ serviceList.reset(ixmlDocument_getElementsByTagName(doc, "serviceType"));
+ unsigned long list_length = ixmlNodeList_length(serviceList.get());
+
+ // Go through the "serviceType" nodes until we find the the correct service type.
+ for (unsigned long node_idx = 0; node_idx < list_length; node_idx++) {
+ IXML_Node* serviceType_node = ixmlNodeList_item(serviceList.get(), node_idx);
+ std::string serviceType(getElementText(serviceType_node));
+
+ // Only check serviceType of WANIPConnection or WANPPPConnection.
+ if (serviceType != UPNP_WANIP_SERVICE
+ && serviceType != UPNP_WANPPP_SERVICE) {
+ // IGD is not WANIP or WANPPP service. Going to next node.
+ continue;
+ }
+
+ // Get parent node.
+ IXML_Node* service_node = ixmlNode_getParentNode(serviceType_node);
+ if (not service_node) {
+ // IGD serviceType has no parent node. Going to next node.
+ continue;
+ }
+
+ // Perform sanity check. The parent node should be called "service".
+ if (strcmp(ixmlNode_getNodeName(service_node), "service") != 0) {
+ // IGD "serviceType" parent node is not called "service". Going to next node.
+ continue;
+ }
+
+ // Get serviceId.
+ IXML_Element* service_element = (IXML_Element*) service_node;
+ std::string serviceId(getFirstElementItem(service_element, "serviceId"));
+ if (serviceId.empty()) {
+ // IGD "serviceId" is empty. Going to next node.
+ continue;
+ }
+
+ // Get the relative controlURL and turn it into absolute address using the URLBase.
+ std::string controlURL(getFirstElementItem(service_element, "controlURL"));
+ if (controlURL.empty()) {
+ // IGD control URL is empty. Going to next node.
+ continue;
+ }
+
+ char* absolute_control_url = nullptr;
+ upnp_err = UpnpResolveURL2(baseURL.c_str(), controlURL.c_str(), &absolute_control_url);
+ if (upnp_err == UPNP_E_SUCCESS)
+ controlURL = absolute_control_url;
+ else
+ JAMI_WARN("PUPnP: Error resolving absolute controlURL -> %s",
+ UpnpGetErrorMessage(upnp_err));
+
+ std::free(absolute_control_url);
+
+ // Get the relative eventSubURL and turn it into absolute address using the URLBase.
+ std::string eventSubURL(getFirstElementItem(service_element, "eventSubURL"));
+ if (eventSubURL.empty()) {
+ JAMI_WARN("PUPnP: IGD event sub URL is empty. Going to next node");
+ continue;
+ }
+
+ char* absolute_event_sub_url = nullptr;
+ upnp_err = UpnpResolveURL2(baseURL.c_str(), eventSubURL.c_str(), &absolute_event_sub_url);
+ if (upnp_err == UPNP_E_SUCCESS)
+ eventSubURL = absolute_event_sub_url;
+ else
+ JAMI_WARN("PUPnP: Error resolving absolute eventSubURL -> %s",
+ UpnpGetErrorMessage(upnp_err));
+
+ std::free(absolute_event_sub_url);
+
+ new_igd.reset(new UPnPIGD(std::move(UDN),
+ std::move(baseURL),
+ std::move(friendlyName),
+ std::move(serviceType),
+ std::move(serviceId),
+ std::move(locationUrl),
+ std::move(controlURL),
+ std::move(eventSubURL)));
+
+ return new_igd;
+ }
+
+ return nullptr;
+}
+
+bool
+PUPnP::actionIsIgdConnected(const UPnPIGD& igd)
+{
+ if (not clientRegistered_)
+ return false;
+
+ // Set action name.
+ IXML_Document* action_container_ptr = UpnpMakeAction("GetStatusInfo",
+ igd.getServiceType().c_str(),
+ 0,
+ nullptr);
+ if (not action_container_ptr) {
+ JAMI_WARN("PUPnP: Failed to make GetStatusInfo action");
+ return false;
+ }
+ XMLDocument action(action_container_ptr, ixmlDocument_free); // Action pointer.
+
+ IXML_Document* response_container_ptr = nullptr;
+ int upnp_err = UpnpSendAction(ctrlptHandle_,
+ igd.getControlURL().c_str(),
+ igd.getServiceType().c_str(),
+ nullptr,
+ action.get(),
+ &response_container_ptr);
+ if (not response_container_ptr or upnp_err != UPNP_E_SUCCESS) {
+ JAMI_WARN("PUPnP: Failed to send GetStatusInfo action -> %s", UpnpGetErrorMessage(upnp_err));
+ return false;
+ }
+ XMLDocument response(response_container_ptr, ixmlDocument_free);
+
+ if (errorOnResponse(response.get())) {
+ JAMI_WARN("PUPnP: Failed to get GetStatusInfo from %s -> %d: %s",
+ igd.getServiceType().c_str(),
+ upnp_err,
+ UpnpGetErrorMessage(upnp_err));
+ return false;
+ }
+
+ // Parse response.
+ auto status = getFirstDocItem(response.get(), "NewConnectionStatus");
+ return status == "Connected";
+}
+
+IpAddr
+PUPnP::actionGetExternalIP(const UPnPIGD& igd)
+{
+ if (not clientRegistered_)
+ return {};
+
+ // Action and response pointers.
+ std::unique_ptr<IXML_Document, decltype(ixmlDocument_free)&>
+ action(nullptr, ixmlDocument_free); // Action pointer.
+ std::unique_ptr<IXML_Document, decltype(ixmlDocument_free)&>
+ response(nullptr, ixmlDocument_free); // Response pointer.
+
+ // Set action name.
+ static constexpr const char* action_name {"GetExternalIPAddress"};
+
+ IXML_Document* action_container_ptr = nullptr;
+ action_container_ptr = UpnpMakeAction(action_name, igd.getServiceType().c_str(), 0, nullptr);
+ action.reset(action_container_ptr);
+
+ if (not action) {
+ JAMI_WARN("PUPnP: Failed to make GetExternalIPAddress action");
+ return {};
+ }
+
+ IXML_Document* response_container_ptr = nullptr;
+ int upnp_err = UpnpSendAction(ctrlptHandle_,
+ igd.getControlURL().c_str(),
+ igd.getServiceType().c_str(),
+ nullptr,
+ action.get(),
+ &response_container_ptr);
+ response.reset(response_container_ptr);
+
+ if (not response or upnp_err != UPNP_E_SUCCESS) {
+ JAMI_WARN("PUPnP: Failed to send GetExternalIPAddress action -> %s",
+ UpnpGetErrorMessage(upnp_err));
+ return {};
+ }
+
+ if (errorOnResponse(response.get())) {
+ JAMI_WARN("PUPnP: Failed to get GetExternalIPAddress from %s -> %d: %s",
+ igd.getServiceType().c_str(),
+ upnp_err,
+ UpnpGetErrorMessage(upnp_err));
+ return {};
+ }
+
+ return {getFirstDocItem(response.get(), "NewExternalIPAddress")};
+}
+
+std::map<Mapping::key_t, Mapping>
+PUPnP::getMappingsListByDescr(const std::shared_ptr<IGD>& igd, const std::string& description) const
+{
+ auto upnpIgd = std::dynamic_pointer_cast<UPnPIGD>(igd);
+ assert(upnpIgd);
+
+ std::map<Mapping::key_t, Mapping> mapList;
+
+ if (not clientRegistered_ or not upnpIgd->isValid() or not upnpIgd->getLocalIp())
+ return mapList;
+
+ // Set action name.
+ static constexpr const char* action_name {"GetGenericPortMappingEntry"};
+
+ for (int entry_idx = 0;; entry_idx++) {
+ std::unique_ptr<IXML_Document, decltype(ixmlDocument_free)&>
+ action(nullptr, ixmlDocument_free); // Action pointer.
+ IXML_Document* action_container_ptr = nullptr;
+
+ std::unique_ptr<IXML_Document, decltype(ixmlDocument_free)&>
+ response(nullptr, ixmlDocument_free); // Response pointer.
+ IXML_Document* response_container_ptr = nullptr;
+
+ UpnpAddToAction(&action_container_ptr,
+ action_name,
+ upnpIgd->getServiceType().c_str(),
+ "NewPortMappingIndex",
+ std::to_string(entry_idx).c_str());
+ action.reset(action_container_ptr);
+
+ if (not action) {
+ JAMI_WARN("PUPnP: Failed to add NewPortMappingIndex action");
+ break;
+ }
+
+ int upnp_err = UpnpSendAction(ctrlptHandle_,
+ upnpIgd->getControlURL().c_str(),
+ upnpIgd->getServiceType().c_str(),
+ nullptr,
+ action.get(),
+ &response_container_ptr);
+ response.reset(response_container_ptr);
+
+ if (not response) {
+ // No existing mapping. Abort silently.
+ break;
+ }
+
+ if (upnp_err != UPNP_E_SUCCESS) {
+ JAMI_ERR("PUPnP: GetGenericPortMappingEntry returned with error: %i", upnp_err);
+ break;
+ }
+
+ // Check error code.
+ auto errorCode = getFirstDocItem(response.get(), "errorCode");
+ if (not errorCode.empty()) {
+ auto error = to_int<int>(errorCode);
+ if (error == ARRAY_IDX_INVALID or error == CONFLICT_IN_MAPPING) {
+ // No more port mapping entries in the response.
+ JAMI_DBG("PUPnP: No more mappings (found a total of %i mappings", entry_idx);
+ break;
+ } else {
+ auto errorDescription = getFirstDocItem(response.get(), "errorDescription");
+ JAMI_ERROR("PUPnP: GetGenericPortMappingEntry returned with error: {:s}: {:s}",
+ errorCode,
+ errorDescription);
+ break;
+ }
+ }
+
+ // Parse the response.
+ auto desc_actual = getFirstDocItem(response.get(), "NewPortMappingDescription");
+ auto client_ip = getFirstDocItem(response.get(), "NewInternalClient");
+
+ if (client_ip != getHostAddress().toString()) {
+ // Silently ignore un-matching addresses.
+ continue;
+ }
+
+ if (desc_actual.find(description) == std::string::npos)
+ continue;
+
+ auto port_internal = getFirstDocItem(response.get(), "NewInternalPort");
+ auto port_external = getFirstDocItem(response.get(), "NewExternalPort");
+ std::string transport(getFirstDocItem(response.get(), "NewProtocol"));
+
+ if (port_internal.empty() || port_external.empty() || transport.empty()) {
+ JAMI_ERR("PUPnP: GetGenericPortMappingEntry returned an invalid entry at index %i",
+ entry_idx);
+ continue;
+ }
+
+ std::transform(transport.begin(), transport.end(), transport.begin(), ::toupper);
+ PortType type = transport.find("TCP") != std::string::npos ? PortType::TCP : PortType::UDP;
+ auto ePort = to_int<uint16_t>(port_external);
+ auto iPort = to_int<uint16_t>(port_internal);
+
+ Mapping map(type, ePort, iPort);
+ map.setIgd(igd);
+
+ mapList.emplace(map.getMapKey(), std::move(map));
+ }
+
+ JAMI_DEBUG("PUPnP: Found {:d} allocated mappings on IGD {:s}",
+ mapList.size(),
+ upnpIgd->toString());
+
+ return mapList;
+}
+
+void
+PUPnP::deleteMappingsByDescription(const std::shared_ptr<IGD>& igd, const std::string& description)
+{
+ if (not(clientRegistered_ and igd->getLocalIp()))
+ return;
+
+ JAMI_DBG("PUPnP: Remove all mappings (if any) on IGD %s matching descr prefix %s",
+ igd->toString().c_str(),
+ Mapping::UPNP_MAPPING_DESCRIPTION_PREFIX);
+
+ auto mapList = getMappingsListByDescr(igd, description);
+
+ for (auto const& [_, map] : mapList) {
+ requestMappingRemove(map);
+ }
+}
+
+bool
+PUPnP::actionAddPortMapping(const Mapping& mapping)
+{
+ CHECK_VALID_THREAD();
+
+ if (not clientRegistered_)
+ return false;
+
+ auto igdIn = std::dynamic_pointer_cast<UPnPIGD>(mapping.getIgd());
+ if (not igdIn)
+ return false;
+
+ // The requested IGD must be present in the list of local valid IGDs.
+ auto igd = findMatchingIgd(igdIn->getControlURL());
+
+ if (not igd or not igd->isValid())
+ return false;
+
+ // Action and response pointers.
+ XMLDocument action(nullptr, ixmlDocument_free);
+ IXML_Document* action_container_ptr = nullptr;
+ XMLDocument response(nullptr, ixmlDocument_free);
+ IXML_Document* response_container_ptr = nullptr;
+
+ // Set action sequence.
+ UpnpAddToAction(&action_container_ptr,
+ ACTION_ADD_PORT_MAPPING,
+ igd->getServiceType().c_str(),
+ "NewRemoteHost",
+ "");
+ UpnpAddToAction(&action_container_ptr,
+ ACTION_ADD_PORT_MAPPING,
+ igd->getServiceType().c_str(),
+ "NewExternalPort",
+ mapping.getExternalPortStr().c_str());
+ UpnpAddToAction(&action_container_ptr,
+ ACTION_ADD_PORT_MAPPING,
+ igd->getServiceType().c_str(),
+ "NewProtocol",
+ mapping.getTypeStr());
+ UpnpAddToAction(&action_container_ptr,
+ ACTION_ADD_PORT_MAPPING,
+ igd->getServiceType().c_str(),
+ "NewInternalPort",
+ mapping.getInternalPortStr().c_str());
+ UpnpAddToAction(&action_container_ptr,
+ ACTION_ADD_PORT_MAPPING,
+ igd->getServiceType().c_str(),
+ "NewInternalClient",
+ getHostAddress().toString().c_str());
+ UpnpAddToAction(&action_container_ptr,
+ ACTION_ADD_PORT_MAPPING,
+ igd->getServiceType().c_str(),
+ "NewEnabled",
+ "1");
+ UpnpAddToAction(&action_container_ptr,
+ ACTION_ADD_PORT_MAPPING,
+ igd->getServiceType().c_str(),
+ "NewPortMappingDescription",
+ mapping.toString().c_str());
+ UpnpAddToAction(&action_container_ptr,
+ ACTION_ADD_PORT_MAPPING,
+ igd->getServiceType().c_str(),
+ "NewLeaseDuration",
+ "0");
+
+ action.reset(action_container_ptr);
+
+ int upnp_err = UpnpSendAction(ctrlptHandle_,
+ igd->getControlURL().c_str(),
+ igd->getServiceType().c_str(),
+ nullptr,
+ action.get(),
+ &response_container_ptr);
+ response.reset(response_container_ptr);
+
+ bool success = true;
+
+ if (upnp_err != UPNP_E_SUCCESS) {
+ JAMI_WARN("PUPnP: Failed to send action %s for mapping %s. %d: %s",
+ ACTION_ADD_PORT_MAPPING,
+ mapping.toString().c_str(),
+ upnp_err,
+ UpnpGetErrorMessage(upnp_err));
+ JAMI_WARN("PUPnP: IGD ctrlUrl %s", igd->getControlURL().c_str());
+ JAMI_WARN("PUPnP: IGD service type %s", igd->getServiceType().c_str());
+
+ success = false;
+ }
+
+ // Check if an error has occurred.
+ auto errorCode = getFirstDocItem(response.get(), "errorCode");
+ if (not errorCode.empty()) {
+ success = false;
+ // Try to get the error description.
+ std::string errorDescription;
+ if (response) {
+ errorDescription = getFirstDocItem(response.get(), "errorDescription");
+ }
+
+ JAMI_WARNING("PUPnP: {:s} returned with error: {:s} {:s}",
+ ACTION_ADD_PORT_MAPPING,
+ errorCode,
+ errorDescription);
+ }
+ return success;
+}
+
+bool
+PUPnP::actionDeletePortMapping(const Mapping& mapping)
+{
+ CHECK_VALID_THREAD();
+
+ if (not clientRegistered_)
+ return false;
+
+ auto igdIn = std::dynamic_pointer_cast<UPnPIGD>(mapping.getIgd());
+ if (not igdIn)
+ return false;
+
+ // The requested IGD must be present in the list of local valid IGDs.
+ auto igd = findMatchingIgd(igdIn->getControlURL());
+
+ if (not igd or not igd->isValid())
+ return false;
+
+ // Action and response pointers.
+ XMLDocument action(nullptr, ixmlDocument_free);
+ IXML_Document* action_container_ptr = nullptr;
+ XMLDocument response(nullptr, ixmlDocument_free);
+ IXML_Document* response_container_ptr = nullptr;
+
+ // Set action sequence.
+ UpnpAddToAction(&action_container_ptr,
+ ACTION_DELETE_PORT_MAPPING,
+ igd->getServiceType().c_str(),
+ "NewRemoteHost",
+ "");
+ UpnpAddToAction(&action_container_ptr,
+ ACTION_DELETE_PORT_MAPPING,
+ igd->getServiceType().c_str(),
+ "NewExternalPort",
+ mapping.getExternalPortStr().c_str());
+ UpnpAddToAction(&action_container_ptr,
+ ACTION_DELETE_PORT_MAPPING,
+ igd->getServiceType().c_str(),
+ "NewProtocol",
+ mapping.getTypeStr());
+
+ action.reset(action_container_ptr);
+
+ int upnp_err = UpnpSendAction(ctrlptHandle_,
+ igd->getControlURL().c_str(),
+ igd->getServiceType().c_str(),
+ nullptr,
+ action.get(),
+ &response_container_ptr);
+ response.reset(response_container_ptr);
+
+ bool success = true;
+
+ if (upnp_err != UPNP_E_SUCCESS) {
+ JAMI_WARN("PUPnP: Failed to send action %s for mapping from %s. %d: %s",
+ ACTION_DELETE_PORT_MAPPING,
+ mapping.toString().c_str(),
+ upnp_err,
+ UpnpGetErrorMessage(upnp_err));
+ JAMI_WARN("PUPnP: IGD ctrlUrl %s", igd->getControlURL().c_str());
+ JAMI_WARN("PUPnP: IGD service type %s", igd->getServiceType().c_str());
+
+ success = false;
+ }
+
+ if (not response) {
+ JAMI_WARN("PUPnP: Failed to get response for %s", ACTION_DELETE_PORT_MAPPING);
+ success = false;
+ }
+
+ // Check if there is an error code.
+ auto errorCode = getFirstDocItem(response.get(), "errorCode");
+ if (not errorCode.empty()) {
+ auto errorDescription = getFirstDocItem(response.get(), "errorDescription");
+ JAMI_WARNING("PUPnP: {:s} returned with error: {:s}: {:s}",
+ ACTION_DELETE_PORT_MAPPING,
+ errorCode,
+ errorDescription);
+ success = false;
+ }
+
+ return success;
+}
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/protocol/pupnp/pupnp.h b/src/upnp/protocol/pupnp/pupnp.h
new file mode 100644
index 0000000..a77f30f
--- /dev/null
+++ b/src/upnp/protocol/pupnp/pupnp.h
@@ -0,0 +1,271 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com>
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+
+#ifdef _WIN32
+#define UPNP_USE_MSVCPP
+#define UPNP_STATIC_LIB
+#endif
+
+#include "../upnp_protocol.h"
+#include "../igd.h"
+#include "upnp_igd.h"
+
+#include "logger.h"
+#include "connectivity/ip_utils.h"
+#include "noncopyable.h"
+#include "compiler_intrinsics.h"
+
+#include <upnp/upnp.h>
+#include <upnp/upnptools.h>
+
+#ifdef _WIN32
+#include <windows.h>
+#include <wincrypt.h>
+#endif
+
+#include <atomic>
+#include <thread>
+#include <list>
+#include <map>
+#include <set>
+#include <string>
+#include <memory>
+#include <future>
+
+namespace jami {
+class IpAddr;
+}
+
+namespace jami {
+namespace upnp {
+
+class PUPnP : public UPnPProtocol
+{
+public:
+ using XMLDocument = std::unique_ptr<IXML_Document, decltype(ixmlDocument_free)&>;
+
+ enum class CtrlAction {
+ UNKNOWN,
+ ADD_PORT_MAPPING,
+ DELETE_PORT_MAPPING,
+ GET_GENERIC_PORT_MAPPING_ENTRY,
+ GET_STATUS_INFO,
+ GET_EXTERNAL_IP_ADDRESS
+ };
+
+ PUPnP();
+ ~PUPnP();
+
+ // Set the observer
+ void setObserver(UpnpMappingObserver* obs) override;
+
+ // Returns the protocol type.
+ NatProtocolType getProtocol() const override { return NatProtocolType::PUPNP; }
+
+ // Get protocol type as string.
+ char const* getProtocolName() const override { return "PUPNP"; }
+
+ // Notifies a change in network.
+ void clearIgds() override;
+
+ // Sends out async search for IGD.
+ void searchForIgd() override;
+
+ // Get the IGD list.
+ std::list<std::shared_ptr<IGD>> getIgdList() const override;
+
+ // Return true if the it's fully setup.
+ bool isReady() const override;
+
+ // Get from the IGD the list of already allocated mappings if any.
+ std::map<Mapping::key_t, Mapping> getMappingsListByDescr(
+ const std::shared_ptr<IGD>& igd, const std::string& descr) const override;
+
+ // Request a new mapping.
+ void requestMappingAdd(const Mapping& mapping) override;
+
+ // Renew an allocated mapping.
+ // Not implemented. Currently, UPNP allocations do not have expiration time.
+ void requestMappingRenew([[maybe_unused]] const Mapping& mapping) override { assert(false); };
+
+ // Removes a mapping.
+ void requestMappingRemove(const Mapping& igdMapping) override;
+
+ // Get the host (local) address.
+ const IpAddr getHostAddress() const override;
+
+ // Terminate the instance.
+ void terminate() override;
+
+private:
+ NON_COPYABLE(PUPnP);
+
+ // Helpers to run tasks on PUPNP private execution queue.
+ ScheduledExecutor* getPUPnPScheduler() { return &pupnpScheduler_; }
+ template<typename Callback>
+ void runOnPUPnPQueue(Callback&& cb)
+ {
+ pupnpScheduler_.run([cb = std::forward<Callback>(cb)]() mutable { cb(); });
+ }
+
+ // Helper to run tasks on UPNP context execution queue.
+ ScheduledExecutor* getUpnContextScheduler() { return UpnpThreadUtil::getScheduler(); }
+
+ void terminate(std::condition_variable& cv);
+
+ // Init lib-upnp
+ void initUpnpLib();
+
+ // Return true if running.
+ bool isRunning() const;
+
+ // Register the client
+ void registerClient();
+
+ // Start search for UPNP devices
+ void searchForDevices();
+
+ // Return true if it has at least one valid IGD.
+ bool hasValidIgd() const;
+
+ // Update the host (local) address.
+ void updateHostAddress();
+
+ // Check the host (local) address.
+ // Returns true if the address is valid.
+ bool hasValidHostAddress();
+
+ // Delete mappings matching the description
+ void deleteMappingsByDescription(const std::shared_ptr<IGD>& igd,
+ const std::string& description);
+
+ // Search for the IGD in the local list of known IGDs.
+ std::shared_ptr<UPnPIGD> findMatchingIgd(const std::string& ctrlURL) const;
+
+ // Process the reception of an add mapping action answer.
+ void processAddMapAction(const Mapping& map);
+
+ // Process the a mapping request failure.
+ void processRequestMappingFailure(const Mapping& map);
+
+ // Process the reception of a remove mapping action answer.
+ void processRemoveMapAction(const Mapping& map);
+
+ // Increment IGD errors counter.
+ void incrementErrorsCounter(const std::shared_ptr<IGD>& igd);
+
+ // Download XML document.
+ void downLoadIgdDescription(const std::string& url);
+
+ // Validate IGD from the xml document received from the router.
+ bool validateIgd(const std::string& location, IXML_Document* doc_container_ptr);
+
+ // Returns control point action callback based on xml node.
+ static CtrlAction getAction(const char* xmlNode);
+
+ // Control point callback.
+ static int ctrlPtCallback(Upnp_EventType event_type, const void* event, void* user_data);
+#if UPNP_VERSION < 10800
+ static inline int ctrlPtCallback(Upnp_EventType event_type, void* event, void* user_data)
+ {
+ return ctrlPtCallback(event_type, (const void*) event, user_data);
+ };
+#endif
+ // Process IGD responses.
+ void processDiscoverySearchResult(const std::string& deviceId,
+ const std::string& igdUrl,
+ const IpAddr& dstAddr);
+ void processDiscoveryAdvertisementByebye(const std::string& deviceId);
+ void processDiscoverySubscriptionExpired(Upnp_EventType event_type,
+ const std::string& eventSubUrl);
+
+ // Callback event handler function for the UPnP client (control point).
+ int handleCtrlPtUPnPEvents(Upnp_EventType event_type, const void* event);
+
+ // Subscription event callback.
+ static int subEventCallback(Upnp_EventType event_type, const void* event, void* user_data);
+#if UPNP_VERSION < 10800
+ static inline int subEventCallback(Upnp_EventType event_type, void* event, void* user_data)
+ {
+ return subEventCallback(event_type, (const void*) event, user_data);
+ };
+#endif
+
+ // Callback subscription event function for handling subscription request.
+ int handleSubscriptionUPnPEvent(Upnp_EventType event_type, const void* event);
+
+ // Parses the IGD candidate.
+ std::unique_ptr<UPnPIGD> parseIgd(IXML_Document* doc, std::string locationUrl);
+
+ // These functions directly create UPnP actions and make synchronous UPnP
+ // control point calls. Must be run on the PUPNP internal execution queue.
+ bool actionIsIgdConnected(const UPnPIGD& igd);
+ IpAddr actionGetExternalIP(const UPnPIGD& igd);
+ bool actionAddPortMapping(const Mapping& mapping);
+ bool actionDeletePortMapping(const Mapping& mapping);
+
+ // Event type to string
+ static const char* eventTypeToString(Upnp_EventType eventType);
+
+ std::weak_ptr<PUPnP> weak() { return std::static_pointer_cast<PUPnP>(shared_from_this()); }
+
+ // Execution queue to run lib upnp actions
+ ScheduledExecutor pupnpScheduler_ {"pupnp"};
+
+ // Initialization status.
+ std::atomic_bool initialized_ {false};
+ // Client registration status.
+ std::atomic_bool clientRegistered_ {false};
+
+ std::shared_ptr<Task> searchForIgdTimer_ {};
+ unsigned int igdSearchCounter_ {0};
+
+ // List of discovered IGDs.
+ std::set<std::string> discoveredIgdList_;
+
+ // Control point handle.
+ UpnpClient_Handle ctrlptHandle_ {-1};
+
+ // Observer to report the results.
+ UpnpMappingObserver* observer_ {nullptr};
+
+ // List of valid IGDs.
+ std::list<std::shared_ptr<IGD>> validIgdList_;
+
+ // Current host address.
+ IpAddr hostAddress_ {};
+
+ // Calls from other threads that does not need synchronous access are
+ // rescheduled on the UPNP private queue. This will avoid the need to
+ // protect most of the data members of this class.
+ // For some internal members (namely the validIgdList and the hostAddress)
+ // that need to be synchronously accessed, are protected by this mutex.
+ mutable std::mutex pupnpMutex_;
+
+ // Shutdown synchronization
+ bool shutdownComplete_ {false};
+};
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/protocol/pupnp/upnp_igd.cpp b/src/upnp/protocol/pupnp/upnp_igd.cpp
new file mode 100644
index 0000000..2f8a332
--- /dev/null
+++ b/src/upnp/protocol/pupnp/upnp_igd.cpp
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com>
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "upnp_igd.h"
+
+namespace jami {
+namespace upnp {
+
+UPnPIGD::UPnPIGD(std::string&& UDN,
+ std::string&& baseURL,
+ std::string&& friendlyName,
+ std::string&& serviceType,
+ std::string&& serviceId,
+ std::string&& locationURL,
+ std::string&& controlURL,
+ std::string&& eventSubURL,
+ IpAddr&& localIp,
+ IpAddr&& publicIp)
+ : IGD(NatProtocolType::PUPNP)
+{
+ uid_ = std::move(UDN);
+ baseURL_ = std::move(baseURL);
+ friendlyName_ = std::move(friendlyName);
+ serviceType_ = std::move(serviceType);
+ serviceId_ = std::move(serviceId);
+ locationURL_ = std::move(locationURL);
+ controlURL_ = std::move(controlURL);
+ eventSubURL_ = std::move(eventSubURL);
+ localIp_ = std::move(localIp);
+ publicIp_ = std::move(publicIp);
+}
+
+bool
+UPnPIGD::operator==(IGD& other) const
+{
+ return localIp_ == other.getLocalIp() and publicIp_ == other.getPublicIp();
+}
+
+bool
+UPnPIGD::operator==(UPnPIGD& other) const
+{
+ if (localIp_ and publicIp_) {
+ if (localIp_ != other.localIp_ or publicIp_ != other.publicIp_) {
+ return false;
+ }
+ }
+
+ return uid_ == other.uid_ and baseURL_ == other.baseURL_
+ and friendlyName_ == other.friendlyName_ and serviceType_ == other.serviceType_
+ and serviceId_ == other.serviceId_ and locationURL_ == other.locationURL_
+ and controlURL_ == other.controlURL_ and eventSubURL_ == other.eventSubURL_;
+}
+
+} // namespace upnp
+} // namespace jami
\ No newline at end of file
diff --git a/src/upnp/protocol/pupnp/upnp_igd.h b/src/upnp/protocol/pupnp/upnp_igd.h
new file mode 100644
index 0000000..2ad213b
--- /dev/null
+++ b/src/upnp/protocol/pupnp/upnp_igd.h
@@ -0,0 +1,106 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com>
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+
+#include "connectivity/upnp/protocol/igd.h"
+
+#include "noncopyable.h"
+#include "connectivity/ip_utils.h"
+
+#include <map>
+#include <string>
+#include <chrono>
+#include <functional>
+
+namespace jami {
+namespace upnp {
+
+class UPnPIGD : public IGD
+{
+public:
+ UPnPIGD(std::string&& UDN,
+ std::string&& baseURL,
+ std::string&& friendlyName,
+ std::string&& serviceType,
+ std::string&& serviceId,
+ std::string&& locationURL,
+ std::string&& controlURL,
+ std::string&& eventSubURL,
+ IpAddr&& localIp = {},
+ IpAddr&& publicIp = {});
+
+ ~UPnPIGD() {}
+
+ bool operator==(IGD& other) const;
+ bool operator==(UPnPIGD& other) const;
+
+ const std::string& getBaseURL() const
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return baseURL_;
+ }
+ const std::string& getFriendlyName() const
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return friendlyName_;
+ }
+ const std::string& getServiceType() const
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return serviceType_;
+ }
+ const std::string& getServiceId() const
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return serviceId_;
+ }
+ const std::string& getLocationURL() const
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return locationURL_;
+ }
+ const std::string& getControlURL() const
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return controlURL_;
+ }
+ const std::string& getEventSubURL() const
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return eventSubURL_;
+ }
+
+ const std::string toString() const override { return controlURL_; }
+
+private:
+ std::string baseURL_ {};
+ std::string friendlyName_ {};
+ std::string serviceType_ {};
+ std::string serviceId_ {};
+ std::string locationURL_ {};
+ std::string controlURL_ {};
+ std::string eventSubURL_ {};
+};
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/protocol/upnp_protocol.h b/src/upnp/protocol/upnp_protocol.h
new file mode 100644
index 0000000..b38a4dd
--- /dev/null
+++ b/src/upnp/protocol/upnp_protocol.h
@@ -0,0 +1,126 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+
+#include "igd.h"
+#include "mapping.h"
+#include "ip_utils.h"
+//#include "upnp/upnp_thread_util.h"
+
+#include <map>
+#include <string>
+#include <chrono>
+#include <functional>
+#include <condition_variable>
+#include <list>
+
+namespace jami {
+namespace upnp {
+
+// UPnP device descriptions.
+constexpr static const char* UPNP_ROOT_DEVICE = "upnp:rootdevice";
+constexpr static const char* UPNP_IGD_DEVICE
+ = "urn:schemas-upnp-org:device:InternetGatewayDevice:1";
+constexpr static const char* UPNP_WAN_DEVICE = "urn:schemas-upnp-org:device:WANDevice:1";
+constexpr static const char* UPNP_WANCON_DEVICE
+ = "urn:schemas-upnp-org:device:WANConnectionDevice:1";
+constexpr static const char* UPNP_WANIP_SERVICE = "urn:schemas-upnp-org:service:WANIPConnection:1";
+constexpr static const char* UPNP_WANPPP_SERVICE
+ = "urn:schemas-upnp-org:service:WANPPPConnection:1";
+
+enum class UpnpIgdEvent { ADDED, REMOVED, INVALID_STATE };
+
+// Interface used to report mapping event from the protocol implementations.
+// This interface is meant to be implemented only by UPnPConext class. Sincce
+// this class is a singleton, it's assumed that it out-lives the protocol
+// implementations. In other words, the observer is always assumed to point to a
+// valid instance.
+class UpnpMappingObserver
+{
+public:
+ UpnpMappingObserver() {};
+ virtual ~UpnpMappingObserver() {};
+
+ virtual void onIgdUpdated(const std::shared_ptr<IGD>& igd, UpnpIgdEvent event) = 0;
+ virtual void onMappingAdded(const std::shared_ptr<IGD>& igd, const Mapping& map) = 0;
+ virtual void onMappingRequestFailed(const Mapping& map) = 0;
+#if HAVE_LIBNATPMP
+ virtual void onMappingRenewed(const std::shared_ptr<IGD>& igd, const Mapping& map) = 0;
+#endif
+ virtual void onMappingRemoved(const std::shared_ptr<IGD>& igd, const Mapping& map) = 0;
+};
+
+// Pure virtual interface class that UPnPContext uses to call protocol functions.
+class UPnPProtocol : public std::enable_shared_from_this<UPnPProtocol>//, protected UpnpThreadUtil
+{
+public:
+ enum class UpnpError : int { INVALID_ERR = -1, ERROR_OK, CONFLICT_IN_MAPPING };
+
+ UPnPProtocol() {};
+ virtual ~UPnPProtocol() {};
+
+ // Get protocol type.
+ virtual NatProtocolType getProtocol() const = 0;
+
+ // Get protocol type as string.
+ virtual char const* getProtocolName() const = 0;
+
+ // Clear all known IGDs.
+ virtual void clearIgds() = 0;
+
+ // Search for IGD.
+ virtual void searchForIgd() = 0;
+
+ // Get the IGD instance.
+ virtual std::list<std::shared_ptr<IGD>> getIgdList() const = 0;
+
+ // Return true if it has at least one valid IGD.
+ virtual bool isReady() const = 0;
+
+ // Get the list of already allocated mappings if any.
+ virtual std::map<Mapping::key_t, Mapping> getMappingsListByDescr(const std::shared_ptr<IGD>&,
+ const std::string&) const
+ {
+ return {};
+ }
+
+ // Sends a request to add a mapping.
+ virtual void requestMappingAdd(const Mapping& map) = 0;
+
+ // Renew an allocated mapping.
+ virtual void requestMappingRenew(const Mapping& mapping) = 0;
+
+ // Sends a request to remove a mapping.
+ virtual void requestMappingRemove(const Mapping& igdMapping) = 0;
+
+ // Set the user callbacks.
+ virtual void setObserver(UpnpMappingObserver* obs) = 0;
+
+ // Get the current host (local) address
+ virtual const IpAddr getHostAddress() const = 0;
+
+ // Terminate
+ virtual void terminate() = 0;
+};
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/upnp_context.cpp b/src/upnp/upnp_context.cpp
new file mode 100644
index 0000000..ef556f1
--- /dev/null
+++ b/src/upnp/upnp_context.cpp
@@ -0,0 +1,1339 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com>
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "upnp_context.h"
+
+namespace jami {
+namespace upnp {
+
+constexpr static auto MAP_UPDATE_INTERVAL = std::chrono::seconds(30);
+constexpr static int MAX_REQUEST_RETRIES = 20;
+constexpr static int MAX_REQUEST_REMOVE_COUNT = 5;
+
+constexpr static uint16_t UPNP_TCP_PORT_MIN {10000};
+constexpr static uint16_t UPNP_TCP_PORT_MAX {UPNP_TCP_PORT_MIN + 5000};
+constexpr static uint16_t UPNP_UDP_PORT_MIN {20000};
+constexpr static uint16_t UPNP_UDP_PORT_MAX {UPNP_UDP_PORT_MIN + 5000};
+
+UPnPContext::UPnPContext()
+{
+ JAMI_DBG("Creating UPnPContext instance [%p]", this);
+
+ // Set port ranges
+ portRange_.emplace(PortType::TCP, std::make_pair(UPNP_TCP_PORT_MIN, UPNP_TCP_PORT_MAX));
+ portRange_.emplace(PortType::UDP, std::make_pair(UPNP_UDP_PORT_MIN, UPNP_UDP_PORT_MAX));
+
+ if (not isValidThread()) {
+ runOnUpnpContextQueue([this] { init(); });
+ return;
+ }
+}
+
+std::shared_ptr<UPnPContext>
+UPnPContext::getUPnPContext()
+{
+ // This is the unique shared instance (singleton) of UPnPContext class.
+ static auto context = std::make_shared<UPnPContext>();
+ return context;
+}
+
+void
+UPnPContext::shutdown(std::condition_variable& cv)
+{
+ JAMI_DBG("Shutdown UPnPContext instance [%p]", this);
+
+ stopUpnp(true);
+
+ for (auto const& [_, proto] : protocolList_) {
+ proto->terminate();
+ }
+
+ {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ mappingList_->clear();
+ if (mappingListUpdateTimer_)
+ mappingListUpdateTimer_->cancel();
+ controllerList_.clear();
+ protocolList_.clear();
+ shutdownComplete_ = true;
+ cv.notify_one();
+ }
+}
+
+void
+UPnPContext::shutdown()
+{
+ std::unique_lock<std::mutex> lk(mappingMutex_);
+ std::condition_variable cv;
+
+ runOnUpnpContextQueue([&, this] { shutdown(cv); });
+
+ JAMI_DBG("Waiting for shutdown ...");
+
+ if (cv.wait_for(lk, std::chrono::seconds(30), [this] { return shutdownComplete_; })) {
+ JAMI_DBG("Shutdown completed");
+ } else {
+ JAMI_ERR("Shutdown timed-out");
+ }
+}
+
+UPnPContext::~UPnPContext()
+{
+ JAMI_DBG("UPnPContext instance [%p] destroyed", this);
+}
+
+void
+UPnPContext::init()
+{
+ threadId_ = getCurrentThread();
+ CHECK_VALID_THREAD();
+
+#if HAVE_LIBNATPMP
+ auto natPmp = std::make_shared<NatPmp>();
+ natPmp->setObserver(this);
+ protocolList_.emplace(NatProtocolType::NAT_PMP, std::move(natPmp));
+#endif
+
+#if HAVE_LIBUPNP
+ auto pupnp = std::make_shared<PUPnP>();
+ pupnp->setObserver(this);
+ protocolList_.emplace(NatProtocolType::PUPNP, std::move(pupnp));
+#endif
+}
+
+void
+UPnPContext::startUpnp()
+{
+ assert(not controllerList_.empty());
+
+ CHECK_VALID_THREAD();
+
+ JAMI_DBG("Starting UPNP context");
+
+ // Request a new IGD search.
+ for (auto const& [_, protocol] : protocolList_) {
+ protocol->searchForIgd();
+ }
+
+ started_ = true;
+}
+
+void
+UPnPContext::stopUpnp(bool forceRelease)
+{
+ if (not isValidThread()) {
+ runOnUpnpContextQueue([this, forceRelease] { stopUpnp(forceRelease); });
+ return;
+ }
+
+ JAMI_DBG("Stopping UPNP context");
+
+ // Clear all current mappings if any.
+
+ // Use a temporary list to avoid processing the mapping
+ // list while holding the lock.
+ std::list<Mapping::sharedPtr_t> toRemoveList;
+ {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+
+ PortType types[2] {PortType::TCP, PortType::UDP};
+ for (auto& type : types) {
+ auto& mappingList = getMappingList(type);
+ for (auto const& [_, map] : mappingList) {
+ toRemoveList.emplace_back(map);
+ }
+ }
+ // Invalidate the current IGDs.
+ preferredIgd_.reset();
+ validIgdList_.clear();
+ }
+ for (auto const& map : toRemoveList) {
+ requestRemoveMapping(map);
+
+ // Notify is not needed in updateMappingState when
+ // shutting down (hence set it to false). NotifyCallback
+ // would trigger a new SIP registration and create a
+ // false registered state upon program close.
+ // It's handled by upper layers.
+
+ updateMappingState(map, MappingState::FAILED, false);
+ // We dont remove mappings with auto-update enabled,
+ // unless forceRelease is true.
+ if (not map->getAutoUpdate() or forceRelease) {
+ map->enableAutoUpdate(false);
+ unregisterMapping(map);
+ }
+ }
+
+ // Clear all current IGDs.
+ for (auto const& [_, protocol] : protocolList_) {
+ protocol->clearIgds();
+ }
+
+ started_ = false;
+}
+
+uint16_t
+UPnPContext::generateRandomPort(PortType type, bool mustBeEven)
+{
+ auto minPort = type == PortType::TCP ? UPNP_TCP_PORT_MIN : UPNP_UDP_PORT_MIN;
+ auto maxPort = type == PortType::TCP ? UPNP_TCP_PORT_MAX : UPNP_UDP_PORT_MAX;
+
+ if (minPort >= maxPort) {
+ JAMI_ERR("Max port number (%i) must be greater than min port number (%i)", maxPort, minPort);
+ // Must be called with valid range.
+ assert(false);
+ }
+
+ int fact = mustBeEven ? 2 : 1;
+ if (mustBeEven) {
+ minPort /= fact;
+ maxPort /= fact;
+ }
+
+ // Seed the generator.
+ static std::mt19937 gen(dht::crypto::getSeededRandomEngine());
+ // Define the range.
+ std::uniform_int_distribution<uint16_t> dist(minPort, maxPort);
+ return dist(gen) * fact;
+}
+
+void
+UPnPContext::connectivityChanged()
+{
+ if (not isValidThread()) {
+ runOnUpnpContextQueue([this] { connectivityChanged(); });
+ return;
+ }
+
+ auto hostAddr = ip_utils::getLocalAddr(AF_INET);
+
+ JAMI_DBG("Connectivity change check: host address %s", hostAddr.toString().c_str());
+
+ auto restartUpnp = false;
+
+ // On reception of "connectivity change" notification, the UPNP search
+ // will be restarted if either there is no valid IGD, or the IGD address
+ // changed.
+
+ if (not isReady()) {
+ restartUpnp = true;
+ } else {
+ // Check if the host address changed.
+ for (auto const& [_, protocol] : protocolList_) {
+ if (protocol->isReady() and hostAddr != protocol->getHostAddress()) {
+ JAMI_WARN("Host address changed from %s to %s",
+ protocol->getHostAddress().toString().c_str(),
+ hostAddr.toString().c_str());
+ protocol->clearIgds();
+ restartUpnp = true;
+ break;
+ }
+ }
+ }
+
+ // We have at least one valid IGD and the host address did
+ // not change, so no need to restart.
+ if (not restartUpnp) {
+ return;
+ }
+
+ // No registered controller. A new search will be performed when
+ // a controller is registered.
+ if (controllerList_.empty())
+ return;
+
+ JAMI_DBG("Connectivity changed. Clear the IGDs and restart");
+
+ stopUpnp();
+ startUpnp();
+
+ // Mapping with auto update enabled must be processed first.
+ processMappingWithAutoUpdate();
+}
+
+void
+UPnPContext::setPublicAddress(const IpAddr& addr)
+{
+ if (not addr)
+ return;
+
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ if (knownPublicAddress_ != addr) {
+ knownPublicAddress_ = std::move(addr);
+ JAMI_DBG("Setting the known public address to %s", addr.toString().c_str());
+ }
+}
+
+bool
+UPnPContext::isReady() const
+{
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ return not validIgdList_.empty();
+}
+
+IpAddr
+UPnPContext::getExternalIP() const
+{
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ // Return the first IGD Ip available.
+ if (not validIgdList_.empty()) {
+ return (*validIgdList_.begin())->getPublicIp();
+ }
+ return {};
+}
+
+Mapping::sharedPtr_t
+UPnPContext::reserveMapping(Mapping& requestedMap)
+{
+ auto desiredPort = requestedMap.getExternalPort();
+
+ if (desiredPort == 0) {
+ JAMI_DBG("Desired port is not set, will provide the first available port for [%s]",
+ requestedMap.getTypeStr());
+ } else {
+ JAMI_DBG("Try to find mapping for port %i [%s]", desiredPort, requestedMap.getTypeStr());
+ }
+
+ Mapping::sharedPtr_t mapRes;
+
+ {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ auto& mappingList = getMappingList(requestedMap.getType());
+
+ // We try to provide a mapping in "OPEN" state. If not found,
+ // we provide any available mapping. In this case, it's up to
+ // the caller to use it or not.
+ for (auto const& [_, map] : mappingList) {
+ // If the desired port is null, we pick the first available port.
+ if (map->isValid() and (desiredPort == 0 or map->getExternalPort() == desiredPort)
+ and map->isAvailable()) {
+ // Considere the first available mapping regardless of its
+ // state. A mapping with OPEN state will be used if found.
+ if (not mapRes)
+ mapRes = map;
+
+ if (map->getState() == MappingState::OPEN) {
+ // Found an "OPEN" mapping. We are done.
+ mapRes = map;
+ break;
+ }
+ }
+ }
+ }
+
+ // Create a mapping if none was available.
+ if (not mapRes) {
+ JAMI_WARN("Did not find any available mapping. Will request one now");
+ mapRes = registerMapping(requestedMap);
+ }
+
+ if (mapRes) {
+ // Make the mapping unavailable
+ mapRes->setAvailable(false);
+ // Copy attributes.
+ mapRes->setNotifyCallback(requestedMap.getNotifyCallback());
+ mapRes->enableAutoUpdate(requestedMap.getAutoUpdate());
+ // Notify the listener.
+ if (auto cb = mapRes->getNotifyCallback())
+ cb(mapRes);
+ }
+
+ updateMappingList(true);
+
+ return mapRes;
+}
+
+void
+UPnPContext::releaseMapping(const Mapping& map)
+{
+ if (not isValidThread()) {
+ runOnUpnpContextQueue([this, map] { releaseMapping(map); });
+ return;
+ }
+
+ auto mapPtr = getMappingWithKey(map.getMapKey());
+
+ if (not mapPtr) {
+ // Might happen if the mapping failed or was never granted.
+ JAMI_DBG("Mapping %s does not exist or was already removed", map.toString().c_str());
+ return;
+ }
+
+ if (mapPtr->isAvailable()) {
+ JAMI_WARN("Trying to release an unused mapping %s", mapPtr->toString().c_str());
+ return;
+ }
+
+ // Remove it.
+ requestRemoveMapping(mapPtr);
+ unregisterMapping(mapPtr);
+}
+
+void
+UPnPContext::registerController(void* controller)
+{
+ {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ if (shutdownComplete_) {
+ JAMI_WARN("UPnPContext already shut down");
+ return;
+ }
+ }
+
+ if (not isValidThread()) {
+ runOnUpnpContextQueue([this, controller] { registerController(controller); });
+ return;
+ }
+
+ auto ret = controllerList_.emplace(controller);
+ if (not ret.second) {
+ JAMI_WARN("Controller %p is already registered", controller);
+ return;
+ }
+
+ JAMI_DBG("Successfully registered controller %p", controller);
+ if (not started_)
+ startUpnp();
+}
+
+void
+UPnPContext::unregisterController(void* controller)
+{
+ if (not isValidThread()) {
+ runOnUpnpContextQueue([this, controller] { unregisterController(controller); });
+ return;
+ }
+
+ if (controllerList_.erase(controller) == 1) {
+ JAMI_DBG("Successfully unregistered controller %p", controller);
+ } else {
+ JAMI_DBG("Controller %p was already removed", controller);
+ }
+
+ if (controllerList_.empty()) {
+ stopUpnp();
+ }
+}
+
+uint16_t
+UPnPContext::getAvailablePortNumber(PortType type)
+{
+ // Only return an availalable random port. No actual
+ // reservation is made here.
+
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ auto& mappingList = getMappingList(type);
+ int tryCount = 0;
+ while (tryCount++ < MAX_REQUEST_RETRIES) {
+ uint16_t port = generateRandomPort(type);
+ Mapping map(type, port, port);
+ if (mappingList.find(map.getMapKey()) == mappingList.end())
+ return port;
+ }
+
+ // Very unlikely to get here.
+ JAMI_ERR("Could not find an available port after %i trials", MAX_REQUEST_RETRIES);
+ return 0;
+}
+
+void
+UPnPContext::requestMapping(const Mapping::sharedPtr_t& map)
+{
+ assert(map);
+
+ if (not isValidThread()) {
+ runOnUpnpContextQueue([this, map] { requestMapping(map); });
+ return;
+ }
+
+ auto const& igd = getPreferredIgd();
+ // We must have at least a valid IGD pointer if we get here.
+ // Not this method is called only if there were a valid IGD, however,
+ // because the processing is asynchronous, it's possible that the IGD
+ // was invalidated when the this code executed.
+ if (not igd) {
+ JAMI_DBG("No valid IGDs available");
+ return;
+ }
+
+ map->setIgd(igd);
+
+ JAMI_DBG("Request mapping %s using protocol [%s] IGD [%s]",
+ map->toString().c_str(),
+ igd->getProtocolName(),
+ igd->toString().c_str());
+
+ if (map->getState() != MappingState::IN_PROGRESS)
+ updateMappingState(map, MappingState::IN_PROGRESS);
+
+ auto const& protocol = protocolList_.at(igd->getProtocol());
+ protocol->requestMappingAdd(*map);
+}
+
+bool
+UPnPContext::provisionNewMappings(PortType type, int portCount)
+{
+ JAMI_DBG("Provision %i new mappings of type [%s]", portCount, Mapping::getTypeStr(type));
+
+ assert(portCount > 0);
+
+ while (portCount > 0) {
+ auto port = getAvailablePortNumber(type);
+ if (port > 0) {
+ // Found an available port number
+ portCount--;
+ Mapping map(type, port, port, true);
+ registerMapping(map);
+ } else {
+ // Very unlikely to get here!
+ JAMI_ERR("Can not find any available port to provision!");
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool
+UPnPContext::deleteUnneededMappings(PortType type, int portCount)
+{
+ JAMI_DBG("Remove %i unneeded mapping of type [%s]", portCount, Mapping::getTypeStr(type));
+
+ assert(portCount > 0);
+
+ CHECK_VALID_THREAD();
+
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ auto& mappingList = getMappingList(type);
+
+ for (auto it = mappingList.begin(); it != mappingList.end();) {
+ auto map = it->second;
+ assert(map);
+
+ if (not map->isAvailable()) {
+ it++;
+ continue;
+ }
+
+ if (map->getState() == MappingState::OPEN and portCount > 0) {
+ // Close portCount mappings in "OPEN" state.
+ requestRemoveMapping(map);
+ it = unregisterMapping(it);
+ portCount--;
+ } else if (map->getState() != MappingState::OPEN) {
+ // If this methods is called, it means there are more open
+ // mappings than required. So, all mappings in a state other
+ // than "OPEN" state (typically in in-progress state) will
+ // be deleted as well.
+ it = unregisterMapping(it);
+ } else {
+ it++;
+ }
+ }
+
+ return true;
+}
+
+void
+UPnPContext::updatePreferredIgd()
+{
+ CHECK_VALID_THREAD();
+
+ if (preferredIgd_ and preferredIgd_->isValid())
+ return;
+
+ // Reset and search for the best IGD.
+ preferredIgd_.reset();
+
+ for (auto const& [_, protocol] : protocolList_) {
+ if (protocol->isReady()) {
+ auto igdList = protocol->getIgdList();
+ assert(not igdList.empty());
+ auto const& igd = igdList.front();
+ if (not igd->isValid())
+ continue;
+
+ // Prefer NAT-PMP over PUPNP.
+ if (preferredIgd_ and igd->getProtocol() != NatProtocolType::NAT_PMP)
+ continue;
+
+ // Update.
+ preferredIgd_ = igd;
+ }
+ }
+
+ if (preferredIgd_ and preferredIgd_->isValid()) {
+ JAMI_DBG("Preferred IGD updated to [%s] IGD [%s %s] ",
+ preferredIgd_->getProtocolName(),
+ preferredIgd_->getUID().c_str(),
+ preferredIgd_->toString().c_str());
+ }
+}
+
+std::shared_ptr<IGD>
+UPnPContext::getPreferredIgd() const
+{
+ CHECK_VALID_THREAD();
+
+ return preferredIgd_;
+}
+
+void
+UPnPContext::updateMappingList(bool async)
+{
+ // Run async if requested.
+ if (async) {
+ runOnUpnpContextQueue([this] { updateMappingList(false); });
+ return;
+ }
+
+ CHECK_VALID_THREAD();
+
+ // Update the preferred IGD.
+ updatePreferredIgd();
+
+ if (mappingListUpdateTimer_) {
+ mappingListUpdateTimer_->cancel();
+ mappingListUpdateTimer_ = {};
+ }
+
+ // Skip if no controller registered.
+ if (controllerList_.empty())
+ return;
+
+ // Cancel the current timer (if any) and re-schedule.
+ std::shared_ptr<IGD> prefIgd = getPreferredIgd();
+ if (not prefIgd) {
+ JAMI_DBG("UPNP/NAT-PMP enabled, but no valid IGDs available");
+ // No valid IGD. Nothing to do.
+ return;
+ }
+
+ mappingListUpdateTimer_ = getScheduler()->scheduleIn([this] { updateMappingList(false); },
+ MAP_UPDATE_INTERVAL);
+
+ // Process pending requests if any.
+ processPendingRequests(prefIgd);
+
+ // Make new requests for mappings that failed and have
+ // the auto-update option enabled.
+ processMappingWithAutoUpdate();
+
+ PortType typeArray[2] = {PortType::TCP, PortType::UDP};
+
+ for (auto idx : {0, 1}) {
+ auto type = typeArray[idx];
+
+ MappingStatus status;
+ getMappingStatus(type, status);
+
+ JAMI_DBG("Mapping status [%s] - overall %i: %i open (%i ready + %i in use), %i pending, %i "
+ "in-progress, %i failed",
+ Mapping::getTypeStr(type),
+ status.sum(),
+ status.openCount_,
+ status.readyCount_,
+ status.openCount_ - status.readyCount_,
+ status.pendingCount_,
+ status.inProgressCount_,
+ status.failedCount_);
+
+ if (status.failedCount_ > 0) {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ auto const& mappingList = getMappingList(type);
+ for (auto const& [_, map] : mappingList) {
+ if (map->getState() == MappingState::FAILED) {
+ JAMI_DBG("Mapping status [%s] - Available [%s]",
+ map->toString(true).c_str(),
+ map->isAvailable() ? "YES" : "NO");
+ }
+ }
+ }
+
+ int toRequestCount = (int) minOpenPortLimit_[idx]
+ - (int) (status.readyCount_ + status.inProgressCount_
+ + status.pendingCount_);
+
+ // Provision/release mappings accordingly.
+ if (toRequestCount > 0) {
+ // Take into account the request in-progress when making
+ // requests for new mappings.
+ provisionNewMappings(type, toRequestCount);
+ } else if (status.readyCount_ > maxOpenPortLimit_[idx]) {
+ deleteUnneededMappings(type, status.readyCount_ - maxOpenPortLimit_[idx]);
+ }
+ }
+
+ // Prune the mapping list if needed
+ if (protocolList_.at(NatProtocolType::PUPNP)->isReady()) {
+#if HAVE_LIBNATPMP
+ // Dont perform if NAT-PMP is valid.
+ if (not protocolList_.at(NatProtocolType::NAT_PMP)->isReady())
+#endif
+ {
+ pruneMappingList();
+ }
+ }
+
+#if HAVE_LIBNATPMP
+ // Renew nat-pmp allocations
+ if (protocolList_.at(NatProtocolType::NAT_PMP)->isReady())
+ renewAllocations();
+#endif
+}
+
+void
+UPnPContext::pruneMappingList()
+{
+ CHECK_VALID_THREAD();
+
+ MappingStatus status;
+ getMappingStatus(status);
+
+ // Do not prune the list if there are pending/in-progress requests.
+ if (status.inProgressCount_ != 0 or status.pendingCount_ != 0) {
+ return;
+ }
+
+ auto const& igd = getPreferredIgd();
+ if (not igd or igd->getProtocol() != NatProtocolType::PUPNP) {
+ return;
+ }
+ auto protocol = protocolList_.at(NatProtocolType::PUPNP);
+
+ auto remoteMapList = protocol->getMappingsListByDescr(igd,
+ Mapping::UPNP_MAPPING_DESCRIPTION_PREFIX);
+ if (remoteMapList.empty()) {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ if (not getMappingList(PortType::TCP).empty() or getMappingList(PortType::TCP).empty()) {
+ JAMI_WARN("We have provisionned mappings but the PUPNP IGD returned an empty list!");
+ }
+ }
+
+ pruneUnMatchedMappings(igd, remoteMapList);
+ pruneUnTrackedMappings(igd, remoteMapList);
+}
+
+void
+UPnPContext::pruneUnMatchedMappings(const std::shared_ptr<IGD>& igd,
+ const std::map<Mapping::key_t, Mapping>& remoteMapList)
+{
+ // Check/synchronize local mapping list with the list
+ // returned by the IGD.
+
+ PortType types[2] {PortType::TCP, PortType::UDP};
+
+ for (auto& type : types) {
+ // Use a temporary list to avoid processing mappings while holding the lock.
+ std::list<Mapping::sharedPtr_t> toRemoveList;
+ {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ auto& mappingList = getMappingList(type);
+ for (auto const& [_, map] : mappingList) {
+ // Only check mappings allocated by UPNP protocol.
+ if (map->getProtocol() != NatProtocolType::PUPNP) {
+ continue;
+ }
+ // Set mapping as failed if not found in the list
+ // returned by the IGD.
+ if (map->getState() == MappingState::OPEN
+ and remoteMapList.find(map->getMapKey()) == remoteMapList.end()) {
+ toRemoveList.emplace_back(map);
+
+ JAMI_WARN("Mapping %s (IGD %s) marked as \"OPEN\" but not found in the "
+ "remote list. Mark as failed!",
+ map->toString().c_str(),
+ igd->toString().c_str());
+ }
+ }
+ }
+
+ for (auto const& map : toRemoveList) {
+ updateMappingState(map, MappingState::FAILED);
+ unregisterMapping(map);
+ }
+ }
+}
+
+void
+UPnPContext::pruneUnTrackedMappings(const std::shared_ptr<IGD>& igd,
+ const std::map<Mapping::key_t, Mapping>& remoteMapList)
+{
+ // Use a temporary list to avoid processing mappings while holding the lock.
+ std::list<Mapping> toRemoveList;
+ {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+
+ for (auto const& [_, map] : remoteMapList) {
+ // Must has valid IGD pointer and use UPNP protocol.
+ assert(map.getIgd());
+ assert(map.getIgd()->getProtocol() == NatProtocolType::PUPNP);
+ auto& mappingList = getMappingList(map.getType());
+ auto it = mappingList.find(map.getMapKey());
+ if (it == mappingList.end()) {
+ // Not present, request mapping remove.
+ toRemoveList.emplace_back(std::move(map));
+ // Make only few remove requests at once.
+ if (toRemoveList.size() >= MAX_REQUEST_REMOVE_COUNT)
+ break;
+ }
+ }
+ }
+
+ // Remove un-tracked mappings.
+ auto protocol = protocolList_.at(NatProtocolType::PUPNP);
+ for (auto const& map : toRemoveList) {
+ protocol->requestMappingRemove(map);
+ }
+}
+
+void
+UPnPContext::pruneMappingsWithInvalidIgds(const std::shared_ptr<IGD>& igd)
+{
+ CHECK_VALID_THREAD();
+
+ // Use temporary list to avoid holding the lock while
+ // processing the mapping list.
+ std::list<Mapping::sharedPtr_t> toRemoveList;
+ {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+
+ PortType types[2] {PortType::TCP, PortType::UDP};
+ for (auto& type : types) {
+ auto& mappingList = getMappingList(type);
+ for (auto const& [_, map] : mappingList) {
+ if (map->getIgd() == igd)
+ toRemoveList.emplace_back(map);
+ }
+ }
+ }
+
+ for (auto const& map : toRemoveList) {
+ JAMI_DBG("Remove mapping %s (has an invalid IGD %s [%s])",
+ map->toString().c_str(),
+ igd->toString().c_str(),
+ igd->getProtocolName());
+ updateMappingState(map, MappingState::FAILED);
+ unregisterMapping(map);
+ }
+}
+
+void
+UPnPContext::processPendingRequests(const std::shared_ptr<IGD>& igd)
+{
+ // This list holds the mappings to be requested. This is
+ // needed to avoid performing the requests while holding
+ // the lock.
+ std::list<Mapping::sharedPtr_t> requestsList;
+
+ // Populate the list of requests to perform.
+ {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ PortType typeArray[2] {PortType::TCP, PortType::UDP};
+
+ for (auto type : typeArray) {
+ auto& mappingList = getMappingList(type);
+ for (auto& [_, map] : mappingList) {
+ if (map->getState() == MappingState::PENDING) {
+ JAMI_DBG("Send pending request for mapping %s to IGD %s",
+ map->toString().c_str(),
+ igd->toString().c_str());
+ requestsList.emplace_back(map);
+ }
+ }
+ }
+ }
+
+ // Process the pending requests.
+ for (auto const& map : requestsList) {
+ requestMapping(map);
+ }
+}
+
+void
+UPnPContext::processMappingWithAutoUpdate()
+{
+ // This list holds the mappings to be requested. This is
+ // needed to avoid performing the requests while holding
+ // the lock.
+ std::list<Mapping::sharedPtr_t> requestsList;
+
+ // Populate the list of requests for mappings with auto-update enabled.
+ {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ PortType typeArray[2] {PortType::TCP, PortType::UDP};
+
+ for (auto type : typeArray) {
+ auto& mappingList = getMappingList(type);
+ for (auto const& [_, map] : mappingList) {
+ if (map->getState() == MappingState::FAILED and map->getAutoUpdate()) {
+ requestsList.emplace_back(map);
+ }
+ }
+ }
+ }
+
+ for (auto const& oldMap : requestsList) {
+ // Request a new mapping if auto-update is enabled.
+ JAMI_DBG("Mapping %s has auto-update enabled, a new mapping will be requested",
+ oldMap->toString().c_str());
+
+ // Reserve a new mapping.
+ Mapping newMapping(oldMap->getType());
+ newMapping.enableAutoUpdate(true);
+ newMapping.setNotifyCallback(oldMap->getNotifyCallback());
+
+ auto const& mapPtr = reserveMapping(newMapping);
+ assert(mapPtr);
+
+ // Release the old one.
+ oldMap->setAvailable(true);
+ oldMap->enableAutoUpdate(false);
+ oldMap->setNotifyCallback(nullptr);
+ unregisterMapping(oldMap);
+ }
+}
+
+void
+UPnPContext::onIgdUpdated(const std::shared_ptr<IGD>& igd, UpnpIgdEvent event)
+{
+ assert(igd);
+
+ if (not isValidThread()) {
+ runOnUpnpContextQueue([this, igd, event] { onIgdUpdated(igd, event); });
+ return;
+ }
+
+ // Reset to start search for a new best IGD.
+ preferredIgd_.reset();
+
+ char const* IgdState = event == UpnpIgdEvent::ADDED ? "ADDED"
+ : event == UpnpIgdEvent::REMOVED ? "REMOVED"
+ : "INVALID";
+
+ auto const& igdLocalAddr = igd->getLocalIp();
+ auto protocolName = igd->getProtocolName();
+
+ JAMI_DBG("New event for IGD [%s %s] [%s]: [%s]",
+ igd->getUID().c_str(),
+ igd->toString().c_str(),
+ protocolName,
+ IgdState);
+
+ // Check if the IGD has valid addresses.
+ if (not igdLocalAddr) {
+ JAMI_WARN("[%s] IGD has an invalid local address", protocolName);
+ return;
+ }
+
+ if (not igd->getPublicIp()) {
+ JAMI_WARN("[%s] IGD has an invalid public address", protocolName);
+ return;
+ }
+
+ if (knownPublicAddress_ and igd->getPublicIp() != knownPublicAddress_) {
+ JAMI_WARN("[%s] IGD external address [%s] does not match known public address [%s]."
+ " The mapped addresses might not be reachable",
+ protocolName,
+ igd->getPublicIp().toString().c_str(),
+ knownPublicAddress_.toString().c_str());
+ }
+
+ // The IGD was removed or is invalid.
+ if (event == UpnpIgdEvent::REMOVED or event == UpnpIgdEvent::INVALID_STATE) {
+ JAMI_WARN("State of IGD [%s %s] [%s] changed to [%s]. Pruning the mapping list",
+ igd->getUID().c_str(),
+ igd->toString().c_str(),
+ protocolName,
+ IgdState);
+
+ pruneMappingsWithInvalidIgds(igd);
+
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ validIgdList_.erase(igd);
+ return;
+ }
+
+ // Update the IGD list.
+ {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ auto ret = validIgdList_.emplace(igd);
+ if (ret.second) {
+ JAMI_DBG("IGD [%s] on address %s was added. Will process any pending requests",
+ protocolName,
+ igdLocalAddr.toString(true, true).c_str());
+ } else {
+ // Already in the list.
+ JAMI_ERR("IGD [%s] on address %s already in the list",
+ protocolName,
+ igdLocalAddr.toString(true, true).c_str());
+ return;
+ }
+ }
+
+ // Update the provisionned mappings.
+ updateMappingList(false);
+}
+
+void
+UPnPContext::onMappingAdded(const std::shared_ptr<IGD>& igd, const Mapping& mapRes)
+{
+ CHECK_VALID_THREAD();
+
+ // Check if we have a pending request for this response.
+ auto map = getMappingWithKey(mapRes.getMapKey());
+ if (not map) {
+ // We may receive a response for a canceled request. Just ignore it.
+ JAMI_DBG("Response for mapping %s [IGD %s] [%s] does not have a local match",
+ mapRes.toString().c_str(),
+ igd->toString().c_str(),
+ mapRes.getProtocolName());
+ return;
+ }
+
+ // The mapping request is new and successful. Update.
+ map->setIgd(igd);
+ map->setInternalAddress(mapRes.getInternalAddress());
+ map->setExternalPort(mapRes.getExternalPort());
+
+ // Update the state and report to the owner.
+ updateMappingState(map, MappingState::OPEN);
+
+ JAMI_DBG("Mapping %s (on IGD %s [%s]) successfully performed",
+ map->toString().c_str(),
+ igd->toString().c_str(),
+ map->getProtocolName());
+
+ // Call setValid() to reset the errors counter. We need
+ // to reset the counter on each successful response.
+ igd->setValid(true);
+}
+
+#if HAVE_LIBNATPMP
+void
+UPnPContext::onMappingRenewed(const std::shared_ptr<IGD>& igd, const Mapping& map)
+{
+ auto mapPtr = getMappingWithKey(map.getMapKey());
+
+ if (not mapPtr) {
+ // We may receive a notification for a canceled request. Ignore it.
+ JAMI_WARN("Renewed mapping %s from IGD %s [%s] does not have a match in local list",
+ map.toString().c_str(),
+ igd->toString().c_str(),
+ map.getProtocolName());
+ return;
+ }
+ if (mapPtr->getProtocol() != NatProtocolType::NAT_PMP or not mapPtr->isValid()
+ or mapPtr->getState() != MappingState::OPEN) {
+ JAMI_WARN("Renewed mapping %s from IGD %s [%s] is in unexpected state",
+ mapPtr->toString().c_str(),
+ igd->toString().c_str(),
+ mapPtr->getProtocolName());
+ return;
+ }
+
+ mapPtr->setRenewalTime(map.getRenewalTime());
+}
+#endif
+
+void
+UPnPContext::requestRemoveMapping(const Mapping::sharedPtr_t& map)
+{
+ CHECK_VALID_THREAD();
+
+ if (not map) {
+ JAMI_ERR("Mapping shared pointer is null!");
+ return;
+ }
+
+ if (not map->isValid()) {
+ // Silently ignore if the mapping is invalid
+ return;
+ }
+
+ auto protocol = protocolList_.at(map->getIgd()->getProtocol());
+ protocol->requestMappingRemove(*map);
+}
+
+void
+UPnPContext::deleteAllMappings(PortType type)
+{
+ if (not isValidThread()) {
+ runOnUpnpContextQueue([this, type] { deleteAllMappings(type); });
+ return;
+ }
+
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ auto& mappingList = getMappingList(type);
+
+ for (auto const& [_, map] : mappingList) {
+ requestRemoveMapping(map);
+ }
+}
+
+void
+UPnPContext::onMappingRemoved(const std::shared_ptr<IGD>& igd, const Mapping& mapRes)
+{
+ if (not mapRes.isValid())
+ return;
+
+ if (not isValidThread()) {
+ runOnUpnpContextQueue([this, igd, mapRes] { onMappingRemoved(igd, mapRes); });
+ return;
+ }
+
+ auto map = getMappingWithKey(mapRes.getMapKey());
+ // Notify the listener.
+ if (map and map->getNotifyCallback())
+ map->getNotifyCallback()(map);
+}
+
+Mapping::sharedPtr_t
+UPnPContext::registerMapping(Mapping& map)
+{
+ if (map.getExternalPort() == 0) {
+ JAMI_DBG("Port number not set. Will set a random port number");
+ auto port = getAvailablePortNumber(map.getType());
+ map.setExternalPort(port);
+ map.setInternalPort(port);
+ }
+
+ // Newly added mapping must be in pending state by default.
+ map.setState(MappingState::PENDING);
+
+ Mapping::sharedPtr_t mapPtr;
+
+ {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ auto& mappingList = getMappingList(map.getType());
+
+ auto ret = mappingList.emplace(map.getMapKey(), std::make_shared<Mapping>(map));
+ if (not ret.second) {
+ JAMI_WARN("Mapping request for %s already added!", map.toString().c_str());
+ return {};
+ }
+ mapPtr = ret.first->second;
+ assert(mapPtr);
+ }
+
+ // No available IGD. The pending mapping requests will be processed
+ // when a IGD becomes available (in onIgdAdded() method).
+ if (not isReady()) {
+ JAMI_WARN("No IGD available. Mapping will be requested when an IGD becomes available");
+ } else {
+ requestMapping(mapPtr);
+ }
+
+ return mapPtr;
+}
+
+std::map<Mapping::key_t, Mapping::sharedPtr_t>::iterator
+UPnPContext::unregisterMapping(std::map<Mapping::key_t, Mapping::sharedPtr_t>::iterator it)
+{
+ assert(it->second);
+
+ CHECK_VALID_THREAD();
+ auto descr = it->second->toString();
+ auto& mappingList = getMappingList(it->second->getType());
+ auto ret = mappingList.erase(it);
+
+ return ret;
+}
+
+void
+UPnPContext::unregisterMapping(const Mapping::sharedPtr_t& map)
+{
+ CHECK_VALID_THREAD();
+
+ if (not map) {
+ JAMI_ERR("Mapping pointer is null");
+ return;
+ }
+
+ if (map->getAutoUpdate()) {
+ // Dont unregister mappings with auto-update enabled.
+ return;
+ }
+ auto& mappingList = getMappingList(map->getType());
+
+ if (mappingList.erase(map->getMapKey()) == 1) {
+ JAMI_DBG("Unregistered mapping %s", map->toString().c_str());
+ } else {
+ // The mapping may already be un-registered. Just ignore it.
+ JAMI_DBG("Mapping %s [%s] does not have a local match",
+ map->toString().c_str(),
+ map->getProtocolName());
+ }
+}
+
+std::map<Mapping::key_t, Mapping::sharedPtr_t>&
+UPnPContext::getMappingList(PortType type)
+{
+ unsigned typeIdx = type == PortType::TCP ? 0 : 1;
+ return mappingList_[typeIdx];
+}
+
+Mapping::sharedPtr_t
+UPnPContext::getMappingWithKey(Mapping::key_t key)
+{
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ auto const& mappingList = getMappingList(Mapping::getTypeFromMapKey(key));
+ auto it = mappingList.find(key);
+ if (it == mappingList.end())
+ return nullptr;
+ return it->second;
+}
+
+void
+UPnPContext::getMappingStatus(PortType type, MappingStatus& status)
+{
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ auto& mappingList = getMappingList(type);
+
+ for (auto const& [_, map] : mappingList) {
+ switch (map->getState()) {
+ case MappingState::PENDING: {
+ status.pendingCount_++;
+ break;
+ }
+ case MappingState::IN_PROGRESS: {
+ status.inProgressCount_++;
+ break;
+ }
+ case MappingState::FAILED: {
+ status.failedCount_++;
+ break;
+ }
+ case MappingState::OPEN: {
+ status.openCount_++;
+ if (map->isAvailable())
+ status.readyCount_++;
+ break;
+ }
+
+ default:
+ // Must not get here.
+ assert(false);
+ break;
+ }
+ }
+}
+
+void
+UPnPContext::getMappingStatus(MappingStatus& status)
+{
+ getMappingStatus(PortType::TCP, status);
+ getMappingStatus(PortType::UDP, status);
+}
+
+void
+UPnPContext::onMappingRequestFailed(const Mapping& mapRes)
+{
+ CHECK_VALID_THREAD();
+
+ auto const& map = getMappingWithKey(mapRes.getMapKey());
+ if (not map) {
+ // We may receive a response for a removed request. Just ignore it.
+ JAMI_DBG("Mapping %s [IGD %s] does not have a local match",
+ mapRes.toString().c_str(),
+ mapRes.getProtocolName());
+ return;
+ }
+
+ auto igd = map->getIgd();
+ if (not igd) {
+ JAMI_ERR("IGD pointer is null");
+ return;
+ }
+
+ updateMappingState(map, MappingState::FAILED);
+ unregisterMapping(map);
+
+ JAMI_WARN("Mapping request for %s failed on IGD %s [%s]",
+ map->toString().c_str(),
+ igd->toString().c_str(),
+ igd->getProtocolName());
+}
+
+void
+UPnPContext::updateMappingState(const Mapping::sharedPtr_t& map, MappingState newState, bool notify)
+{
+ CHECK_VALID_THREAD();
+
+ assert(map);
+
+ // Ignore if the state did not change.
+ if (newState == map->getState()) {
+ JAMI_DBG("Mapping %s already in state %s", map->toString().c_str(), map->getStateStr());
+ return;
+ }
+
+ // Update the state.
+ map->setState(newState);
+
+ // Notify the listener if set.
+ if (notify and map->getNotifyCallback())
+ map->getNotifyCallback()(map);
+}
+
+#if HAVE_LIBNATPMP
+void
+UPnPContext::renewAllocations()
+{
+ CHECK_VALID_THREAD();
+
+ // Check if the we have valid PMP IGD.
+ auto pmpProto = protocolList_.at(NatProtocolType::NAT_PMP);
+
+ auto now = sys_clock::now();
+ std::vector<Mapping::sharedPtr_t> toRenew;
+
+ for (auto type : {PortType::TCP, PortType::UDP}) {
+ std::lock_guard<std::mutex> lock(mappingMutex_);
+ auto mappingList = getMappingList(type);
+ for (auto const& [_, map] : mappingList) {
+ if (not map->isValid())
+ continue;
+ if (map->getProtocol() != NatProtocolType::NAT_PMP)
+ continue;
+ if (map->getState() != MappingState::OPEN)
+ continue;
+ if (now < map->getRenewalTime())
+ continue;
+
+ toRenew.emplace_back(map);
+ }
+ }
+
+ // Quit if there are no mapping to renew
+ if (toRenew.empty())
+ return;
+
+ for (auto const& map : toRenew) {
+ pmpProto->requestMappingRenew(*map);
+ }
+}
+#endif
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/upnp_context.h b/src/upnp/upnp_context.h
new file mode 100644
index 0000000..30d50c0
--- /dev/null
+++ b/src/upnp/upnp_context.h
@@ -0,0 +1,294 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com>
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+
+#include "protocol/upnp_protocol.h"
+#if HAVE_LIBNATPMP
+#include "protocol/natpmp/nat_pmp.h"
+#endif
+#if HAVE_LIBUPNP
+#include "protocol/pupnp/pupnp.h"
+#endif
+#include "protocol/igd.h"
+
+#include "ip_utils.h"
+
+#include <opendht/rng.h>
+#include <asio/steady_timer.hpp>
+
+#include <set>
+#include <map>
+#include <mutex>
+#include <memory>
+#include <string>
+#include <chrono>
+#include <random>
+#include <atomic>
+#include <cstdlib>
+
+//#include "upnp_thread_util.h"
+
+using random_device = dht::crypto::random_device;
+
+using IgdFoundCallback = std::function<void()>;
+
+namespace jami {
+class IpAddr;
+}
+
+namespace jami {
+namespace upnp {
+
+class UPnPContext : public UpnpMappingObserver//, protected UpnpThreadUtil
+{
+private:
+ struct MappingStatus
+ {
+ int openCount_ {0};
+ int readyCount_ {0};
+ int pendingCount_ {0};
+ int inProgressCount_ {0};
+ int failedCount_ {0};
+
+ void reset()
+ {
+ openCount_ = 0;
+ readyCount_ = 0;
+ pendingCount_ = 0;
+ inProgressCount_ = 0;
+ failedCount_ = 0;
+ };
+ int sum() { return openCount_ + pendingCount_ + inProgressCount_ + failedCount_; }
+ };
+
+public:
+ UPnPContext();
+ ~UPnPContext();
+
+ // Retrieve the UPnPContext singleton.
+ static std::shared_ptr<UPnPContext> getUPnPContext();
+
+ // Terminate the instance.
+ void shutdown();
+
+ // Set the known public address
+ void setPublicAddress(const IpAddr& addr);
+
+ // Check if there is a valid IGD in the IGD list.
+ bool isReady() const;
+
+ // Get external Ip of a chosen IGD.
+ IpAddr getExternalIP() const;
+
+ // Inform the UPnP context that the network status has changed. This clears the list of known
+ void connectivityChanged();
+
+ // Returns a shared pointer of the mapping.
+ Mapping::sharedPtr_t reserveMapping(Mapping& requestedMap);
+
+ // Release an used mapping (make it available for future use).
+ void releaseMapping(const Mapping& map);
+
+ // Register a controller
+ void registerController(void* controller);
+ // Unregister a controller
+ void unregisterController(void* controller);
+
+ // Generate random port numbers
+ static uint16_t generateRandomPort(PortType type, bool mustBeEven = false);
+
+private:
+ // Initialization
+ void init();
+
+ /**
+ * @brief start the search for IGDs activate the mapping
+ * list update.
+ *
+ */
+ void startUpnp();
+
+ /**
+ * @brief Clear all IGDs and release/delete current mappings
+ *
+ * @param forceRelease If true, also delete mappings with enabled
+ * auto-update feature.
+ *
+ */
+ void stopUpnp(bool forceRelease = false);
+
+ void shutdown(std::condition_variable& cv);
+
+ // Create and register a new mapping.
+ Mapping::sharedPtr_t registerMapping(Mapping& map);
+
+ // Removes the mapping from the list.
+ std::map<Mapping::key_t, Mapping::sharedPtr_t>::iterator unregisterMapping(
+ std::map<Mapping::key_t, Mapping::sharedPtr_t>::iterator it);
+ void unregisterMapping(const Mapping::sharedPtr_t& map);
+
+ // Perform the request on the provided IGD.
+ void requestMapping(const Mapping::sharedPtr_t& map);
+
+ // Request a mapping remove from the IGD.
+ void requestRemoveMapping(const Mapping::sharedPtr_t& map);
+
+ // Remove all mappings of the given type.
+ void deleteAllMappings(PortType type);
+
+ // Update the state and notify the listener
+ void updateMappingState(const Mapping::sharedPtr_t& map,
+ MappingState newState,
+ bool notify = true);
+
+ // Provision ports.
+ uint16_t getAvailablePortNumber(PortType type);
+
+ // Update preferred IGD
+ void updatePreferredIgd();
+
+ // Get preferred IGD
+ std::shared_ptr<IGD> getPreferredIgd() const;
+
+ // Check and prune the mapping list. Called periodically.
+ void updateMappingList(bool async);
+
+ // Provision (pre-allocate) the requested number of mappings.
+ bool provisionNewMappings(PortType type, int portCount);
+
+ // Close unused mappings.
+ bool deleteUnneededMappings(PortType type, int portCount);
+
+ /**
+ * Prune the mapping list.To avoid competing with allocation
+ * requests, the pruning is performed only if there are no
+ * requests in progress.
+ */
+ void pruneMappingList();
+
+ /**
+ * Check if there are allocated mappings from previous instances,
+ * and try to close them.
+ * Only done for UPNP protocol. NAT-PMP allocations will expire
+ * anyway if not renewed.
+ */
+ void pruneUnMatchedMappings(const std::shared_ptr<IGD>& igd,
+ const std::map<Mapping::key_t, Mapping>& remoteMapList);
+
+ /**
+ * Check the local mapping list against the list returned by the
+ * IGD and remove all mappings which do not have a match.
+ * Only done for UPNP protocol.
+ */
+ void pruneUnTrackedMappings(const std::shared_ptr<IGD>& igd,
+ const std::map<Mapping::key_t, Mapping>& remoteMapList);
+
+ void pruneMappingsWithInvalidIgds(const std::shared_ptr<IGD>& igd);
+
+ /**
+ * @brief Get the mapping list
+ *
+ * @param type transport type (TCP/UDP)
+ * @return a reference on the map
+ * @warning concurrency protection done by the caller
+ */
+ std::map<Mapping::key_t, Mapping::sharedPtr_t>& getMappingList(PortType type);
+
+ // Get the mapping from the key.
+ Mapping::sharedPtr_t getMappingWithKey(Mapping::key_t key);
+
+ // Get the number of mappings per state.
+ void getMappingStatus(PortType type, MappingStatus& status);
+ void getMappingStatus(MappingStatus& status);
+
+#if HAVE_LIBNATPMP
+ void renewAllocations();
+#endif
+
+ // Process requests with pending status.
+ void processPendingRequests(const std::shared_ptr<IGD>& igd);
+
+ // Process mapping with auto-update flag enabled.
+ void processMappingWithAutoUpdate();
+
+ // Implementation of UpnpMappingObserver interface.
+
+ // Callback used to report changes in IGD status.
+ void onIgdUpdated(const std::shared_ptr<IGD>& igd, UpnpIgdEvent event) override;
+ // Callback used to report add request status.
+ void onMappingAdded(const std::shared_ptr<IGD>& igd, const Mapping& map) override;
+ // Callback invoked when a request fails. Reported on failures for both
+ // new requests and renewal requests (if supported by the the protocol).
+ void onMappingRequestFailed(const Mapping& map) override;
+#if HAVE_LIBNATPMP
+ // Callback used to report renew request status.
+ void onMappingRenewed(const std::shared_ptr<IGD>& igd, const Mapping& map) override;
+#endif
+ // Callback used to report remove request status.
+ void onMappingRemoved(const std::shared_ptr<IGD>& igd, const Mapping& map) override;
+
+private:
+ UPnPContext(const UPnPContext&) = delete;
+ UPnPContext(UPnPContext&&) = delete;
+ UPnPContext& operator=(UPnPContext&&) = delete;
+ UPnPContext& operator=(const UPnPContext&) = delete;
+
+ bool started_ {false};
+
+ // The known public address. The external addresses returned by
+ // the IGDs will be checked against this address.
+ IpAddr knownPublicAddress_ {};
+
+ // Set of registered controllers
+ std::set<void*> controllerList_;
+
+ // Map of available protocols.
+ std::map<NatProtocolType, std::shared_ptr<UPnPProtocol>> protocolList_;
+
+ // Port ranges for TCP and UDP (in that order).
+ std::map<PortType, std::pair<uint16_t, uint16_t>> portRange_ {};
+
+ // Min open ports limit
+ int minOpenPortLimit_[2] {4, 8};
+ // Max open ports limit
+ int maxOpenPortLimit_[2] {8, 12};
+
+ //std::shared_ptr<Task> mappingListUpdateTimer_ {};
+ asio::steady_timer mappingListUpdateTimer_;// {};
+
+ // Current preferred IGD. Can be null if there is no valid IGD.
+ std::shared_ptr<IGD> preferredIgd_;
+
+ // This mutex must lock only these two members. All other
+ // members must be accessed only from the UPNP context thread.
+ std::mutex mutable mappingMutex_;
+ // List of mappings.
+ std::map<Mapping::key_t, Mapping::sharedPtr_t> mappingList_[2] {};
+ std::set<std::shared_ptr<IGD>> validIgdList_ {};
+
+ // Shutdown synchronization
+ bool shutdownComplete_ {false};
+};
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/upnp_control.cpp b/src/upnp/upnp_control.cpp
new file mode 100644
index 0000000..b255617
--- /dev/null
+++ b/src/upnp/upnp_control.cpp
@@ -0,0 +1,150 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com>
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "upnp_control.h"
+
+namespace jami {
+namespace upnp {
+
+Controller::Controller()
+{
+ try {
+ upnpContext_ = UPnPContext::getUPnPContext();
+ } catch (std::runtime_error& e) {
+ JAMI_ERR("UPnP context error: %s", e.what());
+ }
+
+ assert(upnpContext_);
+ upnpContext_->registerController(this);
+
+ JAMI_DBG("Controller@%p: Created UPnP Controller session", this);
+}
+
+Controller::~Controller()
+{
+ JAMI_DBG("Controller@%p: Destroying UPnP Controller session", this);
+
+ releaseAllMappings();
+ upnpContext_->unregisterController(this);
+}
+
+void
+Controller::setPublicAddress(const IpAddr& addr)
+{
+ assert(upnpContext_);
+
+ if (addr and addr.getFamily() == AF_INET) {
+ upnpContext_->setPublicAddress(addr);
+ }
+}
+
+bool
+Controller::isReady() const
+{
+ assert(upnpContext_);
+ return upnpContext_->isReady();
+}
+
+IpAddr
+Controller::getExternalIP() const
+{
+ assert(upnpContext_);
+ if (upnpContext_->isReady()) {
+ return upnpContext_->getExternalIP();
+ }
+ return {};
+}
+
+Mapping::sharedPtr_t
+Controller::reserveMapping(uint16_t port, PortType type)
+{
+ Mapping map(type, port, port);
+ return reserveMapping(map);
+}
+
+Mapping::sharedPtr_t
+Controller::reserveMapping(Mapping& requestedMap)
+{
+ assert(upnpContext_);
+
+ // Try to get a provisioned port
+ auto mapRes = upnpContext_->reserveMapping(requestedMap);
+ if (mapRes)
+ addLocalMap(*mapRes);
+ return mapRes;
+}
+
+void
+Controller::releaseMapping(const Mapping& map)
+{
+ assert(upnpContext_);
+
+ removeLocalMap(map);
+ return upnpContext_->releaseMapping(map);
+}
+
+void
+Controller::releaseAllMappings()
+{
+ assert(upnpContext_);
+
+ std::lock_guard<std::mutex> lk(mapListMutex_);
+ for (auto const& [_, map] : mappingList_) {
+ upnpContext_->releaseMapping(map);
+ }
+ mappingList_.clear();
+}
+
+void
+Controller::addLocalMap(const Mapping& map)
+{
+ if (map.getMapKey()) {
+ std::lock_guard<std::mutex> lock(mapListMutex_);
+ auto ret = mappingList_.emplace(map.getMapKey(), map);
+ if (not ret.second) {
+ JAMI_WARN("Mapping request for %s already in the list!", map.toString().c_str());
+ }
+ }
+}
+
+bool
+Controller::removeLocalMap(const Mapping& map)
+{
+ assert(upnpContext_);
+
+ std::lock_guard<std::mutex> lk(mapListMutex_);
+ if (mappingList_.erase(map.getMapKey()) != 1) {
+ JAMI_ERR("Failed to remove mapping %s from local list", map.getTypeStr());
+ return false;
+ }
+
+ return true;
+}
+
+uint16_t
+Controller::generateRandomPort(PortType type)
+{
+ return UPnPContext::generateRandomPort(type);
+}
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/upnp_control.h b/src/upnp/upnp_control.h
new file mode 100644
index 0000000..183b4fb
--- /dev/null
+++ b/src/upnp/upnp_control.h
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2004-2023 Savoir-faire Linux Inc.
+ *
+ * Author: Stepan Salenikovich <stepan.salenikovich@savoirfairelinux.com>
+ * Author: Eden Abitbol <eden.abitbol@savoirfairelinux.com>
+ * Author: Mohamed Chibani <mohamed.chibani@savoirfairelinux.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+
+#include "upnp_context.h"
+#include "ip_utils.h"
+
+#include <memory>
+#include <chrono>
+
+namespace jami {
+class IpAddr;
+}
+
+namespace jami {
+namespace upnp {
+
+class UPnPContext;
+
+class Controller
+{
+public:
+ Controller();
+ ~Controller();
+
+ // Set known public address
+ void setPublicAddress(const IpAddr& addr);
+ // Checks if a valid IGD is available.
+ bool isReady() const;
+ // Gets the external ip of the first valid IGD in the list.
+ IpAddr getExternalIP() const;
+
+ // Request port mapping.
+ // Returns a shared pointer on the allocated mapping. The shared
+ // pointer may point to nothing on failure.
+ Mapping::sharedPtr_t reserveMapping(Mapping& map);
+ Mapping::sharedPtr_t reserveMapping(uint16_t port, PortType type);
+
+ // Remove port mapping.
+ void releaseMapping(const Mapping& map);
+ static uint16_t generateRandomPort(PortType);
+
+private:
+ // Adds a mapping locally to the list.
+ void addLocalMap(const Mapping& map);
+ // Removes a mapping from the local list.
+ bool removeLocalMap(const Mapping& map);
+ // Removes all mappings of the given type.
+ void releaseAllMappings();
+
+ std::shared_ptr<UPnPContext> upnpContext_;
+
+ mutable std::mutex mapListMutex_;
+ std::map<Mapping::key_t, Mapping> mappingList_;
+};
+
+} // namespace upnp
+} // namespace jami
diff --git a/src/upnp/upnp_thread_util.h b/src/upnp/upnp_thread_util.h
new file mode 100644
index 0000000..10d454a
--- /dev/null
+++ b/src/upnp/upnp_thread_util.h
@@ -0,0 +1,35 @@
+#pragma once
+
+#include <thread>
+
+// This macro is used to validate that a code is executed from the expected
+// thread. It's useful to detect unexpected race on data members.
+#define CHECK_VALID_THREAD() \
+ if (not isValidThread()) \
+ JAMI_ERR() << "The calling thread " << getCurrentThread() \
+ << " is not the expected thread: " << threadId_;
+
+namespace jami {
+namespace upnp {
+
+class UpnpThreadUtil
+{
+protected:
+ std::thread::id getCurrentThread() const { return std::this_thread::get_id(); }
+
+ bool isValidThread() const { return threadId_ == getCurrentThread(); }
+
+ // Upnp context execution queue (same as manager's scheduler)
+ // Helpers to run tasks on upnp context queue.
+ static ScheduledExecutor* getScheduler() { return &Manager::instance().scheduler(); }
+ template<typename Callback>
+ static void runOnUpnpContextQueue(Callback&& cb)
+ {
+ getScheduler()->run([cb = std::forward<Callback>(cb)]() mutable { cb(); });
+ }
+
+ std::thread::id threadId_;
+};
+
+} // namespace upnp
+} // namespace jami