blob: faf680e695b0df4545b7064547add0715405881c [file] [log] [blame]
Adrien Béraud612b55b2023-05-29 10:42:04 -04001/*
2 * Copyright (C) 2019-2023 Savoir-faire Linux Inc.
3 * Author: Sébastien Blin <sebastien.blin@savoirfairelinux.com>
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation; either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program. If not, see <https://www.gnu.org/licenses/>.
17 */
18
19#include "multiplexed_socket.h"
20#include "peer_connection.h"
21#include "ice_transport.h"
22#include "certstore.h"
23
24#include <opendht/logger.h>
25#include <opendht/thread_pool.h>
26
27#include <asio/io_context.hpp>
28#include <asio/steady_timer.hpp>
29
30#include <deque>
31
32static constexpr std::size_t IO_BUFFER_SIZE {8192}; ///< Size of char buffer used by IO operations
33static constexpr int MULTIPLEXED_SOCKET_VERSION {1};
34
35struct ChanneledMessage
36{
37 uint16_t channel;
38 std::vector<uint8_t> data;
39 MSGPACK_DEFINE(channel, data)
40};
41
42struct BeaconMsg
43{
44 bool p;
45 MSGPACK_DEFINE_MAP(p)
46};
47
48struct VersionMsg
49{
50 int v;
51 MSGPACK_DEFINE_MAP(v)
52};
53
Adrien Béraud1ae60aa2023-07-07 09:55:09 -040054namespace dhtnet {
Adrien Béraud612b55b2023-05-29 10:42:04 -040055
56using clock = std::chrono::steady_clock;
57using time_point = clock::time_point;
58
59class MultiplexedSocket::Impl
60{
61public:
62 Impl(MultiplexedSocket& parent,
63 std::shared_ptr<asio::io_context> ctx,
64 const DeviceId& deviceId,
65 std::unique_ptr<TlsSocketEndpoint> endpoint)
66 : parent_(parent)
67 , deviceId(deviceId)
68 , ctx_(std::move(ctx))
69 , beaconTimer_(*ctx_)
70 , endpoint(std::move(endpoint))
71 , eventLoopThread_ {[this] {
72 try {
73 eventLoop();
74 } catch (const std::exception& e) {
75 if (logger_)
76 logger_->error("[CNX] peer connection event loop failure: {}", e.what());
77 shutdown();
78 }
79 }}
80 {}
81
82 ~Impl() {}
83
84 void join()
85 {
86 if (!isShutdown_) {
87 if (endpoint)
88 endpoint->setOnStateChange({});
89 shutdown();
90 } else {
91 clearSockets();
92 }
93 if (eventLoopThread_.joinable())
94 eventLoopThread_.join();
95 }
96
97 void clearSockets()
98 {
99 decltype(sockets) socks;
100 {
101 std::lock_guard<std::mutex> lkSockets(socketsMutex);
102 socks = std::move(sockets);
103 }
104 for (auto& socket : socks) {
105 // Just trigger onShutdown() to make client know
106 // No need to write the EOF for the channel, the write will fail because endpoint is
107 // already shutdown
108 if (socket.second)
109 socket.second->stop();
110 }
111 }
112
113 void shutdown()
114 {
115 if (isShutdown_)
116 return;
117 stop.store(true);
118 isShutdown_ = true;
119 beaconTimer_.cancel();
120 if (onShutdown_)
121 onShutdown_();
122 if (endpoint) {
123 std::unique_lock<std::mutex> lk(writeMtx);
124 endpoint->shutdown();
125 }
126 clearSockets();
127 }
128
129 std::shared_ptr<ChannelSocket> makeSocket(const std::string& name,
130 uint16_t channel,
131 bool isInitiator = false)
132 {
133 auto& channelSocket = sockets[channel];
134 if (not channelSocket)
135 channelSocket = std::make_shared<ChannelSocket>(
136 parent_.weak(), name, channel, isInitiator, [w = parent_.weak(), channel]() {
137 // Remove socket in another thread to avoid any lock
138 dht::ThreadPool::io().run([w, channel]() {
139 if (auto shared = w.lock()) {
140 shared->eraseChannel(channel);
141 }
142 });
143 });
144 else {
145 if (logger_)
146 logger_->warn("A channel is already present on that socket, accepting "
147 "the request will close the previous one {}", name);
148 }
149 return channelSocket;
150 }
151
152 /**
153 * Handle packets on the TLS endpoint and parse RTP
154 */
155 void eventLoop();
156 /**
157 * Triggered when a new control packet is received
158 */
159 void handleControlPacket(std::vector<uint8_t>&& pkt);
160 void handleProtocolPacket(std::vector<uint8_t>&& pkt);
161 bool handleProtocolMsg(const msgpack::object& o);
162 /**
163 * Triggered when a new packet on a channel is received
164 */
165 void handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt);
166 void onRequest(const std::string& name, uint16_t channel);
167 void onAccept(const std::string& name, uint16_t channel);
168
169 void setOnReady(OnConnectionReadyCb&& cb) { onChannelReady_ = std::move(cb); }
170 void setOnRequest(OnConnectionRequestCb&& cb) { onRequest_ = std::move(cb); }
171
172 // Beacon
173 void sendBeacon(const std::chrono::milliseconds& timeout);
174 void handleBeaconRequest();
175 void handleBeaconResponse();
176 std::atomic_int beaconCounter_ {0};
177
178 bool writeProtocolMessage(const msgpack::sbuffer& buffer);
179
180 msgpack::unpacker pac_ {};
181
182 MultiplexedSocket& parent_;
183
184 std::shared_ptr<Logger> logger_;
185 std::shared_ptr<asio::io_context> ctx_;
186
187 OnConnectionReadyCb onChannelReady_ {};
188 OnConnectionRequestCb onRequest_ {};
189 OnShutdownCb onShutdown_ {};
190
191 DeviceId deviceId {};
192 // Main socket
193 std::unique_ptr<TlsSocketEndpoint> endpoint {};
194
195 std::mutex socketsMutex {};
196 std::map<uint16_t, std::shared_ptr<ChannelSocket>> sockets {};
197
198 // Main loop to parse incoming packets
199 std::atomic_bool stop {false};
200 std::thread eventLoopThread_ {};
201
202 std::atomic_bool isShutdown_ {false};
203
204 std::mutex writeMtx {};
205
206 time_point start_ {clock::now()};
207 //std::shared_ptr<Task> beaconTask_ {};
208 asio::steady_timer beaconTimer_;
209
210 // version related stuff
211 void sendVersion();
212 void onVersion(int version);
213 std::atomic_bool canSendBeacon_ {false};
214 std::atomic_bool answerBeacon_ {true};
215 int version_ {MULTIPLEXED_SOCKET_VERSION};
216 std::function<void(bool)> onBeaconCb_ {};
217 std::function<void(int)> onVersionCb_ {};
218};
219
220void
221MultiplexedSocket::Impl::eventLoop()
222{
223 endpoint->setOnStateChange([this](tls::TlsSessionState state) {
224 if (state == tls::TlsSessionState::SHUTDOWN && !isShutdown_) {
225 if (logger_)
226 logger_->debug("Tls endpoint is down, shutdown multiplexed socket");
227 shutdown();
228 return false;
229 }
230 return true;
231 });
232 sendVersion();
233 std::error_code ec;
234 while (!stop) {
235 if (!endpoint) {
236 shutdown();
237 return;
238 }
239 pac_.reserve_buffer(IO_BUFFER_SIZE);
240 int size = endpoint->read(reinterpret_cast<uint8_t*>(&pac_.buffer()[0]), IO_BUFFER_SIZE, ec);
241 if (size < 0) {
242 if (ec && logger_)
243 logger_->error("Read error detected: {}", ec.message());
244 break;
245 }
246 if (size == 0) {
247 // We can close the socket
248 shutdown();
249 break;
250 }
251
252 pac_.buffer_consumed(size);
253 msgpack::object_handle oh;
254 while (pac_.next(oh) && !stop) {
255 try {
256 auto msg = oh.get().as<ChanneledMessage>();
257 if (msg.channel == CONTROL_CHANNEL)
258 handleControlPacket(std::move(msg.data));
259 else if (msg.channel == PROTOCOL_CHANNEL)
260 handleProtocolPacket(std::move(msg.data));
261 else
262 handleChannelPacket(msg.channel, std::move(msg.data));
263 } catch (const std::exception& e) {
264 if (logger_)
265 logger_->warn("Failed to unpacked message of {:d} bytes: {:s}", size, e.what());
266 } catch (...) {
267 if (logger_)
268 logger_->error("Unknown exception catched while unpacking message of {:d} bytes", size);
269 }
270 }
271 }
272}
273
274void
275MultiplexedSocket::Impl::onAccept(const std::string& name, uint16_t channel)
276{
277 std::lock_guard<std::mutex> lkSockets(socketsMutex);
278 auto& socket = sockets[channel];
279 if (!socket) {
280 if (logger_)
281 logger_->error("Receiving an answer for a non existing channel. This is a bug.");
282 return;
283 }
284
285 onChannelReady_(deviceId, socket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400286 socket->ready(true);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400287 // Due to the callbacks that can take some time, onAccept can arrive after
288 // receiving all the data. In this case, the socket should be removed here
289 // as handle by onChannelReady_
290 if (socket->isRemovable())
291 sockets.erase(channel);
292 else
293 socket->answered();
294}
295
296void
297MultiplexedSocket::Impl::sendBeacon(const std::chrono::milliseconds& timeout)
298{
299 if (!canSendBeacon_)
300 return;
301 beaconCounter_++;
302 if (logger_)
303 logger_->debug("Send beacon to peer {}", deviceId);
304
305 msgpack::sbuffer buffer(8);
306 msgpack::packer<msgpack::sbuffer> pk(&buffer);
307 pk.pack(BeaconMsg {true});
308 if (!writeProtocolMessage(buffer))
309 return;
310 beaconTimer_.expires_after(timeout);
311 beaconTimer_.async_wait([w = parent_.weak()](const asio::error_code& ec) {
312 if (ec == asio::error::operation_aborted)
313 return;
314 if (auto shared = w.lock()) {
315 if (shared->pimpl_->beaconCounter_ != 0) {
316 if (shared->pimpl_->logger_)
317 shared->pimpl_->logger_->error("Beacon doesn't get any response. Stopping socket");
318 shared->shutdown();
319 }
320 }
321 });
322}
323
324void
325MultiplexedSocket::Impl::handleBeaconRequest()
326{
327 if (!answerBeacon_)
328 return;
329 // Run this on dedicated thread because some callbacks can take time
330 dht::ThreadPool::io().run([w = parent_.weak()]() {
331 if (auto shared = w.lock()) {
332 msgpack::sbuffer buffer(8);
333 msgpack::packer<msgpack::sbuffer> pk(&buffer);
334 pk.pack(BeaconMsg {false});
335 if (shared->pimpl_->logger_)
336 shared->pimpl_->logger_->debug("Send beacon response to peer {}", shared->deviceId());
337 shared->pimpl_->writeProtocolMessage(buffer);
338 }
339 });
340}
341
342void
343MultiplexedSocket::Impl::handleBeaconResponse()
344{
345 if (logger_)
346 logger_->debug("Get beacon response from peer {}", deviceId);
347 beaconCounter_--;
348}
349
350bool
351MultiplexedSocket::Impl::writeProtocolMessage(const msgpack::sbuffer& buffer)
352{
353 std::error_code ec;
354 int wr = parent_.write(PROTOCOL_CHANNEL,
355 (const unsigned char*) buffer.data(),
356 buffer.size(),
357 ec);
358 return wr > 0;
359}
360
361void
362MultiplexedSocket::Impl::sendVersion()
363{
364 dht::ThreadPool::io().run([w = parent_.weak()]() {
365 if (auto shared = w.lock()) {
366 auto version = shared->pimpl_->version_;
367 msgpack::sbuffer buffer(8);
368 msgpack::packer<msgpack::sbuffer> pk(&buffer);
369 pk.pack(VersionMsg {version});
370 shared->pimpl_->writeProtocolMessage(buffer);
371 }
372 });
373}
374
375void
376MultiplexedSocket::Impl::onVersion(int version)
377{
378 // Check if version > 1
379 if (version >= 1) {
380 if (logger_)
381 logger_->debug("Peer {} supports beacon", deviceId);
382 canSendBeacon_ = true;
383 } else {
384 if (logger_)
385 logger_->warn("Peer {} uses version {:d} which doesn't support beacon",
386 deviceId,
387 version);
388 canSendBeacon_ = false;
389 }
390}
391
392void
393MultiplexedSocket::Impl::onRequest(const std::string& name, uint16_t channel)
394{
395 auto accept = onRequest_(endpoint->peerCertificate(), channel, name);
396 std::shared_ptr<ChannelSocket> channelSocket;
397 if (accept) {
398 std::lock_guard<std::mutex> lkSockets(socketsMutex);
399 channelSocket = makeSocket(name, channel);
400 }
401
402 // Answer to ChannelRequest if accepted
403 ChannelRequest val;
404 val.channel = channel;
405 val.name = name;
406 val.state = accept ? ChannelRequestState::ACCEPT : ChannelRequestState::DECLINE;
407 msgpack::sbuffer buffer(512);
408 msgpack::pack(buffer, val);
409 std::error_code ec;
410 int wr = parent_.write(CONTROL_CHANNEL,
411 reinterpret_cast<const uint8_t*>(buffer.data()),
412 buffer.size(),
413 ec);
414 if (wr < 0) {
415 if (ec && logger_)
416 logger_->error("The write operation failed with error: {:s}", ec.message());
417 stop.store(true);
418 return;
419 }
420
421 if (accept) {
422 onChannelReady_(deviceId, channelSocket);
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400423 channelSocket->ready(true);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400424 }
425}
426
427void
428MultiplexedSocket::Impl::handleControlPacket(std::vector<uint8_t>&& pkt)
429{
430 // Run this on dedicated thread because some callbacks can take time
431 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
432 auto shared = w.lock();
433 if (!shared)
434 return;
435 auto& pimpl = *shared->pimpl_;
436 try {
437 size_t off = 0;
438 while (off != pkt.size()) {
439 msgpack::unpacked result;
440 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
441 auto object = result.get();
442 if (pimpl.handleProtocolMsg(object))
443 continue;
444 auto req = object.as<ChannelRequest>();
445 if (req.state == ChannelRequestState::ACCEPT) {
446 pimpl.onAccept(req.name, req.channel);
447 } else if (req.state == ChannelRequestState::DECLINE) {
448 std::lock_guard<std::mutex> lkSockets(pimpl.socketsMutex);
449 auto channel = pimpl.sockets.find(req.channel);
450 if (channel != pimpl.sockets.end()) {
Adrien Béraudc5b971d2023-06-13 19:41:25 -0400451 channel->second->ready(false);
Adrien Béraud612b55b2023-05-29 10:42:04 -0400452 channel->second->stop();
453 pimpl.sockets.erase(channel);
454 }
455 } else if (pimpl.onRequest_) {
456 pimpl.onRequest(req.name, req.channel);
457 }
458 }
459 } catch (const std::exception& e) {
460 if (pimpl.logger_)
461 pimpl.logger_->error("Error on the control channel: {}", e.what());
462 }
463 });
464}
465
466void
467MultiplexedSocket::Impl::handleChannelPacket(uint16_t channel, std::vector<uint8_t>&& pkt)
468{
469 std::lock_guard<std::mutex> lkSockets(socketsMutex);
470 auto sockIt = sockets.find(channel);
471 if (channel > 0 && sockIt != sockets.end() && sockIt->second) {
472 if (pkt.size() == 0) {
473 sockIt->second->stop();
474 if (sockIt->second->isAnswered())
475 sockets.erase(sockIt);
476 else
477 sockIt->second->removable(); // This means that onAccept didn't happen yet, will be
478 // removed later.
479 } else {
480 sockIt->second->onRecv(std::move(pkt));
481 }
482 } else if (pkt.size() != 0) {
483 if (logger_)
484 logger_->warn("Non existing channel: {}", channel);
485 }
486}
487
488bool
489MultiplexedSocket::Impl::handleProtocolMsg(const msgpack::object& o)
490{
491 try {
492 if (o.type == msgpack::type::MAP && o.via.map.size > 0) {
493 auto key = o.via.map.ptr[0].key.as<std::string_view>();
494 if (key == "p") {
495 auto msg = o.as<BeaconMsg>();
496 if (msg.p)
497 handleBeaconRequest();
498 else
499 handleBeaconResponse();
500 if (onBeaconCb_)
501 onBeaconCb_(msg.p);
502 return true;
503 } else if (key == "v") {
504 auto msg = o.as<VersionMsg>();
505 onVersion(msg.v);
506 if (onVersionCb_)
507 onVersionCb_(msg.v);
508 return true;
509 } else {
510 if (logger_)
511 logger_->warn("Unknown message type");
512 }
513 }
514 } catch (const std::exception& e) {
515 if (logger_)
516 logger_->error("Error on the protocol channel: {}", e.what());
517 }
518 return false;
519}
520
521void
522MultiplexedSocket::Impl::handleProtocolPacket(std::vector<uint8_t>&& pkt)
523{
524 // Run this on dedicated thread because some callbacks can take time
525 dht::ThreadPool::io().run([w = parent_.weak(), pkt = std::move(pkt)]() {
526 auto shared = w.lock();
527 if (!shared)
528 return;
529 try {
530 size_t off = 0;
531 while (off != pkt.size()) {
532 msgpack::unpacked result;
533 msgpack::unpack(result, (const char*) pkt.data(), pkt.size(), off);
534 auto object = result.get();
535 if (shared->pimpl_->handleProtocolMsg(object))
536 return;
537 }
538 } catch (const std::exception& e) {
539 if (shared->pimpl_->logger_)
540 shared->pimpl_->logger_->error("Error on the protocol channel: {}", e.what());
541 }
542 });
543}
544
545MultiplexedSocket::MultiplexedSocket(std::shared_ptr<asio::io_context> ctx, const DeviceId& deviceId,
546 std::unique_ptr<TlsSocketEndpoint> endpoint)
547 : pimpl_(std::make_unique<Impl>(*this, ctx, deviceId, std::move(endpoint)))
548{}
549
550MultiplexedSocket::~MultiplexedSocket() {}
551
552std::shared_ptr<ChannelSocket>
553MultiplexedSocket::addChannel(const std::string& name)
554{
555 // Note: because both sides can request the same channel number at the same time
556 // it's better to use a random channel number instead of just incrementing the request.
557 thread_local dht::crypto::random_device rd;
558 std::uniform_int_distribution<uint16_t> dist;
559 auto offset = dist(rd);
560 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
561 for (int i = 1; i < UINT16_MAX; ++i) {
562 auto c = (offset + i) % UINT16_MAX;
563 if (c == CONTROL_CHANNEL || c == PROTOCOL_CHANNEL
564 || pimpl_->sockets.find(c) != pimpl_->sockets.end())
565 continue;
566 auto channel = pimpl_->makeSocket(name, c, true);
567 return channel;
568 }
569 return {};
570}
571
572DeviceId
573MultiplexedSocket::deviceId() const
574{
575 return pimpl_->deviceId;
576}
577
578void
579MultiplexedSocket::setOnReady(OnConnectionReadyCb&& cb)
580{
581 pimpl_->onChannelReady_ = std::move(cb);
582}
583
584void
585MultiplexedSocket::setOnRequest(OnConnectionRequestCb&& cb)
586{
587 pimpl_->onRequest_ = std::move(cb);
588}
589
590bool
591MultiplexedSocket::isReliable() const
592{
593 return true;
594}
595
596bool
597MultiplexedSocket::isInitiator() const
598{
599 if (!pimpl_->endpoint) {
600 if (pimpl_->logger_)
601 pimpl_->logger_->warn("No endpoint found for socket");
602 return false;
603 }
604 return pimpl_->endpoint->isInitiator();
605}
606
607int
608MultiplexedSocket::maxPayload() const
609{
610 if (!pimpl_->endpoint) {
611 if (pimpl_->logger_)
612 pimpl_->logger_->warn("No endpoint found for socket");
613 return 0;
614 }
615 return pimpl_->endpoint->maxPayload();
616}
617
618std::size_t
619MultiplexedSocket::write(const uint16_t& channel,
620 const uint8_t* buf,
621 std::size_t len,
622 std::error_code& ec)
623{
624 assert(nullptr != buf);
625
626 if (pimpl_->isShutdown_) {
627 ec = std::make_error_code(std::errc::broken_pipe);
628 return -1;
629 }
630 if (len > UINT16_MAX) {
631 ec = std::make_error_code(std::errc::message_size);
632 return -1;
633 }
634 bool oneShot = len < 8192;
635 msgpack::sbuffer buffer(oneShot ? 16 + len : 16);
636 msgpack::packer<msgpack::sbuffer> pk(&buffer);
637 pk.pack_array(2);
638 pk.pack(channel);
639 pk.pack_bin(len);
640 if (oneShot)
641 pk.pack_bin_body((const char*) buf, len);
642
643 std::unique_lock<std::mutex> lk(pimpl_->writeMtx);
644 if (!pimpl_->endpoint) {
645 if (pimpl_->logger_)
646 pimpl_->logger_->warn("No endpoint found for socket");
647 ec = std::make_error_code(std::errc::broken_pipe);
648 return -1;
649 }
650 int res = pimpl_->endpoint->write((const unsigned char*) buffer.data(), buffer.size(), ec);
651 if (not oneShot and res >= 0)
652 res = pimpl_->endpoint->write(buf, len, ec);
653 lk.unlock();
654 if (res < 0) {
655 if (ec && pimpl_->logger_)
656 pimpl_->logger_->error("Error when writing on socket: {:s}", ec.message());
657 shutdown();
658 }
659 return res;
660}
661
662void
663MultiplexedSocket::shutdown()
664{
665 pimpl_->shutdown();
666}
667
668void
669MultiplexedSocket::join()
670{
671 pimpl_->join();
672}
673
674void
675MultiplexedSocket::onShutdown(OnShutdownCb&& cb)
676{
677 pimpl_->onShutdown_ = std::move(cb);
678 if (pimpl_->isShutdown_)
679 pimpl_->onShutdown_();
680}
681
682const std::shared_ptr<Logger>&
683MultiplexedSocket::logger()
684{
685 return pimpl_->logger_;
686}
687
688void
689MultiplexedSocket::monitor() const
690{
691 auto cert = peerCertificate();
692 if (!cert || !cert->issuer)
693 return;
694 auto now = clock::now();
695 if (!pimpl_->logger_)
696 return;
697 pimpl_->logger_->debug("- Socket with device: {:s} - account: {:s}", deviceId(), cert->issuer->getId());
698 pimpl_->logger_->debug("- Duration: {}", dht::print_duration(now - pimpl_->start_));
699 pimpl_->endpoint->monitor();
700 std::lock_guard<std::mutex> lk(pimpl_->socketsMutex);
701 for (const auto& [_, channel] : pimpl_->sockets) {
702 if (channel)
703 pimpl_->logger_->debug("\t\t- Channel {} (count: {}) with name {:s} Initiator: {}",
704 fmt::ptr(channel.get()),
705 channel.use_count(),
706 channel->name(),
707 channel->isInitiator());
708 }
709}
710
711void
712MultiplexedSocket::sendBeacon(const std::chrono::milliseconds& timeout)
713{
714 pimpl_->sendBeacon(timeout);
715}
716
717std::shared_ptr<dht::crypto::Certificate>
718MultiplexedSocket::peerCertificate() const
719{
720 return pimpl_->endpoint->peerCertificate();
721}
722
723#ifdef LIBJAMI_TESTABLE
724bool
725MultiplexedSocket::canSendBeacon() const
726{
727 return pimpl_->canSendBeacon_;
728}
729
730void
731MultiplexedSocket::answerToBeacon(bool value)
732{
733 pimpl_->answerBeacon_ = value;
734}
735
736void
737MultiplexedSocket::setVersion(int version)
738{
739 pimpl_->version_ = version;
740}
741
742void
743MultiplexedSocket::setOnBeaconCb(const std::function<void(bool)>& cb)
744{
745 pimpl_->onBeaconCb_ = cb;
746}
747
748void
749MultiplexedSocket::setOnVersionCb(const std::function<void(int)>& cb)
750{
751 pimpl_->onVersionCb_ = cb;
752}
753
754void
755MultiplexedSocket::sendVersion()
756{
757 pimpl_->sendVersion();
758}
759
760IpAddr
761MultiplexedSocket::getLocalAddress() const
762{
763 return pimpl_->endpoint->getLocalAddress();
764}
765
766IpAddr
767MultiplexedSocket::getRemoteAddress() const
768{
769 return pimpl_->endpoint->getRemoteAddress();
770}
771
772#endif
773
774void
775MultiplexedSocket::eraseChannel(uint16_t channel)
776{
777 std::lock_guard<std::mutex> lkSockets(pimpl_->socketsMutex);
778 auto itSocket = pimpl_->sockets.find(channel);
779 if (pimpl_->sockets.find(channel) != pimpl_->sockets.end())
780 pimpl_->sockets.erase(itSocket);
781}
782
783////////////////////////////////////////////////////////////////
784
785class ChannelSocket::Impl
786{
787public:
788 Impl(std::weak_ptr<MultiplexedSocket> endpoint,
789 const std::string& name,
790 const uint16_t& channel,
791 bool isInitiator,
792 std::function<void()> rmFromMxSockCb)
793 : name(name)
794 , channel(channel)
795 , endpoint(std::move(endpoint))
796 , isInitiator_(isInitiator)
797 , rmFromMxSockCb_(std::move(rmFromMxSockCb))
798 {}
799
800 ~Impl() {}
801
802 ChannelReadyCb readyCb_ {};
803 OnShutdownCb shutdownCb_ {};
804 std::atomic_bool isShutdown_ {false};
805 std::string name {};
806 uint16_t channel {};
807 std::weak_ptr<MultiplexedSocket> endpoint {};
808 bool isInitiator_ {false};
809 std::function<void()> rmFromMxSockCb_;
810
811 bool isAnswered_ {false};
812 bool isRemovable_ {false};
813
814 std::vector<uint8_t> buf {};
815 std::mutex mutex {};
816 std::condition_variable cv {};
817 GenericSocket<uint8_t>::RecvCb cb {};
818};
819
820ChannelSocketTest::ChannelSocketTest(std::shared_ptr<asio::io_context> ctx,
821 const DeviceId& deviceId,
822 const std::string& name,
823 const uint16_t& channel)
824 : pimpl_deviceId(deviceId)
825 , pimpl_name(name)
826 , pimpl_channel(channel)
827 , ioCtx_(*ctx)
828{}
829
830ChannelSocketTest::~ChannelSocketTest() {}
831
832void
833ChannelSocketTest::link(const std::shared_ptr<ChannelSocketTest>& socket1,
834 const std::shared_ptr<ChannelSocketTest>& socket2)
835{
836 socket1->remote = socket2;
837 socket2->remote = socket1;
838}
839
840DeviceId
841ChannelSocketTest::deviceId() const
842{
843 return pimpl_deviceId;
844}
845
846std::string
847ChannelSocketTest::name() const
848{
849 return pimpl_name;
850}
851
852uint16_t
853ChannelSocketTest::channel() const
854{
855 return pimpl_channel;
856}
857
858void
859ChannelSocketTest::shutdown()
860{
861 {
862 std::unique_lock<std::mutex> lk {mutex};
863 if (!isShutdown_.exchange(true)) {
864 lk.unlock();
865 shutdownCb_();
866 }
867 cv.notify_all();
868 }
869
870 if (auto peer = remote.lock()) {
871 if (!peer->isShutdown_.exchange(true)) {
872 peer->shutdownCb_();
873 }
874 peer->cv.notify_all();
875 }
876}
877
878std::size_t
879ChannelSocketTest::read(ValueType* buf, std::size_t len, std::error_code& ec)
880{
881 std::size_t size = std::min(len, this->rx_buf.size());
882
883 for (std::size_t i = 0; i < size; ++i)
884 buf[i] = this->rx_buf[i];
885
886 if (size == this->rx_buf.size()) {
887 this->rx_buf.clear();
888 } else
889 this->rx_buf.erase(this->rx_buf.begin(), this->rx_buf.begin() + size);
890 return size;
891}
892
893std::size_t
894ChannelSocketTest::write(const ValueType* buf, std::size_t len, std::error_code& ec)
895{
896 if (isShutdown_) {
897 ec = std::make_error_code(std::errc::broken_pipe);
898 return -1;
899 }
900 ec = {};
901 dht::ThreadPool::computation().run(
902 [r = remote, data = std::vector<uint8_t>(buf, buf + len)]() mutable {
903 if (auto peer = r.lock())
904 peer->onRecv(std::move(data));
905 });
906 return len;
907}
908
909int
910ChannelSocketTest::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
911{
912 std::unique_lock<std::mutex> lk {mutex};
913 cv.wait_for(lk, timeout, [&] { return !rx_buf.empty() or isShutdown_; });
914 return rx_buf.size();
915}
916
917void
918ChannelSocketTest::setOnRecv(RecvCb&& cb)
919{
920 std::lock_guard<std::mutex> lkSockets(mutex);
921 this->cb = std::move(cb);
922 if (!rx_buf.empty() && this->cb) {
923 this->cb(rx_buf.data(), rx_buf.size());
924 rx_buf.clear();
925 }
926}
927
928void
929ChannelSocketTest::onRecv(std::vector<uint8_t>&& pkt)
930{
931 std::lock_guard<std::mutex> lkSockets(mutex);
932 if (cb) {
933 cb(pkt.data(), pkt.size());
934 return;
935 }
936 rx_buf.insert(rx_buf.end(),
937 std::make_move_iterator(pkt.begin()),
938 std::make_move_iterator(pkt.end()));
939 cv.notify_all();
940}
941
942void
943ChannelSocketTest::onReady(ChannelReadyCb&& cb)
944{}
945
946void
947ChannelSocketTest::onShutdown(OnShutdownCb&& cb)
948{
949 std::unique_lock<std::mutex> lk {mutex};
950 shutdownCb_ = std::move(cb);
951
952 if (isShutdown_) {
953 lk.unlock();
954 shutdownCb_();
955 }
956}
957
958ChannelSocket::ChannelSocket(std::weak_ptr<MultiplexedSocket> endpoint,
959 const std::string& name,
960 const uint16_t& channel,
961 bool isInitiator,
962 std::function<void()> rmFromMxSockCb)
963 : pimpl_ {
964 std::make_unique<Impl>(endpoint, name, channel, isInitiator, std::move(rmFromMxSockCb))}
965{}
966
967ChannelSocket::~ChannelSocket() {}
968
969DeviceId
970ChannelSocket::deviceId() const
971{
972 if (auto ep = pimpl_->endpoint.lock()) {
973 return ep->deviceId();
974 }
975 return {};
976}
977
978std::string
979ChannelSocket::name() const
980{
981 return pimpl_->name;
982}
983
984uint16_t
985ChannelSocket::channel() const
986{
987 return pimpl_->channel;
988}
989
990bool
991ChannelSocket::isReliable() const
992{
993 if (auto ep = pimpl_->endpoint.lock()) {
994 return ep->isReliable();
995 }
996 return false;
997}
998
999bool
1000ChannelSocket::isInitiator() const
1001{
1002 // Note. Is initiator here as not the same meaning of MultiplexedSocket.
1003 // because a multiplexed socket can have sockets from accepted requests
1004 // or made via connectDevice(). Here, isInitiator_ return if the socket
1005 // is from connectDevice.
1006 return pimpl_->isInitiator_;
1007}
1008
1009int
1010ChannelSocket::maxPayload() const
1011{
1012 if (auto ep = pimpl_->endpoint.lock()) {
1013 return ep->maxPayload();
1014 }
1015 return -1;
1016}
1017
1018void
1019ChannelSocket::setOnRecv(RecvCb&& cb)
1020{
1021 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1022 pimpl_->cb = std::move(cb);
1023 if (!pimpl_->buf.empty() && pimpl_->cb) {
1024 pimpl_->cb(pimpl_->buf.data(), pimpl_->buf.size());
1025 pimpl_->buf.clear();
1026 }
1027}
1028
1029void
1030ChannelSocket::onRecv(std::vector<uint8_t>&& pkt)
1031{
1032 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1033 if (pimpl_->cb) {
1034 pimpl_->cb(&pkt[0], pkt.size());
1035 return;
1036 }
1037 pimpl_->buf.insert(pimpl_->buf.end(),
1038 std::make_move_iterator(pkt.begin()),
1039 std::make_move_iterator(pkt.end()));
1040 pimpl_->cv.notify_all();
1041}
1042
1043#ifdef LIBJAMI_TESTABLE
1044std::shared_ptr<MultiplexedSocket>
1045ChannelSocket::underlyingSocket() const
1046{
1047 if (auto mtx = pimpl_->endpoint.lock())
1048 return mtx;
1049 return {};
1050}
1051#endif
1052
1053void
1054ChannelSocket::answered()
1055{
1056 pimpl_->isAnswered_ = true;
1057}
1058
1059void
1060ChannelSocket::removable()
1061{
1062 pimpl_->isRemovable_ = true;
1063}
1064
1065bool
1066ChannelSocket::isRemovable() const
1067{
1068 return pimpl_->isRemovable_;
1069}
1070
1071bool
1072ChannelSocket::isAnswered() const
1073{
1074 return pimpl_->isAnswered_;
1075}
1076
1077void
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001078ChannelSocket::ready(bool accepted)
Adrien Béraud612b55b2023-05-29 10:42:04 -04001079{
1080 if (pimpl_->readyCb_)
Adrien Béraudc5b971d2023-06-13 19:41:25 -04001081 pimpl_->readyCb_(accepted);
Adrien Béraud612b55b2023-05-29 10:42:04 -04001082}
1083
1084void
1085ChannelSocket::stop()
1086{
1087 if (pimpl_->isShutdown_)
1088 return;
1089 pimpl_->isShutdown_ = true;
1090 if (pimpl_->shutdownCb_)
1091 pimpl_->shutdownCb_();
1092 pimpl_->cv.notify_all();
1093 // stop() can be called by ChannelSocket::shutdown()
1094 // In this case, the eventLoop is not used, but MxSock
1095 // must remove the channel from its list (so that the
1096 // channel can be destroyed and its shared_ptr invalidated).
1097 if (pimpl_->rmFromMxSockCb_)
1098 pimpl_->rmFromMxSockCb_();
1099}
1100
1101void
1102ChannelSocket::shutdown()
1103{
1104 if (pimpl_->isShutdown_)
1105 return;
1106 stop();
1107 if (auto ep = pimpl_->endpoint.lock()) {
1108 std::error_code ec;
1109 const uint8_t dummy = '\0';
1110 ep->write(pimpl_->channel, &dummy, 0, ec);
1111 }
1112}
1113
1114std::size_t
1115ChannelSocket::read(ValueType* outBuf, std::size_t len, std::error_code& ec)
1116{
1117 std::lock_guard<std::mutex> lkSockets(pimpl_->mutex);
1118 std::size_t size = std::min(len, pimpl_->buf.size());
1119
1120 for (std::size_t i = 0; i < size; ++i)
1121 outBuf[i] = pimpl_->buf[i];
1122
1123 pimpl_->buf.erase(pimpl_->buf.begin(), pimpl_->buf.begin() + size);
1124 return size;
1125}
1126
1127std::size_t
1128ChannelSocket::write(const ValueType* buf, std::size_t len, std::error_code& ec)
1129{
1130 if (pimpl_->isShutdown_) {
1131 ec = std::make_error_code(std::errc::broken_pipe);
1132 return -1;
1133 }
1134 if (auto ep = pimpl_->endpoint.lock()) {
1135 std::size_t sent = 0;
1136 do {
1137 std::size_t toSend = std::min(static_cast<std::size_t>(UINT16_MAX), len - sent);
1138 auto res = ep->write(pimpl_->channel, buf + sent, toSend, ec);
1139 if (ec) {
1140 if (ep->logger())
1141 ep->logger()->error("Error when writing on channel: {}", ec.message());
1142 return res;
1143 }
1144 sent += toSend;
1145 } while (sent < len);
1146 return sent;
1147 }
1148 ec = std::make_error_code(std::errc::broken_pipe);
1149 return -1;
1150}
1151
1152int
1153ChannelSocket::waitForData(std::chrono::milliseconds timeout, std::error_code& ec) const
1154{
1155 std::unique_lock<std::mutex> lk {pimpl_->mutex};
1156 pimpl_->cv.wait_for(lk, timeout, [&] { return !pimpl_->buf.empty() or pimpl_->isShutdown_; });
1157 return pimpl_->buf.size();
1158}
1159
1160void
1161ChannelSocket::onShutdown(OnShutdownCb&& cb)
1162{
1163 pimpl_->shutdownCb_ = std::move(cb);
1164 if (pimpl_->isShutdown_) {
1165 pimpl_->shutdownCb_();
1166 }
1167}
1168
1169void
1170ChannelSocket::onReady(ChannelReadyCb&& cb)
1171{
1172 pimpl_->readyCb_ = std::move(cb);
1173}
1174
1175void
1176ChannelSocket::sendBeacon(const std::chrono::milliseconds& timeout)
1177{
1178 if (auto ep = pimpl_->endpoint.lock()) {
1179 ep->sendBeacon(timeout);
1180 } else {
1181 shutdown();
1182 }
1183}
1184
1185std::shared_ptr<dht::crypto::Certificate>
1186ChannelSocket::peerCertificate() const
1187{
1188 if (auto ep = pimpl_->endpoint.lock())
1189 return ep->peerCertificate();
1190 return {};
1191}
1192
1193IpAddr
1194ChannelSocket::getLocalAddress() const
1195{
1196 if (auto ep = pimpl_->endpoint.lock())
1197 return ep->getLocalAddress();
1198 return {};
1199}
1200
1201IpAddr
1202ChannelSocket::getRemoteAddress() const
1203{
1204 if (auto ep = pimpl_->endpoint.lock())
1205 return ep->getRemoteAddress();
1206 return {};
1207}
1208
1209} // namespace jami