blob: faf680e695b0df4545b7064547add0715405881c [file] [log] [blame]
/*
* Copyright (C) 2019-2023 Savoir-faire Linux Inc.
* Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com>
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
#include "multiplexed_socket.h"
#include "peer_connection.h"
#include "ice_transport.h"
#include "certstore.h"
#include <opendht/logger.h>
#include <opendht/thread_pool.h>
#include <asio/io_context.hpp>
#include <asio/steady_timer.hpp>
#include <deque>
static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations
static constexpr int MULTIPLEXED_SOCKET_VERSION {1};
struct ChanneledMessage
{
uint16_t channel;
std::vector<uint8_t> data;
MSGPACK_DEFINE(channel, data)
};
struct BeaconMsg
{
bool p;
MSGPACK_DEFINE_MAP(p)
};
struct VersionMsg
{
int v;
MSGPACK_DEFINE_MAP(v)
};
namespace dhtnet {
using clock = std::chrono::steady_clock;
using time_point = clock::time_point;
class MultiplexedSocket::Impl
{
public:
Impl(MultiplexedSocket& parent,
std::shared_ptr<asio::io_context> ctx,
const DeviceId& deviceId,
std::unique_ptr<TlsSocketEndpoint> endpoint)
: parent_(parent)
, deviceId(deviceId)
, ctx_(std::move(ctx))
, beaconTimer_(*ctx_)
, endpoint(std::move(endpoint))
, eventLoopThread_ {[this] {
try {
eventLoop();
} catch (const std::exception& e) {
if (logger_)
logger_->error("[CNX] peer connection event loop failure: {}", e.what());
shutdown();
}
}}
{}
~Impl() {}
void join()
{
if (!isShutdown_) {
if (endpoint)
endpoint->setOnStateChange({});
shutdown();
} else {
clearSockets();
}
if (eventLoopThread_.joinable())
eventLoopThread_.join();
}
void clearSockets()
{
decltype(sockets) socks;
{
std::lock_guard<std::mutex> lkSockets(socketsMutex);
socks = std::move(sockets);
}
for (auto& socket : socks) {
// Just trigger onShutdown() to make client know
// No need to write the EOF for the channel, the write will fail because endpoint is
// already shutdown
if (socket.second)
socket.second->stop();
}
}
void shutdown()
{
if (isShutdown_)
return;
stop.store(true);
isShutdown_ = true;
beaconTimer_.cancel();
if (onShutdown_)
onShutdown_();
if (endpoint) {
std::unique_lock<std::mutex> lk(writeMtx);
endpoint->shutdown();
}
clearSockets();
}
std::shared_ptr<ChannelSocket> makeSocket(const std::string& name,
uint16_t channel,
bool isInitiator = false)
{
auto& channelSocket = sockets[channel];
if (not channelSocket)
channelSocket = std::make_shared<ChannelSocket>(
parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() {
// Remove socket in another thread to avoid any lock
dht::ThreadPool::io().run([w, channel]() {
if (auto shared = w.lock()) {
shared->eraseChannel(channel);
}
});
});
else {
if (logger_)
logger_->warn("A channel is already present on that socket, accepting "
"the request will close the previous one {}", name);
}
return channelSocket;
}
/**
* Handle packets on the TLS endpoint and parse RTP
*/
void eventLoop();
/**
* Triggered when a new control packet is received
*/
void handleControlPacket(std::vector<uint8_t>&& pkt);
void handleProtocolPacket(std::vector<uint8_t>&& pkt);
bool handleProtocolMsg(const msgpack::object& o);
/**
* Triggered when a new packet on a channel is received
*/
void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt);
void onRequest(const std::string& name, uint16_t channel);
void onAccept(const std::string& name, uint16_t channel);
void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); }
void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); }
// Beacon
void sendBeacon(const std::chrono::milliseconds& timeout);
void handleBeaconRequest();
void handleBeaconResponse();
std::atomic_int beaconCounter_ {0};
bool writeProtocolMessage(const msgpack::sbuffer& buffer);
msgpack::unpacker pac_ {};
MultiplexedSocket& parent_;
std::shared_ptr<Logger> logger_;
std::shared_ptr<asio::io_context> ctx_;
OnConnectionReadyCb onChannelReady_ {};
OnConnectionRequestCb onRequest_ {};
OnShutdownCb onShutdown_ {};
DeviceId deviceId {};
// Main socket
std::unique_ptr<TlsSocketEndpoint> endpoint {};
std::mutex socketsMutex {};
std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {};
// Main loop to parse incoming packets
std::atomic_bool stop {false};
std::thread eventLoopThread_ {};
std::atomic_bool isShutdown_ {false};
std::mutex writeMtx {};
time_point start_ {clock::now()};
//std::shared_ptr<Task> beaconTask_ {};
asio::steady_timer beaconTimer_;
// version related stuff
void sendVersion();
void onVersion(int version);
std::atomic_bool canSendBeacon_ {false};
std::atomic_bool answerBeacon_ {true};
int version_ {MULTIPLEXED_SOCKET_VERSION};
std::function<void(bool)> onBeaconCb_ {};
std::function<void(int)> onVersionCb_ {};
};
void
MultiplexedSocket::Impl::eventLoop()
{
endpoint->setOnStateChange([this](tls::TlsSessionState state) {
if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
if (logger_)
logger_->debug("Tls endpoint is down, shutdown multiplexed socket");
shutdown();
return false;
}
return true;
});
sendVersion();
std::error_code ec;
while (!stop) {
if (!endpoint) {
shutdown();
return;
}
pac_.reserve_buffer(IO_BUFFER_SIZE);
int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec);
if (size < 0) {
if (ec && logger_)
logger_->error("Read error detected: {}", ec.message());
break;
}
if (size == 0) {
// We can close the socket
shutdown();
break;
}
pac_.buffer_consumed(size);
msgpack::object_handle oh;
while (pac_.next(oh) && !stop) {
try {
auto msg = oh.get().as<ChanneledMessage>();
if (msg.channel == CONTROL_CHANNEL)
handleControlPacket(std::move(msg.data));
else if (msg.channel == PROTOCOL_CHANNEL)
handleProtocolPacket(std::move(msg.data));
else
handleChannelPacket(msg.channel, std::move(msg.data));
} catch (const std::exception& e) {
if (logger_)
logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what());
} catch (...) {
if (logger_)
logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size);
}
}
}
}
void
MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
{
std::lock_guard<std::mutex> lkSockets(socketsMutex);
auto& socket = sockets[channel];
if (!socket) {
if (logger_)
logger_->error("Receiving an answer for a non existing channel. This is a bug.");
return;
}
onChannelReady_(deviceId, socket);
socket->ready(true);
// Due to the callbacks that can take some time, onAccept can arrive after
// receiving all the data. In this case, the socket should be removed here
// as handle by onChannelReady_
if (socket->isRemovable())
sockets.erase(channel);
else
socket->answered();
}
void
MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout)
{
if (!canSendBeacon_)
return;
beaconCounter_++;
if (logger_)
logger_->debug("Send beacon to peer {}", deviceId);
msgpack::sbuffer buffer(8);
msgpack::packer<msgpack::sbuffer> pk(&buffer);
pk.pack(BeaconMsg {true});
if (!writeProtocolMessage(buffer))
return;
beaconTimer_.expires_after(timeout);
beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) {
if (ec == asio::error::operation_aborted)
return;
if (auto shared = w.lock()) {
if (shared->pimpl_->beaconCounter_ != 0) {
if (shared->pimpl_->logger_)
shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket");
shared->shutdown();
}
}
});
}
void
MultiplexedSocket::Impl::handleBeaconRequest()
{
if (!answerBeacon_)
return;
// Run this on dedicated thread because some callbacks can take time
dht::ThreadPool::io().run([w = parent_.weak()]() {
if (auto shared = w.lock()) {
msgpack::sbuffer buffer(8);
msgpack::packer<msgpack::sbuffer> pk(&buffer);
pk.pack(BeaconMsg {false});
if (shared->pimpl_->logger_)
shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId());
shared->pimpl_->writeProtocolMessage(buffer);
}
});
}
void
MultiplexedSocket::Impl::handleBeaconResponse()
{
if (logger_)
logger_->debug("Get beacon response from peer {}", deviceId);
beaconCounter_--;
}
bool
MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer)
{
std::error_code ec;
int wr = parent_.write(PROTOCOL_CHANNEL,
(const unsigned char*) buffer.data(),
buffer.size(),
ec);
return wr > 0;
}
void
MultiplexedSocket::Impl::sendVersion()
{
dht::ThreadPool::io().run([w = parent_.weak()]() {
if (auto shared = w.lock()) {
auto version = shared->pimpl_->version_;
msgpack::sbuffer buffer(8);
msgpack::packer<msgpack::sbuffer> pk(&buffer);
pk.pack(VersionMsg {version});
shared->pimpl_->writeProtocolMessage(buffer);
}
});
}
void
MultiplexedSocket::Impl::onVersion(int version)
{
// Check if version > 1
if (version >= 1) {
if (logger_)
logger_->debug("Peer {} supports beacon", deviceId);
canSendBeacon_ = true;
} else {
if (logger_)
logger_->warn("Peer {} uses version {:d} which doesn't support beacon",
deviceId,
version);
canSendBeacon_ = false;
}
}
void
MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
{
auto accept = onRequest_(endpoint->peerCertificate(), channel, name);
std::shared_ptr<ChannelSocket> channelSocket;
if (accept) {
std::lock_guard<std::mutex> lkSockets(socketsMutex);
channelSocket = makeSocket(name, channel);
}
// Answer to ChannelRequest if accepted
ChannelRequest val;
val.channel = channel;
val.name = name;
val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE;
msgpack::sbuffer buffer(512);
msgpack::pack(buffer, val);
std::error_code ec;
int wr = parent_.write(CONTROL_CHANNEL,
reinterpret_cast<const uint8_t*>(buffer.data()),
buffer.size(),
ec);
if (wr < 0) {
if (ec && logger_)
logger_->error("The write operation failed with error: {:s}", ec.message());
stop.store(true);
return;
}
if (accept) {
onChannelReady_(deviceId, channelSocket);
channelSocket->ready(true);
}
}
void
MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
{
// Run this on dedicated thread because some callbacks can take time
dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
auto shared = w.lock();
if (!shared)
return;
auto& pimpl = *shared->pimpl_;
try {
size_t off = 0;
while (off != pkt.size()) {
msgpack::unpacked result;
msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
auto object = result.get();
if (pimpl.handleProtocolMsg(object))
continue;
auto req = object.as<ChannelRequest>();
if (req.state == ChannelRequestState::ACCEPT) {
pimpl.onAccept(req.name, req.channel);
} else if (req.state == ChannelRequestState::DECLINE) {
std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex);
auto channel = pimpl.sockets.find(req.channel);
if (channel != pimpl.sockets.end()) {
channel->second->ready(false);
channel->second->stop();
pimpl.sockets.erase(channel);
}
} else if (pimpl.onRequest_) {
pimpl.onRequest(req.name, req.channel);
}
}
} catch (const std::exception& e) {
if (pimpl.logger_)
pimpl.logger_->error("Error on the control channel: {}", e.what());
}
});
}
void
MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt)
{
std::lock_guard<std::mutex> lkSockets(socketsMutex);
auto sockIt = sockets.find(channel);
if (channel > 0 && sockIt != sockets.end() && sockIt->second) {
if (pkt.size() == 0) {
sockIt->second->stop();
if (sockIt->second->isAnswered())
sockets.erase(sockIt);
else
sockIt->second->removable(); // This means that onAccept didn't happen yet, will be
// removed later.
} else {
sockIt->second->onRecv(std::move(pkt));
}
} else if (pkt.size() != 0) {
if (logger_)
logger_->warn("Non existing channel: {}", channel);
}
}
bool
MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o)
{
try {
if (o.type == msgpack::type::MAP && o.via.map.size > 0) {
auto key = o.via.map.ptr[0].key.as<std::string_view>();
if (key == "p") {
auto msg = o.as<BeaconMsg>();
if (msg.p)
handleBeaconRequest();
else
handleBeaconResponse();
if (onBeaconCb_)
onBeaconCb_(msg.p);
return true;
} else if (key == "v") {
auto msg = o.as<VersionMsg>();
onVersion(msg.v);
if (onVersionCb_)
onVersionCb_(msg.v);
return true;
} else {
if (logger_)
logger_->warn("Unknown message type");
}
}
} catch (const std::exception& e) {
if (logger_)
logger_->error("Error on the protocol channel: {}", e.what());
}
return false;
}
void
MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt)
{
// Run this on dedicated thread because some callbacks can take time
dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
auto shared = w.lock();
if (!shared)
return;
try {
size_t off = 0;
while (off != pkt.size()) {
msgpack::unpacked result;
msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
auto object = result.get();
if (shared->pimpl_->handleProtocolMsg(object))
return;
}
} catch (const std::exception& e) {
if (shared->pimpl_->logger_)
shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what());
}
});
}
MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId,
std::unique_ptr<TlsSocketEndpoint> endpoint)
: pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint)))
{}
MultiplexedSocket::~MultiplexedSocket() {}
std::shared_ptr<ChannelSocket>
MultiplexedSocket::addChannel(const std::string& name)
{
// Note: because both sides can request the same channel number at the same time
// it's better to use a random channel number instead of just incrementing the request.
thread_local dht::crypto::random_device rd;
std::uniform_int_distribution<uint16_t> dist;
auto offset = dist(rd);
std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
for (int i = 1; i < UINT16_MAX; ++i) {
auto c = (offset + i) % UINT16_MAX;
if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL
|| pimpl_->sockets.find(c) != pimpl_->sockets.end())
continue;
auto channel = pimpl_->makeSocket(name, c, true);
return channel;
}
return {};
}
DeviceId
MultiplexedSocket::deviceId() const
{
return pimpl_->deviceId;
}
void
MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb)
{
pimpl_->onChannelReady_ = std::move(cb);
}
void
MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb)
{
pimpl_->onRequest_ = std::move(cb);
}
bool
MultiplexedSocket::isReliable() const
{
return true;
}
bool
MultiplexedSocket::isInitiator() const
{
if (!pimpl_->endpoint) {
if (pimpl_->logger_)
pimpl_->logger_->warn("No endpoint found for socket");
return false;
}
return pimpl_->endpoint->isInitiator();
}
int
MultiplexedSocket::maxPayload() const
{
if (!pimpl_->endpoint) {
if (pimpl_->logger_)
pimpl_->logger_->warn("No endpoint found for socket");
return 0;
}
return pimpl_->endpoint->maxPayload();
}
std::size_t
MultiplexedSocket::write(const uint16_t& channel,
const uint8_t* buf,
std::size_t len,
std::error_code& ec)
{
assert(nullptr != buf);
if (pimpl_->isShutdown_) {
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
if (len > UINT16_MAX) {
ec = std::make_error_code(std::errc::message_size);
return -1;
}
bool oneShot = len < 8192;
msgpack::sbuffer buffer(oneShot ? 16 + len : 16);
msgpack::packer<msgpack::sbuffer> pk(&buffer);
pk.pack_array(2);
pk.pack(channel);
pk.pack_bin(len);
if (oneShot)
pk.pack_bin_body((const char*) buf, len);
std::unique_lock<std::mutex> lk(pimpl_->writeMtx);
if (!pimpl_->endpoint) {
if (pimpl_->logger_)
pimpl_->logger_->warn("No endpoint found for socket");
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec);
if (not oneShot and res >= 0)
res = pimpl_->endpoint->write(buf, len, ec);
lk.unlock();
if (res < 0) {
if (ec && pimpl_->logger_)
pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message());
shutdown();
}
return res;
}
void
MultiplexedSocket::shutdown()
{
pimpl_->shutdown();
}
void
MultiplexedSocket::join()
{
pimpl_->join();
}
void
MultiplexedSocket::onShutdown(OnShutdownCb&& cb)
{
pimpl_->onShutdown_ = std::move(cb);
if (pimpl_->isShutdown_)
pimpl_->onShutdown_();
}
const std::shared_ptr<Logger>&
MultiplexedSocket::logger()
{
return pimpl_->logger_;
}
void
MultiplexedSocket::monitor() const
{
auto cert = peerCertificate();
if (!cert || !cert->issuer)
return;
auto now = clock::now();
if (!pimpl_->logger_)
return;
pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId());
pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_));
pimpl_->endpoint->monitor();
std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
for (const auto& [_, channel] : pimpl_->sockets) {
if (channel)
pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}",
fmt::ptr(channel.get()),
channel.use_count(),
channel->name(),
channel->isInitiator());
}
}
void
MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout)
{
pimpl_->sendBeacon(timeout);
}
std::shared_ptr<dht::crypto::Certificate>
MultiplexedSocket::peerCertificate() const
{
return pimpl_->endpoint->peerCertificate();
}
#ifdef LIBJAMI_TESTABLE
bool
MultiplexedSocket::canSendBeacon() const
{
return pimpl_->canSendBeacon_;
}
void
MultiplexedSocket::answerToBeacon(bool value)
{
pimpl_->answerBeacon_ = value;
}
void
MultiplexedSocket::setVersion(int version)
{
pimpl_->version_ = version;
}
void
MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb)
{
pimpl_->onBeaconCb_ = cb;
}
void
MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb)
{
pimpl_->onVersionCb_ = cb;
}
void
MultiplexedSocket::sendVersion()
{
pimpl_->sendVersion();
}
IpAddr
MultiplexedSocket::getLocalAddress() const
{
return pimpl_->endpoint->getLocalAddress();
}
IpAddr
MultiplexedSocket::getRemoteAddress() const
{
return pimpl_->endpoint->getRemoteAddress();
}
#endif
void
MultiplexedSocket::eraseChannel(uint16_t channel)
{
std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
auto itSocket = pimpl_->sockets.find(channel);
if (pimpl_->sockets.find(channel) != pimpl_->sockets.end())
pimpl_->sockets.erase(itSocket);
}
////////////////////////////////////////////////////////////////
class ChannelSocket::Impl
{
public:
Impl(std::weak_ptr<MultiplexedSocket> endpoint,
const std::string& name,
const uint16_t& channel,
bool isInitiator,
std::function<void()> rmFromMxSockCb)
: name(name)
, channel(channel)
, endpoint(std::move(endpoint))
, isInitiator_(isInitiator)
, rmFromMxSockCb_(std::move(rmFromMxSockCb))
{}
~Impl() {}
ChannelReadyCb readyCb_ {};
OnShutdownCb shutdownCb_ {};
std::atomic_bool isShutdown_ {false};
std::string name {};
uint16_t channel {};
std::weak_ptr<MultiplexedSocket> endpoint {};
bool isInitiator_ {false};
std::function<void()> rmFromMxSockCb_;
bool isAnswered_ {false};
bool isRemovable_ {false};
std::vector<uint8_t> buf {};
std::mutex mutex {};
std::condition_variable cv {};
GenericSocket<uint8_t>::RecvCb cb {};
};
ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx,
const DeviceId& deviceId,
const std::string& name,
const uint16_t& channel)
: pimpl_deviceId(deviceId)
, pimpl_name(name)
, pimpl_channel(channel)
, ioCtx_(*ctx)
{}
ChannelSocketTest::~ChannelSocketTest() {}
void
ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1,
const std::shared_ptr<ChannelSocketTest>& socket2)
{
socket1->remote = socket2;
socket2->remote = socket1;
}
DeviceId
ChannelSocketTest::deviceId() const
{
return pimpl_deviceId;
}
std::string
ChannelSocketTest::name() const
{
return pimpl_name;
}
uint16_t
ChannelSocketTest::channel() const
{
return pimpl_channel;
}
void
ChannelSocketTest::shutdown()
{
{
std::unique_lock<std::mutex> lk {mutex};
if (!isShutdown_.exchange(true)) {
lk.unlock();
shutdownCb_();
}
cv.notify_all();
}
if (auto peer = remote.lock()) {
if (!peer->isShutdown_.exchange(true)) {
peer->shutdownCb_();
}
peer->cv.notify_all();
}
}
std::size_t
ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec)
{
std::size_t size = std::min(len, this->rx_buf.size());
for (std::size_t i = 0; i < size; ++i)
buf[i] = this->rx_buf[i];
if (size == this->rx_buf.size()) {
this->rx_buf.clear();
} else
this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size);
return size;
}
std::size_t
ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec)
{
if (isShutdown_) {
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
ec = {};
dht::ThreadPool::computation().run(
[r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable {
if (auto peer = r.lock())
peer->onRecv(std::move(data));
});
return len;
}
int
ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
{
std::unique_lock<std::mutex> lk {mutex};
cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; });
return rx_buf.size();
}
void
ChannelSocketTest::setOnRecv(RecvCb&& cb)
{
std::lock_guard<std::mutex> lkSockets(mutex);
this->cb = std::move(cb);
if (!rx_buf.empty() && this->cb) {
this->cb(rx_buf.data(), rx_buf.size());
rx_buf.clear();
}
}
void
ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt)
{
std::lock_guard<std::mutex> lkSockets(mutex);
if (cb) {
cb(pkt.data(), pkt.size());
return;
}
rx_buf.insert(rx_buf.end(),
std::make_move_iterator(pkt.begin()),
std::make_move_iterator(pkt.end()));
cv.notify_all();
}
void
ChannelSocketTest::onReady(ChannelReadyCb&& cb)
{}
void
ChannelSocketTest::onShutdown(OnShutdownCb&& cb)
{
std::unique_lock<std::mutex> lk {mutex};
shutdownCb_ = std::move(cb);
if (isShutdown_) {
lk.unlock();
shutdownCb_();
}
}
ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
const std::string& name,
const uint16_t& channel,
bool isInitiator,
std::function<void()> rmFromMxSockCb)
: pimpl_ {
std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))}
{}
ChannelSocket::~ChannelSocket() {}
DeviceId
ChannelSocket::deviceId() const
{
if (auto ep = pimpl_->endpoint.lock()) {
return ep->deviceId();
}
return {};
}
std::string
ChannelSocket::name() const
{
return pimpl_->name;
}
uint16_t
ChannelSocket::channel() const
{
return pimpl_->channel;
}
bool
ChannelSocket::isReliable() const
{
if (auto ep = pimpl_->endpoint.lock()) {
return ep->isReliable();
}
return false;
}
bool
ChannelSocket::isInitiator() const
{
// Note. Is initiator here as not the same meaning of MultiplexedSocket.
// because a multiplexed socket can have sockets from accepted requests
// or made via connectDevice(). Here, isInitiator_ return if the socket
// is from connectDevice.
return pimpl_->isInitiator_;
}
int
ChannelSocket::maxPayload() const
{
if (auto ep = pimpl_->endpoint.lock()) {
return ep->maxPayload();
}
return -1;
}
void
ChannelSocket::setOnRecv(RecvCb&& cb)
{
std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
pimpl_->cb = std::move(cb);
if (!pimpl_->buf.empty() && pimpl_->cb) {
pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size());
pimpl_->buf.clear();
}
}
void
ChannelSocket::onRecv(std::vector<uint8_t>&& pkt)
{
std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
if (pimpl_->cb) {
pimpl_->cb(&pkt[0], pkt.size());
return;
}
pimpl_->buf.insert(pimpl_->buf.end(),
std::make_move_iterator(pkt.begin()),
std::make_move_iterator(pkt.end()));
pimpl_->cv.notify_all();
}
#ifdef LIBJAMI_TESTABLE
std::shared_ptr<MultiplexedSocket>
ChannelSocket::underlyingSocket() const
{
if (auto mtx = pimpl_->endpoint.lock())
return mtx;
return {};
}
#endif
void
ChannelSocket::answered()
{
pimpl_->isAnswered_ = true;
}
void
ChannelSocket::removable()
{
pimpl_->isRemovable_ = true;
}
bool
ChannelSocket::isRemovable() const
{
return pimpl_->isRemovable_;
}
bool
ChannelSocket::isAnswered() const
{
return pimpl_->isAnswered_;
}
void
ChannelSocket::ready(bool accepted)
{
if (pimpl_->readyCb_)
pimpl_->readyCb_(accepted);
}
void
ChannelSocket::stop()
{
if (pimpl_->isShutdown_)
return;
pimpl_->isShutdown_ = true;
if (pimpl_->shutdownCb_)
pimpl_->shutdownCb_();
pimpl_->cv.notify_all();
// stop() can be called by ChannelSocket::shutdown()
// In this case, the eventLoop is not used, but MxSock
// must remove the channel from its list (so that the
// channel can be destroyed and its shared_ptr invalidated).
if (pimpl_->rmFromMxSockCb_)
pimpl_->rmFromMxSockCb_();
}
void
ChannelSocket::shutdown()
{
if (pimpl_->isShutdown_)
return;
stop();
if (auto ep = pimpl_->endpoint.lock()) {
std::error_code ec;
const uint8_t dummy = '\0';
ep->write(pimpl_->channel, &dummy, 0, ec);
}
}
std::size_t
ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec)
{
std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
std::size_t size = std::min(len, pimpl_->buf.size());
for (std::size_t i = 0; i < size; ++i)
outBuf[i] = pimpl_->buf[i];
pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size);
return size;
}
std::size_t
ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
{
if (pimpl_->isShutdown_) {
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
if (auto ep = pimpl_->endpoint.lock()) {
std::size_t sent = 0;
do {
std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent);
auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec);
if (ec) {
if (ep->logger())
ep->logger()->error("Error when writing on channel: {}", ec.message());
return res;
}
sent += toSend;
} while (sent < len);
return sent;
}
ec = std::make_error_code(std::errc::broken_pipe);
return -1;
}
int
ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
{
std::unique_lock<std::mutex> lk {pimpl_->mutex};
pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; });
return pimpl_->buf.size();
}
void
ChannelSocket::onShutdown(OnShutdownCb&& cb)
{
pimpl_->shutdownCb_ = std::move(cb);
if (pimpl_->isShutdown_) {
pimpl_->shutdownCb_();
}
}
void
ChannelSocket::onReady(ChannelReadyCb&& cb)
{
pimpl_->readyCb_ = std::move(cb);
}
void
ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout)
{
if (auto ep = pimpl_->endpoint.lock()) {
ep->sendBeacon(timeout);
} else {
shutdown();
}
}
std::shared_ptr<dht::crypto::Certificate>
ChannelSocket::peerCertificate() const
{
if (auto ep = pimpl_->endpoint.lock())
return ep->peerCertificate();
return {};
}
IpAddr
ChannelSocket::getLocalAddress() const
{
if (auto ep = pimpl_->endpoint.lock())
return ep->getLocalAddress();
return {};
}
IpAddr
ChannelSocket::getRemoteAddress() const
{
if (auto ep = pimpl_->endpoint.lock())
return ep->getRemoteAddress();
return {};
}
} // namespace jami